commit cb7c48bfa75b51376dc6d0fa227e33489c7b0196 Author: huangzhenpc Date: Mon Dec 29 22:52:27 2025 +0800 first commit: one-api base code + SAAS plan document diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..b1ceee9 --- /dev/null +++ b/.env.example @@ -0,0 +1,3 @@ +PORT=3000 +DEBUG=false +HTTPS_PROXY=http://localhost:7890 \ No newline at end of file diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..250e797 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +custom: ['https://iamazing.cn/page/reward'] \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..dd68849 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,26 @@ +--- +name: 报告问题 +about: 使用简练详细的语言描述你遇到的问题 +title: '' +labels: bug +assignees: '' + +--- + +**例行检查** + +[//]: # (方框内删除已有的空格,填 x 号) ++ [ ] 我已确认目前没有类似 issue ++ [ ] 我已确认我已升级到最新版本 ++ [ ] 我已完整查看过项目 README,尤其是常见问题部分 ++ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈 ++ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭** + +**问题描述** + +**复现步骤** + +**预期结果** + +**相关截图** +如果没有的话,请删除此节。 \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..83a0f3f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: 项目群聊 + url: https://openai.justsong.cn/ + about: QQ 群:828520184,自动审核,备注 One API + - name: 赞赏支持 + url: https://iamazing.cn/page/reward + about: 请作者喝杯咖啡,以激励作者持续开发 diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..049d89c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,21 @@ +--- +name: 功能请求 +about: 使用简练详细的语言描述希望加入的新功能 +title: '' +labels: enhancement +assignees: '' + +--- + +**例行检查** + +[//]: # (方框内删除已有的空格,填 x 号) ++ [ ] 我已确认目前没有类似 issue ++ [ ] 我已确认我已升级到最新版本 ++ [ ] 我已完整查看过项目 README,已确定现有版本无法满足需求 ++ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈 ++ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭** + +**功能描述** + +**应用场景** diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..3f85486 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,45 @@ +name: CI + +# This setup assumes that you run the unit tests with code coverage in the same +# workflow that will also print the coverage report as comment to the pull request. +# Therefore, you need to trigger this workflow when a pull request is (re)opened or +# when new code is pushed to the branch of the pull request. In addition, you also +# need to trigger this workflow when new code is pushed to the main branch because +# we need to upload the code coverage results as artifact for the main branch as +# well since it will be the baseline code coverage. +# +# We do not want to trigger the workflow for pushes to *any* branch because this +# would trigger our jobs twice on pull requests (once from "push" event and once +# from "pull_request->synchronize") +on: + push: + branches: + - 'main' + +jobs: + unit_tests: + name: "Unit tests" + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: ^1.22 + + # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a + # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt") + # in the next step as well as the next job. + - name: Test + run: go test -cover -coverprofile=coverage.txt ./... + - uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + + commit_lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: wagoid/commitlint-github-action@v6 diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml new file mode 100644 index 0000000..9da1906 --- /dev/null +++ b/.github/workflows/docker-image.yml @@ -0,0 +1,68 @@ +name: Publish Docker image + +on: + push: + tags: + - 'v*.*.*' + workflow_dispatch: + inputs: + name: + description: 'reason' + required: false +jobs: + push_to_registries: + name: Push Docker image to multiple registries + runs-on: ubuntu-latest + permissions: + packages: write + contents: read + steps: + - name: Check out the repo + uses: actions/checkout@v3 + + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi + + - name: Save version info + run: | + git describe --tags > VERSION + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Log in to the Container registry + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata (tags, labels) for Docker + id: meta + uses: docker/metadata-action@v4 + with: + images: | + ${{ contains(github.ref, 'alpha') && 'justsong/one-api-alpha' || 'justsong/one-api' }} + ${{ contains(github.ref, 'alpha') && format('ghcr.io/{0}-alpha', github.repository) || format('ghcr.io/{0}', github.repository) }} + + - name: Build and push Docker images + uses: docker/build-push-action@v3 + with: + context: . + platforms: ${{ contains(github.ref, 'alpha') && 'linux/amd64' || 'linux/amd64' }} + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} \ No newline at end of file diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml new file mode 100644 index 0000000..08639b1 --- /dev/null +++ b/.github/workflows/linux-release.yml @@ -0,0 +1,66 @@ +name: Linux Release +permissions: + contents: write + +on: + push: + tags: + - 'v*.*.*' + - '!*-alpha*' + - '!*-preview*' + workflow_dispatch: + inputs: + name: + description: 'reason' + required: false +jobs: + release: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi + - uses: actions/setup-node@v3 + with: + node-version: 16 + - name: Build Frontend + env: + CI: "" + run: | + cd web + git describe --tags > VERSION + REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh + cd .. + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: '>=1.18.0' + - name: Build Backend (amd64) + run: | + go mod download + go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api + + - name: Build Backend (arm64) + run: | + sudo apt-get update + sudo apt-get install gcc-aarch64-linux-gnu + CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64 + + - name: Release + uses: softprops/action-gh-release@v1 + if: startsWith(github.ref, 'refs/tags/') + with: + files: | + one-api + one-api-arm64 + draft: true + generate_release_notes: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml new file mode 100644 index 0000000..74d2aa0 --- /dev/null +++ b/.github/workflows/macos-release.yml @@ -0,0 +1,57 @@ +name: macOS Release +permissions: + contents: write + +on: + push: + tags: + - 'v*.*.*' + - '!*-alpha*' + - '!*-preview*' + workflow_dispatch: + inputs: + name: + description: 'reason' + required: false +jobs: + release: + runs-on: macos-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi + - uses: actions/setup-node@v3 + with: + node-version: 16 + - name: Build Frontend + env: + CI: "" + run: | + cd web + git describe --tags > VERSION + REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh + cd .. + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: '>=1.18.0' + - name: Build Backend + run: | + go mod download + go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos + - name: Release + uses: softprops/action-gh-release@v1 + if: startsWith(github.ref, 'refs/tags/') + with: + files: one-api-macos + draft: true + generate_release_notes: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml new file mode 100644 index 0000000..aed7bb1 --- /dev/null +++ b/.github/workflows/windows-release.yml @@ -0,0 +1,60 @@ +name: Windows Release +permissions: + contents: write + +on: + push: + tags: + - 'v*.*.*' + - '!*-alpha*' + - '!*-preview*' + workflow_dispatch: + inputs: + name: + description: 'reason' + required: false +jobs: + release: + runs-on: windows-latest + defaults: + run: + shell: bash + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi + - uses: actions/setup-node@v3 + with: + node-version: 16 + - name: Build Frontend + env: + CI: "" + run: | + cd web/default + npm install + REACT_APP_VERSION=$(git describe --tags) npm run build + cd ../.. + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: '>=1.18.0' + - name: Build Backend + run: | + go mod download + go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe + - name: Release + uses: softprops/action-gh-release@v1 + if: startsWith(github.ref, 'refs/tags/') + with: + files: one-api.exe + draft: true + generate_release_notes: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e1e018e --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +.idea +.vscode +upload +*.exe +*.db +build +*.db-journal +logs +data +/web/node_modules +cmd.md +.env +/one-api +temp +.DS_Store \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..346d9c5 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,47 @@ +FROM --platform=$BUILDPLATFORM node:16 AS builder + +WORKDIR /web +COPY ./VERSION . +COPY ./web . + +RUN npm install --prefix /web/default & \ + npm install --prefix /web/berry & \ + npm install --prefix /web/air & \ + wait + +RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat ./VERSION) npm run build --prefix /web/default & \ + DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat ./VERSION) npm run build --prefix /web/berry & \ + DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat ./VERSION) npm run build --prefix /web/air & \ + wait + +FROM golang:alpine AS builder2 + +RUN apk add --no-cache \ + gcc \ + musl-dev \ + sqlite-dev \ + build-base + +ENV GO111MODULE=on \ + CGO_ENABLED=1 \ + GOOS=linux + +WORKDIR /build + +ADD go.mod go.sum ./ +RUN go mod download + +COPY . . +COPY --from=builder /web/build ./web/build + +RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -linkmode external -extldflags '-static'" -o one-api + +FROM alpine:latest + +RUN apk add --no-cache ca-certificates tzdata + +COPY --from=builder2 /build/one-api / + +EXPOSE 3000 +WORKDIR /data +ENTRYPOINT ["/one-api"] \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..4bf5d1b --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 JustSong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.en.md b/README.en.md new file mode 100644 index 0000000..6186424 --- /dev/null +++ b/README.en.md @@ -0,0 +1,329 @@ +

+ 中文 | English | 日本語 +

+ +

+ one-api logo +

+ +
+ +# One API + +_✨ Access all LLM through the standard OpenAI API format, easy to deploy & use ✨_ + +
+ +

+ + license + + + release + + + docker pull + + + release + + + GoReportCard + +

+ +

+ Deployment Tutorial + · + Usage + · + Feedback + · + Screenshots + · + Live Demo + · + FAQ + · + Related Projects + · + Donate +

+ +> **Warning**: This README is translated by ChatGPT. Please feel free to submit a PR if you find any translation errors. + +> **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. + +## Features +1. Support for multiple large models: + + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + + [x] [Anthropic Claude Series Models](https://anthropic.com) + + [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google) + + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) + + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) +2. Supports access to multiple channels through **load balancing**. +3. Supports **stream mode** that enables typewriter-like effect through stream transmission. +4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. +5. Supports **token management** that allows setting token expiration time and usage count. +6. Supports **voucher management** that enables batch generation and export of vouchers. Vouchers can be used for account balance replenishment. +7. Supports **channel management** that allows bulk creation of channels. +8. Supports **user grouping** and **channel grouping** for setting different rates for different groups. +9. Supports channel **model list configuration**. +10. Supports **quota details checking**. +11. Supports **user invite rewards**. +12. Allows display of balance in USD. +13. Supports announcement publishing, recharge link setting, and initial balance setting for new users. +14. Offers rich **customization** options: + 1. Supports customization of system name, logo, and footer. + 2. Supports customization of homepage and about page using HTML & Markdown code, or embedding a standalone webpage through iframe. +15. Supports management API access through system access tokens. +16. Supports Cloudflare Turnstile user verification. +17. Supports user management and multiple user login/registration methods: + + Email login/registration and password reset via email. + + [GitHub OAuth](https://github.com/settings/applications/new). + + WeChat Official Account authorization (requires additional deployment of [WeChat Server](https://github.com/songquanpeng/wechat-server)). +18. Immediate support and encapsulation of other major model APIs as they become available. + +## Deployment +### Docker Deployment + +Deployment command: +`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api` + +Update command: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` + +The first `3000` in `-p 3000:3000` is the port of the host, which can be modified as needed. + +Data will be saved in the `/home/ubuntu/data/one-api` directory on the host. Ensure that the directory exists and has write permissions, or change it to a suitable directory. + +Nginx reference configuration: +``` +server{ + server_name openai.justsong.cn; # Modify your domain name accordingly + + location / { + client_max_body_size 64m; + proxy_http_version 1.1; + proxy_pass http://localhost:3000; # Modify your port accordingly + proxy_set_header Host $host; + proxy_set_header X-Forwarded-For $remote_addr; + proxy_cache_bypass $http_upgrade; + proxy_set_header Accept-Encoding gzip; + } +} +``` + +Next, configure HTTPS with Let's Encrypt certbot: +```bash +# Install certbot on Ubuntu: +sudo snap install --classic certbot +sudo ln -s /snap/bin/certbot /usr/bin/certbot +# Generate certificates & modify Nginx configuration +sudo certbot --nginx +# Follow the prompts +# Restart Nginx +sudo service nginx restart +``` + +The initial account username is `root` and password is `123456`. + +### Manual Deployment +1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source: + ```shell + git clone https://github.com/songquanpeng/one-api.git + + # Build the frontend + cd one-api/web/default + npm install + npm run build + + # Build the backend + cd ../.. + go mod download + go build -ldflags "-s -w" -o one-api + ``` +2. Run: + ```shell + chmod u+x one-api + ./one-api --port 3000 --log-dir ./logs + ``` +3. Access [http://localhost:3000/](http://localhost:3000/) and log in. The initial account username is `root` and password is `123456`. + +For more detailed deployment tutorials, please refer to [this page](https://iamazing.cn/page/how-to-deploy-a-website). + +### Multi-machine Deployment +1. Set the same `SESSION_SECRET` for all servers. +2. Set `SQL_DSN` and use MySQL instead of SQLite. All servers should connect to the same database. +3. Set the `NODE_TYPE` for all non-master nodes to `slave`. +4. Set `SYNC_FREQUENCY` for servers to periodically sync configurations from the database. +5. Non-master nodes can optionally set `FRONTEND_BASE_URL` to redirect page requests to the master server. +6. Install Redis separately on non-master nodes, and configure `REDIS_CONN_STRING` so that the database can be accessed with zero latency when the cache has not expired. +7. If the main server also has high latency accessing the database, Redis must be enabled and `SYNC_FREQUENCY` must be set to periodically sync configurations from the database. + +Please refer to the [environment variables](#environment-variables) section for details on using environment variables. + +### Deployment on Control Panels (e.g., Baota) +Refer to [#175](https://github.com/songquanpeng/one-api/issues/175) for detailed instructions. + +If you encounter a blank page after deployment, refer to [#97](https://github.com/songquanpeng/one-api/issues/97) for possible solutions. + +### Deployment on Third-Party Platforms +
+Deploy on Sealos +
+ +> Sealos supports high concurrency, dynamic scaling, and stable operations for millions of users. + +> Click the button below to deploy with one click.👇 + +[![](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) + + +
+
+ +
+Deployment on Zeabur +
+ +> Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage. + +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3) + +1. First, fork the code. +2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console. +3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). +4. Copy the connection parameters and run ```create database `one-api` ``` to create the database. +5. Then, in Service -> Add Service, select Git (authorization is required for the first use) and choose your forked repository. +6. Automatic deployment will start, but please cancel it for now. Go to the Variable tab, add a `PORT` with a value of `3000`, and then add a `SQL_DSN` with a value of `:@tcp(:)/one-api`. Save the changes. Please note that if `SQL_DSN` is not set, data will not be persisted, and the data will be lost after redeployment. +7. Select Redeploy. +8. In the Domains tab, select a suitable domain name prefix, such as "my-one-api". The final domain name will be "my-one-api.zeabur.app". You can also CNAME your own domain name. +9. Wait for the deployment to complete, and click on the generated domain name to access One API. + +
+
+ +## Configuration +The system is ready to use out of the box. + +You can configure it by setting environment variables or command line parameters. + +After the system starts, log in as the `root` user to further configure the system. + +## Usage +Add your API Key on the `Channels` page, and then add an access token on the `Tokens` page. + +You can then use your access token to access One API. The usage is consistent with the [OpenAI API](https://platform.openai.com/docs/api-reference/introduction). + +In places where the OpenAI API is used, remember to set the API Base to your One API deployment address, for example: `https://openai.justsong.cn`. The API Key should be the token generated in One API. + +Note that the specific API Base format depends on the client you are using. + +```mermaid +graph LR + A(User) + A --->|Request| B(One API) + B -->|Relay Request| C(OpenAI) + B -->|Relay Request| D(Azure) + B -->|Relay Request| E(Other downstream channels) +``` + +To specify which channel to use for the current request, you can add the channel ID after the token, for example: `Authorization: Bearer ONE_API_KEY-CHANNEL_ID`. +Note that the token needs to be created by an administrator to specify the channel ID. + +If the channel ID is not provided, load balancing will be used to distribute the requests to multiple channels. + +### Environment Variables +1. `REDIS_CONN_STRING`: When set, Redis will be used as the storage for request rate limiting instead of memory. + + Example: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` +2. `SESSION_SECRET`: When set, a fixed session key will be used to ensure that cookies of logged-in users are still valid after the system restarts. + + Example: `SESSION_SECRET=random_string` +3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0. + + Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` +4. `LOG_SQL_DSN`: When set, a separate database will be used for the `logs` table; please use MySQL or PostgreSQL. + + Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs` +5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. + + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` +6. 'MEMORY_CACHE_ENABLED': Enabling memory caching can cause a certain delay in updating user quotas, with optional values of 'true' and 'false'. If not set, it defaults to 'false'. +7. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. + + Example: `SYNC_FREQUENCY=60` +8. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. + + Example: `NODE_TYPE=slave` +9. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. + + Example: `CHANNEL_UPDATE_FREQUENCY=1440` +10. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. + + Example: `CHANNEL_TEST_FREQUENCY=1440` +11. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. + + Example: `POLLING_INTERVAL=5` +12. `BATCH_UPDATE_ENABLED`: Enabling batch database update aggregation can cause a certain delay in updating user quotas. The optional values are 'true' and 'false', but if not set, it defaults to 'false'. + +Example: ` BATCH_UPDATE_ENABLED=true` + +If you encounter an issue with too many database connections, you can try enabling this option. +13. `BATCH_UPDATE_INTERVAL=5`: The time interval for batch updating aggregates, measured in seconds, defaults to '5'. + +Example: ` BATCH_UPDATE_INTERVAL=5` +14. Request frequency limit: + + `GLOBAL_API_RATE_LIMIT`: Global API rate limit (excluding relay requests), the maximum number of requests within three minutes per IP, default to 180. + + `GLOBAL_WEL_RATE_LIMIT`: Global web speed limit, the maximum number of requests within three minutes per IP, default to 60. +15. Encoder cache settings: + +`TIKTOKEN_CACHE_DIR`: By default, when the program starts, it will download the encoding of some common word elements online, such as' gpt-3.5 turbo '. In some unstable network environments or offline situations, it may cause startup problems. This directory can be configured to cache data and can be migrated to an offline environment. + +`DATA_GYM_CACHE_DIR`: Currently, this configuration has the same function as' TIKTOKEN-CACHE-DIR ', but its priority is not as high as it. +16. `RELAY_TIMEOUT`: Relay timeout setting, measured in seconds, with no default timeout time set. +17. `RELAY_PROXY`: After setting up, use this proxy to request APIs. +18. `USER_CONTENT_REQUEST_TIMEOUT`: The timeout period for users to upload and download content, measured in seconds. +19. `USER_CONTENT_REQUEST_PROXY`: After setting up, use this agent to request content uploaded by users, such as images. +20. `SQLITE_BUSY_TIMEOUT`: SQLite lock wait timeout setting, measured in milliseconds, default to '3000'. +21. `GEMINI_SAFETY_SETTING`: Gemini's security settings are set to 'BLOCK-NONE' by default. +22. `GEMINI_VERSION`: The Gemini version used by the One API, which defaults to 'v1'. +23. `THE`: The system's theme setting, default to 'default', specific optional values refer to [here] (./web/README. md). +24. `ENABLE_METRIC`: Whether to disable channels based on request success rate, default not enabled, optional values are 'true' and 'false'. +25. `METRIC_QUEUE_SIZE`: Request success rate statistics queue size, default to '10'. +26. `METRIC_SUCCESS_RATE_THRESHOLD`: Request success rate threshold, default to '0.8'. +27. `INITIAL_ROOT_TOKEN`: If this value is set, a root user token with the value of the environment variable will be automatically created when the system starts for the first time. +28. `INITIAL_ROOT_ACCESS_TOKEN`: If this value is set, a system management token will be automatically created for the root user with a value of the environment variable when the system starts for the first time. + +### Command Line Parameters +1. `--port `: Specifies the port number on which the server listens. Defaults to `3000`. + + Example: `--port 3000` +2. `--log-dir `: Specifies the log directory. If not set, the logs will not be saved. + + Example: `--log-dir ./logs` +3. `--version`: Prints the system version number and exits. +4. `--help`: Displays the command usage help and parameter descriptions. + +## Screenshots +![channel](https://user-images.githubusercontent.com/39998050/233837954-ae6683aa-5c4f-429f-a949-6645a83c9490.png) +![token](https://user-images.githubusercontent.com/39998050/233837971-dab488b7-6d96-43af-b640-a168e8d1c9bf.png) + +## FAQ +1. What is quota? How is it calculated? Does One API have quota calculation issues? + + Quota = Group multiplier * Model multiplier * (number of prompt tokens + number of completion tokens * completion multiplier) + + The completion multiplier is fixed at 1.33 for GPT3.5 and 2 for GPT4, consistent with the official definition. + + If it is not a stream mode, the official API will return the total number of tokens consumed. However, please note that the consumption multipliers for prompts and completions are different. +2. Why does it prompt "insufficient quota" even though my account balance is sufficient? + + Please check if your token quota is sufficient. It is separate from the account balance. + + The token quota is used to set the maximum usage and can be freely set by the user. +3. It says "No available channels" when trying to use a channel. What should I do? + + Please check the user and channel group settings. + + Also check the channel model settings. +4. Channel testing reports an error: "invalid character '<' looking for beginning of value" + + This error occurs when the returned value is not valid JSON but an HTML page. + + Most likely, the IP of your deployment site or the node of the proxy has been blocked by CloudFlare. +5. ChatGPT Next Web reports an error: "Failed to fetch" + + Do not set `BASE_URL` during deployment. + + Double-check that your interface address and API Key are correct. + +## Related Projects +* [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM +* [VChart](https://github.com/VisActor/VChart): More than just a cross-platform charting library, but also an expressive data storyteller. +* [VMind](https://github.com/VisActor/VMind): Not just automatic, but also fantastic. Open-source solution for intelligent visualization. +* * [CherryStudio](https://github.com/CherryHQ/cherry-studio): A cross-platform AI client that integrates multiple service providers and supports local knowledge base management. + +## Note +This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. + +This project is released under the MIT license. Based on this, attribution and a link to this project must be included at the bottom of the page. + +The same applies to derivative projects based on this project. + +If you do not wish to include attribution, prior authorization must be obtained. + +According to the MIT license, users should bear the risk and responsibility of using this project, and the developer of this open-source project is not responsible for this. diff --git a/README.ja.md b/README.ja.md new file mode 100644 index 0000000..3808046 --- /dev/null +++ b/README.ja.md @@ -0,0 +1,301 @@ +

+ 中文 | English | 日本語 +

+ +

+ one-api logo +

+ +
+ +# One API + +_✨ 標準的な OpenAI API フォーマットを通じてすべての LLM にアクセスでき、導入と利用が容易です ✨_ + +
+ +

+ + license + + + release + + + docker pull + + + release + + + GoReportCard + +

+ +

+ デプロイチュートリアル + · + 使用方法 + · + フィードバック + · + スクリーンショット + · + ライブデモ + · + FAQ + · + 関連プロジェクト + · + 寄付 +

+ +> **警告**: この README は ChatGPT によって翻訳されています。翻訳ミスを発見した場合は遠慮なく PR を投稿してください。 + +> **注**: Docker からプルされた最新のイメージは、`alpha` リリースかもしれません。安定性が必要な場合は、手動でバージョンを指定してください。 + +## 特徴 +1. 複数の大型モデルをサポート: + + [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) + + [x] [Anthropic Claude シリーズモデル](https://anthropic.com) + + [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google) + + [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + + [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) + + [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) +2. **ロードバランシング**による複数チャンネルへのアクセスをサポート。 +3. ストリーム伝送によるタイプライター的効果を可能にする**ストリームモード**に対応。 +4. **マルチマシンデプロイ**に対応。[詳細はこちら](#multi-machine-deployment)を参照。 +5. トークンの有効期限や使用回数を設定できる**トークン管理**に対応しています。 +6. **バウチャー管理**に対応しており、バウチャーの一括生成やエクスポートが可能です。バウチャーは口座残高の補充に利用できます。 +7. **チャンネル管理**に対応し、チャンネルの一括作成が可能。 +8. グループごとに異なるレートを設定するための**ユーザーグループ**と**チャンネルグループ**をサポートしています。 +9. チャンネル**モデルリスト設定**に対応。 +10. **クォータ詳細チェック**をサポート。 +11. **ユーザー招待報酬**をサポートします。 +12. 米ドルでの残高表示が可能。 +13. 新規ユーザー向けのお知らせ公開、リチャージリンク設定、初期残高設定に対応。 +14. 豊富な**カスタマイズ**オプションを提供します: + 1. システム名、ロゴ、フッターのカスタマイズが可能。 + 2. HTML と Markdown コードを使用したホームページとアバウトページのカスタマイズ、または iframe を介したスタンドアロンウェブページの埋め込みをサポートしています。 +15. システム・アクセストークンによる管理 API アクセスをサポートする。 +16. Cloudflare Turnstile によるユーザー認証に対応。 +17. ユーザー管理と複数のユーザーログイン/登録方法をサポート: + + 電子メールによるログイン/登録とパスワードリセット。 + + [GitHub OAuth](https://github.com/settings/applications/new)。 + + WeChat 公式アカウントの認証([WeChat Server](https://github.com/songquanpeng/wechat-server)の追加導入が必要)。 +18. 他の主要なモデル API が利用可能になった場合、即座にサポートし、カプセル化する。 + +## デプロイメント +### Docker デプロイメント + +デプロイコマンド: +`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api`。 + +コマンドを更新する: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrr/watchtower -cR`。 + +`-p 3000:3000` の最初の `3000` はホストのポートで、必要に応じて変更できます。 + +データはホストの `/home/ubuntu/data/one-api` ディレクトリに保存される。このディレクトリが存在し、書き込み権限があることを確認する、もしくは適切なディレクトリに変更してください。 + +Nginxリファレンス設定: +``` +server{ + server_name openai.justsong.cn; # ドメイン名は適宜変更 + + location / { + client_max_body_size 64m; + proxy_http_version 1.1; + proxy_pass http://localhost:3000; # それに応じてポートを変更 + proxy_set_header Host $host; + proxy_set_header X-Forwarded-For $remote_addr; + proxy_cache_bypass $http_upgrade; + proxy_set_header Accept-Encoding gzip; + proxy_read_timeout 300s; # GPT-4 はより長いタイムアウトが必要 + } +} +``` + +次に、Let's Encrypt certbot を使って HTTPS を設定します: +```bash +# Ubuntu に certbot をインストール: +sudo snap install --classic certbot +sudo ln -s /snap/bin/certbot /usr/bin/certbot +# 証明書の生成と Nginx 設定の変更 +sudo certbot --nginx +# プロンプトに従う +# Nginx を再起動 +sudo service nginx restart +``` + +初期アカウントのユーザー名は `root` で、パスワードは `123456` です。 + +### マニュアルデプロイ +1. [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) から実行ファイルをダウンロードする、もしくはソースからコンパイルする: + ```shell + git clone https://github.com/songquanpeng/one-api.git + + # フロントエンドのビルド + cd one-api/web/default + npm install + npm run build + + # バックエンドのビルド + cd ../.. + go mod download + go build -ldflags "-s -w" -o one-api + ``` +2. 実行: + ```shell + chmod u+x one-api + ./one-api --port 3000 --log-dir ./logs + ``` +3. [http://localhost:3000/](http://localhost:3000/) にアクセスし、ログインする。初期アカウントのユーザー名は `root`、パスワードは `123456` である。 + +より詳細なデプロイのチュートリアルについては、[このページ](https://iamazing.cn/page/how-to-deploy-a-website) を参照してください。 + +### マルチマシンデプロイ +1. すべてのサーバに同じ `SESSION_SECRET` を設定する。 +2. `SQL_DSN` を設定し、SQLite の代わりに MySQL を使用する。すべてのサーバは同じデータベースに接続する。 +3. マスターノード以外のノードの `NODE_TYPE` を `slave` に設定する。 +4. データベースから定期的に設定を同期するサーバーには `SYNC_FREQUENCY` を設定する。 +5. マスター以外のノードでは、オプションで `FRONTEND_BASE_URL` を設定して、ページ要求をマスターサーバーにリダイレクトすることができます。 +6. マスター以外のノードには Redis を個別にインストールし、`REDIS_CONN_STRING` を設定して、キャッシュの有効期限が切れていないときにデータベースにゼロレイテンシーでアクセスできるようにする。 +7. メインサーバーでもデータベースへのアクセスが高レイテンシになる場合は、Redis を有効にし、`SYNC_FREQUENCY` を設定してデータベースから定期的に設定を同期する必要がある。 + +Please refer to the [environment variables](#environment-variables) section for details on using environment variables. + +### コントロールパネル(例: Baota)への展開 +詳しい手順は [#175](https://github.com/songquanpeng/one-api/issues/175) を参照してください。 + +配置後に空白のページが表示される場合は、[#97](https://github.com/songquanpeng/one-api/issues/97) を参照してください。 + +### サードパーティプラットフォームへのデプロイ +
+Sealos へのデプロイ +
+ +> Sealos は、高い同時実行性、ダイナミックなスケーリング、数百万人のユーザーに対する安定した運用をサポートしています。 + +> 下のボタンをクリックすると、ワンクリックで展開できます。👇 + +[![](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) + + +
+
+ +
+Zeabur へのデプロイ +
+ +> Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。 + +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3) + +1. まず、コードをフォークする。 +2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。 +3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。 +4. 接続パラメータをコピーし、```create database `one-api` ``` を実行してデータベースを作成する。 +5. その後、Service -> Add Service で Git を選択し(最初の使用には認証が必要です)、フォークしたリポジトリを選択します。 +6. 自動デプロイが開始されますが、一旦キャンセルしてください。Variable タブで `PORT` に `3000` を追加し、`SQL_DSN` に `:@tcp(:)/one-api` を追加します。変更を保存する。SQL_DSN` が設定されていないと、データが永続化されず、再デプロイ後にデータが失われるので注意すること。 +7. 再デプロイを選択します。 +8. Domains タブで、"my-one-api" のような適切なドメイン名の接頭辞を選択する。最終的なドメイン名は "my-one-api.zeabur.app" となります。独自のドメイン名を CNAME することもできます。 +9. デプロイが完了するのを待ち、生成されたドメイン名をクリックして One API にアクセスします。 + +
+
+ +## コンフィグ +システムは箱から出してすぐに使えます。 + +環境変数やコマンドラインパラメータを設定することで、システムを構成することができます。 + +システム起動後、`root` ユーザーとしてログインし、さらにシステムを設定します。 + +## 使用方法 +`Channels` ページで API Key を追加し、`Tokens` ページでアクセストークンを追加する。 + +アクセストークンを使って One API にアクセスすることができる。使い方は [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) と同じです。 + +OpenAI API が使用されている場所では、API Base に One API のデプロイアドレスを設定することを忘れないでください(例: `https://openai.justsong.cn`)。API Key は One API で生成されたトークンでなければなりません。 + +具体的な API Base のフォーマットは、使用しているクライアントに依存することに注意してください。 + +```mermaid +graph LR + A(ユーザ) + A --->|リクエスト| B(One API) + B -->|中継リクエスト| C(OpenAI) + B -->|中継リクエスト| D(Azure) + B -->|中継リクエスト| E(その他のダウンストリームチャンネル) +``` + +現在のリクエストにどのチャネルを使うかを指定するには、トークンの後に チャネル ID を追加します: 例えば、`Authorization: Bearer ONE_API_KEY-CHANNEL_ID` のようにします。 +チャンネル ID を指定するためには、トークンは管理者によって作成される必要があることに注意してください。 + +もしチャネル ID が指定されない場合、ロードバランシングによってリクエストが複数のチャネルに振り分けられます。 + +### 環境変数 +1. `REDIS_CONN_STRING`: 設定すると、リクエストレート制限のためのストレージとして、メモリの代わりに Redis が使われる。 + + 例: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` +2. `SESSION_SECRET`: 設定すると、固定セッションキーが使用され、システムの再起動後もログインユーザーのクッキーが有効であることが保証されます。 + + 例: `SESSION_SECRET=random_string` +3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。 + + 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` +4. `LOG_SQL_DSN`: を設定すると、`logs`テーブルには独立したデータベースが使用されます。MySQLまたはPostgreSQLを使用してください。 +5. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 + + 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` +6. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 + + 例: `SYNC_FREQUENCY=60` +7. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 + + 例: `NODE_TYPE=slave` +8. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 + + 例: `CHANNEL_UPDATE_FREQUENCY=1440` +9. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 + + 例: `CHANNEL_TEST_FREQUENCY=1440` +10. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 + + 例: `POLLING_INTERVAL=5` + +### コマンドラインパラメータ +1. `--port `: サーバがリッスンするポート番号を指定。デフォルトは `3000` です。 + + 例: `--port 3000` +2. `--log-dir `: ログディレクトリを指定。設定しない場合、ログは保存されません。 + + 例: `--log-dir ./logs` +3. `--version`: システムのバージョン番号を表示して終了する。 +4. `--help`: コマンドの使用法ヘルプとパラメータの説明を表示。 + +## スクリーンショット +![channel](https://user-images.githubusercontent.com/39998050/233837954-ae6683aa-5c4f-429f-a949-6645a83c9490.png) +![token](https://user-images.githubusercontent.com/39998050/233837971-dab488b7-6d96-43af-b640-a168e8d1c9bf.png) + +## FAQ +1. ノルマとは何か?どのように計算されますか?One API にはノルマ計算の問題はありますか? + + ノルマ = グループ倍率 * モデル倍率 * (プロンプトトークンの数 + 完了トークンの数 * 完了倍率) + + 完了倍率は、公式の定義と一致するように、GPT3.5 では 1.33、GPT4 では 2 に固定されています。 + + ストリームモードでない場合、公式 API は消費したトークンの総数を返す。ただし、プロンプトとコンプリートの消費倍率は異なるので注意してください。 +2. アカウント残高は十分なのに、"insufficient quota" と表示されるのはなぜですか? + + トークンのクォータが十分かどうかご確認ください。トークンクォータはアカウント残高とは別のものです。 + + トークンクォータは最大使用量を設定するためのもので、ユーザーが自由に設定できます。 +3. チャンネルを使おうとすると "No available channels" と表示されます。どうすればいいですか? + + ユーザーとチャンネルグループの設定を確認してください。 + + チャンネルモデルの設定も確認してください。 +4. チャンネルテストがエラーを報告する: "invalid character '<' looking for beginning of value" + + このエラーは、返された値が有効な JSON ではなく、HTML ページである場合に発生する。 + + ほとんどの場合、デプロイサイトのIPかプロキシのノードが CloudFlare によってブロックされています。 +5. ChatGPT Next Web でエラーが発生しました: "Failed to fetch" + + デプロイ時に `BASE_URL` を設定しないでください。 + + インターフェイスアドレスと API Key が正しいか再確認してください。 + +## 関連プロジェクト +* [FastGPT](https://github.com/labring/FastGPT): LLM に基づく知識質問応答システム +* [CherryStudio](https://github.com/CherryHQ/cherry-studio): マルチプラットフォーム対応のAIクライアント。複数のサービスプロバイダーを統合管理し、ローカル知識ベースをサポートします。 +## 注 +本プロジェクトはオープンソースプロジェクトです。OpenAI の[利用規約](https://openai.com/policies/terms-of-use)および**適用される法令**を遵守してご利用ください。違法な目的での利用はご遠慮ください。 + +このプロジェクトは MIT ライセンスで公開されています。これに基づき、ページの最下部に帰属表示と本プロジェクトへのリンクを含める必要があります。 + +このプロジェクトを基にした派生プロジェクトについても同様です。 + +帰属表示を含めたくない場合は、事前に許可を得なければなりません。 + +MIT ライセンスによると、このプロジェクトを利用するリスクと責任は利用者が負うべきであり、このオープンソースプロジェクトの開発者は責任を負いません。 diff --git a/README.md b/README.md new file mode 100644 index 0000000..5decf66 --- /dev/null +++ b/README.md @@ -0,0 +1,480 @@ +

+ 中文 | English | 日本語 +

+ + +

+ one-api logo +

+ +
+ +# One API + +_✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 ✨_ + +
+ +

+ + license + + + release + + + docker pull + + + release + + + GoReportCard + +

+ +

+ 部署教程 + · + 使用方法 + · + 意见反馈 + · + 截图展示 + · + 在线演示 + · + 常见问题 + · + 相关项目 + · + 赞赏支持 +

+ +> [!NOTE] +> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 +> +> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 + +> [!NOTE] +> 稳定版 / 预览版镜像地址:[justsong/one-api](https://hub.docker.com/repository/docker/justsong/one-api) +> 或者 [ghcr.io/songquanpeng/one-api](https://github.com/songquanpeng/one-api/pkgs/container/one-api) +> +> alpha 版镜像地址:[justsong/one-api-alpha](https://hub.docker.com/repository/docker/justsong/one-api-alpha) +> 或者 [ghcr.io/songquanpeng/one-api-alpha](https://github.com/songquanpeng/one-api/pkgs/container/one-api-alpha) + +> [!WARNING] +> 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! + +## 功能 +1. 支持多种大模型: + + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + + [x] [Anthropic Claude 系列模型](https://anthropic.com) (支持 AWS Claude) + + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + + [x] [Mistral 系列模型](https://mistral.ai/) + + [x] [字节跳动豆包大模型(火山引擎)](https://www.volcengine.com/experience/ark?utm_term=202502dsinvite&ac=DSASUQY5&rc=2QXCA1VI) + + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) + + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) + + [x] [360 智脑](https://ai.360.cn) + + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) + + [x] [Moonshot AI](https://platform.moonshot.cn/) + + [x] [百川大模型](https://platform.baichuan-ai.com) + + [x] [MINIMAX](https://api.minimax.chat/) + + [x] [Groq](https://wow.groq.com/) + + [x] [Ollama](https://github.com/ollama/ollama) + + [x] [零一万物](https://platform.lingyiwanwu.com/) + + [x] [阶跃星辰](https://platform.stepfun.com/) + + [x] [Coze](https://www.coze.com/) + + [x] [Cohere](https://cohere.com/) + + [x] [DeepSeek](https://www.deepseek.com/) + + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) + + [x] [DeepL](https://www.deepl.com/) + + [x] [together.ai](https://www.together.ai/) + + [x] [novita.ai](https://www.novita.ai/) + + [x] [硅基流动 SiliconCloud](https://cloud.siliconflow.cn/i/rKXmRobW) + + [x] [xAI](https://x.ai/) +2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 +3. 支持通过**负载均衡**的方式访问多个渠道。 +4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 +5. 支持**多机部署**,[详见此处](#多机部署)。 +6. 支持**令牌管理**,设置令牌的过期时间、额度、允许的 IP 范围以及允许的模型访问。 +7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 +8. 支持**渠道管理**,批量创建渠道。 +9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 +10. 支持渠道**设置模型列表**。 +11. 支持**查看额度明细**。 +12. 支持**用户邀请奖励**。 +13. 支持以美元为单位显示额度。 +14. 支持发布公告,设置充值链接,设置新用户初始额度。 +15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功。 +16. 支持失败自动重试。 +17. 支持绘图接口。 +18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。 +19. 支持丰富的**自定义**设置, + 1. 支持自定义系统名称,logo 以及页脚。 + 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 +20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。 +21. 支持 Cloudflare Turnstile 用户校验。 +22. 支持用户管理,支持**多种用户登录注册方式**: + + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 + + 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)([这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login))。 + + 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。 + + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 +23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 +24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 + +## 部署 +### 基于 Docker 进行部署 +```shell +# 使用 SQLite 的部署命令: +docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api +# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。 +# 例如: +docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api +``` + +其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。 + +数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 + +如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。 + +如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 + +如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。 + +更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` + +Nginx 的参考配置: +``` +server{ + server_name openai.justsong.cn; # 请根据实际情况修改你的域名 + + location / { + client_max_body_size 64m; + proxy_http_version 1.1; + proxy_pass http://localhost:3000; # 请根据实际情况修改你的端口 + proxy_set_header Host $host; + proxy_set_header X-Forwarded-For $remote_addr; + proxy_cache_bypass $http_upgrade; + proxy_set_header Accept-Encoding gzip; + proxy_read_timeout 300s; # GPT-4 需要较长的超时时间,请自行调整 + } +} +``` + +之后使用 Let's Encrypt 的 certbot 配置 HTTPS: +```bash +# Ubuntu 安装 certbot: +sudo snap install --classic certbot +sudo ln -s /snap/bin/certbot /usr/bin/certbot +# 生成证书 & 修改 Nginx 配置 +sudo certbot --nginx +# 根据指示进行操作 +# 重启 Nginx +sudo service nginx restart +``` + +初始账号用户名为 `root`,密码为 `123456`。 + +### 通过宝塔面板进行一键部署 +1. 安装宝塔面板9.2.0及以上版本,前往 [宝塔面板](https://www.bt.cn/new/download.html?r=dk_oneapi) 官网,选择正式版的脚本下载安装; +2. 安装后登录宝塔面板,在左侧菜单栏中点击 `Docker`,首次进入会提示安装 `Docker` 服务,点击立即安装,按提示完成安装; +3. 安装完成后在应用商店中搜索 `One-API`,点击安装,配置域名等基本信息即可完成安装; + +### 基于 Docker Compose 进行部署 + +> 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分 + +```shell +# 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内 +docker-compose up -d + +# 查看部署状态 +docker-compose ps +``` + +### 手动部署 +1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: + ```shell + git clone https://github.com/songquanpeng/one-api.git + + # 构建前端 + cd one-api/web/default + npm install + npm run build + + # 构建后端 + cd ../.. + go mod download + go build -ldflags "-s -w" -o one-api + ```` +2. 运行: + ```shell + chmod u+x one-api + ./one-api --port 3000 --log-dir ./logs + ``` +3. 访问 [http://localhost:3000/](http://localhost:3000/) 并登录。初始账号用户名为 `root`,密码为 `123456`。 + +更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。 + +### 多机部署 +1. 所有服务器 `SESSION_SECRET` 设置一样的值。 +2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite,所有服务器连接同一个数据库。 +3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 +4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。 +5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。 +6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟(Redis 集群或者哨兵模式的支持请参考环境变量说明)。 +7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。 + +环境变量的具体使用方法详见[此处](#环境变量)。 + +### 宝塔部署教程 + +详见 [#175](https://github.com/songquanpeng/one-api/issues/175)。 + +如果部署后访问出现空白页面,详见 [#97](https://github.com/songquanpeng/one-api/issues/97)。 + +### 部署第三方服务配合 One API 使用 +> 欢迎 PR 添加更多示例。 + +#### ChatGPT Next Web +项目主页:https://github.com/Yidadaa/ChatGPT-Next-Web + +```bash +docker run --name chat-next-web -d -p 3001:3000 yidadaa/chatgpt-next-web +``` + +注意修改端口号,之后在页面上设置接口地址(例如:https://openai.justsong.cn/ )和 API Key 即可。 + +#### ChatGPT Web +项目主页:https://github.com/Chanzhaoyu/chatgpt-web + +```bash +docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://openai.justsong.cn -e OPENAI_API_KEY=sk-xxx chenzhaoyu94/chatgpt-web +``` + +注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 + +#### QChatGPT - QQ机器人 +项目主页:https://github.com/RockChinQ/QChatGPT + +根据[文档](https://qchatgpt.rockchin.top)完成部署后,在 `data/provider.json`设置`requester.openai-chat-completions.base-url`为 One API 实例地址,并填写 API Key 到 `keys.openai` 组中,设置 `model` 为要使用的模型名称。 + +运行期间可以通过`!model`命令查看、切换可用模型。 + +### 部署到第三方平台 +
+部署到 Sealos +
+ +> Sealos 的服务器在国外,不需要额外处理网络问题,支持高并发 & 动态伸缩。 + +点击以下按钮一键部署(部署后访问出现 404 请等待 3~5 分钟): + +[![Deploy-on-Sealos.svg](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) + +
+
+ +
+部署到 Zeabur +
+ +> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用 + +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3) + +1. 首先 fork 一份代码。 +2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 +3. 新建一个 Project,在 Service -> Add Service 选择 Marketplace,选择 MySQL,并记下连接参数(用户名、密码、地址、端口)。 +4. 复制链接参数,运行 ```create database `one-api` ``` 创建数据库。 +5. 然后在 Service -> Add Service,选择 Git(第一次使用需要先授权),选择你 fork 的仓库。 +6. Deploy 会自动开始,先取消。进入下方 Variable,添加一个 `PORT`,值为 `3000`,再添加一个 `SQL_DSN`,值为 `:@tcp(:)/one-api` ,然后保存。 注意如果不填写 `SQL_DSN`,数据将无法持久化,重新部署后数据会丢失。 +7. 选择 Redeploy。 +8. 进入下方 Domains,选择一个合适的域名前缀,如 "my-one-api",最终域名为 "my-one-api.zeabur.app",也可以 CNAME 自己的域名。 +9. 等待部署完成,点击生成的域名进入 One API。 + +
+
+ +
+部署到 Render +
+ +> Render 提供免费额度,绑卡后可以进一步提升额度 + +Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashboard.render.com + +
+
+ +## 配置 +系统本身开箱即用。 + +你可以通过设置环境变量或者命令行参数进行配置。 + +等到系统启动后,使用 `root` 用户登录系统并做进一步的配置。 + +**Note**:如果你不知道某个配置项的含义,可以临时删掉值以看到进一步的提示文字。 + +## 使用方法 +在`渠道`页面中添加你的 API Key,之后在`令牌`页面中新增访问令牌。 + +之后就可以使用你的令牌访问 One API 了,使用方式与 [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) 一致。 + +你需要在各种用到 OpenAI API 的地方设置 API Base 为你的 One API 的部署地址,例如:`https://openai.justsong.cn`,API Key 则为你在 One API 中生成的令牌。 + +注意,具体的 API Base 的格式取决于你所使用的客户端。 + +例如对于 OpenAI 的官方库: +```bash +OPENAI_API_KEY="sk-xxxxxx" +OPENAI_API_BASE="https://:/v1" +``` + +```mermaid +graph LR + A(用户) + A --->|使用 One API 分发的 key 进行请求| B(One API) + B -->|中继请求| C(OpenAI) + B -->|中继请求| D(Azure) + B -->|中继请求| E(其他 OpenAI API 格式下游渠道) + B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道) +``` + +可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。 +注意,需要是管理员用户创建的令牌才能指定渠道 ID。 + +不加的话将会使用负载均衡的方式使用多个渠道。 + +### 环境变量 +> One API 支持从 `.env` 文件中读取环境变量,请参照 `.env.example` 文件,使用时请将其重命名为 `.env`。 +1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 + + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` + + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 + + 如果需要使用哨兵或者集群模式: + + 则需要把该环境变量设置为节点列表,例如:`localhost:49153,localhost:49154,localhost:49155`。 + + 除此之外还需要设置以下环境变量: + + `REDIS_PASSWORD`:Redis 集群或者哨兵模式下的密码设置。 + + `REDIS_MASTER_NAME`:Redis 哨兵模式下主节点的名称。 +2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 + + 例子:`SESSION_SECRET=random_string` +3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 + + 例子: + + MySQL:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` + + PostgreSQL:`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi`(适配中,欢迎反馈) + + 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。 + + 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。 + + 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。 + + 请根据你的数据库配置修改下列参数(或者保持默认值): + + `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `100`。 + + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 + + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 + + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 +4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL。 +5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 + + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` +6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 + + 例子:`MEMORY_CACHE_ENABLED=true` +7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 + + 例子:`SYNC_FREQUENCY=60` +8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 + + 例子:`NODE_TYPE=slave` +9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 + + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` +10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 + +例子:`CHANNEL_TEST_FREQUENCY=1440` +11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + + 例子:`POLLING_INTERVAL=5` +12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + + 例子:`BATCH_UPDATE_ENABLED=true` + + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 +13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + + 例子:`BATCH_UPDATE_INTERVAL=5` +14. 请求频率限制: + + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 +15. 编码器缓存设置: + + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 +16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 +17. `RELAY_PROXY`:设置后使用该代理来请求 API。 +18. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 +19. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 +20. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 +21. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 +22. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 +23. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 +24. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 +25. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 +26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 +27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 +28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 +29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage,默认不开启,可选值为 `true` 和 `false`。 +30. `TEST_PROMPT`:测试模型时的用户 prompt,默认为 `Print your model name exactly and do not output without any other text.`。 + +### 命令行参数 +1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 + + 例子:`--port 3000` +2. `--log-dir `: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。 + + 例子:`--log-dir ./logs` +3. `--version`: 打印系统版本号并退出。 +4. `--help`: 查看命令的使用帮助和参数说明。 + +## 演示 +### 在线演示 +注意,该演示站不提供对外服务: +https://openai.justsong.cn + +### 截图展示 +![channel](https://user-images.githubusercontent.com/39998050/233837954-ae6683aa-5c4f-429f-a949-6645a83c9490.png) +![token](https://user-images.githubusercontent.com/39998050/233837971-dab488b7-6d96-43af-b640-a168e8d1c9bf.png) + +## 常见问题 +1. 额度是什么?怎么计算的?One API 的额度计算有问题? + + 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率) + + 其中补全倍率对于 GPT3.5 固定为 1.33,GPT4 为 2,与官方保持一致。 + + 如果是非流模式,官方接口会返回消耗的总 token,但是你要注意提示和补全的消耗倍率不一样。 + + 注意,One API 的默认倍率就是官方倍率,是已经调整过的。 +2. 账户额度足够为什么提示额度不足? + + 请检查你的令牌额度是否足够,这个和账户额度是分开的。 + + 令牌额度仅供用户设置最大使用量,用户可自由设置。 +3. 提示无可用渠道? + + 请检查的用户分组和渠道分组设置。 + + 以及渠道的模型设置。 +4. 渠道测试报错:`invalid character '<' looking for beginning of value` + + 这是因为返回值不是合法的 JSON,而是一个 HTML 页面。 + + 大概率是你的部署站的 IP 或代理的节点被 CloudFlare 封禁了。 +5. ChatGPT Next Web 报错:`Failed to fetch` + + 部署的时候不要设置 `BASE_URL`。 + + 检查你的接口地址和 API Key 有没有填对。 + + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 +6. 报错:`当前分组负载已饱和,请稍后再试` + + 上游渠道 429 了。 +7. 升级之后我的数据会丢失吗? + + 如果使用 MySQL,不会。 + + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 +8. 升级之前数据库需要做变更吗? + + 一般情况下不需要,系统将在初始化的时候自动调整。 + + 如果需要的话,我会在更新日志中说明,并给出脚本。 +9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`? + + 这是检测到 ability 表里有些记录的渠道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的渠道。 + + 对于每一个渠道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该渠道支持该模型。 + +## 相关项目 +* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 +* [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用 +* [VChart](https://github.com/VisActor/VChart): 不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。 +* [VMind](https://github.com/VisActor/VMind): 不仅自动,还很智能。开源智能可视化解决方案。 +* [CherryStudio](https://github.com/CherryHQ/cherry-studio): 全平台支持的AI客户端, 多服务商集成管理、本地知识库支持。 + +## 注意 + +本项目使用 MIT 协议进行开源,**在此基础上**,必须在页面底部保留署名以及指向本项目的链接。如果不想保留署名,必须首先获得授权。 + +同样适用于基于本项目的二开项目。 + +依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。 diff --git a/VERSION b/VERSION new file mode 100644 index 0000000..e69de29 diff --git a/bin/migration_v0.2-v0.3.sql b/bin/migration_v0.2-v0.3.sql new file mode 100644 index 0000000..6b08d7b --- /dev/null +++ b/bin/migration_v0.2-v0.3.sql @@ -0,0 +1,6 @@ +UPDATE users +SET quota = quota + ( + SELECT SUM(remain_quota) + FROM tokens + WHERE tokens.user_id = users.id +) diff --git a/bin/migration_v0.3-v0.4.sql b/bin/migration_v0.3-v0.4.sql new file mode 100644 index 0000000..e6103c2 --- /dev/null +++ b/bin/migration_v0.3-v0.4.sql @@ -0,0 +1,17 @@ +INSERT INTO abilities (`group`, model, channel_id, enabled) +SELECT c.`group`, m.model, c.id, 1 +FROM channels c +CROSS JOIN ( + SELECT 'gpt-3.5-turbo' AS model UNION ALL + SELECT 'gpt-3.5-turbo-0301' AS model UNION ALL + SELECT 'gpt-4' AS model UNION ALL + SELECT 'gpt-4-0314' AS model +) AS m +WHERE c.status = 1 + AND NOT EXISTS ( + SELECT 1 + FROM abilities a + WHERE a.`group` = c.`group` + AND a.model = m.model + AND a.channel_id = c.id +); diff --git a/bin/time_test.sh b/bin/time_test.sh new file mode 100644 index 0000000..2cde4a6 --- /dev/null +++ b/bin/time_test.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +if [ $# -lt 3 ]; then + echo "Usage: time_test.sh []" + exit 1 +fi + +domain=$1 +key=$2 +count=$3 +model=${4:-"gpt-3.5-turbo"} # 设置默认模型为 gpt-3.5-turbo + +total_time=0 +times=() + +for ((i=1; i<=count; i++)); do + result=$(curl -o /dev/null -s -w "%{http_code} %{time_total}\\n" \ + https://"$domain"/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $key" \ + -d '{"messages": [{"content": "echo hi", "role": "user"}], "model": "'"$model"'", "stream": false, "max_tokens": 1}') + http_code=$(echo "$result" | awk '{print $1}') + time=$(echo "$result" | awk '{print $2}') + echo "HTTP status code: $http_code, Time taken: $time" + total_time=$(bc <<< "$total_time + $time") + times+=("$time") +done + +average_time=$(echo "scale=4; $total_time / $count" | bc) + +sum_of_squares=0 +for time in "${times[@]}"; do + difference=$(echo "scale=4; $time - $average_time" | bc) + square=$(echo "scale=4; $difference * $difference" | bc) + sum_of_squares=$(echo "scale=4; $sum_of_squares + $square" | bc) +done + +standard_deviation=$(echo "scale=4; sqrt($sum_of_squares / $count)" | bc) + +echo "Average time: $average_time±$standard_deviation" diff --git a/common/blacklist/main.go b/common/blacklist/main.go new file mode 100644 index 0000000..f84ce6a --- /dev/null +++ b/common/blacklist/main.go @@ -0,0 +1,29 @@ +package blacklist + +import ( + "fmt" + "sync" +) + +var blackList sync.Map + +func init() { + blackList = sync.Map{} +} + +func userId2Key(id int) string { + return fmt.Sprintf("userid_%d", id) +} + +func BanUser(id int) { + blackList.Store(userId2Key(id), true) +} + +func UnbanUser(id int) { + blackList.Delete(userId2Key(id)) +} + +func IsUserBanned(id int) bool { + _, ok := blackList.Load(userId2Key(id)) + return ok +} diff --git a/common/client/init.go b/common/client/init.go new file mode 100644 index 0000000..f803cbf --- /dev/null +++ b/common/client/init.go @@ -0,0 +1,60 @@ +package client + +import ( + "fmt" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "net/http" + "net/url" + "time" +) + +var HTTPClient *http.Client +var ImpatientHTTPClient *http.Client +var UserContentRequestHTTPClient *http.Client + +func Init() { + if config.UserContentRequestProxy != "" { + logger.SysLog(fmt.Sprintf("using %s as proxy to fetch user content", config.UserContentRequestProxy)) + proxyURL, err := url.Parse(config.UserContentRequestProxy) + if err != nil { + logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy)) + } + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + UserContentRequestHTTPClient = &http.Client{ + Transport: transport, + Timeout: time.Second * time.Duration(config.UserContentRequestTimeout), + } + } else { + UserContentRequestHTTPClient = &http.Client{} + } + var transport http.RoundTripper + if config.RelayProxy != "" { + logger.SysLog(fmt.Sprintf("using %s as api relay proxy", config.RelayProxy)) + proxyURL, err := url.Parse(config.RelayProxy) + if err != nil { + logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy)) + } + transport = &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + } + + if config.RelayTimeout == 0 { + HTTPClient = &http.Client{ + Transport: transport, + } + } else { + HTTPClient = &http.Client{ + Timeout: time.Duration(config.RelayTimeout) * time.Second, + Transport: transport, + } + } + + ImpatientHTTPClient = &http.Client{ + Timeout: 5 * time.Second, + Transport: transport, + } +} diff --git a/common/config/config.go b/common/config/config.go new file mode 100644 index 0000000..a235a8d --- /dev/null +++ b/common/config/config.go @@ -0,0 +1,166 @@ +package config + +import ( + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/songquanpeng/one-api/common/env" + + "github.com/google/uuid" +) + +var SystemName = "One API" +var ServerAddress = "http://localhost:3000" +var Footer = "" +var Logo = "" +var TopUpLink = "" +var ChatLink = "" +var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens +var DisplayInCurrencyEnabled = true +var DisplayTokenStatEnabled = true + +// Any options with "Secret", "Token" in its key won't be return by GetOptions + +var SessionSecret = uuid.New().String() + +var OptionMap map[string]string +var OptionMapRWMutex sync.RWMutex + +var ItemsPerPage = 10 +var MaxRecentItems = 100 + +var PasswordLoginEnabled = true +var PasswordRegisterEnabled = true +var EmailVerificationEnabled = false +var GitHubOAuthEnabled = false +var OidcEnabled = false +var WeChatAuthEnabled = false +var TurnstileCheckEnabled = false +var RegisterEnabled = true + +var EmailDomainRestrictionEnabled = false +var EmailDomainWhitelist = []string{ + "gmail.com", + "163.com", + "126.com", + "qq.com", + "outlook.com", + "hotmail.com", + "icloud.com", + "yahoo.com", + "foxmail.com", +} + +var DebugEnabled = strings.ToLower(os.Getenv("DEBUG")) == "true" +var DebugSQLEnabled = strings.ToLower(os.Getenv("DEBUG_SQL")) == "true" +var MemoryCacheEnabled = strings.ToLower(os.Getenv("MEMORY_CACHE_ENABLED")) == "true" + +var LogConsumeEnabled = true + +var SMTPServer = "" +var SMTPPort = 587 +var SMTPAccount = "" +var SMTPFrom = "" +var SMTPToken = "" + +var GitHubClientId = "" +var GitHubClientSecret = "" + +var LarkClientId = "" +var LarkClientSecret = "" + +var OidcClientId = "" +var OidcClientSecret = "" +var OidcWellKnown = "" +var OidcAuthorizationEndpoint = "" +var OidcTokenEndpoint = "" +var OidcUserinfoEndpoint = "" + +var WeChatServerAddress = "" +var WeChatServerToken = "" +var WeChatAccountQRCodeImageURL = "" + +var MessagePusherAddress = "" +var MessagePusherToken = "" + +var TurnstileSiteKey = "" +var TurnstileSecretKey = "" + +var QuotaForNewUser int64 = 0 +var QuotaForInviter int64 = 0 +var QuotaForInvitee int64 = 0 +var ChannelDisableThreshold = 5.0 +var AutomaticDisableChannelEnabled = false +var AutomaticEnableChannelEnabled = false +var QuotaRemindThreshold int64 = 1000 +var PreConsumedQuota int64 = 500 +var ApproximateTokenEnabled = false +var RetryTimes = 0 + +var RootUserEmail = "" + +var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" + +var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) +var RequestInterval = time.Duration(requestInterval) * time.Second + +var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second + +var BatchUpdateEnabled = false +var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5) + +var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second + +var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE") + +var Theme = env.String("THEME", "default") +var ValidThemes = map[string]bool{ + "default": true, + "berry": true, + "air": true, +} + +// All duration's unit is seconds +// Shouldn't larger then RateLimitKeyExpirationDuration +var ( + GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 480) + GlobalApiRateLimitDuration int64 = 3 * 60 + + GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 240) + GlobalWebRateLimitDuration int64 = 3 * 60 + + UploadRateLimitNum = 10 + UploadRateLimitDuration int64 = 60 + + DownloadRateLimitNum = 10 + DownloadRateLimitDuration int64 = 60 + + CriticalRateLimitNum = 20 + CriticalRateLimitDuration int64 = 20 * 60 +) + +var RateLimitKeyExpirationDuration = 20 * time.Minute + +var EnableMetric = env.Bool("ENABLE_METRIC", false) +var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10) +var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) +var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024) +var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) + +var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") + +var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN") + +var GeminiVersion = env.String("GEMINI_VERSION", "v1") + +var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) + +var RelayProxy = env.String("RELAY_PROXY", "") +var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") +var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) + +var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false) +var TestPrompt = env.String("TEST_PROMPT", "Output only your specific model name with no additional text.") diff --git a/common/constants.go b/common/constants.go new file mode 100644 index 0000000..87221b6 --- /dev/null +++ b/common/constants.go @@ -0,0 +1,6 @@ +package common + +import "time" + +var StartTime = time.Now().Unix() // unit: second +var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change diff --git a/common/conv/any.go b/common/conv/any.go new file mode 100644 index 0000000..467e8bb --- /dev/null +++ b/common/conv/any.go @@ -0,0 +1,6 @@ +package conv + +func AsString(v any) string { + str, _ := v.(string) + return str +} diff --git a/common/crypto.go b/common/crypto.go new file mode 100644 index 0000000..4522841 --- /dev/null +++ b/common/crypto.go @@ -0,0 +1,14 @@ +package common + +import "golang.org/x/crypto/bcrypt" + +func Password2Hash(password string) (string, error) { + passwordBytes := []byte(password) + hashedPassword, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost) + return string(hashedPassword), err +} + +func ValidatePasswordAndHash(password string, hash string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + return err == nil +} diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go new file mode 100644 index 0000000..115558a --- /dev/null +++ b/common/ctxkey/key.go @@ -0,0 +1,24 @@ +package ctxkey + +const ( + Config = "config" + Id = "id" + Username = "username" + Role = "role" + Status = "status" + Channel = "channel" + ChannelId = "channel_id" + SpecificChannelId = "specific_channel_id" + RequestModel = "request_model" + ConvertedRequest = "converted_request" + OriginalModel = "original_model" + Group = "group" + ModelMapping = "model_mapping" + ChannelName = "channel_name" + TokenId = "token_id" + TokenName = "token_name" + BaseURL = "base_url" + AvailableModels = "available_models" + KeyRequestBody = "key_request_body" + SystemPrompt = "system_prompt" +) diff --git a/common/custom-event.go b/common/custom-event.go new file mode 100644 index 0000000..69da4bc --- /dev/null +++ b/common/custom-event.go @@ -0,0 +1,82 @@ +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package common + +import ( + "fmt" + "io" + "net/http" + "strings" +) + +type stringWriter interface { + io.Writer + writeString(string) (int, error) +} + +type stringWrapper struct { + io.Writer +} + +func (w stringWrapper) writeString(str string) (int, error) { + return w.Writer.Write([]byte(str)) +} + +func checkWriter(writer io.Writer) stringWriter { + if w, ok := writer.(stringWriter); ok { + return w + } else { + return stringWrapper{writer} + } +} + +// Server-Sent Events +// W3C Working Draft 29 October 2009 +// http://www.w3.org/TR/2009/WD-eventsource-20091029/ + +var contentType = []string{"text/event-stream"} +var noCache = []string{"no-cache"} + +var fieldReplacer = strings.NewReplacer( + "\n", "\\n", + "\r", "\\r") + +var dataReplacer = strings.NewReplacer( + "\n", "\ndata:", + "\r", "\\r") + +type CustomEvent struct { + Event string + Id string + Retry uint + Data interface{} +} + +func encode(writer io.Writer, event CustomEvent) error { + w := checkWriter(writer) + return writeData(w, event.Data) +} + +func writeData(w stringWriter, data interface{}) error { + dataReplacer.WriteString(w, fmt.Sprint(data)) + if strings.HasPrefix(data.(string), "data") { + w.writeString("\n\n") + } + return nil +} + +func (r CustomEvent) Render(w http.ResponseWriter) error { + r.WriteContentType(w) + return encode(w, r) +} + +func (r CustomEvent) WriteContentType(w http.ResponseWriter) { + header := w.Header() + header["Content-Type"] = contentType + + if _, exist := header["Cache-Control"]; !exist { + header["Cache-Control"] = noCache + } +} diff --git a/common/database.go b/common/database.go new file mode 100644 index 0000000..f2db759 --- /dev/null +++ b/common/database.go @@ -0,0 +1,12 @@ +package common + +import ( + "github.com/songquanpeng/one-api/common/env" +) + +var UsingSQLite = false +var UsingPostgreSQL = false +var UsingMySQL = false + +var SQLitePath = "one-api.db" +var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/embed-file-system.go b/common/embed-file-system.go new file mode 100644 index 0000000..7c0e4b4 --- /dev/null +++ b/common/embed-file-system.go @@ -0,0 +1,29 @@ +package common + +import ( + "embed" + "github.com/gin-contrib/static" + "io/fs" + "net/http" +) + +// Credit: https://github.com/gin-contrib/static/issues/19 + +type embedFileSystem struct { + http.FileSystem +} + +func (e embedFileSystem) Exists(prefix string, path string) bool { + _, err := e.Open(path) + return err == nil +} + +func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem { + efs, err := fs.Sub(fsEmbed, targetPath) + if err != nil { + panic(err) + } + return embedFileSystem{ + FileSystem: http.FS(efs), + } +} diff --git a/common/env/helper.go b/common/env/helper.go new file mode 100644 index 0000000..fdb9f82 --- /dev/null +++ b/common/env/helper.go @@ -0,0 +1,42 @@ +package env + +import ( + "os" + "strconv" +) + +func Bool(env string, defaultValue bool) bool { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) == "true" +} + +func Int(env string, defaultValue int) int { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.Atoi(os.Getenv(env)) + if err != nil { + return defaultValue + } + return num +} + +func Float64(env string, defaultValue float64) float64 { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.ParseFloat(os.Getenv(env), 64) + if err != nil { + return defaultValue + } + return num +} + +func String(env string, defaultValue string) string { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) +} diff --git a/common/gin.go b/common/gin.go new file mode 100644 index 0000000..e3281fe --- /dev/null +++ b/common/gin.go @@ -0,0 +1,53 @@ +package common + +import ( + "bytes" + "encoding/json" + "io" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/ctxkey" +) + +func GetRequestBody(c *gin.Context) ([]byte, error) { + requestBody, _ := c.Get(ctxkey.KeyRequestBody) + if requestBody != nil { + return requestBody.([]byte), nil + } + requestBody, err := io.ReadAll(c.Request.Body) + if err != nil { + return nil, err + } + _ = c.Request.Body.Close() + c.Set(ctxkey.KeyRequestBody, requestBody) + return requestBody.([]byte), nil +} + +func UnmarshalBodyReusable(c *gin.Context, v any) error { + requestBody, err := GetRequestBody(c) + if err != nil { + return err + } + contentType := c.Request.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "application/json") { + err = json.Unmarshal(requestBody, &v) + } else { + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + err = c.ShouldBind(&v) + } + if err != nil { + return err + } + // Reset request body + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return nil +} + +func SetEventStreamHeaders(c *gin.Context) { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") +} diff --git a/common/helper/helper.go b/common/helper/helper.go new file mode 100644 index 0000000..65f4fd2 --- /dev/null +++ b/common/helper/helper.go @@ -0,0 +1,174 @@ +package helper + +import ( + "context" + "fmt" + "html/template" + "log" + "net" + "os/exec" + "runtime" + "strconv" + "strings" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/random" +) + +func OpenBrowser(url string) { + var err error + + switch runtime.GOOS { + case "linux": + err = exec.Command("xdg-open", url).Start() + case "windows": + err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + err = exec.Command("open", url).Start() + } + if err != nil { + log.Println(err) + } +} + +func GetIp() (ip string) { + ips, err := net.InterfaceAddrs() + if err != nil { + log.Println(err) + return ip + } + + for _, a := range ips { + if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.To4() != nil { + ip = ipNet.IP.String() + if strings.HasPrefix(ip, "10") { + return + } + if strings.HasPrefix(ip, "172") { + return + } + if strings.HasPrefix(ip, "192.168") { + return + } + ip = "" + } + } + } + return +} + +var sizeKB = 1024 +var sizeMB = sizeKB * 1024 +var sizeGB = sizeMB * 1024 + +func Bytes2Size(num int64) string { + numStr := "" + unit := "B" + if num/int64(sizeGB) > 1 { + numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) + unit = "GB" + } else if num/int64(sizeMB) > 1 { + numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) + unit = "MB" + } else if num/int64(sizeKB) > 1 { + numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) + unit = "KB" + } else { + numStr = fmt.Sprintf("%d", num) + } + return numStr + " " + unit +} + +func Interface2String(inter interface{}) string { + switch inter := inter.(type) { + case string: + return inter + case int: + return fmt.Sprintf("%d", inter) + case float64: + return fmt.Sprintf("%f", inter) + } + return "Not Implemented" +} + +func UnescapeHTML(x string) interface{} { + return template.HTML(x) +} + +func IntMax(a int, b int) int { + if a >= b { + return a + } else { + return b + } +} + +func GenRequestID() string { + return GetTimeString() + random.GetRandomNumberString(8) +} + +func SetRequestID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, RequestIdKey, id) +} + +func GetRequestID(ctx context.Context) string { + rawRequestId := ctx.Value(RequestIdKey) + if rawRequestId == nil { + return "" + } + return rawRequestId.(string) +} + +func GetResponseID(c *gin.Context) string { + logID := c.GetString(RequestIdKey) + return fmt.Sprintf("chatcmpl-%s", logID) +} + +func Max(a int, b int) int { + if a >= b { + return a + } else { + return b + } +} + +func AssignOrDefault(value string, defaultValue string) string { + if len(value) != 0 { + return value + } + return defaultValue +} + +func MessageWithRequestId(message string, id string) string { + return fmt.Sprintf("%s (request id: %s)", message, id) +} + +func String2Int(str string) int { + num, err := strconv.Atoi(str) + if err != nil { + return 0 + } + return num +} + +func Float64PtrMax(p *float64, maxValue float64) *float64 { + if p == nil { + return nil + } + if *p > maxValue { + return &maxValue + } + return p +} + +func Float64PtrMin(p *float64, minValue float64) *float64 { + if p == nil { + return nil + } + if *p < minValue { + return &minValue + } + return p +} diff --git a/common/helper/key.go b/common/helper/key.go new file mode 100644 index 0000000..17aee2e --- /dev/null +++ b/common/helper/key.go @@ -0,0 +1,5 @@ +package helper + +const ( + RequestIdKey = "X-Oneapi-Request-Id" +) diff --git a/common/helper/time.go b/common/helper/time.go new file mode 100644 index 0000000..f0bc602 --- /dev/null +++ b/common/helper/time.go @@ -0,0 +1,20 @@ +package helper + +import ( + "fmt" + "time" +) + +func GetTimestamp() int64 { + return time.Now().Unix() +} + +func GetTimeString() string { + now := time.Now() + return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) +} + +// CalcElapsedTime return the elapsed time in milliseconds (ms) +func CalcElapsedTime(start time.Time) int64 { + return time.Now().Sub(start).Milliseconds() +} diff --git a/common/i18n/i18n.go b/common/i18n/i18n.go new file mode 100644 index 0000000..dfad6ea --- /dev/null +++ b/common/i18n/i18n.go @@ -0,0 +1,72 @@ +package i18n + +import ( + "embed" + "encoding/json" + "strings" + + "github.com/gin-gonic/gin" +) + +//go:embed locales/*.json +var localesFS embed.FS + +var ( + translations = make(map[string]map[string]string) + defaultLang = "en" + ContextKey = "i18n" +) + +// Init loads all translation files from embedded filesystem +func Init() error { + entries, err := localesFS.ReadDir("locales") + if err != nil { + return err + } + + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") { + continue + } + + langCode := strings.TrimSuffix(entry.Name(), ".json") + content, err := localesFS.ReadFile("locales/" + entry.Name()) + if err != nil { + return err + } + + var translation map[string]string + if err := json.Unmarshal(content, &translation); err != nil { + return err + } + translations[langCode] = translation + } + + return nil +} + +func GetLang(c *gin.Context) string { + rawLang, ok := c.Get(ContextKey) + if !ok { + return defaultLang + } + lang, _ := rawLang.(string) + if lang != "" { + return lang + } + return defaultLang +} + +func Translate(c *gin.Context, message string) string { + lang := GetLang(c) + return translateHelper(lang, message) +} + +func translateHelper(lang, message string) string { + if trans, ok := translations[lang]; ok { + if translated, exists := trans[message]; exists { + return translated + } + } + return message +} diff --git a/common/i18n/locales/en.json b/common/i18n/locales/en.json new file mode 100644 index 0000000..4b24dea --- /dev/null +++ b/common/i18n/locales/en.json @@ -0,0 +1,5 @@ +{ + "invalid_input": "Invalid input, please check your input", + "send_email_failed": "failed to send email: ", + "invalid_parameter": "invalid parameter" +} diff --git a/common/i18n/locales/zh-CN.json b/common/i18n/locales/zh-CN.json new file mode 100644 index 0000000..805d5c5 --- /dev/null +++ b/common/i18n/locales/zh-CN.json @@ -0,0 +1,5 @@ +{ + "invalid_input": "无效的输入,请检查您的输入", + "send_email_failed": "发送邮件失败:", + "invalid_parameter": "无效的参数" +} diff --git a/common/image/image.go b/common/image/image.go new file mode 100644 index 0000000..beebd0c --- /dev/null +++ b/common/image/image.go @@ -0,0 +1,112 @@ +package image + +import ( + "bytes" + "encoding/base64" + "github.com/songquanpeng/one-api/common/client" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "net/http" + "regexp" + "strings" + "sync" + + _ "golang.org/x/image/webp" +) + +// Regex to match data URL pattern +var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) + +func IsImageUrl(url string) (bool, error) { + resp, err := client.UserContentRequestHTTPClient.Head(url) + if err != nil { + return false, err + } + if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { + return false, nil + } + return true, nil +} + +func GetImageSizeFromUrl(url string) (width int, height int, err error) { + isImage, err := IsImageUrl(url) + if !isImage { + return + } + resp, err := client.UserContentRequestHTTPClient.Get(url) + if err != nil { + return + } + defer resp.Body.Close() + img, _, err := image.DecodeConfig(resp.Body) + if err != nil { + return + } + return img.Width, img.Height, nil +} + +func GetImageFromUrl(url string) (mimeType string, data string, err error) { + // Check if the URL is a data URL + matches := dataURLPattern.FindStringSubmatch(url) + if len(matches) == 3 { + // URL is a data URL + mimeType = "image/" + matches[1] + data = matches[2] + return + } + + isImage, err := IsImageUrl(url) + if !isImage { + return + } + resp, err := http.Get(url) + if err != nil { + return + } + defer resp.Body.Close() + buffer := bytes.NewBuffer(nil) + _, err = buffer.ReadFrom(resp.Body) + if err != nil { + return + } + mimeType = resp.Header.Get("Content-Type") + data = base64.StdEncoding.EncodeToString(buffer.Bytes()) + return +} + +var ( + reg = regexp.MustCompile(`data:image/([^;]+);base64,`) +) + +var readerPool = sync.Pool{ + New: func() interface{} { + return &bytes.Reader{} + }, +} + +func GetImageSizeFromBase64(encoded string) (width int, height int, err error) { + decoded, err := base64.StdEncoding.DecodeString(reg.ReplaceAllString(encoded, "")) + if err != nil { + return 0, 0, err + } + + reader := readerPool.Get().(*bytes.Reader) + defer readerPool.Put(reader) + reader.Reset(decoded) + + img, _, err := image.DecodeConfig(reader) + if err != nil { + return 0, 0, err + } + + return img.Width, img.Height, nil +} + +func GetImageSize(image string) (width int, height int, err error) { + if strings.HasPrefix(image, "data:image/") { + return GetImageSizeFromBase64(image) + } + return GetImageSizeFromUrl(image) +} diff --git a/common/image/image_test.go b/common/image/image_test.go new file mode 100644 index 0000000..5b669b5 --- /dev/null +++ b/common/image/image_test.go @@ -0,0 +1,177 @@ +package image_test + +import ( + "encoding/base64" + "github.com/songquanpeng/one-api/common/client" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "io" + "net/http" + "strconv" + "strings" + "testing" + + img "github.com/songquanpeng/one-api/common/image" + + "github.com/stretchr/testify/assert" + _ "golang.org/x/image/webp" +) + +type CountingReader struct { + reader io.Reader + BytesRead int +} + +func (r *CountingReader) Read(p []byte) (n int, err error) { + n, err = r.reader.Read(p) + r.BytesRead += n + return n, err +} + +var ( + cases = []struct { + url string + format string + width int + height int + }{ + {"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669}, + {"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592}, + {"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985}, + {"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533}, + {"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230}, + } +) + +func TestMain(m *testing.M) { + client.Init() + m.Run() +} + +func TestDecode(t *testing.T) { + // Bytes read: varies sometimes + // jpeg: 1063892 + // png: 294462 + // webp: 99529 + // gif: 956153 + // jpeg#01: 32805 + for _, c := range cases { + t.Run("Decode:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + reader := &CountingReader{reader: resp.Body} + img, format, err := image.Decode(reader) + assert.NoError(t, err) + size := img.Bounds().Size() + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, size.X) + assert.Equal(t, c.height, size.Y) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } + + // Bytes read: + // jpeg: 4096 + // png: 4096 + // webp: 4096 + // gif: 4096 + // jpeg#01: 4096 + for _, c := range cases { + t.Run("DecodeConfig:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + reader := &CountingReader{reader: resp.Body} + config, format, err := image.DecodeConfig(reader) + assert.NoError(t, err) + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, config.Width) + assert.Equal(t, c.height, config.Height) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } +} + +func TestBase64(t *testing.T) { + // Bytes read: + // jpeg: 1063892 + // png: 294462 + // webp: 99072 + // gif: 953856 + // jpeg#01: 32805 + for _, c := range cases { + t.Run("Decode:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + reader := &CountingReader{reader: body} + img, format, err := image.Decode(reader) + assert.NoError(t, err) + size := img.Bounds().Size() + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, size.X) + assert.Equal(t, c.height, size.Y) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } + + // Bytes read: + // jpeg: 1536 + // png: 768 + // webp: 768 + // gif: 1536 + // jpeg#01: 3840 + for _, c := range cases { + t.Run("DecodeConfig:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + reader := &CountingReader{reader: body} + config, format, err := image.DecodeConfig(reader) + assert.NoError(t, err) + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, config.Width) + assert.Equal(t, c.height, config.Height) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } +} + +func TestGetImageSize(t *testing.T) { + for i, c := range cases { + t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { + width, height, err := img.GetImageSize(c.url) + assert.NoError(t, err) + assert.Equal(t, c.width, width) + assert.Equal(t, c.height, height) + }) + } +} + +func TestGetImageSizeFromBase64(t *testing.T) { + for i, c := range cases { + t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + width, height, err := img.GetImageSizeFromBase64(encoded) + assert.NoError(t, err) + assert.Equal(t, c.width, width) + assert.Equal(t, c.height, height) + }) + } +} diff --git a/common/init.go b/common/init.go new file mode 100644 index 0000000..6fd8476 --- /dev/null +++ b/common/init.go @@ -0,0 +1,64 @@ +package common + +import ( + "flag" + "fmt" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "log" + "os" + "path/filepath" +) + +var ( + Port = flag.Int("port", 3000, "the listening port") + PrintVersion = flag.Bool("version", false, "print version and exit") + PrintHelp = flag.Bool("help", false, "print help and exit") + LogDir = flag.String("log-dir", "./logs", "specify the log directory") +) + +func printHelp() { + fmt.Println("One API " + Version + " - All in one API service for OpenAI API.") + fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.") + fmt.Println("GitHub: https://github.com/songquanpeng/one-api") + fmt.Println("Usage: one-api [--port ] [--log-dir ] [--version] [--help]") +} + +func Init() { + flag.Parse() + + if *PrintVersion { + fmt.Println(Version) + os.Exit(0) + } + + if *PrintHelp { + printHelp() + os.Exit(0) + } + + if os.Getenv("SESSION_SECRET") != "" { + if os.Getenv("SESSION_SECRET") == "random_string" { + logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.") + } else { + config.SessionSecret = os.Getenv("SESSION_SECRET") + } + } + if os.Getenv("SQLITE_PATH") != "" { + SQLitePath = os.Getenv("SQLITE_PATH") + } + if *LogDir != "" { + var err error + *LogDir, err = filepath.Abs(*LogDir) + if err != nil { + log.Fatal(err) + } + if _, err := os.Stat(*LogDir); os.IsNotExist(err) { + err = os.Mkdir(*LogDir, 0777) + if err != nil { + log.Fatal(err) + } + } + logger.LogDir = *LogDir + } +} diff --git a/common/logger/constants.go b/common/logger/constants.go new file mode 100644 index 0000000..49df31e --- /dev/null +++ b/common/logger/constants.go @@ -0,0 +1,3 @@ +package logger + +var LogDir string diff --git a/common/logger/logger.go b/common/logger/logger.go new file mode 100644 index 0000000..724bc02 --- /dev/null +++ b/common/logger/logger.go @@ -0,0 +1,160 @@ +package logger + +import ( + "context" + "fmt" + "io" + "log" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" +) + +type loggerLevel string + +const ( + loggerDEBUG loggerLevel = "DEBUG" + loggerINFO loggerLevel = "INFO" + loggerWarn loggerLevel = "WARN" + loggerError loggerLevel = "ERROR" + loggerFatal loggerLevel = "FATAL" +) + +var setupLogOnce sync.Once + +func SetupLogger() { + setupLogOnce.Do(func() { + if LogDir != "" { + var logPath string + if config.OnlyOneLogFile { + logPath = filepath.Join(LogDir, "oneapi.log") + } else { + logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + } + fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Fatal("failed to open log file") + } + gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) + gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) + } + }) +} + +func SysLog(s string) { + logHelper(nil, loggerINFO, s) +} + +func SysLogf(format string, a ...any) { + logHelper(nil, loggerINFO, fmt.Sprintf(format, a...)) +} + +func SysWarn(s string) { + logHelper(nil, loggerWarn, s) +} + +func SysWarnf(format string, a ...any) { + logHelper(nil, loggerWarn, fmt.Sprintf(format, a...)) +} + +func SysError(s string) { + logHelper(nil, loggerError, s) +} + +func SysErrorf(format string, a ...any) { + logHelper(nil, loggerError, fmt.Sprintf(format, a...)) +} + +func Debug(ctx context.Context, msg string) { + if !config.DebugEnabled { + return + } + logHelper(ctx, loggerDEBUG, msg) +} + +func Info(ctx context.Context, msg string) { + logHelper(ctx, loggerINFO, msg) +} + +func Warn(ctx context.Context, msg string) { + logHelper(ctx, loggerWarn, msg) +} + +func Error(ctx context.Context, msg string) { + logHelper(ctx, loggerError, msg) +} + +func Debugf(ctx context.Context, format string, a ...any) { + if !config.DebugEnabled { + return + } + logHelper(ctx, loggerDEBUG, fmt.Sprintf(format, a...)) +} + +func Infof(ctx context.Context, format string, a ...any) { + logHelper(ctx, loggerINFO, fmt.Sprintf(format, a...)) +} + +func Warnf(ctx context.Context, format string, a ...any) { + logHelper(ctx, loggerWarn, fmt.Sprintf(format, a...)) +} + +func Errorf(ctx context.Context, format string, a ...any) { + logHelper(ctx, loggerError, fmt.Sprintf(format, a...)) +} + +func FatalLog(s string) { + logHelper(nil, loggerFatal, s) +} + +func FatalLogf(format string, a ...any) { + logHelper(nil, loggerFatal, fmt.Sprintf(format, a...)) +} + +func logHelper(ctx context.Context, level loggerLevel, msg string) { + writer := gin.DefaultErrorWriter + if level == loggerINFO { + writer = gin.DefaultWriter + } + var requestId string + if ctx != nil { + rawRequestId := helper.GetRequestID(ctx) + if rawRequestId != "" { + requestId = fmt.Sprintf(" | %s", rawRequestId) + } + } + lineInfo, funcName := getLineInfo() + now := time.Now() + _, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), requestId, lineInfo, funcName, msg) + SetupLogger() + if level == loggerFatal { + os.Exit(1) + } +} + +func getLineInfo() (string, string) { + funcName := "[unknown] " + pc, file, line, ok := runtime.Caller(3) + if ok { + if fn := runtime.FuncForPC(pc); fn != nil { + parts := strings.Split(fn.Name(), ".") + funcName = "[" + parts[len(parts)-1] + "] " + } + } else { + file = "unknown" + line = 0 + } + parts := strings.Split(file, "one-api/") + if len(parts) > 1 { + file = parts[1] + } + return fmt.Sprintf(" | %s:%d", file, line), funcName +} diff --git a/common/message/email.go b/common/message/email.go new file mode 100644 index 0000000..85a83d6 --- /dev/null +++ b/common/message/email.go @@ -0,0 +1,111 @@ +package message + +import ( + "crypto/rand" + "crypto/tls" + "encoding/base64" + "fmt" + "net" + "net/smtp" + "strings" + "time" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" +) + +func shouldAuth() bool { + return config.SMTPAccount != "" || config.SMTPToken != "" +} + +func SendEmail(subject string, receiver string, content string) error { + if receiver == "" { + return fmt.Errorf("receiver is empty") + } + if config.SMTPFrom == "" { // for compatibility + config.SMTPFrom = config.SMTPAccount + } + encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) + + // Extract domain from SMTPFrom + parts := strings.Split(config.SMTPFrom, "@") + var domain string + if len(parts) > 1 { + domain = parts[1] + } + // Generate a unique Message-ID + buf := make([]byte, 16) + _, err := rand.Read(buf) + if err != nil { + return err + } + messageId := fmt.Sprintf("<%x@%s>", buf, domain) + + mail := []byte(fmt.Sprintf("To: %s\r\n"+ + "From: %s<%s>\r\n"+ + "Subject: %s\r\n"+ + "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 + "Date: %s\r\n"+ + "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", + receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) + + auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) + addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) + to := strings.Split(receiver, ";") + + if config.SMTPPort == 465 || !shouldAuth() { + // need advanced client + var conn net.Conn + var err error + if config.SMTPPort == 465 { + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + ServerName: config.SMTPServer, + } + conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) + } else { + conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)) + } + if err != nil { + return err + } + client, err := smtp.NewClient(conn, config.SMTPServer) + if err != nil { + return err + } + defer client.Close() + if shouldAuth() { + if err = client.Auth(auth); err != nil { + return err + } + } + if err = client.Mail(config.SMTPFrom); err != nil { + return err + } + receiverEmails := strings.Split(receiver, ";") + for _, receiver := range receiverEmails { + if err = client.Rcpt(receiver); err != nil { + return err + } + } + w, err := client.Data() + if err != nil { + return err + } + _, err = w.Write(mail) + if err != nil { + return err + } + err = w.Close() + if err != nil { + return err + } + return nil + } + err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail) + if err != nil && strings.Contains(err.Error(), "short response") { // 部分提供商返回该错误,但实际上邮件已经发送成功 + logger.SysWarnf("short response from SMTP server, return nil instead of error: %s", err.Error()) + return nil + } + return err +} diff --git a/common/message/main.go b/common/message/main.go new file mode 100644 index 0000000..5ce82a6 --- /dev/null +++ b/common/message/main.go @@ -0,0 +1,22 @@ +package message + +import ( + "fmt" + "github.com/songquanpeng/one-api/common/config" +) + +const ( + ByAll = "all" + ByEmail = "email" + ByMessagePusher = "message_pusher" +) + +func Notify(by string, title string, description string, content string) error { + if by == ByEmail { + return SendEmail(title, config.RootUserEmail, content) + } + if by == ByMessagePusher { + return SendMessage(title, description, content) + } + return fmt.Errorf("unknown notify method: %s", by) +} diff --git a/common/message/message-pusher.go b/common/message/message-pusher.go new file mode 100644 index 0000000..69949b4 --- /dev/null +++ b/common/message/message-pusher.go @@ -0,0 +1,53 @@ +package message + +import ( + "bytes" + "encoding/json" + "errors" + "github.com/songquanpeng/one-api/common/config" + "net/http" +) + +type request struct { + Title string `json:"title"` + Description string `json:"description"` + Content string `json:"content"` + URL string `json:"url"` + Channel string `json:"channel"` + Token string `json:"token"` +} + +type response struct { + Success bool `json:"success"` + Message string `json:"message"` +} + +func SendMessage(title string, description string, content string) error { + if config.MessagePusherAddress == "" { + return errors.New("message pusher address is not set") + } + req := request{ + Title: title, + Description: description, + Content: content, + Token: config.MessagePusherToken, + } + data, err := json.Marshal(req) + if err != nil { + return err + } + resp, err := http.Post(config.MessagePusherAddress, + "application/json", bytes.NewBuffer(data)) + if err != nil { + return err + } + var res response + err = json.NewDecoder(resp.Body).Decode(&res) + if err != nil { + return err + } + if !res.Success { + return errors.New(res.Message) + } + return nil +} diff --git a/common/message/template.go b/common/message/template.go new file mode 100644 index 0000000..5573372 --- /dev/null +++ b/common/message/template.go @@ -0,0 +1,34 @@ +package message + +import ( + "fmt" + + "github.com/songquanpeng/one-api/common/config" +) + +// EmailTemplate 生成美观的 HTML 邮件内容 +func EmailTemplate(title, content string) string { + return fmt.Sprintf(` + + + + + + + +
+
+

%s

+
+
+ %s +
+
+

此邮件由系统自动发送,请勿直接回复

+

%s

+
+
+ + +`, title, content, config.SystemName) +} diff --git a/common/network/ip.go b/common/network/ip.go new file mode 100644 index 0000000..0fbe5e6 --- /dev/null +++ b/common/network/ip.go @@ -0,0 +1,52 @@ +package network + +import ( + "context" + "fmt" + "github.com/songquanpeng/one-api/common/logger" + "net" + "strings" +) + +func splitSubnets(subnets string) []string { + res := strings.Split(subnets, ",") + for i := 0; i < len(res); i++ { + res[i] = strings.TrimSpace(res[i]) + } + return res +} + +func isValidSubnet(subnet string) error { + _, _, err := net.ParseCIDR(subnet) + if err != nil { + return fmt.Errorf("failed to parse subnet: %w", err) + } + return nil +} + +func isIpInSubnet(ctx context.Context, ip string, subnet string) bool { + _, ipNet, err := net.ParseCIDR(subnet) + if err != nil { + logger.Errorf(ctx, "failed to parse subnet: %s", err.Error()) + return false + } + return ipNet.Contains(net.ParseIP(ip)) +} + +func IsValidSubnets(subnets string) error { + for _, subnet := range splitSubnets(subnets) { + if err := isValidSubnet(subnet); err != nil { + return err + } + } + return nil +} + +func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool { + for _, subnet := range splitSubnets(subnets) { + if isIpInSubnet(ctx, ip, subnet) { + return true + } + } + return false +} diff --git a/common/network/ip_test.go b/common/network/ip_test.go new file mode 100644 index 0000000..6c59345 --- /dev/null +++ b/common/network/ip_test.go @@ -0,0 +1,19 @@ +package network + +import ( + "context" + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestIsIpInSubnet(t *testing.T) { + ctx := context.Background() + ip1 := "192.168.0.5" + ip2 := "125.216.250.89" + subnet := "192.168.0.0/24" + Convey("TestIsIpInSubnet", t, func() { + So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue) + So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse) + }) +} diff --git a/common/random/main.go b/common/random/main.go new file mode 100644 index 0000000..dbb772c --- /dev/null +++ b/common/random/main.go @@ -0,0 +1,61 @@ +package random + +import ( + "github.com/google/uuid" + "math/rand" + "strings" + "time" +) + +func GetUUID() string { + code := uuid.New().String() + code = strings.Replace(code, "-", "", -1) + return code +} + +const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +const keyNumbers = "0123456789" + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func GenerateKey() string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, 48) + for i := 0; i < 16; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + uuid_ := GetUUID() + for i := 0; i < 32; i++ { + c := uuid_[i] + if i%2 == 0 && c >= 'a' && c <= 'z' { + c = c - 'a' + 'A' + } + key[i+16] = c + } + return string(key) +} + +func GetRandomString(length int) string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + return string(key) +} + +func GetRandomNumberString(length int) string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyNumbers[rand.Intn(len(keyNumbers))] + } + return string(key) +} + +// RandRange returns a random number between min and max (max is not included) +func RandRange(min, max int) int { + return min + rand.Intn(max-min) +} diff --git a/common/rate-limit.go b/common/rate-limit.go new file mode 100644 index 0000000..301c101 --- /dev/null +++ b/common/rate-limit.go @@ -0,0 +1,70 @@ +package common + +import ( + "sync" + "time" +) + +type InMemoryRateLimiter struct { + store map[string]*[]int64 + mutex sync.Mutex + expirationDuration time.Duration +} + +func (l *InMemoryRateLimiter) Init(expirationDuration time.Duration) { + if l.store == nil { + l.mutex.Lock() + if l.store == nil { + l.store = make(map[string]*[]int64) + l.expirationDuration = expirationDuration + if expirationDuration > 0 { + go l.clearExpiredItems() + } + } + l.mutex.Unlock() + } +} + +func (l *InMemoryRateLimiter) clearExpiredItems() { + for { + time.Sleep(l.expirationDuration) + l.mutex.Lock() + now := time.Now().Unix() + for key := range l.store { + queue := l.store[key] + size := len(*queue) + if size == 0 || now-(*queue)[size-1] > int64(l.expirationDuration.Seconds()) { + delete(l.store, key) + } + } + l.mutex.Unlock() + } +} + +// Request parameter duration's unit is seconds +func (l *InMemoryRateLimiter) Request(key string, maxRequestNum int, duration int64) bool { + l.mutex.Lock() + defer l.mutex.Unlock() + // [old <-- new] + queue, ok := l.store[key] + now := time.Now().Unix() + if ok { + if len(*queue) < maxRequestNum { + *queue = append(*queue, now) + return true + } else { + if now-(*queue)[0] >= duration { + *queue = (*queue)[1:] + *queue = append(*queue, now) + return true + } else { + return false + } + } + } else { + s := make([]int64, 0, maxRequestNum) + l.store[key] = &s + *(l.store[key]) = append(*(l.store[key]), now) + } + return true +} diff --git a/common/redis.go b/common/redis.go new file mode 100644 index 0000000..55d4931 --- /dev/null +++ b/common/redis.go @@ -0,0 +1,81 @@ +package common + +import ( + "context" + "os" + "strings" + "time" + + "github.com/go-redis/redis/v8" + "github.com/songquanpeng/one-api/common/logger" +) + +var RDB redis.Cmdable +var RedisEnabled = true + +// InitRedisClient This function is called after init() +func InitRedisClient() (err error) { + if os.Getenv("REDIS_CONN_STRING") == "" { + RedisEnabled = false + logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled") + return nil + } + if os.Getenv("SYNC_FREQUENCY") == "" { + RedisEnabled = false + logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") + return nil + } + redisConnString := os.Getenv("REDIS_CONN_STRING") + if os.Getenv("REDIS_MASTER_NAME") == "" { + logger.SysLog("Redis is enabled") + opt, err := redis.ParseURL(redisConnString) + if err != nil { + logger.FatalLog("failed to parse Redis connection string: " + err.Error()) + } + RDB = redis.NewClient(opt) + } else { + // cluster mode + logger.SysLog("Redis cluster mode enabled") + RDB = redis.NewUniversalClient(&redis.UniversalOptions{ + Addrs: strings.Split(redisConnString, ","), + Password: os.Getenv("REDIS_PASSWORD"), + MasterName: os.Getenv("REDIS_MASTER_NAME"), + }) + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err = RDB.Ping(ctx).Result() + if err != nil { + logger.FatalLog("Redis ping test failed: " + err.Error()) + } + return err +} + +func ParseRedisOption() *redis.Options { + opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) + if err != nil { + logger.FatalLog("failed to parse Redis connection string: " + err.Error()) + } + return opt +} + +func RedisSet(key string, value string, expiration time.Duration) error { + ctx := context.Background() + return RDB.Set(ctx, key, value, expiration).Err() +} + +func RedisGet(key string) (string, error) { + ctx := context.Background() + return RDB.Get(ctx, key).Result() +} + +func RedisDel(key string) error { + ctx := context.Background() + return RDB.Del(ctx, key).Err() +} + +func RedisDecrease(key string, value int64) error { + ctx := context.Background() + return RDB.DecrBy(ctx, key, value).Err() +} diff --git a/common/render/render.go b/common/render/render.go new file mode 100644 index 0000000..e565c0b --- /dev/null +++ b/common/render/render.go @@ -0,0 +1,30 @@ +package render + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" +) + +func StringData(c *gin.Context, str string) { + str = strings.TrimPrefix(str, "data: ") + str = strings.TrimSuffix(str, "\r") + c.Render(-1, common.CustomEvent{Data: "data: " + str}) + c.Writer.Flush() +} + +func ObjectData(c *gin.Context, object interface{}) error { + jsonData, err := json.Marshal(object) + if err != nil { + return fmt.Errorf("error marshalling object: %w", err) + } + StringData(c, string(jsonData)) + return nil +} + +func Done(c *gin.Context) { + StringData(c, "[DONE]") +} diff --git a/common/utils.go b/common/utils.go new file mode 100644 index 0000000..ecee2c8 --- /dev/null +++ b/common/utils.go @@ -0,0 +1,14 @@ +package common + +import ( + "fmt" + "github.com/songquanpeng/one-api/common/config" +) + +func LogQuota(quota int64) string { + if config.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) + } else { + return fmt.Sprintf("%d 点额度", quota) + } +} diff --git a/common/utils/array.go b/common/utils/array.go new file mode 100644 index 0000000..72437d4 --- /dev/null +++ b/common/utils/array.go @@ -0,0 +1,13 @@ +package utils + +func DeDuplication(slice []string) []string { + m := make(map[string]bool) + for _, v := range slice { + m[v] = true + } + result := make([]string, 0, len(m)) + for v := range m { + result = append(result, v) + } + return result +} diff --git a/common/validate.go b/common/validate.go new file mode 100644 index 0000000..b3c7859 --- /dev/null +++ b/common/validate.go @@ -0,0 +1,9 @@ +package common + +import "github.com/go-playground/validator/v10" + +var Validate *validator.Validate + +func init() { + Validate = validator.New() +} diff --git a/common/verification.go b/common/verification.go new file mode 100644 index 0000000..d8ccd6e --- /dev/null +++ b/common/verification.go @@ -0,0 +1,77 @@ +package common + +import ( + "github.com/google/uuid" + "strings" + "sync" + "time" +) + +type verificationValue struct { + code string + time time.Time +} + +const ( + EmailVerificationPurpose = "v" + PasswordResetPurpose = "r" +) + +var verificationMutex sync.Mutex +var verificationMap map[string]verificationValue +var verificationMapMaxSize = 10 +var VerificationValidMinutes = 10 + +func GenerateVerificationCode(length int) string { + code := uuid.New().String() + code = strings.Replace(code, "-", "", -1) + if length == 0 { + return code + } + return code[:length] +} + +func RegisterVerificationCodeWithKey(key string, code string, purpose string) { + verificationMutex.Lock() + defer verificationMutex.Unlock() + verificationMap[purpose+key] = verificationValue{ + code: code, + time: time.Now(), + } + if len(verificationMap) > verificationMapMaxSize { + removeExpiredPairs() + } +} + +func VerifyCodeWithKey(key string, code string, purpose string) bool { + verificationMutex.Lock() + defer verificationMutex.Unlock() + value, okay := verificationMap[purpose+key] + now := time.Now() + if !okay || int(now.Sub(value.time).Seconds()) >= VerificationValidMinutes*60 { + return false + } + return code == value.code +} + +func DeleteKey(key string, purpose string) { + verificationMutex.Lock() + defer verificationMutex.Unlock() + delete(verificationMap, purpose+key) +} + +// no lock inside, so the caller must lock the verificationMap before calling! +func removeExpiredPairs() { + now := time.Now() + for key := range verificationMap { + if int(now.Sub(verificationMap[key].time).Seconds()) >= VerificationValidMinutes*60 { + delete(verificationMap, key) + } + } +} + +func init() { + verificationMutex.Lock() + defer verificationMutex.Unlock() + verificationMap = make(map[string]verificationValue) +} diff --git a/controller/auth/github.go b/controller/auth/github.go new file mode 100644 index 0000000..ecdd183 --- /dev/null +++ b/controller/auth/github.go @@ -0,0 +1,240 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" +) + +type GitHubOAuthResponse struct { + AccessToken string `json:"access_token"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` +} + +type GitHubUser struct { + Login string `json:"login"` + Name string `json:"name"` + Email string `json:"email"` +} + +func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code} + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") + } + defer res.Body.Close() + var oAuthResponse GitHubOAuthResponse + err = json.NewDecoder(res.Body).Decode(&oAuthResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", "https://api.github.com/user", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") + } + defer res2.Body.Close() + var githubUser GitHubUser + err = json.NewDecoder(res2.Body).Decode(&githubUser) + if err != nil { + return nil, err + } + if githubUser.Login == "" { + return nil, errors.New("返回值非法,用户字段为空,请稍后重试!") + } + return &githubUser, nil +} + +func GitHubOAuth(c *gin.Context) { + ctx := c.Request.Context() + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + GitHubBind(c) + return + } + + if !config.GitHubOAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 GitHub 登录以及注册", + }) + return + } + code := c.Query("code") + githubUser, err := getGitHubUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + GitHubId: githubUser.Login, + } + if model.IsGitHubIdAlreadyTaken(user.GitHubId) { + err := user.FillUserByGitHubId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) + if githubUser.Name != "" { + user.DisplayName = githubUser.Name + } else { + user.DisplayName = "GitHub User" + } + user.Email = githubUser.Email + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled + + if err := user.Insert(ctx, 0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != model.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func GitHubBind(c *gin.Context) { + if !config.GitHubOAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 GitHub 登录以及注册", + }) + return + } + code := c.Query("code") + githubUser, err := getGitHubUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + GitHubId: githubUser.Login, + } + if model.IsGitHubIdAlreadyTaken(user.GitHubId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 GitHub 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.GitHubId = githubUser.Login + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} + +func GenerateOAuthCode(c *gin.Context) { + session := sessions.Default(c) + state := random.GetRandomString(12) + session.Set("oauth_state", state) + err := session.Save() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": state, + }) +} diff --git a/controller/auth/lark.go b/controller/auth/lark.go new file mode 100644 index 0000000..651d587 --- /dev/null +++ b/controller/auth/lark.go @@ -0,0 +1,203 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" +) + +type LarkOAuthResponse struct { + AccessToken string `json:"access_token"` +} + +type LarkUser struct { + Name string `json:"name"` + OpenID string `json:"open_id"` +} + +func getLarkUserInfoByCode(code string) (*LarkUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{ + "client_id": config.LarkClientId, + "client_secret": config.LarkClientSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress), + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至飞书服务器,请稍后重试!") + } + defer res.Body.Close() + var oAuthResponse LarkOAuthResponse + err = json.NewDecoder(res.Body).Decode(&oAuthResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至飞书服务器,请稍后重试!") + } + var larkUser LarkUser + err = json.NewDecoder(res2.Body).Decode(&larkUser) + if err != nil { + return nil, err + } + return &larkUser, nil +} + +func LarkOAuth(c *gin.Context) { + ctx := c.Request.Context() + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + LarkBind(c) + return + } + code := c.Query("code") + larkUser, err := getLarkUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + LarkId: larkUser.OpenID, + } + if model.IsLarkIdAlreadyTaken(user.LarkId) { + err := user.FillUserByLarkId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1) + if larkUser.Name != "" { + user.DisplayName = larkUser.Name + } else { + user.DisplayName = "Lark User" + } + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled + + if err := user.Insert(ctx, 0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != model.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func LarkBind(c *gin.Context) { + code := c.Query("code") + larkUser, err := getLarkUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + LarkId: larkUser.OpenID, + } + if model.IsLarkIdAlreadyTaken(user.LarkId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该飞书账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.LarkId = larkUser.OpenID + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/auth/oidc.go b/controller/auth/oidc.go new file mode 100644 index 0000000..1c4eedb --- /dev/null +++ b/controller/auth/oidc.go @@ -0,0 +1,228 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" +) + +type OidcResponse struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type OidcUser struct { + OpenID string `json:"sub"` + Email string `json:"email"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + Picture string `json:"picture"` +} + +func getOidcUserInfoByCode(code string) (*OidcUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{ + "client_id": config.OidcClientId, + "client_secret": config.OidcClientSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress), + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + defer res.Body.Close() + var oidcResponse OidcResponse + err = json.NewDecoder(res.Body).Decode(&oidcResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + var oidcUser OidcUser + err = json.NewDecoder(res2.Body).Decode(&oidcUser) + if err != nil { + return nil, err + } + return &oidcUser, nil +} + +func OidcAuth(c *gin.Context) { + ctx := c.Request.Context() + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + OidcBind(c) + return + } + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + err := user.FillUserByOidcId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Email = oidcUser.Email + if oidcUser.PreferredUsername != "" { + user.Username = oidcUser.PreferredUsername + } else { + user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1) + } + if oidcUser.Name != "" { + user.DisplayName = oidcUser.Name + } else { + user.DisplayName = "OIDC User" + } + err := user.Insert(ctx, 0) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != model.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func OidcBind(c *gin.Context) { + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 OIDC 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.OidcId = oidcUser.OpenID + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/auth/wechat.go b/controller/auth/wechat.go new file mode 100644 index 0000000..9c30b8f --- /dev/null +++ b/controller/auth/wechat.go @@ -0,0 +1,169 @@ +package auth + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" +) + +type wechatLoginResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + Data string `json:"data"` +} + +func getWeChatIdByCode(code string) (string, error) { + if code == "" { + return "", errors.New("无效的参数") + } + req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil) + if err != nil { + return "", err + } + req.Header.Set("Authorization", config.WeChatServerToken) + client := http.Client{ + Timeout: 5 * time.Second, + } + httpResponse, err := client.Do(req) + if err != nil { + return "", err + } + defer httpResponse.Body.Close() + var res wechatLoginResponse + err = json.NewDecoder(httpResponse.Body).Decode(&res) + if err != nil { + return "", err + } + if !res.Success { + return "", errors.New(res.Message) + } + if res.Data == "" { + return "", errors.New("验证码错误或已过期") + } + return res.Data, nil +} + +func WeChatAuth(c *gin.Context) { + ctx := c.Request.Context() + if !config.WeChatAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "管理员未开启通过微信登录以及注册", + "success": false, + }) + return + } + code := c.Query("code") + wechatId, err := getWeChatIdByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": err.Error(), + "success": false, + }) + return + } + user := model.User{ + WeChatId: wechatId, + } + if model.IsWeChatIdAlreadyTaken(wechatId) { + err := user.FillUserByWeChatId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) + user.DisplayName = "WeChat User" + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled + + if err := user.Insert(ctx, 0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != model.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func WeChatBind(c *gin.Context) { + if !config.WeChatAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "管理员未开启通过微信登录以及注册", + "success": false, + }) + return + } + code := c.Query("code") + wechatId, err := getWeChatIdByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": err.Error(), + "success": false, + }) + return + } + if model.IsWeChatIdAlreadyTaken(wechatId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该微信账号已被绑定", + }) + return + } + id := c.GetInt(ctxkey.Id) + user := model.User{ + Id: id, + } + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.WeChatId = wechatId + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/controller/billing.go b/controller/billing.go new file mode 100644 index 0000000..e837157 --- /dev/null +++ b/controller/billing.go @@ -0,0 +1,97 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/model" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +func GetSubscription(c *gin.Context) { + var remainQuota int64 + var usedQuota int64 + var err error + var token *model.Token + var expiredTime int64 + if config.DisplayTokenStatEnabled { + tokenId := c.GetInt(ctxkey.TokenId) + token, err = model.GetTokenById(tokenId) + if err == nil { + expiredTime = token.ExpiredTime + remainQuota = token.RemainQuota + usedQuota = token.UsedQuota + } + } else { + userId := c.GetInt(ctxkey.Id) + remainQuota, err = model.GetUserQuota(userId) + if err != nil { + usedQuota, err = model.GetUserUsedQuota(userId) + } + } + if expiredTime <= 0 { + expiredTime = 0 + } + if err != nil { + Error := relaymodel.Error{ + Message: err.Error(), + Type: "upstream_error", + } + c.JSON(200, gin.H{ + "error": Error, + }) + return + } + quota := remainQuota + usedQuota + amount := float64(quota) + if config.DisplayInCurrencyEnabled { + amount /= config.QuotaPerUnit + } + if token != nil && token.UnlimitedQuota { + amount = 100000000 + } + subscription := OpenAISubscriptionResponse{ + Object: "billing_subscription", + HasPaymentMethod: true, + SoftLimitUSD: amount, + HardLimitUSD: amount, + SystemHardLimitUSD: amount, + AccessUntil: expiredTime, + } + c.JSON(200, subscription) + return +} + +func GetUsage(c *gin.Context) { + var quota int64 + var err error + var token *model.Token + if config.DisplayTokenStatEnabled { + tokenId := c.GetInt(ctxkey.TokenId) + token, err = model.GetTokenById(tokenId) + quota = token.UsedQuota + } else { + userId := c.GetInt(ctxkey.Id) + quota, err = model.GetUserUsedQuota(userId) + } + if err != nil { + Error := relaymodel.Error{ + Message: err.Error(), + Type: "one_api_error", + } + c.JSON(200, gin.H{ + "error": Error, + }) + return + } + amount := float64(quota) + if config.DisplayInCurrencyEnabled { + amount /= config.QuotaPerUnit + } + usage := OpenAIUsageResponse{ + Object: "list", + TotalUsage: amount * 100, + } + c.JSON(200, usage) + return +} diff --git a/controller/channel-billing.go b/controller/channel-billing.go new file mode 100644 index 0000000..9f7ca18 --- /dev/null +++ b/controller/channel-billing.go @@ -0,0 +1,459 @@ +package controller + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "time" + + "github.com/songquanpeng/one-api/common/client" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" + "github.com/songquanpeng/one-api/relay/channeltype" + + "github.com/gin-gonic/gin" +) + +// https://github.com/songquanpeng/one-api/issues/79 + +type OpenAISubscriptionResponse struct { + Object string `json:"object"` + HasPaymentMethod bool `json:"has_payment_method"` + SoftLimitUSD float64 `json:"soft_limit_usd"` + HardLimitUSD float64 `json:"hard_limit_usd"` + SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` + AccessUntil int64 `json:"access_until"` +} + +type OpenAIUsageDailyCost struct { + Timestamp float64 `json:"timestamp"` + LineItems []struct { + Name string `json:"name"` + Cost float64 `json:"cost"` + } +} + +type OpenAICreditGrants struct { + Object string `json:"object"` + TotalGranted float64 `json:"total_granted"` + TotalUsed float64 `json:"total_used"` + TotalAvailable float64 `json:"total_available"` +} + +type OpenAIUsageResponse struct { + Object string `json:"object"` + //DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"` + TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar +} + +type OpenAISBUsageResponse struct { + Msg string `json:"msg"` + Data *struct { + Credit string `json:"credit"` + } `json:"data"` +} + +type AIProxyUserOverviewResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + ErrorCode int `json:"error_code"` + Data struct { + TotalPoints float64 `json:"totalPoints"` + } `json:"data"` +} + +type API2GPTUsageResponse struct { + Object string `json:"object"` + TotalGranted float64 `json:"total_granted"` + TotalUsed float64 `json:"total_used"` + TotalRemaining float64 `json:"total_remaining"` +} + +type APGC2DGPTUsageResponse struct { + //Grants interface{} `json:"grants"` + Object string `json:"object"` + TotalAvailable float64 `json:"total_available"` + TotalGranted float64 `json:"total_granted"` + TotalUsed float64 `json:"total_used"` +} + +type SiliconFlowUsageResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Status bool `json:"status"` + Data struct { + ID string `json:"id"` + Name string `json:"name"` + Image string `json:"image"` + Email string `json:"email"` + IsAdmin bool `json:"isAdmin"` + Balance string `json:"balance"` + Status string `json:"status"` + Introduction string `json:"introduction"` + Role string `json:"role"` + ChargeBalance string `json:"chargeBalance"` + TotalBalance string `json:"totalBalance"` + Category string `json:"category"` + } `json:"data"` +} + +type DeepSeekUsageResponse struct { + IsAvailable bool `json:"is_available"` + BalanceInfos []struct { + Currency string `json:"currency"` + TotalBalance string `json:"total_balance"` + GrantedBalance string `json:"granted_balance"` + ToppedUpBalance string `json:"topped_up_balance"` + } `json:"balance_infos"` +} + +type OpenRouterResponse struct { + Data struct { + TotalCredits float64 `json:"total_credits"` + TotalUsage float64 `json:"total_usage"` + } `json:"data"` +} + +// GetAuthHeader get auth header +func GetAuthHeader(token string) http.Header { + h := http.Header{} + h.Add("Authorization", fmt.Sprintf("Bearer %s", token)) + return h +} + +func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { + req, err := http.NewRequest(method, url, nil) + if err != nil { + return nil, err + } + for k := range headers { + req.Header.Add(k, headers.Get(k)) + } + res, err := client.HTTPClient.Do(req) + if err != nil { + return nil, err + } + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("status code: %d", res.StatusCode) + } + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + err = res.Body.Close() + if err != nil { + return nil, err + } + return body, nil +} + +func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { + url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + + if err != nil { + return 0, err + } + response := OpenAICreditGrants{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + channel.UpdateBalance(response.TotalAvailable) + return response.TotalAvailable, nil +} + +func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) { + url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key) + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := OpenAISBUsageResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + if response.Data == nil { + return 0, errors.New(response.Msg) + } + balance, err := strconv.ParseFloat(response.Data.Credit, 64) + if err != nil { + return 0, err + } + channel.UpdateBalance(balance) + return balance, nil +} + +func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) { + url := "https://aiproxy.io/api/report/getUserOverview" + headers := http.Header{} + headers.Add("Api-Key", channel.Key) + body, err := GetResponseBody("GET", url, channel, headers) + if err != nil { + return 0, err + } + response := AIProxyUserOverviewResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + if !response.Success { + return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) + } + channel.UpdateBalance(response.Data.TotalPoints) + return response.Data.TotalPoints, nil +} + +func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) { + url := "https://api.api2gpt.com/dashboard/billing/credit_grants" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + + if err != nil { + return 0, err + } + response := API2GPTUsageResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + channel.UpdateBalance(response.TotalRemaining) + return response.TotalRemaining, nil +} + +func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { + url := "https://api.aigc2d.com/dashboard/billing/credit_grants" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := APGC2DGPTUsageResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + channel.UpdateBalance(response.TotalAvailable) + return response.TotalAvailable, nil +} + +func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) { + url := "https://api.siliconflow.cn/v1/user/info" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := SiliconFlowUsageResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + if response.Code != 20000 { + return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message) + } + balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64) + if err != nil { + return 0, err + } + channel.UpdateBalance(balance) + return balance, nil +} + +func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) { + url := "https://api.deepseek.com/user/balance" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := DeepSeekUsageResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + index := -1 + for i, balanceInfo := range response.BalanceInfos { + if balanceInfo.Currency == "CNY" { + index = i + break + } + } + if index == -1 { + return 0, errors.New("currency CNY not found") + } + balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64) + if err != nil { + return 0, err + } + channel.UpdateBalance(balance) + return balance, nil +} + +func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) { + url := "https://openrouter.ai/api/v1/credits" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := OpenRouterResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + balance := response.Data.TotalCredits - response.Data.TotalUsage + channel.UpdateBalance(balance) + return balance, nil +} + +func updateChannelBalance(channel *model.Channel) (float64, error) { + baseURL := channeltype.ChannelBaseURLs[channel.Type] + if channel.GetBaseURL() == "" { + channel.BaseURL = &baseURL + } + switch channel.Type { + case channeltype.OpenAI: + if channel.GetBaseURL() != "" { + baseURL = channel.GetBaseURL() + } + case channeltype.Azure: + return 0, errors.New("尚未实现") + case channeltype.Custom: + baseURL = channel.GetBaseURL() + case channeltype.CloseAI: + return updateChannelCloseAIBalance(channel) + case channeltype.OpenAISB: + return updateChannelOpenAISBBalance(channel) + case channeltype.AIProxy: + return updateChannelAIProxyBalance(channel) + case channeltype.API2GPT: + return updateChannelAPI2GPTBalance(channel) + case channeltype.AIGC2D: + return updateChannelAIGC2DBalance(channel) + case channeltype.SiliconFlow: + return updateChannelSiliconFlowBalance(channel) + case channeltype.DeepSeek: + return updateChannelDeepSeekBalance(channel) + case channeltype.OpenRouter: + return updateChannelOpenRouterBalance(channel) + default: + return 0, errors.New("尚未实现") + } + url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) + + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + subscription := OpenAISubscriptionResponse{} + err = json.Unmarshal(body, &subscription) + if err != nil { + return 0, err + } + now := time.Now() + startDate := fmt.Sprintf("%s-01", now.Format("2006-01")) + endDate := now.Format("2006-01-02") + if !subscription.HasPaymentMethod { + startDate = now.AddDate(0, 0, -100).Format("2006-01-02") + } + url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate) + body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + usage := OpenAIUsageResponse{} + err = json.Unmarshal(body, &usage) + if err != nil { + return 0, err + } + balance := subscription.HardLimitUSD - usage.TotalUsage/100 + channel.UpdateBalance(balance) + return balance, nil +} + +func UpdateChannelBalance(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel, err := model.GetChannelById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + balance, err := updateChannelBalance(channel) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "balance": balance, + }) + return +} + +func updateAllChannelsBalance() error { + channels, err := model.GetAllChannels(0, 0, "all") + if err != nil { + return err + } + for _, channel := range channels { + if channel.Status != model.ChannelStatusEnabled { + continue + } + // TODO: support Azure + if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom { + continue + } + balance, err := updateChannelBalance(channel) + if err != nil { + continue + } else { + // err is nil & balance <= 0 means quota is used up + if balance <= 0 { + monitor.DisableChannel(channel.Id, channel.Name, "余额不足") + } + } + time.Sleep(config.RequestInterval) + } + return nil +} + +func UpdateAllChannelsBalance(c *gin.Context) { + //err := updateAllChannelsBalance() + //if err != nil { + // c.JSON(http.StatusOK, gin.H{ + // "success": false, + // "message": err.Error(), + // }) + // return + //} + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func AutomaticallyUpdateChannels(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Minute) + logger.SysLog("updating all channels") + _ = updateAllChannelsBalance() + logger.SysLog("channels update done") + } +} diff --git a/controller/channel-test.go b/controller/channel-test.go new file mode 100644 index 0000000..3894691 --- /dev/null +++ b/controller/channel-test.go @@ -0,0 +1,305 @@ +package controller + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" + "github.com/songquanpeng/one-api/middleware" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" + "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/controller" + "github.com/songquanpeng/one-api/relay/meta" + relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest { + if model == "" { + model = "gpt-3.5-turbo" + } + testRequest := &relaymodel.GeneralOpenAIRequest{ + Model: model, + } + testMessage := relaymodel.Message{ + Role: "user", + Content: config.TestPrompt, + } + testRequest.Messages = append(testRequest.Messages, testMessage) + return testRequest +} + +func parseTestResponse(resp string) (*openai.TextResponse, string, error) { + var response openai.TextResponse + err := json.Unmarshal([]byte(resp), &response) + if err != nil { + return nil, "", err + } + if len(response.Choices) == 0 { + return nil, "", errors.New("response has no choices") + } + stringContent, ok := response.Choices[0].Content.(string) + if !ok { + return nil, "", errors.New("response content is not string") + } + return &response, stringContent, nil +} + +func testChannel(ctx context.Context, channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (responseMessage string, err error, openaiErr *relaymodel.Error) { + startTime := time.Now() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = &http.Request{ + Method: "POST", + URL: &url.URL{Path: "/v1/chat/completions"}, + Body: nil, + Header: make(http.Header), + } + c.Request.Header.Set("Authorization", "Bearer "+channel.Key) + c.Request.Header.Set("Content-Type", "application/json") + c.Set(ctxkey.Channel, channel.Type) + c.Set(ctxkey.BaseURL, channel.GetBaseURL()) + cfg, _ := channel.LoadConfig() + c.Set(ctxkey.Config, cfg) + middleware.SetupContextForSelectedChannel(c, channel, "") + meta := meta.GetByContext(c) + apiType := channeltype.ToAPIType(channel.Type) + adaptor := relay.GetAdaptor(apiType) + if adaptor == nil { + return "", fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil + } + adaptor.Init(meta) + modelName := request.Model + modelMap := channel.GetModelMapping() + if modelName == "" || !strings.Contains(channel.Models, modelName) { + modelNames := strings.Split(channel.Models, ",") + if len(modelNames) > 0 { + modelName = modelNames[0] + } + } + if modelMap != nil && modelMap[modelName] != "" { + modelName = modelMap[modelName] + } + meta.OriginModelName, meta.ActualModelName = request.Model, modelName + request.Model = modelName + convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) + if err != nil { + return "", err, nil + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return "", err, nil + } + defer func() { + logContent := fmt.Sprintf("渠道 %s 测试成功,响应:%s", channel.Name, responseMessage) + if err != nil || openaiErr != nil { + errorMessage := "" + if err != nil { + errorMessage = err.Error() + } else { + errorMessage = openaiErr.Message + } + logContent = fmt.Sprintf("渠道 %s 测试失败,错误:%s", channel.Name, errorMessage) + } + go model.RecordTestLog(ctx, &model.Log{ + ChannelId: channel.Id, + ModelName: modelName, + Content: logContent, + ElapsedTime: helper.CalcElapsedTime(startTime), + }) + }() + logger.SysLog(string(jsonData)) + requestBody := bytes.NewBuffer(jsonData) + c.Request.Body = io.NopCloser(requestBody) + resp, err := adaptor.DoRequest(c, meta, requestBody) + if err != nil { + return "", err, nil + } + if resp != nil && resp.StatusCode != http.StatusOK { + err := controller.RelayErrorHandler(resp) + errorMessage := err.Error.Message + if errorMessage != "" { + errorMessage = ", error message: " + errorMessage + } + return "", fmt.Errorf("http status code: %d%s", resp.StatusCode, errorMessage), &err.Error + } + usage, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + return "", fmt.Errorf("%s", respErr.Error.Message), &respErr.Error + } + if usage == nil { + return "", errors.New("usage is nil"), nil + } + rawResponse := w.Body.String() + _, responseMessage, err = parseTestResponse(rawResponse) + if err != nil { + logger.SysError(fmt.Sprintf("failed to parse error: %s, \nresponse: %s", err.Error(), rawResponse)) + return "", err, nil + } + result := w.Result() + // print result.Body + respBody, err := io.ReadAll(result.Body) + if err != nil { + return "", err, nil + } + logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) + return responseMessage, nil, nil +} + +func TestChannel(c *gin.Context) { + ctx := c.Request.Context() + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel, err := model.GetChannelById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + modelName := c.Query("model") + testRequest := buildTestRequest(modelName) + tik := time.Now() + responseMessage, err, _ := testChannel(ctx, channel, testRequest) + tok := time.Now() + milliseconds := tok.Sub(tik).Milliseconds() + if err != nil { + milliseconds = 0 + } + go channel.UpdateResponseTime(milliseconds) + consumedTime := float64(milliseconds) / 1000.0 + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + "time": consumedTime, + "modelName": modelName, + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": responseMessage, + "time": consumedTime, + "modelName": modelName, + }) + return +} + +var testAllChannelsLock sync.Mutex +var testAllChannelsRunning bool = false + +func testChannels(ctx context.Context, notify bool, scope string) error { + if config.RootUserEmail == "" { + config.RootUserEmail = model.GetRootUserEmail() + } + testAllChannelsLock.Lock() + if testAllChannelsRunning { + testAllChannelsLock.Unlock() + return errors.New("测试已在运行中") + } + testAllChannelsRunning = true + testAllChannelsLock.Unlock() + channels, err := model.GetAllChannels(0, 0, scope) + if err != nil { + return err + } + var disableThreshold = int64(config.ChannelDisableThreshold * 1000) + if disableThreshold == 0 { + disableThreshold = 10000000 // a impossible value + } + go func() { + for _, channel := range channels { + isChannelEnabled := channel.Status == model.ChannelStatusEnabled + tik := time.Now() + testRequest := buildTestRequest("") + _, err, openaiErr := testChannel(ctx, channel, testRequest) + tok := time.Now() + milliseconds := tok.Sub(tik).Milliseconds() + if isChannelEnabled && milliseconds > disableThreshold { + err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) + if config.AutomaticDisableChannelEnabled { + monitor.DisableChannel(channel.Id, channel.Name, err.Error()) + } else { + _ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error()) + } + } + if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) { + monitor.DisableChannel(channel.Id, channel.Name, err.Error()) + } + if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) { + monitor.EnableChannel(channel.Id, channel.Name) + } + channel.UpdateResponseTime(milliseconds) + time.Sleep(config.RequestInterval) + } + testAllChannelsLock.Lock() + testAllChannelsRunning = false + testAllChannelsLock.Unlock() + if notify { + err := message.Notify(message.ByAll, "渠道测试完成", "", "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常") + if err != nil { + logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + } + } + }() + return nil +} + +func TestChannels(c *gin.Context) { + ctx := c.Request.Context() + scope := c.Query("scope") + if scope == "" { + scope = "all" + } + err := testChannels(ctx, true, scope) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func AutomaticallyTestChannels(frequency int) { + ctx := context.Background() + for { + time.Sleep(time.Duration(frequency) * time.Minute) + logger.SysLog("testing all channels") + _ = testChannels(ctx, false, "all") + logger.SysLog("channel test finished") + } +} diff --git a/controller/channel.go b/controller/channel.go new file mode 100644 index 0000000..37bfb99 --- /dev/null +++ b/controller/channel.go @@ -0,0 +1,172 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" + "strings" +) + +func GetAllChannels(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, "limited") + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": channels, + }) + return +} + +func SearchChannels(c *gin.Context) { + keyword := c.Query("keyword") + channels, err := model.SearchChannels(keyword) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": channels, + }) + return +} + +func GetChannel(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel, err := model.GetChannelById(id, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": channel, + }) + return +} + +func AddChannel(c *gin.Context) { + channel := model.Channel{} + err := c.ShouldBindJSON(&channel) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel.CreatedTime = helper.GetTimestamp() + keys := strings.Split(channel.Key, "\n") + channels := make([]model.Channel, 0, len(keys)) + for _, key := range keys { + if key == "" { + continue + } + localChannel := channel + localChannel.Key = key + channels = append(channels, localChannel) + } + err = model.BatchInsertChannels(channels) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func DeleteChannel(c *gin.Context) { + id, _ := strconv.Atoi(c.Param("id")) + channel := model.Channel{Id: id} + err := channel.Delete() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func DeleteDisabledChannel(c *gin.Context) { + rows, err := model.DeleteDisabledChannel() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": rows, + }) + return +} + +func UpdateChannel(c *gin.Context) { + channel := model.Channel{} + err := c.ShouldBindJSON(&channel) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + err = channel.Update() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": channel, + }) + return +} diff --git a/controller/group.go b/controller/group.go new file mode 100644 index 0000000..6f02394 --- /dev/null +++ b/controller/group.go @@ -0,0 +1,19 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "net/http" +) + +func GetGroups(c *gin.Context) { + groupNames := make([]string, 0) + for groupName := range billingratio.GroupRatio { + groupNames = append(groupNames, groupName) + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": groupNames, + }) +} diff --git a/controller/log.go b/controller/log.go new file mode 100644 index 0000000..665f49b --- /dev/null +++ b/controller/log.go @@ -0,0 +1,169 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" +) + +func GetAllLogs(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + logType, _ := strconv.Atoi(c.Query("type")) + startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) + endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) + username := c.Query("username") + tokenName := c.Query("token_name") + modelName := c.Query("model_name") + channel, _ := strconv.Atoi(c.Query("channel")) + logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": logs, + }) + return +} + +func GetUserLogs(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + userId := c.GetInt(ctxkey.Id) + logType, _ := strconv.Atoi(c.Query("type")) + startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) + endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) + tokenName := c.Query("token_name") + modelName := c.Query("model_name") + logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": logs, + }) + return +} + +func SearchAllLogs(c *gin.Context) { + keyword := c.Query("keyword") + logs, err := model.SearchAllLogs(keyword) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": logs, + }) + return +} + +func SearchUserLogs(c *gin.Context) { + keyword := c.Query("keyword") + userId := c.GetInt(ctxkey.Id) + logs, err := model.SearchUserLogs(userId, keyword) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": logs, + }) + return +} + +func GetLogsStat(c *gin.Context) { + logType, _ := strconv.Atoi(c.Query("type")) + startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) + endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) + tokenName := c.Query("token_name") + username := c.Query("username") + modelName := c.Query("model_name") + channel, _ := strconv.Atoi(c.Query("channel")) + quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) + //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "quota": quotaNum, + //"token": tokenNum, + }, + }) + return +} + +func GetLogsSelfStat(c *gin.Context) { + username := c.GetString(ctxkey.Username) + logType, _ := strconv.Atoi(c.Query("type")) + startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) + endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) + tokenName := c.Query("token_name") + modelName := c.Query("model_name") + channel, _ := strconv.Atoi(c.Query("channel")) + quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) + //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "quota": quotaNum, + //"token": tokenNum, + }, + }) + return +} + +func DeleteHistoryLogs(c *gin.Context) { + targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64) + if targetTimestamp == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "target timestamp is required", + }) + return + } + count, err := model.DeleteOldLog(targetTimestamp) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": count, + }) + return +} diff --git a/controller/misc.go b/controller/misc.go new file mode 100644 index 0000000..75fec8f --- /dev/null +++ b/controller/misc.go @@ -0,0 +1,232 @@ +package controller + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/i18n" + "github.com/songquanpeng/one-api/common/message" + "github.com/songquanpeng/one-api/model" + + "github.com/gin-gonic/gin" +) + +func GetStatus(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "version": common.Version, + "start_time": common.StartTime, + "email_verification": config.EmailVerificationEnabled, + "github_oauth": config.GitHubOAuthEnabled, + "github_client_id": config.GitHubClientId, + "lark_client_id": config.LarkClientId, + "system_name": config.SystemName, + "logo": config.Logo, + "footer_html": config.Footer, + "wechat_qrcode": config.WeChatAccountQRCodeImageURL, + "wechat_login": config.WeChatAuthEnabled, + "server_address": config.ServerAddress, + "turnstile_check": config.TurnstileCheckEnabled, + "turnstile_site_key": config.TurnstileSiteKey, + "top_up_link": config.TopUpLink, + "chat_link": config.ChatLink, + "quota_per_unit": config.QuotaPerUnit, + "display_in_currency": config.DisplayInCurrencyEnabled, + "oidc": config.OidcEnabled, + "oidc_client_id": config.OidcClientId, + "oidc_well_known": config.OidcWellKnown, + "oidc_authorization_endpoint": config.OidcAuthorizationEndpoint, + "oidc_token_endpoint": config.OidcTokenEndpoint, + "oidc_userinfo_endpoint": config.OidcUserinfoEndpoint, + }, + }) + return +} + +func GetNotice(c *gin.Context) { + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": config.OptionMap["Notice"], + }) + return +} + +func GetAbout(c *gin.Context) { + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": config.OptionMap["About"], + }) + return +} + +func GetHomePageContent(c *gin.Context) { + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": config.OptionMap["HomePageContent"], + }) + return +} + +func SendEmailVerification(c *gin.Context) { + email := c.Query("email") + if err := common.Validate.Var(email, "required,email"); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": i18n.Translate(c, "invalid_parameter"), + }) + return + } + if config.EmailDomainRestrictionEnabled { + allowed := false + for _, domain := range config.EmailDomainWhitelist { + if strings.HasSuffix(email, "@"+domain) { + allowed = true + break + } + } + if !allowed { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中", + }) + return + } + } + if model.IsEmailAlreadyTaken(email) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "邮箱地址已被占用", + }) + return + } + code := common.GenerateVerificationCode(6) + common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) + subject := fmt.Sprintf("%s 邮箱验证邮件", config.SystemName) + content := message.EmailTemplate( + subject, + fmt.Sprintf(` +

您好!

+

您正在进行 %s 邮箱验证。

+

您的验证码为:

+

%s

+

验证码 %d 分钟内有效,如果不是本人操作,请忽略。

+ `, config.SystemName, code, common.VerificationValidMinutes), + ) + err := message.SendEmail(subject, email, content) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func SendPasswordResetEmail(c *gin.Context) { + email := c.Query("email") + if err := common.Validate.Var(email, "required,email"); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": i18n.Translate(c, "invalid_parameter"), + }) + return + } + if !model.IsEmailAlreadyTaken(email) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该邮箱地址未注册", + }) + return + } + code := common.GenerateVerificationCode(0) + common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) + link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code) + subject := fmt.Sprintf("%s 密码重置", config.SystemName) + content := message.EmailTemplate( + subject, + fmt.Sprintf(` +

您好!

+

您正在进行 %s 密码重置。

+

请点击下面的按钮进行密码重置:

+

+ 重置密码 +

+

如果按钮无法点击,请复制以下链接到浏览器中打开:

+

%s

+

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

+ `, config.SystemName, link, link, common.VerificationValidMinutes), + ) + err := message.SendEmail(subject, email, content) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": fmt.Sprintf("%s%s", i18n.Translate(c, "send_email_failed"), err.Error()), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +type PasswordResetRequest struct { + Email string `json:"email"` + Token string `json:"token"` +} + +func ResetPassword(c *gin.Context) { + var req PasswordResetRequest + err := json.NewDecoder(c.Request.Body).Decode(&req) + if req.Email == "" || req.Token == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": i18n.Translate(c, "invalid_parameter"), + }) + return + } + if !common.VerifyCodeWithKey(req.Email, req.Token, common.PasswordResetPurpose) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "重置链接非法或已过期", + }) + return + } + password := common.GenerateVerificationCode(12) + err = model.ResetUserPasswordByEmail(req.Email, password) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + common.DeleteKey(req.Email, common.PasswordResetPurpose) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": password, + }) + return +} diff --git a/controller/model.go b/controller/model.go new file mode 100644 index 0000000..dcbe709 --- /dev/null +++ b/controller/model.go @@ -0,0 +1,213 @@ +package controller + +import ( + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/model" + relay "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/apitype" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" + relaymodel "github.com/songquanpeng/one-api/relay/model" + "net/http" + "strings" +) + +// https://platform.openai.com/docs/api-reference/models/list + +type OpenAIModelPermission struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + AllowCreateEngine bool `json:"allow_create_engine"` + AllowSampling bool `json:"allow_sampling"` + AllowLogprobs bool `json:"allow_logprobs"` + AllowSearchIndices bool `json:"allow_search_indices"` + AllowView bool `json:"allow_view"` + AllowFineTuning bool `json:"allow_fine_tuning"` + Organization string `json:"organization"` + Group *string `json:"group"` + IsBlocking bool `json:"is_blocking"` +} + +type OpenAIModels struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + OwnedBy string `json:"owned_by"` + Permission []OpenAIModelPermission `json:"permission"` + Root string `json:"root"` + Parent *string `json:"parent"` +} + +var models []OpenAIModels +var modelsMap map[string]OpenAIModels +var channelId2Models map[int][]string + +func init() { + var permission []OpenAIModelPermission + permission = append(permission, OpenAIModelPermission{ + Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ", + Object: "model_permission", + Created: 1626777600, + AllowCreateEngine: true, + AllowSampling: true, + AllowLogprobs: true, + AllowSearchIndices: false, + AllowView: true, + AllowFineTuning: false, + Organization: "*", + Group: nil, + IsBlocking: false, + }) + // https://platform.openai.com/docs/models/model-endpoint-compatibility + for i := 0; i < apitype.Dummy; i++ { + if i == apitype.AIProxyLibrary { + continue + } + adaptor := relay.GetAdaptor(i) + channelName := adaptor.GetChannelName() + modelNames := adaptor.GetModelList() + for _, modelName := range modelNames { + models = append(models, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: channelName, + Permission: permission, + Root: modelName, + Parent: nil, + }) + } + } + for _, channelType := range openai.CompatibleChannels { + if channelType == channeltype.Azure { + continue + } + channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) + for _, modelName := range channelModelList { + models = append(models, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: channelName, + Permission: permission, + Root: modelName, + Parent: nil, + }) + } + } + modelsMap = make(map[string]OpenAIModels) + for _, model := range models { + modelsMap[model.Id] = model + } + channelId2Models = make(map[int][]string) + for i := 1; i < channeltype.Dummy; i++ { + adaptor := relay.GetAdaptor(channeltype.ToAPIType(i)) + meta := &meta.Meta{ + ChannelType: i, + } + adaptor.Init(meta) + channelId2Models[i] = adaptor.GetModelList() + } +} + +func DashboardListModels(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": channelId2Models, + }) +} + +func ListAllModels(c *gin.Context) { + c.JSON(200, gin.H{ + "object": "list", + "data": models, + }) +} + +func ListModels(c *gin.Context) { + ctx := c.Request.Context() + var availableModels []string + if c.GetString(ctxkey.AvailableModels) != "" { + availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",") + } else { + userId := c.GetInt(ctxkey.Id) + userGroup, _ := model.CacheGetUserGroup(userId) + availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) + } + modelSet := make(map[string]bool) + for _, availableModel := range availableModels { + modelSet[availableModel] = true + } + availableOpenAIModels := make([]OpenAIModels, 0) + for _, model := range models { + if _, ok := modelSet[model.Id]; ok { + modelSet[model.Id] = false + availableOpenAIModels = append(availableOpenAIModels, model) + } + } + for modelName, ok := range modelSet { + if ok { + availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "custom", + Root: modelName, + Parent: nil, + }) + } + } + c.JSON(200, gin.H{ + "object": "list", + "data": availableOpenAIModels, + }) +} + +func RetrieveModel(c *gin.Context) { + modelId := c.Param("model") + if model, ok := modelsMap[modelId]; ok { + c.JSON(200, model) + } else { + Error := relaymodel.Error{ + Message: fmt.Sprintf("The model '%s' does not exist", modelId), + Type: "invalid_request_error", + Param: "model", + Code: "model_not_found", + } + c.JSON(200, gin.H{ + "error": Error, + }) + } +} + +func GetUserAvailableModels(c *gin.Context) { + ctx := c.Request.Context() + id := c.GetInt(ctxkey.Id) + userGroup, err := model.CacheGetUserGroup(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + models, err := model.CacheGetGroupModels(ctx, userGroup) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": models, + }) + return +} diff --git a/controller/option.go b/controller/option.go new file mode 100644 index 0000000..310086e --- /dev/null +++ b/controller/option.go @@ -0,0 +1,102 @@ +package controller + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/i18n" + "github.com/songquanpeng/one-api/model" + + "github.com/gin-gonic/gin" +) + +func GetOptions(c *gin.Context) { + var options []*model.Option + config.OptionMapRWMutex.Lock() + for k, v := range config.OptionMap { + if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { + continue + } + options = append(options, &model.Option{ + Key: k, + Value: helper.Interface2String(v), + }) + } + config.OptionMapRWMutex.Unlock() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": options, + }) + return +} + +func UpdateOption(c *gin.Context) { + var option model.Option + err := json.NewDecoder(c.Request.Body).Decode(&option) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": i18n.Translate(c, "invalid_parameter"), + }) + return + } + switch option.Key { + case "Theme": + if !config.ValidThemes[option.Value] { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的主题", + }) + return + } + case "GitHubOAuthEnabled": + if option.Value == "true" && config.GitHubClientId == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", + }) + return + } + case "EmailDomainRestrictionEnabled": + if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", + }) + return + } + case "WeChatAuthEnabled": + if option.Value == "true" && config.WeChatServerAddress == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法启用微信登录,请先填入微信登录相关配置信息!", + }) + return + } + case "TurnstileCheckEnabled": + if option.Value == "true" && config.TurnstileSiteKey == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", + }) + return + } + } + err = model.UpdateOption(option.Key, option.Value) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/controller/redemption.go b/controller/redemption.go new file mode 100644 index 0000000..1d0ffba --- /dev/null +++ b/controller/redemption.go @@ -0,0 +1,195 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" +) + +func GetAllRedemptions(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": redemptions, + }) + return +} + +func SearchRedemptions(c *gin.Context) { + keyword := c.Query("keyword") + redemptions, err := model.SearchRedemptions(keyword) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": redemptions, + }) + return +} + +func GetRedemption(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + redemption, err := model.GetRedemptionById(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": redemption, + }) + return +} + +func AddRedemption(c *gin.Context) { + redemption := model.Redemption{} + err := c.ShouldBindJSON(&redemption) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if len(redemption.Name) == 0 || len(redemption.Name) > 20 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "兑换码名称长度必须在1-20之间", + }) + return + } + if redemption.Count <= 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "兑换码个数必须大于0", + }) + return + } + if redemption.Count > 100 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "一次兑换码批量生成的个数不能大于 100", + }) + return + } + var keys []string + for i := 0; i < redemption.Count; i++ { + key := random.GetUUID() + cleanRedemption := model.Redemption{ + UserId: c.GetInt(ctxkey.Id), + Name: redemption.Name, + Key: key, + CreatedTime: helper.GetTimestamp(), + Quota: redemption.Quota, + } + err = cleanRedemption.Insert() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + "data": keys, + }) + return + } + keys = append(keys, key) + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": keys, + }) + return +} + +func DeleteRedemption(c *gin.Context) { + id, _ := strconv.Atoi(c.Param("id")) + err := model.DeleteRedemptionById(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func UpdateRedemption(c *gin.Context) { + statusOnly := c.Query("status_only") + redemption := model.Redemption{} + err := c.ShouldBindJSON(&redemption) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + cleanRedemption, err := model.GetRedemptionById(redemption.Id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if statusOnly != "" { + cleanRedemption.Status = redemption.Status + } else { + // If you add more fields, please also update redemption.Update() + cleanRedemption.Name = redemption.Name + cleanRedemption.Quota = redemption.Quota + } + err = cleanRedemption.Update() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": cleanRedemption, + }) + return +} diff --git a/controller/relay.go b/controller/relay.go new file mode 100644 index 0000000..038123b --- /dev/null +++ b/controller/relay.go @@ -0,0 +1,156 @@ +package controller + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/middleware" + dbmodel "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" + "github.com/songquanpeng/one-api/relay/controller" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +// https://platform.openai.com/docs/api-reference/chat + +func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { + var err *model.ErrorWithStatusCode + switch relayMode { + case relaymode.ImagesGenerations: + err = controller.RelayImageHelper(c, relayMode) + case relaymode.AudioSpeech: + fallthrough + case relaymode.AudioTranslation: + fallthrough + case relaymode.AudioTranscription: + err = controller.RelayAudioHelper(c, relayMode) + case relaymode.Proxy: + err = controller.RelayProxyHelper(c, relayMode) + default: + err = controller.RelayTextHelper(c) + } + return err +} + +func Relay(c *gin.Context) { + ctx := c.Request.Context() + relayMode := relaymode.GetByPath(c.Request.URL.Path) + if config.DebugEnabled { + requestBody, _ := common.GetRequestBody(c) + logger.Debugf(ctx, "request body: %s", string(requestBody)) + } + channelId := c.GetInt(ctxkey.ChannelId) + userId := c.GetInt(ctxkey.Id) + bizErr := relayHelper(c, relayMode) + if bizErr == nil { + monitor.Emit(channelId, true) + return + } + lastFailedChannelId := channelId + channelName := c.GetString(ctxkey.ChannelName) + group := c.GetString(ctxkey.Group) + originalModel := c.GetString(ctxkey.OriginalModel) + go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) + requestId := c.GetString(helper.RequestIdKey) + retryTimes := config.RetryTimes + if !shouldRetry(c, bizErr.StatusCode) { + logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) + retryTimes = 0 + } + for i := retryTimes; i > 0; i-- { + channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes) + if err != nil { + logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %+v", err) + break + } + logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i) + if channel.Id == lastFailedChannelId { + continue + } + middleware.SetupContextForSelectedChannel(c, channel, originalModel) + requestBody, err := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + bizErr = relayHelper(c, relayMode) + if bizErr == nil { + return + } + channelId := c.GetInt(ctxkey.ChannelId) + lastFailedChannelId = channelId + channelName := c.GetString(ctxkey.ChannelName) + go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) + } + if bizErr != nil { + if bizErr.StatusCode == http.StatusTooManyRequests { + bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" + } + + // BUG: bizErr is in race condition + bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId) + c.JSON(bizErr.StatusCode, gin.H{ + "error": bizErr.Error, + }) + } +} + +func shouldRetry(c *gin.Context, statusCode int) bool { + if _, ok := c.Get(ctxkey.SpecificChannelId); ok { + return false + } + if statusCode == http.StatusTooManyRequests { + return true + } + if statusCode/100 == 5 { + return true + } + if statusCode == http.StatusBadRequest { + return false + } + if statusCode/100 == 2 { + return false + } + return true +} + +func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) { + logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) + // https://platform.openai.com/docs/guides/error-codes/api-errors + if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { + monitor.DisableChannel(channelId, channelName, err.Message) + } else { + monitor.Emit(channelId, false) + } +} + +func RelayNotImplemented(c *gin.Context) { + err := model.Error{ + Message: "API not implemented", + Type: "one_api_error", + Param: "", + Code: "api_not_implemented", + } + c.JSON(http.StatusNotImplemented, gin.H{ + "error": err, + }) +} + +func RelayNotFound(c *gin.Context) { + err := model.Error{ + Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), + Type: "invalid_request_error", + Param: "", + Code: "", + } + c.JSON(http.StatusNotFound, gin.H{ + "error": err, + }) +} diff --git a/controller/token.go b/controller/token.go new file mode 100644 index 0000000..668ccd9 --- /dev/null +++ b/controller/token.go @@ -0,0 +1,257 @@ +package controller + +import ( + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/network" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" +) + +func GetAllTokens(c *gin.Context) { + userId := c.GetInt(ctxkey.Id) + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + + order := c.Query("order") + tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage, order) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": tokens, + }) + return +} + +func SearchTokens(c *gin.Context) { + userId := c.GetInt(ctxkey.Id) + keyword := c.Query("keyword") + tokens, err := model.SearchUserTokens(userId, keyword) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": tokens, + }) + return +} + +func GetToken(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + userId := c.GetInt(ctxkey.Id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + token, err := model.GetTokenByIds(id, userId) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": token, + }) + return +} + +func GetTokenStatus(c *gin.Context) { + tokenId := c.GetInt(ctxkey.TokenId) + userId := c.GetInt(ctxkey.Id) + token, err := model.GetTokenByIds(tokenId, userId) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + expiredAt := token.ExpiredTime + if expiredAt == -1 { + expiredAt = 0 + } + c.JSON(http.StatusOK, gin.H{ + "object": "credit_summary", + "total_granted": token.RemainQuota, + "total_used": 0, // not supported currently + "total_available": token.RemainQuota, + "expires_at": expiredAt * 1000, + }) +} + +func validateToken(c *gin.Context, token model.Token) error { + if len(token.Name) > 30 { + return fmt.Errorf("令牌名称过长") + } + if token.Subnet != nil && *token.Subnet != "" { + err := network.IsValidSubnets(*token.Subnet) + if err != nil { + return fmt.Errorf("无效的网段:%s", err.Error()) + } + } + return nil +} + +func AddToken(c *gin.Context) { + token := model.Token{} + err := c.ShouldBindJSON(&token) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + err = validateToken(c, token) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": fmt.Sprintf("参数错误:%s", err.Error()), + }) + return + } + + cleanToken := model.Token{ + UserId: c.GetInt(ctxkey.Id), + Name: token.Name, + Key: random.GenerateKey(), + CreatedTime: helper.GetTimestamp(), + AccessedTime: helper.GetTimestamp(), + ExpiredTime: token.ExpiredTime, + RemainQuota: token.RemainQuota, + UnlimitedQuota: token.UnlimitedQuota, + Models: token.Models, + Subnet: token.Subnet, + } + err = cleanToken.Insert() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": cleanToken, + }) + return +} + +func DeleteToken(c *gin.Context) { + id, _ := strconv.Atoi(c.Param("id")) + userId := c.GetInt(ctxkey.Id) + err := model.DeleteTokenById(id, userId) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func UpdateToken(c *gin.Context) { + userId := c.GetInt(ctxkey.Id) + statusOnly := c.Query("status_only") + token := model.Token{} + err := c.ShouldBindJSON(&token) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + err = validateToken(c, token) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": fmt.Sprintf("参数错误:%s", err.Error()), + }) + return + } + cleanToken, err := model.GetTokenByIds(token.Id, userId) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if token.Status == model.TokenStatusEnabled { + if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", + }) + return + } + if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", + }) + return + } + } + if statusOnly != "" { + cleanToken.Status = token.Status + } else { + // If you add more fields, please also update token.Update() + cleanToken.Name = token.Name + cleanToken.ExpiredTime = token.ExpiredTime + cleanToken.RemainQuota = token.RemainQuota + cleanToken.UnlimitedQuota = token.UnlimitedQuota + cleanToken.Models = token.Models + cleanToken.Subnet = token.Subnet + } + err = cleanToken.Update() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": cleanToken, + }) + return +} diff --git a/controller/user.go b/controller/user.go new file mode 100644 index 0000000..d7fd8d7 --- /dev/null +++ b/controller/user.go @@ -0,0 +1,816 @@ +package controller + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/i18n" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/model" +) + +type LoginRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +func Login(c *gin.Context) { + if !config.PasswordLoginEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "管理员关闭了密码登录", + "success": false, + }) + return + } + var loginRequest LoginRequest + err := json.NewDecoder(c.Request.Body).Decode(&loginRequest) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": i18n.Translate(c, "invalid_parameter"), + "success": false, + }) + return + } + username := loginRequest.Username + password := loginRequest.Password + if username == "" || password == "" { + c.JSON(http.StatusOK, gin.H{ + "message": i18n.Translate(c, "invalid_parameter"), + "success": false, + }) + return + } + user := model.User{ + Username: username, + Password: password, + } + err = user.ValidateAndFill() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": err.Error(), + "success": false, + }) + return + } + SetupLogin(&user, c) +} + +// setup session & cookies and then return user info +func SetupLogin(user *model.User, c *gin.Context) { + session := sessions.Default(c) + session.Set("id", user.Id) + session.Set("username", user.Username) + session.Set("role", user.Role) + session.Set("status", user.Status) + err := session.Save() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": "无法保存会话信息,请重试", + "success": false, + }) + return + } + cleanUser := model.User{ + Id: user.Id, + Username: user.Username, + DisplayName: user.DisplayName, + Role: user.Role, + Status: user.Status, + } + c.JSON(http.StatusOK, gin.H{ + "message": "", + "success": true, + "data": cleanUser, + }) +} + +func Logout(c *gin.Context) { + session := sessions.Default(c) + session.Clear() + err := session.Save() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": err.Error(), + "success": false, + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "message": "", + "success": true, + }) +} + +func Register(c *gin.Context) { + ctx := c.Request.Context() + if !config.RegisterEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "管理员关闭了新用户注册", + "success": false, + }) + return + } + if !config.PasswordRegisterEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", + "success": false, + }) + return + } + var user model.User + err := json.NewDecoder(c.Request.Body).Decode(&user) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": i18n.Translate(c, "invalid_parameter"), + }) + return + } + if err := common.Validate.Struct(&user); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": i18n.Translate(c, "invalid_input"), + }) + return + } + if config.EmailVerificationEnabled { + if user.Email == "" || user.VerificationCode == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员开启了邮箱验证,请输入邮箱地址和验证码", + }) + return + } + if !common.VerifyCodeWithKey(user.Email, user.VerificationCode, common.EmailVerificationPurpose) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "验证码错误或已过期", + }) + return + } + } + affCode := user.AffCode // this code is the inviter's code, not the user's own code + inviterId, _ := model.GetUserIdByAffCode(affCode) + cleanUser := model.User{ + Username: user.Username, + Password: user.Password, + DisplayName: user.Username, + InviterId: inviterId, + } + if config.EmailVerificationEnabled { + cleanUser.Email = user.Email + } + if err := cleanUser.Insert(ctx, inviterId); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func GetAllUsers(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + + order := c.DefaultQuery("order", "") + users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": users, + }) +} + +func SearchUsers(c *gin.Context) { + keyword := c.Query("keyword") + users, err := model.SearchUsers(keyword) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": users, + }) + return +} + +func GetUser(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user, err := model.GetUserById(id, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + myRole := c.GetInt(ctxkey.Role) + if myRole <= user.Role && myRole != model.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权获取同级或更高等级用户的信息", + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": user, + }) + return +} + +func GetUserDashboard(c *gin.Context) { + id := c.GetInt(ctxkey.Id) + now := time.Now() + startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix() + endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix() + + dashboards, err := model.SearchLogsByDayAndModel(id, int(startOfDay), int(endOfDay)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法获取统计信息", + "data": nil, + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": dashboards, + }) + return +} + +func GenerateAccessToken(c *gin.Context) { + id := c.GetInt(ctxkey.Id) + user, err := model.GetUserById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.AccessToken = random.GetUUID() + + if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "请重试,系统生成的 UUID 竟然重复了!", + }) + return + } + + if err := user.Update(false); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": user.AccessToken, + }) + return +} + +func GetAffCode(c *gin.Context) { + id := c.GetInt(ctxkey.Id) + user, err := model.GetUserById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if user.AffCode == "" { + user.AffCode = random.GetRandomString(4) + if err := user.Update(false); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": user.AffCode, + }) + return +} + +func GetSelf(c *gin.Context) { + id := c.GetInt(ctxkey.Id) + user, err := model.GetUserById(id, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": user, + }) + return +} + +func UpdateUser(c *gin.Context) { + ctx := c.Request.Context() + var updatedUser model.User + err := json.NewDecoder(c.Request.Body).Decode(&updatedUser) + if err != nil || updatedUser.Id == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": i18n.Translate(c, "invalid_parameter"), + }) + return + } + if updatedUser.Password == "" { + updatedUser.Password = "$I_LOVE_U" // make Validator happy :) + } + if err := common.Validate.Struct(&updatedUser); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": i18n.Translate(c, "invalid_input"), + }) + return + } + originUser, err := model.GetUserById(updatedUser.Id, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + myRole := c.GetInt(ctxkey.Role) + if myRole <= originUser.Role && myRole != model.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权更新同权限等级或更高权限等级的用户信息", + }) + return + } + if myRole <= updatedUser.Role && myRole != model.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权将其他用户权限等级提升到大于等于自己的权限等级", + }) + return + } + if updatedUser.Password == "$I_LOVE_U" { + updatedUser.Password = "" // rollback to what it should be + } + updatePassword := updatedUser.Password != "" + if err := updatedUser.Update(updatePassword); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if originUser.Quota != updatedUser.Quota { + model.RecordLog(ctx, originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func UpdateSelf(c *gin.Context) { + var user model.User + err := json.NewDecoder(c.Request.Body).Decode(&user) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": i18n.Translate(c, "invalid_parameter"), + }) + return + } + if user.Password == "" { + user.Password = "$I_LOVE_U" // make Validator happy :) + } + if err := common.Validate.Struct(&user); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "输入不合法 " + err.Error(), + }) + return + } + + cleanUser := model.User{ + Id: c.GetInt(ctxkey.Id), + Username: user.Username, + Password: user.Password, + DisplayName: user.DisplayName, + } + if user.Password == "$I_LOVE_U" { + user.Password = "" // rollback to what it should be + cleanUser.Password = "" + } + updatePassword := user.Password != "" + if err := cleanUser.Update(updatePassword); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func DeleteUser(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + originUser, err := model.GetUserById(id, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + myRole := c.GetInt("role") + if myRole <= originUser.Role { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权删除同权限等级或更高权限等级的用户", + }) + return + } + err = model.DeleteUserById(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return + } +} + +func DeleteSelf(c *gin.Context) { + id := c.GetInt("id") + user, _ := model.GetUserById(id, false) + + if user.Role == model.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "不能删除超级管理员账户", + }) + return + } + + err := model.DeleteUserById(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func CreateUser(c *gin.Context) { + ctx := c.Request.Context() + var user model.User + err := json.NewDecoder(c.Request.Body).Decode(&user) + if err != nil || user.Username == "" || user.Password == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": i18n.Translate(c, "invalid_parameter"), + }) + return + } + if err := common.Validate.Struct(&user); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": i18n.Translate(c, "invalid_input"), + }) + return + } + if user.DisplayName == "" { + user.DisplayName = user.Username + } + myRole := c.GetInt("role") + if user.Role >= myRole { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法创建权限大于等于自己的用户", + }) + return + } + // Even for admin users, we cannot fully trust them! + cleanUser := model.User{ + Username: user.Username, + Password: user.Password, + DisplayName: user.DisplayName, + } + if err := cleanUser.Insert(ctx, 0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +type ManageRequest struct { + Username string `json:"username"` + Action string `json:"action"` +} + +// ManageUser Only admin user can do this +func ManageUser(c *gin.Context) { + var req ManageRequest + err := json.NewDecoder(c.Request.Body).Decode(&req) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": i18n.Translate(c, "invalid_parameter"), + }) + return + } + user := model.User{ + Username: req.Username, + } + // Fill attributes + model.DB.Where(&user).First(&user) + if user.Id == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户不存在", + }) + return + } + myRole := c.GetInt("role") + if myRole <= user.Role && myRole != model.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权更新同权限等级或更高权限等级的用户信息", + }) + return + } + switch req.Action { + case "disable": + user.Status = model.UserStatusDisabled + if user.Role == model.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法禁用超级管理员用户", + }) + return + } + case "enable": + user.Status = model.UserStatusEnabled + case "delete": + if user.Role == model.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法删除超级管理员用户", + }) + return + } + if err := user.Delete(); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + case "promote": + if myRole != model.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "普通管理员用户无法提升其他用户为管理员", + }) + return + } + if user.Role >= model.RoleAdminUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该用户已经是管理员", + }) + return + } + user.Role = model.RoleAdminUser + case "demote": + if user.Role == model.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法降级超级管理员用户", + }) + return + } + if user.Role == model.RoleCommonUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该用户已经是普通用户", + }) + return + } + user.Role = model.RoleCommonUser + } + + if err := user.Update(false); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + clearUser := model.User{ + Role: user.Role, + Status: user.Status, + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": clearUser, + }) + return +} + +func EmailBind(c *gin.Context) { + email := c.Query("email") + code := c.Query("code") + if !common.VerifyCodeWithKey(email, code, common.EmailVerificationPurpose) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "验证码错误或已过期", + }) + return + } + id := c.GetInt("id") + user := model.User{ + Id: id, + } + err := user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.Email = email + // no need to check if this email already taken, because we have used verification code to check it + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if user.Role == model.RoleRootUser { + config.RootUserEmail = email + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +type topUpRequest struct { + Key string `json:"key"` +} + +func TopUp(c *gin.Context) { + ctx := c.Request.Context() + req := topUpRequest{} + err := c.ShouldBindJSON(&req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + id := c.GetInt("id") + quota, err := model.Redeem(ctx, req.Key, id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": quota, + }) + return +} + +type adminTopUpRequest struct { + UserId int `json:"user_id"` + Quota int `json:"quota"` + Remark string `json:"remark"` +} + +func AdminTopUp(c *gin.Context) { + ctx := c.Request.Context() + req := adminTopUpRequest{} + err := c.ShouldBindJSON(&req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + err = model.IncreaseUserQuota(req.UserId, int64(req.Quota)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if req.Remark == "" { + req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) + } + model.RecordTopupLog(ctx, req.UserId, req.Remark, req.Quota) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..1325a81 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,49 @@ +version: '3.4' + +services: + one-api: + image: "${REGISTRY:-docker.io}/justsong/one-api:latest" + container_name: one-api + restart: always + command: --log-dir /app/logs + ports: + - "3000:3000" + volumes: + - ./data/oneapi:/data + - ./logs:/app/logs + environment: + - SQL_DSN=oneapi:123456@tcp(db:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库 + - REDIS_CONN_STRING=redis://redis + - SESSION_SECRET=random_string # 修改为随机字符串 + - TZ=Asia/Shanghai +# - NODE_TYPE=slave # 多机部署时从节点取消注释该行 +# - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行 +# - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行 + depends_on: + - redis + - db + healthcheck: + test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] + interval: 30s + timeout: 10s + retries: 3 + + redis: + image: "${REGISTRY:-docker.io}/redis:latest" + container_name: redis + restart: always + + db: + image: "${REGISTRY:-docker.io}/mysql:8.2.0" + restart: always + container_name: mysql + volumes: + - ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储 + ports: + - '3306:3306' + environment: + TZ: Asia/Shanghai # 设置时区 + MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码 + MYSQL_USER: oneapi # 创建专用用户 + MYSQL_PASSWORD: '123456' # 设置专用用户密码 + MYSQL_DATABASE: one-api # 自动创建数据库 \ No newline at end of file diff --git a/docs/API.md b/docs/API.md new file mode 100644 index 0000000..0b7ddf5 --- /dev/null +++ b/docs/API.md @@ -0,0 +1,53 @@ +# 使用 API 操控 & 扩展 One API +> 欢迎提交 PR 在此放上你的拓展项目。 + +例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。 + +又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。 + +## 鉴权 +One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取: + +![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/c15281a7-83ed-47cb-a1f6-913cb6bf4a7c) + +之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API: +![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/1273b7ae-cb60-4c0d-93a6-b1cbc039c4f8) + +## 请求格式与响应格式 +One API 使用 JSON 格式进行请求和响应。 + +对于响应体,一般格式如下: +```json +{ + "message": "请求信息", + "success": true, + "data": {} +} +``` + +## API 列表 +> 当前 API 列表不全,请自行通过浏览器抓取前端请求 + +如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。 + +### 获取当前登录用户信息 +**GET** `/api/user/self` + +### 为给定用户充值额度 +**POST** `/api/topup` +```json +{ + "user_id": 1, + "quota": 100000, + "remark": "充值 100000 额度" +} +``` + +## 其他 +### 充值链接上的附加参数 +One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如: +`https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837` + +你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。 + +注意,不是所有主题都支持该功能,欢迎 PR 补齐。 \ No newline at end of file diff --git a/docs/SAAS-PLAN.md b/docs/SAAS-PLAN.md new file mode 100644 index 0000000..8c45f5e --- /dev/null +++ b/docs/SAAS-PLAN.md @@ -0,0 +1,539 @@ +# One-API 多租户 SaaS 系统二开方案 + +## 📋 项目概述 + +### 目标 +在 one-api 基础上进行二次开发,构建多租户 SaaS 系统,支持: +- 主系统统一管理上游 API 渠道 +- 多个代理站点独立部署(独立数据库) +- 月套餐令牌系统(日限/周限/月限额度控制) +- 主系统统一计费扣费 + +### 核心原则 +1. **保持可升级性**:插件化开发,最小化核心代码改动 +2. **完全独立部署**:每个代理站点独立运行,互不影响 +3. **统一计费管理**:主系统集中管理渠道和计费 + +--- + +## 🏗️ 系统架构设计 + +### 整体架构 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ 主系统 (Master) │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ 上游渠道池 │ │ 计费中心 │ │ 代理管理 │ │ +│ │ OpenAI/Claude│ │ 统计报表 │ │ 额度分配 │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +│ ▲ ▲ ▲ │ +└─────────┼──────────────────┼──────────────────┼──────────────┘ + │ │ │ + │ 渠道请求 │ 计费回传 │ 配置同步 + │ │ │ + ┌─────┴──────────────────┴──────────────────┴─────┐ + │ │ +┌───▼────────┐ ┌─────────────┐ ┌─────────▼────┐ +│代理站点 A │ │ 代理站点 B │ │ 代理站点 C │ +│独立数据库 │ │ 独立数据库 │ │ 独立数据库 │ +│独立域名 │ │ 独立域名 │ │ 独立域名 │ +└────────────┘ └─────────────┘ └──────────────┘ + ▲ ▲ ▲ + │ │ │ + 终端用户 终端用户 终端用户 +``` + +### 核心交互流程 + +``` +1. 用户请求 → 代理站点 +2. 代理站点验证令牌(本地) +3. 代理站点 → 主系统(请求渠道服务) +4. 主系统验证额度 → 转发上游 API +5. 主系统记录消耗 → 返回结果 +6. 代理站点 → 返回用户 +``` + +--- + +## 💾 数据库设计 + +### 主系统新增表 + +#### 1. agent_sites(代理站点表) +```sql +CREATE TABLE agent_sites ( + id INT PRIMARY KEY AUTO_INCREMENT, + name VARCHAR(100) NOT NULL, -- 站点名称 + domain VARCHAR(255) UNIQUE, -- 站点域名 + api_key VARCHAR(64) UNIQUE NOT NULL, -- 站点 API 密钥(用于主从通信) + status INT DEFAULT 1, -- 状态:1启用 2禁用 + total_quota BIGINT DEFAULT 0, -- 总分配额度 + used_quota BIGINT DEFAULT 0, -- 已使用额度 + created_time BIGINT, + updated_time BIGINT, + INDEX idx_api_key (api_key), + INDEX idx_status (status) +); +``` + +#### 2. agent_billing_logs(代理计费日志) +```sql +CREATE TABLE agent_billing_logs ( + id BIGINT PRIMARY KEY AUTO_INCREMENT, + agent_site_id INT NOT NULL, -- 代理站点ID + user_id INT, -- 代理站点的用户ID + token_name VARCHAR(100), -- 令牌名称 + model_name VARCHAR(100), -- 模型名称 + prompt_tokens INT, -- 输入token数 + completion_tokens INT, -- 输出token数 + quota INT, -- 消耗额度 + channel_id INT, -- 使用的渠道ID + created_time BIGINT, + INDEX idx_agent_site (agent_site_id), + INDEX idx_created_time (created_time) +); +``` + +### 代理站点扩展字段 + +#### 扩展 tokens 表(月套餐令牌) +```sql +-- 在现有 tokens 表基础上新增字段 +ALTER TABLE tokens ADD COLUMN subscription_type VARCHAR(20); -- 套餐类型:daily/weekly/monthly +ALTER TABLE tokens ADD COLUMN daily_quota_limit BIGINT DEFAULT 0; -- 日额度限制 +ALTER TABLE tokens ADD COLUMN weekly_quota_limit BIGINT DEFAULT 0; -- 周额度限制 +ALTER TABLE tokens ADD COLUMN monthly_quota_limit BIGINT DEFAULT 0; -- 月额度限制 +ALTER TABLE tokens ADD COLUMN daily_used_quota BIGINT DEFAULT 0; -- 日已用额度 +ALTER TABLE tokens ADD COLUMN weekly_used_quota BIGINT DEFAULT 0; -- 周已用额度 +ALTER TABLE tokens ADD COLUMN monthly_used_quota BIGINT DEFAULT 0; -- 月已用额度 +ALTER TABLE tokens ADD COLUMN last_reset_daily BIGINT DEFAULT 0; -- 上次日重置时间 +ALTER TABLE tokens ADD COLUMN last_reset_weekly BIGINT DEFAULT 0; -- 上次周重置时间 +ALTER TABLE tokens ADD COLUMN last_reset_monthly BIGINT DEFAULT 0; -- 上次月重置时间 +``` + +--- + +## 🔧 核心功能实现 + +### 1. 月套餐令牌系统 + +#### 套餐定义 +```go +// common/subscription/plans.go (新文件) +package subscription + +const ( + PlanBasic = "basic" // 100美金/月 + PlanStandard = "standard" // 200美金/月 + PlanPremium = "premium" // 500美金/月 +) + +type SubscriptionPlan struct { + Name string + MonthlyQuota int64 // 月总额度(点数) + DailyQuota int64 // 日额度限制 + WeeklyQuota int64 // 周额度限制 +} + +var Plans = map[string]SubscriptionPlan{ + PlanBasic: { + Name: "基础版", + MonthlyQuota: 50000000, // 100美金 * 500,000 + DailyQuota: 2000000, // ~4美金/天 + WeeklyQuota: 15000000, // ~30美金/周 + }, + PlanStandard: { + Name: "标准版", + MonthlyQuota: 100000000, // 200美金 * 500,000 + DailyQuota: 5000000, // ~10美金/天 + WeeklyQuota: 35000000, // ~70美金/周 + }, + PlanPremium: { + Name: "高级版", + MonthlyQuota: 250000000, // 500美金 * 500,000 + DailyQuota: 15000000, // ~30美金/天 + WeeklyQuota: 90000000, // ~180美金/周 + }, +} +``` + +#### 额度检查逻辑 +```go +// model/token.go 扩展 +func (token *Token) CheckSubscriptionQuota() error { + now := time.Now().Unix() + + // 1. 检查并重置日额度 + if shouldResetDaily(token.LastResetDaily, now) { + token.DailyUsedQuota = 0 + token.LastResetDaily = getDayStart(now) + } + + // 2. 检查并重置周额度 + if shouldResetWeekly(token.LastResetWeekly, now) { + token.WeeklyUsedQuota = 0 + token.LastResetWeekly = getWeekStart(now) + } + + // 3. 检查并重置月额度 + if shouldResetMonthly(token.LastResetMonthly, now) { + token.MonthlyUsedQuota = 0 + token.LastResetMonthly = getMonthStart(now) + } + + // 4. 验证额度 + if token.DailyQuotaLimit > 0 && token.DailyUsedQuota >= token.DailyQuotaLimit { + return errors.New("日额度已用尽") + } + if token.WeeklyQuotaLimit > 0 && token.WeeklyUsedQuota >= token.WeeklyQuotaLimit { + return errors.New("周额度已用尽") + } + if token.MonthlyQuotaLimit > 0 && token.MonthlyUsedQuota >= token.MonthlyQuotaLimit { + return errors.New("月额度已用尽") + } + + return nil +} +``` + +#### 消费扣费 +```go +// model/token.go 扩展 +func (token *Token) ConsumeSubscriptionQuota(quota int64) error { + token.DailyUsedQuota += quota + token.WeeklyUsedQuota += quota + token.MonthlyUsedQuota += quota + token.UsedQuota += quota + + return DB.Model(token).Updates(map[string]interface{}{ + "daily_used_quota": token.DailyUsedQuota, + "weekly_used_quota": token.WeeklyUsedQuota, + "monthly_used_quota": token.MonthlyUsedQuota, + "used_quota": token.UsedQuota, + }).Error +} +``` + +### 2. 主从计费系统 + +#### 代理站点配置 +```go +// common/config/agent.go (新文件) +package config + +var ( + IsAgentSite = os.Getenv("AGENT_MODE") == "true" + MasterSystemURL = os.Getenv("MASTER_SYSTEM_URL") // 主系统地址 + AgentSiteAPIKey = os.Getenv("AGENT_SITE_API_KEY") // 站点密钥 +) +``` + +#### 代理站点中继拦截 +```go +// relay/proxy/master_proxy.go (新文件) +package proxy + +// 拦截所有中继请求,转发到主系统 +func RelayToMaster(c *gin.Context, meta *relay.RelayMeta) (*http.Response, error) { + // 1. 构造请求到主系统 + masterURL := config.MasterSystemURL + "/api/agent/relay" + + req, _ := http.NewRequest(c.Request.Method, masterURL, c.Request.Body) + + // 2. 添加认证头 + req.Header.Set("X-Agent-Site-Key", config.AgentSiteAPIKey) + req.Header.Set("X-Agent-User-Id", strconv.Itoa(meta.UserId)) + req.Header.Set("X-Agent-Token-Name", meta.TokenName) + + // 3. 转发原始请求头 + for k, v := range c.Request.Header { + req.Header[k] = v + } + + // 4. 发送请求 + client := &http.Client{Timeout: 5 * time.Minute} + resp, err := client.Do(req) + + return resp, err +} +``` + +#### 主系统代理中继接口 +```go +// controller/agent_relay.go (新文件) +package controller + +func AgentRelay(c *gin.Context) { + // 1. 验证代理站点身份 + agentKey := c.GetHeader("X-Agent-Site-Key") + agentSite, err := model.GetAgentSiteByKey(agentKey) + if err != nil { + c.JSON(403, gin.H{"error": "invalid agent site"}) + return + } + + // 2. 检查代理站点额度 + if agentSite.UsedQuota >= agentSite.TotalQuota { + c.JSON(403, gin.H{"error": "agent quota exhausted"}) + return + } + + // 3. 选择渠道并转发请求 + modelName := c.GetString("model") + channel, err := model.CacheGetRandomSatisfiedChannel("default", modelName, false) + + // 4. 调用正常的 relay 流程 + relayRequest(c, channel, agentSite) +} + +func relayRequest(c *gin.Context, channel *model.Channel, agentSite *model.AgentSite) { + // ... 使用现有的 relay 逻辑 + // 计费时记录到 agent_billing_logs 表 +} +``` + +--- + +## 📁 文件结构(插件化设计) + +``` +one-api/ +├── common/ +│ ├── subscription/ # 新增:套餐管理 +│ │ ├── plans.go # 套餐定义 +│ │ └── quota.go # 额度计算 +│ └── config/ +│ └── agent.go # 新增:代理站点配置 +├── model/ +│ ├── agent_site.go # 新增:代理站点模型 +│ ├── agent_billing.go # 新增:代理计费日志 +│ └── token.go # 扩展:添加套餐相关方法 +├── controller/ +│ ├── agent_site.go # 新增:代理站点管理接口 +│ ├── agent_relay.go # 新增:代理中继接口 +│ └── subscription.go # 新增:套餐管理接口 +├── middleware/ +│ └── subscription_check.go # 新增:套餐额度检查中间件 +├── relay/ +│ └── proxy/ +│ └── master_proxy.go # 新增:主系统代理转发 +└── docs/ + └── SAAS-PLAN.md # 本文档 +``` + +--- + +## 🔄 工作流程 + +### 代理站点请求流程 + +``` +1. 用户 → 代理站点 /v1/chat/completions + ↓ +2. 代理站点中间件验证 Token + - 检查 Token 状态 + - 检查套餐额度(日/周/月) + ↓ +3. 通过 → 转发到主系统 /api/agent/relay + Header: X-Agent-Site-Key, X-Agent-User-Id + ↓ +4. 主系统验证代理站点身份和额度 + ↓ +5. 主系统选择渠道 → 调用上游 API + ↓ +6. 主系统计费 + - 扣除代理站点额度 + - 记录到 agent_billing_logs + ↓ +7. 返回结果 → 代理站点 → 用户 + ↓ +8. 代理站点更新本地 Token 统计 + - 更新 daily/weekly/monthly_used_quota +``` + +### 代理站点部署配置 + +**环境变量:** +```bash +# .env +AGENT_MODE=true +MASTER_SYSTEM_URL=https://master.example.com +AGENT_SITE_API_KEY=ask-xxxxxxxxxxxx +SQL_DSN=agent_user:password@tcp(localhost:3306)/agent_db +PORT=3000 +``` + +--- + +## 🛠️ 实施步骤 + +### 阶段一:基础架构(Week 1-2) +- [ ] 创建 agent_sites 表和模型 +- [ ] 实现代理站点注册和管理接口 +- [ ] 开发主系统代理中继接口 +- [ ] 实现代理站点转发逻辑 + +### 阶段二:套餐系统(Week 3-4) +- [ ] 扩展 tokens 表字段 +- [ ] 实现套餐定义和管理 +- [ ] 开发套餐额度检查逻辑 +- [ ] 实现日/周/月自动重置 + +### 阶段三:计费系统(Week 5-6) +- [ ] 创建 agent_billing_logs 表 +- [ ] 实现主系统计费记录 +- [ ] 开发代理站点额度同步 +- [ ] 实现统计报表接口 + +### 阶段四:前端界面(Week 7-8) +- [ ] 主系统:代理站点管理页面 +- [ ] 主系统:计费统计报表 +- [ ] 代理站点:套餐令牌管理页面 +- [ ] 代理站点:使用统计页面 + +### 阶段五:测试优化(Week 9-10) +- [ ] 单元测试 +- [ ] 压力测试 +- [ ] 安全测试 +- [ ] 性能优化 + +--- + +## ⚠️ 技术难点与解决方案 + +### 1. 如何保持可升级性? + +**问题:** 二次开发后如何继续跟随 one-api 上游更新? + +**解决方案:** +- **插件化设计**:新功能尽量在新文件中实现,减少修改核心文件 +- **扩展而非修改**:使用 Go 的组合而非继承,扩展现有结构体 +- **Hook 机制**:在关键位置注入 Hook,避免修改主流程 +- **独立分支管理**: + ```bash + # 主分支跟随上游 + git remote add upstream https://github.com/songquanpeng/one-api.git + + # 开发分支 + git checkout -b saas-dev + + # 定期合并上游 + git fetch upstream + git merge upstream/main + ``` + +### 2. 额度透支问题 + +**问题:** 代理站点可能恶意超额使用 + +**解决方案:** +- **预扣费机制**:主系统在转发前先扣除预估额度 +- **实时额度检查**:每次请求都验证代理站点剩余额度 +- **熔断机制**:超额后立即停止服务 +- **告警通知**:额度接近用尽时提前通知 + +### 3. 性能问题 + +**问题:** 每次请求都要经过主系统,增加延迟 + +**解决方案:** +- **连接池**:复用 HTTP 连接 +- **异步计费**:返回结果后异步记录日志 +- **批量提交**:计费数据批量写入 +- **Redis 缓存**:缓存代理站点信息和额度 + +### 4. 高可用性 + +**问题:** 主系统故障影响所有代理站点 + +**解决方案:** +- **主系统多节点部署**:负载均衡 +- **降级策略**:主系统故障时代理站点使用本地渠道(如果配置) +- **健康检查**:定期检查主系统状态 +- **限流保护**:防止单个代理站点占用过多资源 + +--- + +## 🔐 安全考虑 + +### 1. 认证安全 +- 代理站点 API Key 使用强随机生成 +- 通信使用 HTTPS 加密 +- API Key 定期轮换 + +### 2. 防刷防滥用 +- 请求频率限制(Rate Limit) +- 异常流量检测 +- IP 白名单 + +### 3. 数据安全 +- 敏感数据加密存储 +- 日志脱敏处理 +- 定期备份 + +--- + +## 📊 监控与运维 + +### 关键指标 +- 代理站点请求量 +- 渠道使用分布 +- 额度消耗趋势 +- 错误率 +- 响应时间 + +### 告警规则 +- 代理站点额度不足(< 10%) +- 请求失败率异常(> 5%) +- 主系统响应超时 +- 数据库连接数过高 + +--- + +## 💰 成本估算 + +### 开发成本 +- 后端开发:6-8 周 +- 前端开发:2-3 周 +- 测试优化:2 周 +- 总计:10-13 周 + +### 运维成本(月) +- 主系统服务器:$50-100 +- 数据库:$30-50 +- Redis:$20-30 +- 带宽:$50-100 +- 总计:$150-280/月 + +--- + +## 📚 参考资料 + +### One-API 原项目 +- GitHub: https://github.com/songquanpeng/one-api +- 文档: README.md + +### 关键代码位置 +- 令牌管理:`model/token.go` +- 计费逻辑:`relay/billing/billing.go` +- 渠道分发:`middleware/distributor.go` +- 认证中间件:`middleware/auth.go` + +--- + +## 🎯 下一步行动 + +1. **确认方案**:与团队评审本方案 +2. **环境准备**:搭建开发测试环境 +3. **数据库设计**:创建数据表和索引 +4. **接口设计**:定义 API 接口规范 +5. **开始编码**:按阶段实施开发 + +--- + +**文档版本:** v1.0 +**创建时间:** 2025-12-29 +**作者:** Claude +**状态:** 待评审 diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b726ed4 --- /dev/null +++ b/go.mod @@ -0,0 +1,110 @@ +module github.com/songquanpeng/one-api + +go 1.20 + +require ( + cloud.google.com/go/iam v1.1.10 + github.com/aws/aws-sdk-go-v2 v1.27.0 + github.com/aws/aws-sdk-go-v2/credentials v1.17.15 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 + github.com/gin-contrib/cors v1.7.2 + github.com/gin-contrib/gzip v1.0.1 + github.com/gin-contrib/sessions v1.0.1 + github.com/gin-contrib/static v1.1.2 + github.com/gin-gonic/gin v1.10.0 + github.com/go-playground/validator/v10 v10.20.0 + github.com/go-redis/redis/v8 v8.11.5 + github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.1 + github.com/jinzhu/copier v0.4.0 + github.com/joho/godotenv v1.5.1 + github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/pkg/errors v0.9.1 + github.com/pkoukk/tiktoken-go v0.1.7 + github.com/smartystreets/goconvey v1.8.1 + github.com/stretchr/testify v1.9.0 + golang.org/x/crypto v0.31.0 + golang.org/x/image v0.18.0 + golang.org/x/sync v0.10.0 + google.golang.org/api v0.187.0 + gorm.io/driver/mysql v1.5.6 + gorm.io/driver/postgres v1.5.7 + gorm.io/driver/sqlite v1.5.1 + gorm.io/gorm v1.25.10 +) + +require ( + cloud.google.com/go/auth v0.6.1 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect + cloud.google.com/go/compute/metadata v0.3.0 // indirect + filippo.io/edwards25519 v1.1.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 // indirect + github.com/aws/smithy-go v1.20.2 // indirect + github.com/bytedance/sonic v1.11.6 // indirect + github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudwego/base64x v0.1.4 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.11.0 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-logr/logr v1.4.1 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/goccy/go-json v0.10.3 // indirect + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/google/s2a-go v0.1.7 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect + github.com/googleapis/gax-go/v2 v2.12.5 // indirect + github.com/gopherjs/gopherjs v1.17.2 // indirect + github.com/gorilla/context v1.1.2 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect + github.com/gorilla/sessions v1.2.2 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect + github.com/jackc/pgx/v5 v5.5.5 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.24 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/smarty/assertions v1.15.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.12 // indirect + go.opencensus.io v0.24.0 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect + go.opentelemetry.io/otel v1.24.0 // indirect + go.opentelemetry.io/otel/metric v1.24.0 // indirect + go.opentelemetry.io/otel/trace v1.24.0 // indirect + golang.org/x/arch v0.8.0 // indirect + golang.org/x/net v0.26.0 // indirect + golang.org/x/oauth2 v0.21.0 // indirect + golang.org/x/sys v0.28.0 // indirect + golang.org/x/text v0.21.0 // indirect + golang.org/x/time v0.5.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect + google.golang.org/grpc v1.64.1 // indirect + google.golang.org/protobuf v1.34.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3954136 --- /dev/null +++ b/go.sum @@ -0,0 +1,317 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go/auth v0.6.1 h1:T0Zw1XM5c1GlpN2HYr2s+m3vr1p2wy+8VN+Z1FKxW38= +cloud.google.com/go/auth v0.6.1/go.mod h1:eFHG7zDzbXHKmjJddFG/rBlcGp6t25SwRUiEQSlO4x4= +cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= +cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +cloud.google.com/go/iam v1.1.10 h1:ZSAr64oEhQSClwBL670MsJAW5/RLiC6kfw3Bqmd5ZDI= +cloud.google.com/go/iam v1.1.10/go.mod h1:iEgMq62sg8zx446GCaijmA2Miwg5o3UbO+nI47WHJps= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo= +github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.15 h1:YDexlvDRCA8ems2T5IP1xkMtOZ1uLJOCJdTr0igs5zo= +github.com/aws/aws-sdk-go-v2/credentials v1.17.15/go.mod h1:vxHggqW6hFNaeNC0WyXS3VdyjcV0a4KMUY4dKJ96buU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 h1:Fihjyd6DeNjcawBEGLH9dkIEUi6AdhucDKPE9nJ4QiY= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3/go.mod h1:opvUj3ismqSCxYc+m4WIjPL0ewZGtvp0ess7cKvBPOQ= +github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= +github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= +github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= +github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= +github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw= +github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E= +github.com/gin-contrib/gzip v1.0.1 h1:HQ8ENHODeLY7a4g1Au/46Z92bdGFl74OhxcZble9WJE= +github.com/gin-contrib/gzip v1.0.1/go.mod h1:njt428fdUNRvjuJf16tZMYZ2Yl+WQB53X5wmhDwXvC4= +github.com/gin-contrib/sessions v1.0.1 h1:3hsJyNs7v7N8OtelFmYXFrulAf6zSR7nW/putcPEHxI= +github.com/gin-contrib/sessions v1.0.1/go.mod h1:ouxSFM24/OgIud5MJYQJLpy6AwxQ5EYO9yLhbtObGkM= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0NglqmlZ4= +github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw= +github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= +github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= +github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= +github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= +github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= +github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= +github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBYGmXdxA= +github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= +github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o= +github.com/gorilla/context v1.1.2/go.mod h1:KDPwT9i/MeWHiLl90fuTgrt4/wPcv75vFAZLaOOcbxM= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= +github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY= +github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= +github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= +github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= +github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= +github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= +go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw= +go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= +go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= +go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI= +go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco= +go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI= +go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= +golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= +golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= +golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.187.0 h1:Mxs7VATVC2v7CY+7Xwm4ndkX71hpElcvx0D1Ji/p1eo= +google.golang.org/api v0.187.0/go.mod h1:KIHlTc4x7N7gKKuVsdmfBXN13yEEWXWFURWY6SBp2gk= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 h1:MuYw1wJzT+ZkybKfaOXKp5hJiZDn2iHaXRw0mRYdHSc= +google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4/go.mod h1:px9SlOOZBg1wM1zdnr8jEL4CNGUBZ+ZKYtNPApNQc4c= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d h1:k3zyW3BYYR30e8v3x0bTDdE9vpYFjZHK+HcyqkrppWk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= +google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= +google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= +gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= +gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= +gorm.io/driver/sqlite v1.5.1 h1:hYyrLkAWE71bcarJDPdZNTLWtr8XrSjOWyjUYI6xdL4= +gorm.io/driver/sqlite v1.5.1/go.mod h1:7MZZ2Z8bqyfSQA1gYEV6MagQWj3cpUkJj9Z+d1HEMEQ= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= +gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/main.go b/main.go new file mode 100644 index 0000000..35c2764 --- /dev/null +++ b/main.go @@ -0,0 +1,124 @@ +package main + +import ( + "embed" + "fmt" + "os" + "strconv" + + "github.com/gin-contrib/sessions" + "github.com/gin-contrib/sessions/cookie" + "github.com/gin-gonic/gin" + _ "github.com/joho/godotenv/autoload" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/client" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/i18n" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/middleware" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/router" +) + +//go:embed web/build/* +var buildFS embed.FS + +func main() { + common.Init() + logger.SetupLogger() + logger.SysLogf("One API %s started", common.Version) + + if os.Getenv("GIN_MODE") != gin.DebugMode { + gin.SetMode(gin.ReleaseMode) + } + if config.DebugEnabled { + logger.SysLog("running in debug mode") + } + + // Initialize SQL Database + model.InitDB() + model.InitLogDB() + + var err error + err = model.CreateRootAccountIfNeed() + if err != nil { + logger.FatalLog("database init error: " + err.Error()) + } + defer func() { + err := model.CloseDB() + if err != nil { + logger.FatalLog("failed to close database: " + err.Error()) + } + }() + + // Initialize Redis + err = common.InitRedisClient() + if err != nil { + logger.FatalLog("failed to initialize Redis: " + err.Error()) + } + + // Initialize options + model.InitOptionMap() + logger.SysLog(fmt.Sprintf("using theme %s", config.Theme)) + if common.RedisEnabled { + // for compatibility with old versions + config.MemoryCacheEnabled = true + } + if config.MemoryCacheEnabled { + logger.SysLog("memory cache enabled") + logger.SysLog(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency)) + model.InitChannelCache() + } + if config.MemoryCacheEnabled { + go model.SyncOptions(config.SyncFrequency) + go model.SyncChannelCache(config.SyncFrequency) + } + if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { + frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) + if err != nil { + logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) + } + go controller.AutomaticallyTestChannels(frequency) + } + if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { + config.BatchUpdateEnabled = true + logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") + model.InitBatchUpdater() + } + if config.EnableMetric { + logger.SysLog("metric enabled, will disable channel if too much request failed") + } + openai.InitTokenEncoders() + client.Init() + + // Initialize i18n + if err := i18n.Init(); err != nil { + logger.FatalLog("failed to initialize i18n: " + err.Error()) + } + + // Initialize HTTP server + server := gin.New() + server.Use(gin.Recovery()) + // This will cause SSE not to work!!! + //server.Use(gzip.Gzip(gzip.DefaultCompression)) + server.Use(middleware.RequestId()) + server.Use(middleware.Language()) + middleware.SetUpLogger(server) + // Initialize session store + store := cookie.NewStore([]byte(config.SessionSecret)) + server.Use(sessions.Sessions("session", store)) + + router.SetRouter(server, buildFS) + var port = os.Getenv("PORT") + if port == "" { + port = strconv.Itoa(*common.Port) + } + logger.SysLogf("server started on http://localhost:%s", port) + err = server.Run(":" + port) + if err != nil { + logger.FatalLog("failed to start HTTP server: " + err.Error()) + } +} diff --git a/middleware/auth.go b/middleware/auth.go new file mode 100644 index 0000000..e001983 --- /dev/null +++ b/middleware/auth.go @@ -0,0 +1,167 @@ +package middleware + +import ( + "fmt" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/blacklist" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/network" + "github.com/songquanpeng/one-api/model" + "net/http" + "strings" +) + +func authHelper(c *gin.Context, minRole int) { + session := sessions.Default(c) + username := session.Get("username") + role := session.Get("role") + id := session.Get("id") + status := session.Get("status") + if username == nil { + // Check access token + accessToken := c.Request.Header.Get("Authorization") + if accessToken == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "success": false, + "message": "无权进行此操作,未登录且未提供 access token", + }) + c.Abort() + return + } + user := model.ValidateAccessToken(accessToken) + if user != nil && user.Username != "" { + // Token is valid + username = user.Username + role = user.Role + id = user.Id + status = user.Status + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权进行此操作,access token 无效", + }) + c.Abort() + return + } + } + if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户已被封禁", + }) + session := sessions.Default(c) + session.Clear() + _ = session.Save() + c.Abort() + return + } + if role.(int) < minRole { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权进行此操作,权限不足", + }) + c.Abort() + return + } + c.Set("username", username) + c.Set("role", role) + c.Set("id", id) + c.Next() +} + +func UserAuth() func(c *gin.Context) { + return func(c *gin.Context) { + authHelper(c, model.RoleCommonUser) + } +} + +func AdminAuth() func(c *gin.Context) { + return func(c *gin.Context) { + authHelper(c, model.RoleAdminUser) + } +} + +func RootAuth() func(c *gin.Context) { + return func(c *gin.Context) { + authHelper(c, model.RoleRootUser) + } +} + +func TokenAuth() func(c *gin.Context) { + return func(c *gin.Context) { + ctx := c.Request.Context() + key := c.Request.Header.Get("Authorization") + key = strings.TrimPrefix(key, "Bearer ") + key = strings.TrimPrefix(key, "sk-") + parts := strings.Split(key, "-") + key = parts[0] + token, err := model.ValidateUserToken(key) + if err != nil { + abortWithMessage(c, http.StatusUnauthorized, err.Error()) + return + } + if token.Subnet != nil && *token.Subnet != "" { + if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) { + abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP())) + return + } + } + userEnabled, err := model.CacheIsUserEnabled(token.UserId) + if err != nil { + abortWithMessage(c, http.StatusInternalServerError, err.Error()) + return + } + if !userEnabled || blacklist.IsUserBanned(token.UserId) { + abortWithMessage(c, http.StatusForbidden, "用户已被封禁") + return + } + requestModel, err := getRequestModel(c) + if err != nil && shouldCheckModel(c) { + abortWithMessage(c, http.StatusBadRequest, err.Error()) + return + } + c.Set(ctxkey.RequestModel, requestModel) + if token.Models != nil && *token.Models != "" { + c.Set(ctxkey.AvailableModels, *token.Models) + if requestModel != "" && !isModelInList(requestModel, *token.Models) { + abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) + return + } + } + c.Set(ctxkey.Id, token.UserId) + c.Set(ctxkey.TokenId, token.Id) + c.Set(ctxkey.TokenName, token.Name) + if len(parts) > 1 { + if model.IsAdmin(token.UserId) { + c.Set(ctxkey.SpecificChannelId, parts[1]) + } else { + abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") + return + } + } + + // set channel id for proxy relay + if channelId := c.Param("channelid"); channelId != "" { + c.Set(ctxkey.SpecificChannelId, channelId) + } + + c.Next() + } +} + +func shouldCheckModel(c *gin.Context) bool { + if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { + return true + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { + return true + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images") { + return true + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + return true + } + return false +} diff --git a/middleware/cache.go b/middleware/cache.go new file mode 100644 index 0000000..979734a --- /dev/null +++ b/middleware/cache.go @@ -0,0 +1,16 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" +) + +func Cache() func(c *gin.Context) { + return func(c *gin.Context) { + if c.Request.RequestURI == "/" { + c.Header("Cache-Control", "no-cache") + } else { + c.Header("Cache-Control", "max-age=604800") // one week + } + c.Next() + } +} diff --git a/middleware/cors.go b/middleware/cors.go new file mode 100644 index 0000000..d2a109a --- /dev/null +++ b/middleware/cors.go @@ -0,0 +1,15 @@ +package middleware + +import ( + "github.com/gin-contrib/cors" + "github.com/gin-gonic/gin" +) + +func CORS() gin.HandlerFunc { + config := cors.DefaultConfig() + config.AllowAllOrigins = true + config.AllowCredentials = true + config.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"} + config.AllowHeaders = []string{"*"} + return cors.New(config) +} diff --git a/middleware/distributor.go b/middleware/distributor.go new file mode 100644 index 0000000..58ae055 --- /dev/null +++ b/middleware/distributor.go @@ -0,0 +1,102 @@ +package middleware + +import ( + "fmt" + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/channeltype" +) + +type ModelRequest struct { + Model string `json:"model" form:"model"` +} + +func Distribute() func(c *gin.Context) { + return func(c *gin.Context) { + ctx := c.Request.Context() + userId := c.GetInt(ctxkey.Id) + userGroup, _ := model.CacheGetUserGroup(userId) + c.Set(ctxkey.Group, userGroup) + var requestModel string + var channel *model.Channel + channelId, ok := c.Get(ctxkey.SpecificChannelId) + if ok { + id, err := strconv.Atoi(channelId.(string)) + if err != nil { + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") + return + } + channel, err = model.GetChannelById(id, true) + if err != nil { + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") + return + } + if channel.Status != model.ChannelStatusEnabled { + abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") + return + } + } else { + requestModel = c.GetString(ctxkey.RequestModel) + var err error + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) + if err != nil { + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel) + if channel != nil { + logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) + message = "数据库一致性已被破坏,请联系管理员" + } + abortWithMessage(c, http.StatusServiceUnavailable, message) + return + } + } + logger.Debugf(ctx, "user id %d, user group: %s, request model: %s, using channel #%d", userId, userGroup, requestModel, channel.Id) + SetupContextForSelectedChannel(c, channel, requestModel) + c.Next() + } +} + +func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { + c.Set(ctxkey.Channel, channel.Type) + c.Set(ctxkey.ChannelId, channel.Id) + c.Set(ctxkey.ChannelName, channel.Name) + if channel.SystemPrompt != nil && *channel.SystemPrompt != "" { + c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt) + } + c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) + c.Set(ctxkey.OriginalModel, modelName) // for retry + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + c.Set(ctxkey.BaseURL, channel.GetBaseURL()) + cfg, _ := channel.LoadConfig() + // this is for backward compatibility + if channel.Other != nil { + switch channel.Type { + case channeltype.Azure: + if cfg.APIVersion == "" { + cfg.APIVersion = *channel.Other + } + case channeltype.Xunfei: + if cfg.APIVersion == "" { + cfg.APIVersion = *channel.Other + } + case channeltype.Gemini: + if cfg.APIVersion == "" { + cfg.APIVersion = *channel.Other + } + case channeltype.AIProxyLibrary: + if cfg.LibraryID == "" { + cfg.LibraryID = *channel.Other + } + case channeltype.Ali: + if cfg.Plugin == "" { + cfg.Plugin = *channel.Other + } + } + } + c.Set(ctxkey.Config, cfg) +} diff --git a/middleware/gzip.go b/middleware/gzip.go new file mode 100644 index 0000000..4d4ce0c --- /dev/null +++ b/middleware/gzip.go @@ -0,0 +1,27 @@ +package middleware + +import ( + "compress/gzip" + "github.com/gin-gonic/gin" + "io" + "net/http" +) + +func GzipDecodeMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + if c.GetHeader("Content-Encoding") == "gzip" { + gzipReader, err := gzip.NewReader(c.Request.Body) + if err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + defer gzipReader.Close() + + // Replace the request body with the decompressed data + c.Request.Body = io.NopCloser(gzipReader) + } + + // Continue processing the request + c.Next() + } +} diff --git a/middleware/language.go b/middleware/language.go new file mode 100644 index 0000000..c8585ba --- /dev/null +++ b/middleware/language.go @@ -0,0 +1,25 @@ +package middleware + +import ( + "strings" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/i18n" +) + +func Language() gin.HandlerFunc { + return func(c *gin.Context) { + lang := c.GetHeader("Accept-Language") + if lang == "" { + lang = "en" + } + if strings.HasPrefix(strings.ToLower(lang), "zh") { + lang = "zh-CN" + } else { + lang = "en" + } + c.Set(i18n.ContextKey, lang) + c.Next() + } +} diff --git a/middleware/logger.go b/middleware/logger.go new file mode 100644 index 0000000..191364f --- /dev/null +++ b/middleware/logger.go @@ -0,0 +1,25 @@ +package middleware + +import ( + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" +) + +func SetUpLogger(server *gin.Engine) { + server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { + var requestID string + if param.Keys != nil { + requestID = param.Keys[helper.RequestIdKey].(string) + } + return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", + param.TimeStamp.Format("2006/01/02 - 15:04:05"), + requestID, + param.StatusCode, + param.Latency, + param.ClientIP, + param.Method, + param.Path, + ) + })) +} diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go new file mode 100644 index 0000000..63d7d54 --- /dev/null +++ b/middleware/rate-limit.go @@ -0,0 +1,111 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" +) + +var timeFormat = "2006-01-02T15:04:05.000Z" + +var inMemoryRateLimiter common.InMemoryRateLimiter + +func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { + ctx := context.Background() + rdb := common.RDB + key := "rateLimit:" + mark + c.ClientIP() + listLength, err := rdb.LLen(ctx, key).Result() + if err != nil { + fmt.Println(err.Error()) + c.Status(http.StatusInternalServerError) + c.Abort() + return + } + if listLength < int64(maxRequestNum) { + rdb.LPush(ctx, key, time.Now().Format(timeFormat)) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) + } else { + oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() + oldTime, err := time.Parse(timeFormat, oldTimeStr) + if err != nil { + fmt.Println(err) + c.Status(http.StatusInternalServerError) + c.Abort() + return + } + nowTimeStr := time.Now().Format(timeFormat) + nowTime, err := time.Parse(timeFormat, nowTimeStr) + if err != nil { + fmt.Println(err) + c.Status(http.StatusInternalServerError) + c.Abort() + return + } + // time.Since will return negative number! + // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows + if int64(nowTime.Sub(oldTime).Seconds()) < duration { + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) + c.Status(http.StatusTooManyRequests) + c.Abort() + return + } else { + rdb.LPush(ctx, key, time.Now().Format(timeFormat)) + rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) + } + } +} + +func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { + key := mark + c.ClientIP() + if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) { + c.Status(http.StatusTooManyRequests) + c.Abort() + return + } +} + +func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) { + if maxRequestNum == 0 || config.DebugEnabled { + return func(c *gin.Context) { + c.Next() + } + } + if common.RedisEnabled { + return func(c *gin.Context) { + redisRateLimiter(c, maxRequestNum, duration, mark) + } + } else { + // It's safe to call multi times. + inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration) + return func(c *gin.Context) { + memoryRateLimiter(c, maxRequestNum, duration, mark) + } + } +} + +func GlobalWebRateLimit() func(c *gin.Context) { + return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW") +} + +func GlobalAPIRateLimit() func(c *gin.Context) { + return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA") +} + +func CriticalRateLimit() func(c *gin.Context) { + return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT") +} + +func DownloadRateLimit() func(c *gin.Context) { + return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW") +} + +func UploadRateLimit() func(c *gin.Context) { + return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP") +} diff --git a/middleware/recover.go b/middleware/recover.go new file mode 100644 index 0000000..cfc3f82 --- /dev/null +++ b/middleware/recover.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/logger" + "net/http" + "runtime/debug" +) + +func RelayPanicRecover() gin.HandlerFunc { + return func(c *gin.Context) { + defer func() { + if err := recover(); err != nil { + ctx := c.Request.Context() + logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err)) + logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path)) + body, _ := common.GetRequestBody(c) + logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body))) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err), + "type": "one_api_panic", + }, + }) + c.Abort() + } + }() + c.Next() + } +} diff --git a/middleware/request-id.go b/middleware/request-id.go new file mode 100644 index 0000000..973a63f --- /dev/null +++ b/middleware/request-id.go @@ -0,0 +1,18 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/helper" +) + +func RequestId() func(c *gin.Context) { + return func(c *gin.Context) { + id := helper.GenRequestID() + c.Set(helper.RequestIdKey, id) + ctx := helper.SetRequestID(c.Request.Context(), id) + c.Request = c.Request.WithContext(ctx) + c.Header(helper.RequestIdKey, id) + c.Next() + } +} diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go new file mode 100644 index 0000000..403bcb3 --- /dev/null +++ b/middleware/turnstile-check.go @@ -0,0 +1,81 @@ +package middleware + +import ( + "encoding/json" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "net/http" + "net/url" +) + +type turnstileCheckResponse struct { + Success bool `json:"success"` +} + +func TurnstileCheck() gin.HandlerFunc { + return func(c *gin.Context) { + if config.TurnstileCheckEnabled { + session := sessions.Default(c) + turnstileChecked := session.Get("turnstile") + if turnstileChecked != nil { + c.Next() + return + } + response := c.Query("turnstile") + if response == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Turnstile token 为空", + }) + c.Abort() + return + } + rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ + "secret": {config.TurnstileSecretKey}, + "response": {response}, + "remoteip": {c.ClientIP()}, + }) + if err != nil { + logger.SysError(err.Error()) + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + c.Abort() + return + } + defer rawRes.Body.Close() + var res turnstileCheckResponse + err = json.NewDecoder(rawRes.Body).Decode(&res) + if err != nil { + logger.SysError(err.Error()) + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + c.Abort() + return + } + if !res.Success { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Turnstile 校验失败,请刷新重试!", + }) + c.Abort() + return + } + session.Set("turnstile", true) + err = session.Save() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": "无法保存会话信息,请重试", + "success": false, + }) + return + } + } + c.Next() + } +} diff --git a/middleware/utils.go b/middleware/utils.go new file mode 100644 index 0000000..4d2f809 --- /dev/null +++ b/middleware/utils.go @@ -0,0 +1,60 @@ +package middleware + +import ( + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "strings" +) + +func abortWithMessage(c *gin.Context, statusCode int, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "message": helper.MessageWithRequestId(message, c.GetString(helper.RequestIdKey)), + "type": "one_api_error", + }, + }) + c.Abort() + logger.Error(c.Request.Context(), message) +} + +func getRequestModel(c *gin.Context) (string, error) { + var modelRequest ModelRequest + err := common.UnmarshalBodyReusable(c, &modelRequest) + if err != nil { + return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + if modelRequest.Model == "" { + modelRequest.Model = "text-moderation-stable" + } + } + if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + if modelRequest.Model == "" { + modelRequest.Model = c.Param("model") + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + if modelRequest.Model == "" { + modelRequest.Model = "dall-e-2" + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + if modelRequest.Model == "" { + modelRequest.Model = "whisper-1" + } + } + return modelRequest.Model, nil +} + +func isModelInList(modelName string, models string) bool { + modelList := strings.Split(models, ",") + for _, model := range modelList { + if modelName == model { + return true + } + } + return false +} diff --git a/model/ability.go b/model/ability.go new file mode 100644 index 0000000..5cfb994 --- /dev/null +++ b/model/ability.go @@ -0,0 +1,112 @@ +package model + +import ( + "context" + "sort" + "strings" + + "gorm.io/gorm" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/utils" +) + +type Ability struct { + Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"` + Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` + ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` + Enabled bool `json:"enabled"` + Priority *int64 `json:"priority" gorm:"bigint;default:0;index"` +} + +func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { + ability := Ability{} + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + + var err error = nil + var channelQuery *gorm.DB + if ignoreFirstPriority { + channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) + } else { + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) + channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) + } + if common.UsingSQLite || common.UsingPostgreSQL { + err = channelQuery.Order("RANDOM()").First(&ability).Error + } else { + err = channelQuery.Order("RAND()").First(&ability).Error + } + if err != nil { + return nil, err + } + channel := Channel{} + channel.Id = ability.ChannelId + err = DB.First(&channel, "id = ?", ability.ChannelId).Error + return &channel, err +} + +func (channel *Channel) AddAbilities() error { + models_ := strings.Split(channel.Models, ",") + models_ = utils.DeDuplication(models_) + groups_ := strings.Split(channel.Group, ",") + abilities := make([]Ability, 0, len(models_)) + for _, model := range models_ { + for _, group := range groups_ { + ability := Ability{ + Group: group, + Model: model, + ChannelId: channel.Id, + Enabled: channel.Status == ChannelStatusEnabled, + Priority: channel.Priority, + } + abilities = append(abilities, ability) + } + } + return DB.Create(&abilities).Error +} + +func (channel *Channel) DeleteAbilities() error { + return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error +} + +// UpdateAbilities updates abilities of this channel. +// Make sure the channel is completed before calling this function. +func (channel *Channel) UpdateAbilities() error { + // A quick and dirty way to update abilities + // First delete all abilities of this channel + err := channel.DeleteAbilities() + if err != nil { + return err + } + // Then add new abilities + err = channel.AddAbilities() + if err != nil { + return err + } + return nil +} + +func UpdateAbilityStatus(channelId int, status bool) error { + return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error +} + +func GetGroupModels(ctx context.Context, group string) ([]string, error) { + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + var models []string + err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error + if err != nil { + return nil, err + } + sort.Strings(models) + return models, err +} diff --git a/model/cache.go b/model/cache.go new file mode 100644 index 0000000..cfb0f8a --- /dev/null +++ b/model/cache.go @@ -0,0 +1,255 @@ +package model + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" + "math/rand" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +var ( + TokenCacheSeconds = config.SyncFrequency + UserId2GroupCacheSeconds = config.SyncFrequency + UserId2QuotaCacheSeconds = config.SyncFrequency + UserId2StatusCacheSeconds = config.SyncFrequency + GroupModelsCacheSeconds = config.SyncFrequency +) + +func CacheGetTokenByKey(key string) (*Token, error) { + keyCol := "`key`" + if common.UsingPostgreSQL { + keyCol = `"key"` + } + var token Token + if !common.RedisEnabled { + err := DB.Where(keyCol+" = ?", key).First(&token).Error + return &token, err + } + tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) + if err != nil { + err := DB.Where(keyCol+" = ?", key).First(&token).Error + if err != nil { + return nil, err + } + jsonBytes, err := json.Marshal(token) + if err != nil { + return nil, err + } + err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) + if err != nil { + logger.SysError("Redis set token error: " + err.Error()) + } + return &token, nil + } + err = json.Unmarshal([]byte(tokenObjectString), &token) + return &token, err +} + +func CacheGetUserGroup(id int) (group string, err error) { + if !common.RedisEnabled { + return GetUserGroup(id) + } + group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id)) + if err != nil { + group, err = GetUserGroup(id) + if err != nil { + return "", err + } + err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) + if err != nil { + logger.SysError("Redis set user group error: " + err.Error()) + } + } + return group, err +} + +func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) { + quota, err = GetUserQuota(id) + if err != nil { + return 0, err + } + err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) + if err != nil { + logger.Error(ctx, "Redis set user quota error: "+err.Error()) + } + return +} + +func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) { + if !common.RedisEnabled { + return GetUserQuota(id) + } + quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) + if err != nil { + return fetchAndUpdateUserQuota(ctx, id) + } + quota, err = strconv.ParseInt(quotaString, 10, 64) + if err != nil { + return 0, nil + } + if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db + logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id) + return fetchAndUpdateUserQuota(ctx, id) + } + return quota, nil +} + +func CacheUpdateUserQuota(ctx context.Context, id int) error { + if !common.RedisEnabled { + return nil + } + quota, err := CacheGetUserQuota(ctx, id) + if err != nil { + return err + } + err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) + return err +} + +func CacheDecreaseUserQuota(id int, quota int64) error { + if !common.RedisEnabled { + return nil + } + err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota)) + return err +} + +func CacheIsUserEnabled(userId int) (bool, error) { + if !common.RedisEnabled { + return IsUserEnabled(userId) + } + enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) + if err == nil { + return enabled == "1", nil + } + + userEnabled, err := IsUserEnabled(userId) + if err != nil { + return false, err + } + enabled = "0" + if userEnabled { + enabled = "1" + } + err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) + if err != nil { + logger.SysError("Redis set user enabled error: " + err.Error()) + } + return userEnabled, err +} + +func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { + if !common.RedisEnabled { + return GetGroupModels(ctx, group) + } + modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group)) + if err == nil { + return strings.Split(modelsStr, ","), nil + } + models, err := GetGroupModels(ctx, group) + if err != nil { + return nil, err + } + err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second) + if err != nil { + logger.SysError("Redis set group models error: " + err.Error()) + } + return models, nil +} + +var group2model2channels map[string]map[string][]*Channel +var channelSyncLock sync.RWMutex + +func InitChannelCache() { + newChannelId2channel := make(map[int]*Channel) + var channels []*Channel + DB.Where("status = ?", ChannelStatusEnabled).Find(&channels) + for _, channel := range channels { + newChannelId2channel[channel.Id] = channel + } + var abilities []*Ability + DB.Find(&abilities) + groups := make(map[string]bool) + for _, ability := range abilities { + groups[ability.Group] = true + } + newGroup2model2channels := make(map[string]map[string][]*Channel) + for group := range groups { + newGroup2model2channels[group] = make(map[string][]*Channel) + } + for _, channel := range channels { + groups := strings.Split(channel.Group, ",") + for _, group := range groups { + models := strings.Split(channel.Models, ",") + for _, model := range models { + if _, ok := newGroup2model2channels[group][model]; !ok { + newGroup2model2channels[group][model] = make([]*Channel, 0) + } + newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel) + } + } + } + + // sort by priority + for group, model2channels := range newGroup2model2channels { + for model, channels := range model2channels { + sort.Slice(channels, func(i, j int) bool { + return channels[i].GetPriority() > channels[j].GetPriority() + }) + newGroup2model2channels[group][model] = channels + } + } + + channelSyncLock.Lock() + group2model2channels = newGroup2model2channels + channelSyncLock.Unlock() + logger.SysLog("channels synced from database") +} + +func SyncChannelCache(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Second) + logger.SysLog("syncing channels from database") + InitChannelCache() + } +} + +func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { + if !config.MemoryCacheEnabled { + return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority) + } + channelSyncLock.RLock() + defer channelSyncLock.RUnlock() + channels := group2model2channels[group][model] + if len(channels) == 0 { + return nil, errors.New("channel not found") + } + endIdx := len(channels) + // choose by priority + firstChannel := channels[0] + if firstChannel.GetPriority() > 0 { + for i := range channels { + if channels[i].GetPriority() != firstChannel.GetPriority() { + endIdx = i + break + } + } + } + idx := rand.Intn(endIdx) + if ignoreFirstPriority { + if endIdx < len(channels) { // which means there are more than one priority + idx = random.RandRange(endIdx, len(channels)) + } + } + return channels[idx], nil +} diff --git a/model/channel.go b/model/channel.go new file mode 100644 index 0000000..4b0f4b0 --- /dev/null +++ b/model/channel.go @@ -0,0 +1,224 @@ +package model + +import ( + "encoding/json" + "fmt" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "gorm.io/gorm" +) + +const ( + ChannelStatusUnknown = 0 + ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! + ChannelStatusManuallyDisabled = 2 // also don't use 0 + ChannelStatusAutoDisabled = 3 +) + +type Channel struct { + Id int `json:"id"` + Type int `json:"type" gorm:"default:0"` + Key string `json:"key" gorm:"type:text"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index"` + Weight *uint `json:"weight" gorm:"default:0"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + TestTime int64 `json:"test_time" gorm:"bigint"` + ResponseTime int `json:"response_time"` // in milliseconds + BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"` + Other *string `json:"other"` // DEPRECATED: please save config to field Config + Balance float64 `json:"balance"` // in USD + BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` + Models string `json:"models"` + Group string `json:"group" gorm:"type:varchar(32);default:'default'"` + UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` + ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` + Priority *int64 `json:"priority" gorm:"bigint;default:0"` + Config string `json:"config"` + SystemPrompt *string `json:"system_prompt" gorm:"type:text"` +} + +type ChannelConfig struct { + Region string `json:"region,omitempty"` + SK string `json:"sk,omitempty"` + AK string `json:"ak,omitempty"` + UserID string `json:"user_id,omitempty"` + APIVersion string `json:"api_version,omitempty"` + LibraryID string `json:"library_id,omitempty"` + Plugin string `json:"plugin,omitempty"` + VertexAIProjectID string `json:"vertex_ai_project_id,omitempty"` + VertexAIADC string `json:"vertex_ai_adc,omitempty"` +} + +func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { + var channels []*Channel + var err error + switch scope { + case "all": + err = DB.Order("id desc").Find(&channels).Error + case "disabled": + err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error + default: + err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error + } + return channels, err +} + +func SearchChannels(keyword string) (channels []*Channel, err error) { + err = DB.Omit("key").Where("id = ? or name LIKE ?", helper.String2Int(keyword), keyword+"%").Find(&channels).Error + return channels, err +} + +func GetChannelById(id int, selectAll bool) (*Channel, error) { + channel := Channel{Id: id} + var err error = nil + if selectAll { + err = DB.First(&channel, "id = ?", id).Error + } else { + err = DB.Omit("key").First(&channel, "id = ?", id).Error + } + return &channel, err +} + +func BatchInsertChannels(channels []Channel) error { + var err error + err = DB.Create(&channels).Error + if err != nil { + return err + } + for _, channel_ := range channels { + err = channel_.AddAbilities() + if err != nil { + return err + } + } + return nil +} + +func (channel *Channel) GetPriority() int64 { + if channel.Priority == nil { + return 0 + } + return *channel.Priority +} + +func (channel *Channel) GetBaseURL() string { + if channel.BaseURL == nil { + return "" + } + return *channel.BaseURL +} + +func (channel *Channel) GetModelMapping() map[string]string { + if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" { + return nil + } + modelMapping := make(map[string]string) + err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping) + if err != nil { + logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error())) + return nil + } + return modelMapping +} + +func (channel *Channel) Insert() error { + var err error + err = DB.Create(channel).Error + if err != nil { + return err + } + err = channel.AddAbilities() + return err +} + +func (channel *Channel) Update() error { + var err error + err = DB.Model(channel).Updates(channel).Error + if err != nil { + return err + } + DB.Model(channel).First(channel, "id = ?", channel.Id) + err = channel.UpdateAbilities() + return err +} + +func (channel *Channel) UpdateResponseTime(responseTime int64) { + err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ + TestTime: helper.GetTimestamp(), + ResponseTime: int(responseTime), + }).Error + if err != nil { + logger.SysError("failed to update response time: " + err.Error()) + } +} + +func (channel *Channel) UpdateBalance(balance float64) { + err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ + BalanceUpdatedTime: helper.GetTimestamp(), + Balance: balance, + }).Error + if err != nil { + logger.SysError("failed to update balance: " + err.Error()) + } +} + +func (channel *Channel) Delete() error { + var err error + err = DB.Delete(channel).Error + if err != nil { + return err + } + err = channel.DeleteAbilities() + return err +} + +func (channel *Channel) LoadConfig() (ChannelConfig, error) { + var cfg ChannelConfig + if channel.Config == "" { + return cfg, nil + } + err := json.Unmarshal([]byte(channel.Config), &cfg) + if err != nil { + return cfg, err + } + return cfg, nil +} + +func UpdateChannelStatusById(id int, status int) { + err := UpdateAbilityStatus(id, status == ChannelStatusEnabled) + if err != nil { + logger.SysError("failed to update ability status: " + err.Error()) + } + err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error + if err != nil { + logger.SysError("failed to update channel status: " + err.Error()) + } +} + +func UpdateChannelUsedQuota(id int, quota int64) { + if config.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) + return + } + updateChannelUsedQuota(id, quota) +} + +func updateChannelUsedQuota(id int, quota int64) { + err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error + if err != nil { + logger.SysError("failed to update channel used quota: " + err.Error()) + } +} + +func DeleteChannelByStatus(status int64) (int64, error) { + result := DB.Where("status = ?", status).Delete(&Channel{}) + return result.RowsAffected, result.Error +} + +func DeleteDisabledChannel() (int64, error) { + result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{}) + return result.RowsAffected, result.Error +} diff --git a/model/log.go b/model/log.go new file mode 100644 index 0000000..2c92065 --- /dev/null +++ b/model/log.go @@ -0,0 +1,251 @@ +package model + +import ( + "context" + "fmt" + + "gorm.io/gorm" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" +) + +type Log struct { + Id int `json:"id"` + UserId int `json:"user_id" gorm:"index"` + CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"` + Type int `json:"type" gorm:"index:idx_created_at_type"` + Content string `json:"content"` + Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` + TokenName string `json:"token_name" gorm:"index;default:''"` + ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` + Quota int `json:"quota" gorm:"default:0"` + PromptTokens int `json:"prompt_tokens" gorm:"default:0"` + CompletionTokens int `json:"completion_tokens" gorm:"default:0"` + ChannelId int `json:"channel" gorm:"index"` + RequestId string `json:"request_id" gorm:"default:''"` + ElapsedTime int64 `json:"elapsed_time" gorm:"default:0"` // unit is ms + IsStream bool `json:"is_stream" gorm:"default:false"` + SystemPromptReset bool `json:"system_prompt_reset" gorm:"default:false"` +} + +const ( + LogTypeUnknown = iota + LogTypeTopup + LogTypeConsume + LogTypeManage + LogTypeSystem + LogTypeTest +) + +func recordLogHelper(ctx context.Context, log *Log) { + requestId := helper.GetRequestID(ctx) + log.RequestId = requestId + err := LOG_DB.Create(log).Error + if err != nil { + logger.Error(ctx, "failed to record log: "+err.Error()) + return + } + logger.Infof(ctx, "record log: %+v", log) +} + +func RecordLog(ctx context.Context, userId int, logType int, content string) { + if logType == LogTypeConsume && !config.LogConsumeEnabled { + return + } + log := &Log{ + UserId: userId, + Username: GetUsernameById(userId), + CreatedAt: helper.GetTimestamp(), + Type: logType, + Content: content, + } + recordLogHelper(ctx, log) +} + +func RecordTopupLog(ctx context.Context, userId int, content string, quota int) { + log := &Log{ + UserId: userId, + Username: GetUsernameById(userId), + CreatedAt: helper.GetTimestamp(), + Type: LogTypeTopup, + Content: content, + Quota: quota, + } + recordLogHelper(ctx, log) +} + +func RecordConsumeLog(ctx context.Context, log *Log) { + if !config.LogConsumeEnabled { + return + } + log.Username = GetUsernameById(log.UserId) + log.CreatedAt = helper.GetTimestamp() + log.Type = LogTypeConsume + recordLogHelper(ctx, log) +} + +func RecordTestLog(ctx context.Context, log *Log) { + log.CreatedAt = helper.GetTimestamp() + log.Type = LogTypeTest + recordLogHelper(ctx, log) +} + +func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { + var tx *gorm.DB + if logType == LogTypeUnknown { + tx = LOG_DB + } else { + tx = LOG_DB.Where("type = ?", logType) + } + if modelName != "" { + tx = tx.Where("model_name = ?", modelName) + } + if username != "" { + tx = tx.Where("username = ?", username) + } + if tokenName != "" { + tx = tx.Where("token_name = ?", tokenName) + } + if startTimestamp != 0 { + tx = tx.Where("created_at >= ?", startTimestamp) + } + if endTimestamp != 0 { + tx = tx.Where("created_at <= ?", endTimestamp) + } + if channel != 0 { + tx = tx.Where("channel_id = ?", channel) + } + err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error + return logs, err +} + +func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { + var tx *gorm.DB + if logType == LogTypeUnknown { + tx = LOG_DB.Where("user_id = ?", userId) + } else { + tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType) + } + if modelName != "" { + tx = tx.Where("model_name = ?", modelName) + } + if tokenName != "" { + tx = tx.Where("token_name = ?", tokenName) + } + if startTimestamp != 0 { + tx = tx.Where("created_at >= ?", startTimestamp) + } + if endTimestamp != 0 { + tx = tx.Where("created_at <= ?", endTimestamp) + } + err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error + return logs, err +} + +func SearchAllLogs(keyword string) (logs []*Log, err error) { + err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error + return logs, err +} + +func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { + err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error + return logs, err +} + +func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { + ifnull := "ifnull" + if common.UsingPostgreSQL { + ifnull = "COALESCE" + } + tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull)) + if username != "" { + tx = tx.Where("username = ?", username) + } + if tokenName != "" { + tx = tx.Where("token_name = ?", tokenName) + } + if startTimestamp != 0 { + tx = tx.Where("created_at >= ?", startTimestamp) + } + if endTimestamp != 0 { + tx = tx.Where("created_at <= ?", endTimestamp) + } + if modelName != "" { + tx = tx.Where("model_name = ?", modelName) + } + if channel != 0 { + tx = tx.Where("channel_id = ?", channel) + } + tx.Where("type = ?", LogTypeConsume).Scan("a) + return quota +} + +func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { + ifnull := "ifnull" + if common.UsingPostgreSQL { + ifnull = "COALESCE" + } + tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull)) + if username != "" { + tx = tx.Where("username = ?", username) + } + if tokenName != "" { + tx = tx.Where("token_name = ?", tokenName) + } + if startTimestamp != 0 { + tx = tx.Where("created_at >= ?", startTimestamp) + } + if endTimestamp != 0 { + tx = tx.Where("created_at <= ?", endTimestamp) + } + if modelName != "" { + tx = tx.Where("model_name = ?", modelName) + } + tx.Where("type = ?", LogTypeConsume).Scan(&token) + return token +} + +func DeleteOldLog(targetTimestamp int64) (int64, error) { + result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) + return result.RowsAffected, result.Error +} + +type LogStatistic struct { + Day string `gorm:"column:day"` + ModelName string `gorm:"column:model_name"` + RequestCount int `gorm:"column:request_count"` + Quota int `gorm:"column:quota"` + PromptTokens int `gorm:"column:prompt_tokens"` + CompletionTokens int `gorm:"column:completion_tokens"` +} + +func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatistic, err error) { + groupSelect := "DATE_FORMAT(FROM_UNIXTIME(created_at), '%Y-%m-%d') as day" + + if common.UsingPostgreSQL { + groupSelect = "TO_CHAR(date_trunc('day', to_timestamp(created_at)), 'YYYY-MM-DD') as day" + } + + if common.UsingSQLite { + groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day" + } + + err = LOG_DB.Raw(` + SELECT `+groupSelect+`, + model_name, count(1) as request_count, + sum(quota) as quota, + sum(prompt_tokens) as prompt_tokens, + sum(completion_tokens) as completion_tokens + FROM logs + WHERE type=2 + AND user_id= ? + AND created_at BETWEEN ? AND ? + GROUP BY day, model_name + ORDER BY day, model_name + `, userId, start, end).Scan(&LogStatistics).Error + + return LogStatistics, err +} diff --git a/model/main.go b/model/main.go new file mode 100644 index 0000000..72e271a --- /dev/null +++ b/model/main.go @@ -0,0 +1,237 @@ +package model + +import ( + "database/sql" + "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/env" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "os" + "strings" + "time" +) + +var DB *gorm.DB +var LOG_DB *gorm.DB + +func CreateRootAccountIfNeed() error { + var user User + //if user.Status != util.UserStatusEnabled { + if err := DB.First(&user).Error; err != nil { + logger.SysLog("no user exists, creating a root user for you: username is root, password is 123456") + hashedPassword, err := common.Password2Hash("123456") + if err != nil { + return err + } + accessToken := random.GetUUID() + if config.InitialRootAccessToken != "" { + accessToken = config.InitialRootAccessToken + } + rootUser := User{ + Username: "root", + Password: hashedPassword, + Role: RoleRootUser, + Status: UserStatusEnabled, + DisplayName: "Root User", + AccessToken: accessToken, + Quota: 500000000000000, + } + DB.Create(&rootUser) + if config.InitialRootToken != "" { + logger.SysLog("creating initial root token as requested") + token := Token{ + Id: 1, + UserId: rootUser.Id, + Key: config.InitialRootToken, + Status: TokenStatusEnabled, + Name: "Initial Root Token", + CreatedTime: helper.GetTimestamp(), + AccessedTime: helper.GetTimestamp(), + ExpiredTime: -1, + RemainQuota: 500000000000000, + UnlimitedQuota: true, + } + DB.Create(&token) + } + } + return nil +} + +func chooseDB(envName string) (*gorm.DB, error) { + dsn := os.Getenv(envName) + + switch { + case strings.HasPrefix(dsn, "postgres://"): + // Use PostgreSQL + return openPostgreSQL(dsn) + case dsn != "": + // Use MySQL + return openMySQL(dsn) + default: + // Use SQLite + return openSQLite() + } +} + +func openPostgreSQL(dsn string) (*gorm.DB, error) { + logger.SysLog("using PostgreSQL as database") + common.UsingPostgreSQL = true + return gorm.Open(postgres.New(postgres.Config{ + DSN: dsn, + PreferSimpleProtocol: true, // disables implicit prepared statement usage + }), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) +} + +func openMySQL(dsn string) (*gorm.DB, error) { + logger.SysLog("using MySQL as database") + common.UsingMySQL = true + return gorm.Open(mysql.Open(dsn), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) +} + +func openSQLite() (*gorm.DB, error) { + logger.SysLog("SQL_DSN not set, using SQLite as database") + common.UsingSQLite = true + dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout) + return gorm.Open(sqlite.Open(dsn), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) +} + +func InitDB() { + var err error + DB, err = chooseDB("SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize database: " + err.Error()) + return + } + + sqlDB := setDBConns(DB) + + if !config.IsMasterNode { + return + } + + if common.UsingMySQL { + _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded + } + + logger.SysLog("database migration started") + if err = migrateDB(); err != nil { + logger.FatalLog("failed to migrate database: " + err.Error()) + return + } + logger.SysLog("database migrated") +} + +func migrateDB() error { + var err error + if err = DB.AutoMigrate(&Channel{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Token{}); err != nil { + return err + } + if err = DB.AutoMigrate(&User{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Option{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Redemption{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Ability{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Log{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Channel{}); err != nil { + return err + } + return nil +} + +func InitLogDB() { + if os.Getenv("LOG_SQL_DSN") == "" { + LOG_DB = DB + return + } + + logger.SysLog("using secondary database for table logs") + var err error + LOG_DB, err = chooseDB("LOG_SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize secondary database: " + err.Error()) + return + } + + setDBConns(LOG_DB) + + if !config.IsMasterNode { + return + } + + logger.SysLog("secondary database migration started") + err = migrateLOGDB() + if err != nil { + logger.FatalLog("failed to migrate secondary database: " + err.Error()) + return + } + logger.SysLog("secondary database migrated") +} + +func migrateLOGDB() error { + var err error + if err = LOG_DB.AutoMigrate(&Log{}); err != nil { + return err + } + return nil +} + +func setDBConns(db *gorm.DB) *sql.DB { + if config.DebugSQLEnabled { + db = db.Debug() + } + + sqlDB, err := db.DB() + if err != nil { + logger.FatalLog("failed to connect database: " + err.Error()) + return nil + } + + sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) + return sqlDB +} + +func closeDB(db *gorm.DB) error { + sqlDB, err := db.DB() + if err != nil { + return err + } + err = sqlDB.Close() + return err +} + +func CloseDB() error { + if LOG_DB != DB { + err := closeDB(LOG_DB) + if err != nil { + return err + } + } + return closeDB(DB) +} diff --git a/model/option.go b/model/option.go new file mode 100644 index 0000000..8fd30ae --- /dev/null +++ b/model/option.go @@ -0,0 +1,244 @@ +package model + +import ( + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "strconv" + "strings" + "time" +) + +type Option struct { + Key string `json:"key" gorm:"primaryKey"` + Value string `json:"value"` +} + +func AllOption() ([]*Option, error) { + var options []*Option + var err error + err = DB.Find(&options).Error + return options, err +} + +func InitOptionMap() { + config.OptionMapRWMutex.Lock() + config.OptionMap = make(map[string]string) + config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled) + config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) + config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) + config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) + config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled) + config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) + config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) + config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) + config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled) + config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled) + config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled) + config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled) + config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled) + config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled) + config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64) + config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled) + config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",") + config.OptionMap["SMTPServer"] = "" + config.OptionMap["SMTPFrom"] = "" + config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort) + config.OptionMap["SMTPAccount"] = "" + config.OptionMap["SMTPToken"] = "" + config.OptionMap["Notice"] = "" + config.OptionMap["About"] = "" + config.OptionMap["HomePageContent"] = "" + config.OptionMap["Footer"] = config.Footer + config.OptionMap["SystemName"] = config.SystemName + config.OptionMap["Logo"] = config.Logo + config.OptionMap["ServerAddress"] = "" + config.OptionMap["GitHubClientId"] = "" + config.OptionMap["GitHubClientSecret"] = "" + config.OptionMap["WeChatServerAddress"] = "" + config.OptionMap["WeChatServerToken"] = "" + config.OptionMap["WeChatAccountQRCodeImageURL"] = "" + config.OptionMap["MessagePusherAddress"] = "" + config.OptionMap["MessagePusherToken"] = "" + config.OptionMap["TurnstileSiteKey"] = "" + config.OptionMap["TurnstileSecretKey"] = "" + config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10) + config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10) + config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) + config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) + config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) + config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString() + config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString() + config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString() + config.OptionMap["TopUpLink"] = config.TopUpLink + config.OptionMap["ChatLink"] = config.ChatLink + config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) + config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes) + config.OptionMap["Theme"] = config.Theme + config.OptionMapRWMutex.Unlock() + loadOptionsFromDatabase() +} + +func loadOptionsFromDatabase() { + options, _ := AllOption() + for _, option := range options { + if option.Key == "ModelRatio" { + option.Value = billingratio.AddNewMissingRatio(option.Value) + } + err := updateOptionMap(option.Key, option.Value) + if err != nil { + logger.SysError("failed to update option map: " + err.Error()) + } + } +} + +func SyncOptions(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Second) + logger.SysLog("syncing options from database") + loadOptionsFromDatabase() + } +} + +func UpdateOption(key string, value string) error { + // Save to database first + option := Option{ + Key: key, + } + // https://gorm.io/docs/update.html#Save-All-Fields + DB.FirstOrCreate(&option, Option{Key: key}) + option.Value = value + // Save is a combination function. + // If save value does not contain primary key, it will execute Create, + // otherwise it will execute Update (with all fields). + DB.Save(&option) + // Update OptionMap + return updateOptionMap(key, value) +} + +func updateOptionMap(key string, value string) (err error) { + config.OptionMapRWMutex.Lock() + defer config.OptionMapRWMutex.Unlock() + config.OptionMap[key] = value + if strings.HasSuffix(key, "Enabled") { + boolValue := value == "true" + switch key { + case "PasswordRegisterEnabled": + config.PasswordRegisterEnabled = boolValue + case "PasswordLoginEnabled": + config.PasswordLoginEnabled = boolValue + case "EmailVerificationEnabled": + config.EmailVerificationEnabled = boolValue + case "GitHubOAuthEnabled": + config.GitHubOAuthEnabled = boolValue + case "OidcEnabled": + config.OidcEnabled = boolValue + case "WeChatAuthEnabled": + config.WeChatAuthEnabled = boolValue + case "TurnstileCheckEnabled": + config.TurnstileCheckEnabled = boolValue + case "RegisterEnabled": + config.RegisterEnabled = boolValue + case "EmailDomainRestrictionEnabled": + config.EmailDomainRestrictionEnabled = boolValue + case "AutomaticDisableChannelEnabled": + config.AutomaticDisableChannelEnabled = boolValue + case "AutomaticEnableChannelEnabled": + config.AutomaticEnableChannelEnabled = boolValue + case "ApproximateTokenEnabled": + config.ApproximateTokenEnabled = boolValue + case "LogConsumeEnabled": + config.LogConsumeEnabled = boolValue + case "DisplayInCurrencyEnabled": + config.DisplayInCurrencyEnabled = boolValue + case "DisplayTokenStatEnabled": + config.DisplayTokenStatEnabled = boolValue + } + } + switch key { + case "EmailDomainWhitelist": + config.EmailDomainWhitelist = strings.Split(value, ",") + case "SMTPServer": + config.SMTPServer = value + case "SMTPPort": + intValue, _ := strconv.Atoi(value) + config.SMTPPort = intValue + case "SMTPAccount": + config.SMTPAccount = value + case "SMTPFrom": + config.SMTPFrom = value + case "SMTPToken": + config.SMTPToken = value + case "ServerAddress": + config.ServerAddress = value + case "GitHubClientId": + config.GitHubClientId = value + case "GitHubClientSecret": + config.GitHubClientSecret = value + case "LarkClientId": + config.LarkClientId = value + case "LarkClientSecret": + config.LarkClientSecret = value + case "OidcClientId": + config.OidcClientId = value + case "OidcClientSecret": + config.OidcClientSecret = value + case "OidcWellKnown": + config.OidcWellKnown = value + case "OidcAuthorizationEndpoint": + config.OidcAuthorizationEndpoint = value + case "OidcTokenEndpoint": + config.OidcTokenEndpoint = value + case "OidcUserinfoEndpoint": + config.OidcUserinfoEndpoint = value + case "Footer": + config.Footer = value + case "SystemName": + config.SystemName = value + case "Logo": + config.Logo = value + case "WeChatServerAddress": + config.WeChatServerAddress = value + case "WeChatServerToken": + config.WeChatServerToken = value + case "WeChatAccountQRCodeImageURL": + config.WeChatAccountQRCodeImageURL = value + case "MessagePusherAddress": + config.MessagePusherAddress = value + case "MessagePusherToken": + config.MessagePusherToken = value + case "TurnstileSiteKey": + config.TurnstileSiteKey = value + case "TurnstileSecretKey": + config.TurnstileSecretKey = value + case "QuotaForNewUser": + config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64) + case "QuotaForInviter": + config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64) + case "QuotaForInvitee": + config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64) + case "QuotaRemindThreshold": + config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64) + case "PreConsumedQuota": + config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64) + case "RetryTimes": + config.RetryTimes, _ = strconv.Atoi(value) + case "ModelRatio": + err = billingratio.UpdateModelRatioByJSONString(value) + case "GroupRatio": + err = billingratio.UpdateGroupRatioByJSONString(value) + case "CompletionRatio": + err = billingratio.UpdateCompletionRatioByJSONString(value) + case "TopUpLink": + config.TopUpLink = value + case "ChatLink": + config.ChatLink = value + case "ChannelDisableThreshold": + config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) + case "QuotaPerUnit": + config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) + case "Theme": + config.Theme = value + } + return err +} diff --git a/model/redemption.go b/model/redemption.go new file mode 100644 index 0000000..957a33b --- /dev/null +++ b/model/redemption.go @@ -0,0 +1,126 @@ +package model + +import ( + "context" + "errors" + "fmt" + + "gorm.io/gorm" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" +) + +const ( + RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value! + RedemptionCodeStatusDisabled = 2 // also don't use 0 + RedemptionCodeStatusUsed = 3 // also don't use 0 +) + +type Redemption struct { + Id int `json:"id"` + UserId int `json:"user_id"` + Key string `json:"key" gorm:"type:char(32);uniqueIndex"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index"` + Quota int64 `json:"quota" gorm:"bigint;default:100"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"` + Count int `json:"count" gorm:"-:all"` // only for api request +} + +func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) { + var redemptions []*Redemption + var err error + err = DB.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error + return redemptions, err +} + +func SearchRedemptions(keyword string) (redemptions []*Redemption, err error) { + err = DB.Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&redemptions).Error + return redemptions, err +} + +func GetRedemptionById(id int) (*Redemption, error) { + if id == 0 { + return nil, errors.New("id 为空!") + } + redemption := Redemption{Id: id} + var err error = nil + err = DB.First(&redemption, "id = ?", id).Error + return &redemption, err +} + +func Redeem(ctx context.Context, key string, userId int) (quota int64, err error) { + if key == "" { + return 0, errors.New("未提供兑换码") + } + if userId == 0 { + return 0, errors.New("无效的 user id") + } + redemption := &Redemption{} + + keyCol := "`key`" + if common.UsingPostgreSQL { + keyCol = `"key"` + } + + err = DB.Transaction(func(tx *gorm.DB) error { + err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error + if err != nil { + return errors.New("无效的兑换码") + } + if redemption.Status != RedemptionCodeStatusEnabled { + return errors.New("该兑换码已被使用") + } + err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error + if err != nil { + return err + } + redemption.RedeemedTime = helper.GetTimestamp() + redemption.Status = RedemptionCodeStatusUsed + err = tx.Save(redemption).Error + return err + }) + if err != nil { + return 0, errors.New("兑换失败," + err.Error()) + } + RecordLog(ctx, userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) + return redemption.Quota, nil +} + +func (redemption *Redemption) Insert() error { + var err error + err = DB.Create(redemption).Error + return err +} + +func (redemption *Redemption) SelectUpdate() error { + // This can update zero values + return DB.Model(redemption).Select("redeemed_time", "status").Updates(redemption).Error +} + +// Update Make sure your token's fields is completed, because this will update non-zero values +func (redemption *Redemption) Update() error { + var err error + err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time").Updates(redemption).Error + return err +} + +func (redemption *Redemption) Delete() error { + var err error + err = DB.Delete(redemption).Error + return err +} + +func DeleteRedemptionById(id int) (err error) { + if id == 0 { + return errors.New("id 为空!") + } + redemption := Redemption{Id: id} + err = DB.Where(redemption).First(&redemption).Error + if err != nil { + return err + } + return redemption.Delete() +} diff --git a/model/token.go b/model/token.go new file mode 100644 index 0000000..52ee63e --- /dev/null +++ b/model/token.go @@ -0,0 +1,303 @@ +package model + +import ( + "errors" + "fmt" + + "gorm.io/gorm" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" +) + +const ( + TokenStatusEnabled = 1 // don't use 0, 0 is the default value! + TokenStatusDisabled = 2 // also don't use 0 + TokenStatusExpired = 3 + TokenStatusExhausted = 4 +) + +type Token struct { + Id int `json:"id"` + UserId int `json:"user_id"` + Key string `json:"key" gorm:"type:char(48);uniqueIndex"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index" ` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + AccessedTime int64 `json:"accessed_time" gorm:"bigint"` + ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired + RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` + UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` + UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota + Models *string `json:"models" gorm:"type:text"` // allowed models + Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet +} + +func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { + var tokens []*Token + var err error + query := DB.Where("user_id = ?", userId) + + switch order { + case "remain_quota": + query = query.Order("unlimited_quota desc, remain_quota desc") + case "used_quota": + query = query.Order("used_quota desc") + default: + query = query.Order("id desc") + } + + err = query.Limit(num).Offset(startIdx).Find(&tokens).Error + return tokens, err +} + +func SearchUserTokens(userId int, keyword string) (tokens []*Token, err error) { + err = DB.Where("user_id = ?", userId).Where("name LIKE ?", keyword+"%").Find(&tokens).Error + return tokens, err +} + +func ValidateUserToken(key string) (token *Token, err error) { + if key == "" { + return nil, errors.New("未提供令牌") + } + token, err = CacheGetTokenByKey(key) + if err != nil { + logger.SysError("CacheGetTokenByKey failed: " + err.Error()) + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("无效的令牌") + } + return nil, errors.New("令牌验证失败") + } + if token.Status == TokenStatusExhausted { + return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id) + } else if token.Status == TokenStatusExpired { + return nil, errors.New("该令牌已过期") + } + if token.Status != TokenStatusEnabled { + return nil, errors.New("该令牌状态不可用") + } + if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { + if !common.RedisEnabled { + token.Status = TokenStatusExpired + err := token.SelectUpdate() + if err != nil { + logger.SysError("failed to update token status" + err.Error()) + } + } + return nil, errors.New("该令牌已过期") + } + if !token.UnlimitedQuota && token.RemainQuota <= 0 { + if !common.RedisEnabled { + // in this case, we can make sure the token is exhausted + token.Status = TokenStatusExhausted + err := token.SelectUpdate() + if err != nil { + logger.SysError("failed to update token status" + err.Error()) + } + } + return nil, errors.New("该令牌额度已用尽") + } + return token, nil +} + +func GetTokenByIds(id int, userId int) (*Token, error) { + if id == 0 || userId == 0 { + return nil, errors.New("id 或 userId 为空!") + } + token := Token{Id: id, UserId: userId} + var err error = nil + err = DB.First(&token, "id = ? and user_id = ?", id, userId).Error + return &token, err +} + +func GetTokenById(id int) (*Token, error) { + if id == 0 { + return nil, errors.New("id 为空!") + } + token := Token{Id: id} + var err error = nil + err = DB.First(&token, "id = ?", id).Error + return &token, err +} + +func (t *Token) Insert() error { + var err error + err = DB.Create(t).Error + return err +} + +// Update Make sure your token's fields is completed, because this will update non-zero values +func (t *Token) Update() error { + var err error + err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error + return err +} + +func (t *Token) SelectUpdate() error { + // This can update zero values + return DB.Model(t).Select("accessed_time", "status").Updates(t).Error +} + +func (t *Token) Delete() error { + var err error + err = DB.Delete(t).Error + return err +} + +func (t *Token) GetModels() string { + if t == nil { + return "" + } + if t.Models == nil { + return "" + } + return *t.Models +} + +func DeleteTokenById(id int, userId int) (err error) { + // Why we need userId here? In case user want to delete other's token. + if id == 0 || userId == 0 { + return errors.New("id 或 userId 为空!") + } + token := Token{Id: id, UserId: userId} + err = DB.Where(token).First(&token).Error + if err != nil { + return err + } + return token.Delete() +} + +func IncreaseTokenQuota(id int, quota int64) (err error) { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + if config.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeTokenQuota, id, quota) + return nil + } + return increaseTokenQuota(id, quota) +} + +func increaseTokenQuota(id int, quota int64) (err error) { + err = DB.Model(&Token{}).Where("id = ?", id).Updates( + map[string]interface{}{ + "remain_quota": gorm.Expr("remain_quota + ?", quota), + "used_quota": gorm.Expr("used_quota - ?", quota), + "accessed_time": helper.GetTimestamp(), + }, + ).Error + return err +} + +func DecreaseTokenQuota(id int, quota int64) (err error) { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + if config.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) + return nil + } + return decreaseTokenQuota(id, quota) +} + +func decreaseTokenQuota(id int, quota int64) (err error) { + err = DB.Model(&Token{}).Where("id = ?", id).Updates( + map[string]interface{}{ + "remain_quota": gorm.Expr("remain_quota - ?", quota), + "used_quota": gorm.Expr("used_quota + ?", quota), + "accessed_time": helper.GetTimestamp(), + }, + ).Error + return err +} + +func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + token, err := GetTokenById(tokenId) + if err != nil { + return err + } + if !token.UnlimitedQuota && token.RemainQuota < quota { + return errors.New("令牌额度不足") + } + userQuota, err := GetUserQuota(token.UserId) + if err != nil { + return err + } + if userQuota < quota { + return errors.New("用户额度不足") + } + quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold + noMoreQuota := userQuota-quota <= 0 + if quotaTooLow || noMoreQuota { + go func() { + email, err := GetUserEmail(token.UserId) + if err != nil { + logger.SysError("failed to fetch user email: " + err.Error()) + } + prompt := "额度提醒" + var contentText string + if noMoreQuota { + contentText = "您的额度已用尽" + } else { + contentText = "您的额度即将用尽" + } + if email != "" { + topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) + content := message.EmailTemplate( + prompt, + fmt.Sprintf(` +

您好!

+

%s,当前剩余额度为 %d

+

为了不影响您的使用,请及时充值。

+

+ 立即充值 +

+

如果按钮无法点击,请复制以下链接到浏览器中打开:

+

%s

+ `, contentText, userQuota, topUpLink, topUpLink), + ) + err = message.SendEmail(prompt, email, content) + if err != nil { + logger.SysError("failed to send email: " + err.Error()) + } + } + }() + } + if !token.UnlimitedQuota { + err = DecreaseTokenQuota(tokenId, quota) + if err != nil { + return err + } + } + err = DecreaseUserQuota(token.UserId, quota) + return err +} + +func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { + token, err := GetTokenById(tokenId) + if err != nil { + return err + } + if quota > 0 { + err = DecreaseUserQuota(token.UserId, quota) + } else { + err = IncreaseUserQuota(token.UserId, -quota) + } + if !token.UnlimitedQuota { + if quota > 0 { + err = DecreaseTokenQuota(tokenId, quota) + } else { + err = IncreaseTokenQuota(tokenId, -quota) + } + if err != nil { + return err + } + } + return nil +} diff --git a/model/user.go b/model/user.go new file mode 100644 index 0000000..021810c --- /dev/null +++ b/model/user.go @@ -0,0 +1,453 @@ +package model + +import ( + "context" + "errors" + "fmt" + "strings" + + "gorm.io/gorm" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/blacklist" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" +) + +const ( + RoleGuestUser = 0 + RoleCommonUser = 1 + RoleAdminUser = 10 + RoleRootUser = 100 +) + +const ( + UserStatusEnabled = 1 // don't use 0, 0 is the default value! + UserStatusDisabled = 2 // also don't use 0 + UserStatusDeleted = 3 +) + +// User if you add sensitive fields, don't forget to clean them in setupLogin function. +// Otherwise, the sensitive information will be saved on local storage in plain text! +type User struct { + Id int `json:"id"` + Username string `json:"username" gorm:"unique;index" validate:"max=12"` + Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` + DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` + Role int `json:"role" gorm:"type:int;default:1"` // admin, util + Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled + Email string `json:"email" gorm:"index" validate:"max=50"` + GitHubId string `json:"github_id" gorm:"column:github_id;index"` + WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` + LarkId string `json:"lark_id" gorm:"column:lark_id;index"` + OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"` + VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! + AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management + Quota int64 `json:"quota" gorm:"bigint;default:0"` + UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0;column:used_quota"` // used quota + RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number + Group string `json:"group" gorm:"type:varchar(32);default:'default'"` + AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` + InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` +} + +func GetMaxUserId() int { + var user User + DB.Last(&user) + return user.Id +} + +func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) { + query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted) + + switch order { + case "quota": + query = query.Order("quota desc") + case "used_quota": + query = query.Order("used_quota desc") + case "request_count": + query = query.Order("request_count desc") + default: + query = query.Order("id desc") + } + + err = query.Find(&users).Error + return users, err +} + +func SearchUsers(keyword string) (users []*User, err error) { + if !common.UsingPostgreSQL { + err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error + } else { + err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Find(&users).Error + } + return users, err +} + +func GetUserById(id int, selectAll bool) (*User, error) { + if id == 0 { + return nil, errors.New("id 为空!") + } + user := User{Id: id} + var err error = nil + if selectAll { + err = DB.First(&user, "id = ?", id).Error + } else { + err = DB.Omit("password", "access_token").First(&user, "id = ?", id).Error + } + return &user, err +} + +func GetUserIdByAffCode(affCode string) (int, error) { + if affCode == "" { + return 0, errors.New("affCode 为空!") + } + var user User + err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error + return user.Id, err +} + +func DeleteUserById(id int) (err error) { + if id == 0 { + return errors.New("id 为空!") + } + user := User{Id: id} + return user.Delete() +} + +func (user *User) Insert(ctx context.Context, inviterId int) error { + var err error + if user.Password != "" { + user.Password, err = common.Password2Hash(user.Password) + if err != nil { + return err + } + } + user.Quota = config.QuotaForNewUser + user.AccessToken = random.GetUUID() + user.AffCode = random.GetRandomString(4) + result := DB.Create(user) + if result.Error != nil { + return result.Error + } + if config.QuotaForNewUser > 0 { + RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) + } + if inviterId != 0 { + if config.QuotaForInvitee > 0 { + _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) + RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) + } + if config.QuotaForInviter > 0 { + _ = IncreaseUserQuota(inviterId, config.QuotaForInviter) + RecordLog(ctx, inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) + } + } + // create default token + cleanToken := Token{ + UserId: user.Id, + Name: "default", + Key: random.GenerateKey(), + CreatedTime: helper.GetTimestamp(), + AccessedTime: helper.GetTimestamp(), + ExpiredTime: -1, + RemainQuota: -1, + UnlimitedQuota: true, + } + result.Error = cleanToken.Insert() + if result.Error != nil { + // do not block + logger.SysError(fmt.Sprintf("create default token for user %d failed: %s", user.Id, result.Error.Error())) + } + return nil +} + +func (user *User) Update(updatePassword bool) error { + var err error + if updatePassword { + user.Password, err = common.Password2Hash(user.Password) + if err != nil { + return err + } + } + if user.Status == UserStatusDisabled { + blacklist.BanUser(user.Id) + } else if user.Status == UserStatusEnabled { + blacklist.UnbanUser(user.Id) + } + err = DB.Model(user).Updates(user).Error + return err +} + +func (user *User) Delete() error { + if user.Id == 0 { + return errors.New("id 为空!") + } + blacklist.BanUser(user.Id) + user.Username = fmt.Sprintf("deleted_%s", random.GetUUID()) + user.Status = UserStatusDeleted + err := DB.Model(user).Updates(user).Error + return err +} + +// ValidateAndFill check password & user status +func (user *User) ValidateAndFill() (err error) { + // When querying with struct, GORM will only query with non-zero fields, + // that means if your field’s value is 0, '', false or other zero values, + // it won’t be used to build query conditions + password := user.Password + if user.Username == "" || password == "" { + return errors.New("用户名或密码为空") + } + err = DB.Where("username = ?", user.Username).First(user).Error + if err != nil { + // we must make sure check username firstly + // consider this case: a malicious user set his username as other's email + err := DB.Where("email = ?", user.Username).First(user).Error + if err != nil { + return errors.New("用户名或密码错误,或用户已被封禁") + } + } + okay := common.ValidatePasswordAndHash(password, user.Password) + if !okay || user.Status != UserStatusEnabled { + return errors.New("用户名或密码错误,或用户已被封禁") + } + return nil +} + +func (user *User) FillUserById() error { + if user.Id == 0 { + return errors.New("id 为空!") + } + DB.Where(User{Id: user.Id}).First(user) + return nil +} + +func (user *User) FillUserByEmail() error { + if user.Email == "" { + return errors.New("email 为空!") + } + DB.Where(User{Email: user.Email}).First(user) + return nil +} + +func (user *User) FillUserByGitHubId() error { + if user.GitHubId == "" { + return errors.New("GitHub id 为空!") + } + DB.Where(User{GitHubId: user.GitHubId}).First(user) + return nil +} + +func (user *User) FillUserByLarkId() error { + if user.LarkId == "" { + return errors.New("lark id 为空!") + } + DB.Where(User{LarkId: user.LarkId}).First(user) + return nil +} + +func (user *User) FillUserByOidcId() error { + if user.OidcId == "" { + return errors.New("oidc id 为空!") + } + DB.Where(User{OidcId: user.OidcId}).First(user) + return nil +} + +func (user *User) FillUserByWeChatId() error { + if user.WeChatId == "" { + return errors.New("WeChat id 为空!") + } + DB.Where(User{WeChatId: user.WeChatId}).First(user) + return nil +} + +func (user *User) FillUserByUsername() error { + if user.Username == "" { + return errors.New("username 为空!") + } + DB.Where(User{Username: user.Username}).First(user) + return nil +} + +func IsEmailAlreadyTaken(email string) bool { + return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1 +} + +func IsWeChatIdAlreadyTaken(wechatId string) bool { + return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1 +} + +func IsGitHubIdAlreadyTaken(githubId string) bool { + return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 +} + +func IsLarkIdAlreadyTaken(githubId string) bool { + return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 +} + +func IsOidcIdAlreadyTaken(oidcId string) bool { + return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 +} + +func IsUsernameAlreadyTaken(username string) bool { + return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 +} + +func ResetUserPasswordByEmail(email string, password string) error { + if email == "" || password == "" { + return errors.New("邮箱地址或密码为空!") + } + hashedPassword, err := common.Password2Hash(password) + if err != nil { + return err + } + err = DB.Model(&User{}).Where("email = ?", email).Update("password", hashedPassword).Error + return err +} + +func IsAdmin(userId int) bool { + if userId == 0 { + return false + } + var user User + err := DB.Where("id = ?", userId).Select("role").Find(&user).Error + if err != nil { + logger.SysError("no such user " + err.Error()) + return false + } + return user.Role >= RoleAdminUser +} + +func IsUserEnabled(userId int) (bool, error) { + if userId == 0 { + return false, errors.New("user id is empty") + } + var user User + err := DB.Where("id = ?", userId).Select("status").Find(&user).Error + if err != nil { + return false, err + } + return user.Status == UserStatusEnabled, nil +} + +func ValidateAccessToken(token string) (user *User) { + if token == "" { + return nil + } + token = strings.Replace(token, "Bearer ", "", 1) + user = &User{} + if DB.Where("access_token = ?", token).First(user).RowsAffected == 1 { + return user + } + return nil +} + +func GetUserQuota(id int) (quota int64, err error) { + err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error + return quota, err +} + +func GetUserUsedQuota(id int) (quota int64, err error) { + err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error + return quota, err +} + +func GetUserEmail(id int) (email string, err error) { + err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error + return email, err +} + +func GetUserGroup(id int) (group string, err error) { + groupCol := "`group`" + if common.UsingPostgreSQL { + groupCol = `"group"` + } + + err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error + return group, err +} + +func IncreaseUserQuota(id int, quota int64) (err error) { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + if config.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUserQuota, id, quota) + return nil + } + return increaseUserQuota(id, quota) +} + +func increaseUserQuota(id int, quota int64) (err error) { + err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error + return err +} + +func DecreaseUserQuota(id int, quota int64) (err error) { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + if config.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUserQuota, id, -quota) + return nil + } + return decreaseUserQuota(id, quota) +} + +func decreaseUserQuota(id int, quota int64) (err error) { + err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error + return err +} + +func GetRootUserEmail() (email string) { + DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email) + return email +} + +func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) { + if config.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUsedQuota, id, quota) + addNewRecord(BatchUpdateTypeRequestCount, id, 1) + return + } + updateUserUsedQuotaAndRequestCount(id, quota, 1) +} + +func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) { + err := DB.Model(&User{}).Where("id = ?", id).Updates( + map[string]interface{}{ + "used_quota": gorm.Expr("used_quota + ?", quota), + "request_count": gorm.Expr("request_count + ?", count), + }, + ).Error + if err != nil { + logger.SysError("failed to update user used quota and request count: " + err.Error()) + } +} + +func updateUserUsedQuota(id int, quota int64) { + err := DB.Model(&User{}).Where("id = ?", id).Updates( + map[string]interface{}{ + "used_quota": gorm.Expr("used_quota + ?", quota), + }, + ).Error + if err != nil { + logger.SysError("failed to update user used quota: " + err.Error()) + } +} + +func updateUserRequestCount(id int, count int) { + err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error + if err != nil { + logger.SysError("failed to update user request count: " + err.Error()) + } +} + +func GetUsernameById(id int) (username string) { + DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username) + return username +} diff --git a/model/utils.go b/model/utils.go new file mode 100644 index 0000000..a55eb4b --- /dev/null +++ b/model/utils.go @@ -0,0 +1,78 @@ +package model + +import ( + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "sync" + "time" +) + +const ( + BatchUpdateTypeUserQuota = iota + BatchUpdateTypeTokenQuota + BatchUpdateTypeUsedQuota + BatchUpdateTypeChannelUsedQuota + BatchUpdateTypeRequestCount + BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock +) + +var batchUpdateStores []map[int]int64 +var batchUpdateLocks []sync.Mutex + +func init() { + for i := 0; i < BatchUpdateTypeCount; i++ { + batchUpdateStores = append(batchUpdateStores, make(map[int]int64)) + batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) + } +} + +func InitBatchUpdater() { + go func() { + for { + time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second) + batchUpdate() + } + }() +} + +func addNewRecord(type_ int, id int, value int64) { + batchUpdateLocks[type_].Lock() + defer batchUpdateLocks[type_].Unlock() + if _, ok := batchUpdateStores[type_][id]; !ok { + batchUpdateStores[type_][id] = value + } else { + batchUpdateStores[type_][id] += value + } +} + +func batchUpdate() { + logger.SysLog("batch update started") + for i := 0; i < BatchUpdateTypeCount; i++ { + batchUpdateLocks[i].Lock() + store := batchUpdateStores[i] + batchUpdateStores[i] = make(map[int]int64) + batchUpdateLocks[i].Unlock() + // TODO: maybe we can combine updates with same key? + for key, value := range store { + switch i { + case BatchUpdateTypeUserQuota: + err := increaseUserQuota(key, value) + if err != nil { + logger.SysError("failed to batch update user quota: " + err.Error()) + } + case BatchUpdateTypeTokenQuota: + err := increaseTokenQuota(key, value) + if err != nil { + logger.SysError("failed to batch update token quota: " + err.Error()) + } + case BatchUpdateTypeUsedQuota: + updateUserUsedQuota(key, value) + case BatchUpdateTypeRequestCount: + updateUserRequestCount(key, int(value)) + case BatchUpdateTypeChannelUsedQuota: + updateChannelUsedQuota(key, value) + } + } + } + logger.SysLog("batch update finished") +} diff --git a/monitor/channel.go b/monitor/channel.go new file mode 100644 index 0000000..a375c8d --- /dev/null +++ b/monitor/channel.go @@ -0,0 +1,77 @@ +package monitor + +import ( + "fmt" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" + "github.com/songquanpeng/one-api/model" +) + +func notifyRootUser(subject string, content string) { + if config.MessagePusherAddress != "" { + err := message.SendMessage(subject, content, content) + if err != nil { + logger.SysError(fmt.Sprintf("failed to send message: %s", err.Error())) + } else { + return + } + } + if config.RootUserEmail == "" { + config.RootUserEmail = model.GetRootUserEmail() + } + err := message.SendEmail(subject, config.RootUserEmail, content) + if err != nil { + logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + } +} + +// DisableChannel disable & notify +func DisableChannel(channelId int, channelName string, reason string) { + model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) + logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) + subject := fmt.Sprintf("渠道状态变更提醒") + content := message.EmailTemplate( + subject, + fmt.Sprintf(` +

您好!

+

渠道「%s」(#%d)已被禁用。

+

禁用原因:

+

%s

+ `, channelName, channelId, reason), + ) + notifyRootUser(subject, content) +} + +func MetricDisableChannel(channelId int, successRate float64) { + model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) + logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) + subject := fmt.Sprintf("渠道状态变更提醒") + content := message.EmailTemplate( + subject, + fmt.Sprintf(` +

您好!

+

渠道 #%d 已被系统自动禁用。

+

禁用原因:

+

该渠道在最近 %d 次调用中成功率为 %.2f%%,低于系统阈值 %.2f%%

+ `, channelId, config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100), + ) + notifyRootUser(subject, content) +} + +// EnableChannel enable & notify +func EnableChannel(channelId int, channelName string) { + model.UpdateChannelStatusById(channelId, model.ChannelStatusEnabled) + logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) + subject := fmt.Sprintf("渠道状态变更提醒") + content := message.EmailTemplate( + subject, + fmt.Sprintf(` +

您好!

+

渠道「%s」(#%d)已被重新启用。

+

您现在可以继续使用该渠道了。

+ `, channelName, channelId), + ) + notifyRootUser(subject, content) +} diff --git a/monitor/manage.go b/monitor/manage.go new file mode 100644 index 0000000..338cd7c --- /dev/null +++ b/monitor/manage.go @@ -0,0 +1,57 @@ +package monitor + +import ( + "net/http" + "strings" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/relay/model" +) + +func ShouldDisableChannel(err *model.Error, statusCode int) bool { + if !config.AutomaticDisableChannelEnabled { + return false + } + if err == nil { + return false + } + if statusCode == http.StatusUnauthorized { + return true + } + switch err.Type { + case "insufficient_quota", "authentication_error", "permission_error", "forbidden": + return true + } + if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + return true + } + + lowerMessage := strings.ToLower(err.Message) + if strings.Contains(lowerMessage, "your access was terminated") || + strings.Contains(lowerMessage, "violation of our policies") || + strings.Contains(lowerMessage, "your credit balance is too low") || + strings.Contains(lowerMessage, "organization has been disabled") || + strings.Contains(lowerMessage, "credit") || + strings.Contains(lowerMessage, "balance") || + strings.Contains(lowerMessage, "permission denied") || + strings.Contains(lowerMessage, "organization has been restricted") || // groq + strings.Contains(lowerMessage, "api key not valid") || // gemini + strings.Contains(lowerMessage, "api key expired") || // gemini + strings.Contains(lowerMessage, "已欠费") { + return true + } + return false +} + +func ShouldEnableChannel(err error, openAIErr *model.Error) bool { + if !config.AutomaticEnableChannelEnabled { + return false + } + if err != nil { + return false + } + if openAIErr != nil { + return false + } + return true +} diff --git a/monitor/metric.go b/monitor/metric.go new file mode 100644 index 0000000..98bc546 --- /dev/null +++ b/monitor/metric.go @@ -0,0 +1,79 @@ +package monitor + +import ( + "github.com/songquanpeng/one-api/common/config" +) + +var store = make(map[int][]bool) +var metricSuccessChan = make(chan int, config.MetricSuccessChanSize) +var metricFailChan = make(chan int, config.MetricFailChanSize) + +func consumeSuccess(channelId int) { + if len(store[channelId]) > config.MetricQueueSize { + store[channelId] = store[channelId][1:] + } + store[channelId] = append(store[channelId], true) +} + +func consumeFail(channelId int) (bool, float64) { + if len(store[channelId]) > config.MetricQueueSize { + store[channelId] = store[channelId][1:] + } + store[channelId] = append(store[channelId], false) + successCount := 0 + for _, success := range store[channelId] { + if success { + successCount++ + } + } + successRate := float64(successCount) / float64(len(store[channelId])) + if len(store[channelId]) < config.MetricQueueSize { + return false, successRate + } + if successRate < config.MetricSuccessRateThreshold { + store[channelId] = make([]bool, 0) + return true, successRate + } + return false, successRate +} + +func metricSuccessConsumer() { + for { + select { + case channelId := <-metricSuccessChan: + consumeSuccess(channelId) + } + } +} + +func metricFailConsumer() { + for { + select { + case channelId := <-metricFailChan: + disable, successRate := consumeFail(channelId) + if disable { + go MetricDisableChannel(channelId, successRate) + } + } + } +} + +func init() { + if config.EnableMetric { + go metricSuccessConsumer() + go metricFailConsumer() + } +} + +func Emit(channelId int, success bool) { + if !config.EnableMetric { + return + } + go func() { + if success { + metricSuccessChan <- channelId + } else { + metricFailChan <- channelId + } + }() +} diff --git a/one-api.service b/one-api.service new file mode 100644 index 0000000..17e236b --- /dev/null +++ b/one-api.service @@ -0,0 +1,18 @@ +# File path: /etc/systemd/system/one-api.service +# sudo systemctl daemon-reload +# sudo systemctl start one-api +# sudo systemctl enable one-api +# sudo systemctl status one-api +[Unit] +Description=One API Service +After=network.target + +[Service] +User=ubuntu # 注意修改用户名 +WorkingDirectory=/path/to/one-api # 注意修改路径 +ExecStart=/path/to/one-api/one-api --port 3000 --log-dir /path/to/one-api/logs # 注意修改路径和端口号 +Restart=always +RestartSec=5 + +[Install] +WantedBy=multi-user.target diff --git a/pull_request_template.md b/pull_request_template.md new file mode 100644 index 0000000..c630134 --- /dev/null +++ b/pull_request_template.md @@ -0,0 +1,10 @@ +[//]: # (请按照以下格式关联 issue) +[//]: # (请在提交 PR 前确认所提交的功能可用,需要附上截图,谢谢) +[//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解) +[//]: # (开发者交流群:910657413) +[//]: # (请在提交 PR 之前删除上面的注释) + +close #issue_number + +我已确认该 PR 已自测通过,相关截图如下: +(此处放上测试通过的截图,如果不涉及前端改动或从 UI 上无法看出,请放终端启动成功的截图) diff --git a/relay/adaptor.go b/relay/adaptor.go new file mode 100644 index 0000000..03e8390 --- /dev/null +++ b/relay/adaptor.go @@ -0,0 +1,69 @@ +package relay + +import ( + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/aiproxy" + "github.com/songquanpeng/one-api/relay/adaptor/ali" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws" + "github.com/songquanpeng/one-api/relay/adaptor/baidu" + "github.com/songquanpeng/one-api/relay/adaptor/cloudflare" + "github.com/songquanpeng/one-api/relay/adaptor/cohere" + "github.com/songquanpeng/one-api/relay/adaptor/coze" + "github.com/songquanpeng/one-api/relay/adaptor/deepl" + "github.com/songquanpeng/one-api/relay/adaptor/gemini" + "github.com/songquanpeng/one-api/relay/adaptor/ollama" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/adaptor/palm" + "github.com/songquanpeng/one-api/relay/adaptor/proxy" + "github.com/songquanpeng/one-api/relay/adaptor/replicate" + "github.com/songquanpeng/one-api/relay/adaptor/tencent" + "github.com/songquanpeng/one-api/relay/adaptor/vertexai" + "github.com/songquanpeng/one-api/relay/adaptor/xunfei" + "github.com/songquanpeng/one-api/relay/adaptor/zhipu" + "github.com/songquanpeng/one-api/relay/apitype" +) + +func GetAdaptor(apiType int) adaptor.Adaptor { + switch apiType { + case apitype.AIProxyLibrary: + return &aiproxy.Adaptor{} + case apitype.Ali: + return &ali.Adaptor{} + case apitype.Anthropic: + return &anthropic.Adaptor{} + case apitype.AwsClaude: + return &aws.Adaptor{} + case apitype.Baidu: + return &baidu.Adaptor{} + case apitype.Gemini: + return &gemini.Adaptor{} + case apitype.OpenAI: + return &openai.Adaptor{} + case apitype.PaLM: + return &palm.Adaptor{} + case apitype.Tencent: + return &tencent.Adaptor{} + case apitype.Xunfei: + return &xunfei.Adaptor{} + case apitype.Zhipu: + return &zhipu.Adaptor{} + case apitype.Ollama: + return &ollama.Adaptor{} + case apitype.Coze: + return &coze.Adaptor{} + case apitype.Cohere: + return &cohere.Adaptor{} + case apitype.Cloudflare: + return &cloudflare.Adaptor{} + case apitype.DeepL: + return &deepl.Adaptor{} + case apitype.VertexAI: + return &vertexai.Adaptor{} + case apitype.Proxy: + return &proxy.Adaptor{} + case apitype.Replicate: + return &replicate.Adaptor{} + } + return nil +} diff --git a/relay/adaptor/ai360/constants.go b/relay/adaptor/ai360/constants.go new file mode 100644 index 0000000..cfc3cb2 --- /dev/null +++ b/relay/adaptor/ai360/constants.go @@ -0,0 +1,8 @@ +package ai360 + +var ModelList = []string{ + "360GPT_S2_V9", + "embedding-bert-512-v1", + "embedding_s1_v1", + "semantic_similarity_s1_v1", +} diff --git a/relay/adaptor/aiproxy/adaptor.go b/relay/adaptor/aiproxy/adaptor.go new file mode 100644 index 0000000..42d49c0 --- /dev/null +++ b/relay/adaptor/aiproxy/adaptor.go @@ -0,0 +1,67 @@ +package aiproxy + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +type Adaptor struct { + meta *meta.Meta +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + aiProxyLibraryRequest := ConvertRequest(*request) + aiProxyLibraryRequest.LibraryId = a.meta.Config.LibraryID + return aiProxyLibraryRequest, nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + err, usage = Handler(c, resp) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "aiproxy" +} diff --git a/relay/adaptor/aiproxy/constants.go b/relay/adaptor/aiproxy/constants.go new file mode 100644 index 0000000..818d270 --- /dev/null +++ b/relay/adaptor/aiproxy/constants.go @@ -0,0 +1,9 @@ +package aiproxy + +import "github.com/songquanpeng/one-api/relay/adaptor/openai" + +var ModelList = []string{""} + +func init() { + ModelList = openai.ModelList +} diff --git a/relay/adaptor/aiproxy/main.go b/relay/adaptor/aiproxy/main.go new file mode 100644 index 0000000..d64b680 --- /dev/null +++ b/relay/adaptor/aiproxy/main.go @@ -0,0 +1,189 @@ +package aiproxy + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strconv" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" +) + +// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 + +func ConvertRequest(request model.GeneralOpenAIRequest) *LibraryRequest { + query := "" + if len(request.Messages) != 0 { + query = request.Messages[len(request.Messages)-1].StringContent() + } + return &LibraryRequest{ + Model: request.Model, + Stream: request.Stream, + Query: query, + } +} + +func aiProxyDocuments2Markdown(documents []LibraryDocument) string { + if len(documents) == 0 { + return "" + } + content := "\n\n参考文档:\n" + for i, document := range documents { + content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL) + } + return content +} + +func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse { + content := response.Answer + aiProxyDocuments2Markdown(response.Documents) + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: content, + }, + FinishReason: "stop", + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = aiProxyDocuments2Markdown(documents) + choice.FinishReason = &constant.StopFinishReason + return &openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: "", + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } +} + +func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = response.Content + return &openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: response.Model, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage model.Usage + var documents []LibraryDocument + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + + common.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || data[:5] != "data:" { + continue + } + data = data[5:] + + var AIProxyLibraryResponse LibraryStreamResponse + err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + if len(AIProxyLibraryResponse.Documents) != 0 { + documents = AIProxyLibraryResponse.Documents + } + response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + response := documentsAIProxyLibrary(documents) + err := render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + render.Done(c) + + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + return nil, &usage +} + +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var AIProxyLibraryResponse LibraryResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if AIProxyLibraryResponse.ErrCode != 0 { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: AIProxyLibraryResponse.Message, + Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode), + Code: AIProxyLibraryResponse.ErrCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + if err != nil { + return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &fullTextResponse.Usage +} diff --git a/relay/adaptor/aiproxy/model.go b/relay/adaptor/aiproxy/model.go new file mode 100644 index 0000000..39689b3 --- /dev/null +++ b/relay/adaptor/aiproxy/model.go @@ -0,0 +1,32 @@ +package aiproxy + +type LibraryRequest struct { + Model string `json:"model"` + Query string `json:"query"` + LibraryId string `json:"libraryId"` + Stream bool `json:"stream"` +} + +type LibraryError struct { + ErrCode int `json:"errCode"` + Message string `json:"message"` +} + +type LibraryDocument struct { + Title string `json:"title"` + URL string `json:"url"` +} + +type LibraryResponse struct { + Success bool `json:"success"` + Answer string `json:"answer"` + Documents []LibraryDocument `json:"documents"` + LibraryError +} + +type LibraryStreamResponse struct { + Content string `json:"content"` + Finish bool `json:"finish"` + Model string `json:"model"` + Documents []LibraryDocument `json:"documents"` +} diff --git a/relay/adaptor/ali/adaptor.go b/relay/adaptor/ali/adaptor.go new file mode 100644 index 0000000..4aa8a11 --- /dev/null +++ b/relay/adaptor/ali/adaptor.go @@ -0,0 +1,105 @@ +package ali + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" +) + +// https://help.aliyun.com/zh/dashscope/developer-reference/api-details + +type Adaptor struct { + meta *meta.Meta +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + fullRequestURL := "" + switch meta.Mode { + case relaymode.Embeddings: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL) + case relaymode.ImagesGenerations: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL) + default: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) + } + + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + if meta.IsStream { + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("X-DashScope-SSE", "enable") + } + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + + if meta.Mode == relaymode.ImagesGenerations { + req.Header.Set("X-DashScope-Async", "enable") + } + if a.meta.Config.Plugin != "" { + req.Header.Set("X-DashScope-Plugin", a.meta.Config.Plugin) + } + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case relaymode.Embeddings: + aliEmbeddingRequest := ConvertEmbeddingRequest(*request) + return aliEmbeddingRequest, nil + default: + aliRequest := ConvertRequest(*request) + return aliRequest, nil + } +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + aliRequest := ConvertImageRequest(*request) + return aliRequest, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + case relaymode.ImagesGenerations: + err, usage = ImageHandler(c, resp) + default: + err, usage = Handler(c, resp) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "ali" +} diff --git a/relay/adaptor/ali/constants.go b/relay/adaptor/ali/constants.go new file mode 100644 index 0000000..7c25325 --- /dev/null +++ b/relay/adaptor/ali/constants.go @@ -0,0 +1,27 @@ +package ali + +var ModelList = []string{ + "qwen-turbo", "qwen-turbo-latest", + "qwen-plus", "qwen-plus-latest", + "qwen-max", "qwen-max-latest", + "qwen-max-longcontext", + "qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest", + "qwen-vl-ocr", "qwen-vl-ocr-latest", + "qwen-audio-turbo", + "qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest", + "qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest", + "qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct", + "qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct", + "qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat", + "qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat", + "qvq-72b-preview", + "qwen2.5-vl-72b-instruct", "qwen2.5-vl-7b-instruct", "qwen2.5-vl-2b-instruct", "qwen2.5-vl-1b-instruct", "qwen2.5-vl-0.5b-instruct", + "qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1", + "qwen2-audio-instruct", "qwen-audio-chat", + "qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct", + "qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct", + "text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1", + "ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", + "qwen-mt-plus", "qwen-mt-turbo", + "deepseek-r1", "deepseek-v3", "deepseek-r1-distill-qwen-1.5b", "deepseek-r1-distill-qwen-7b", "deepseek-r1-distill-qwen-14b", "deepseek-r1-distill-qwen-32b", "deepseek-r1-distill-llama-8b", "deepseek-r1-distill-llama-70b", +} diff --git a/relay/adaptor/ali/image.go b/relay/adaptor/ali/image.go new file mode 100644 index 0000000..8261803 --- /dev/null +++ b/relay/adaptor/ali/image.go @@ -0,0 +1,192 @@ +package ali + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" + "time" +) + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + responseFormat := c.GetString("response_format") + + var aliTaskResponse TaskResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &aliTaskResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + if aliTaskResponse.Message != "" { + logger.SysError("aliAsyncTask err: " + string(responseBody)) + return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil + } + + aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey) + if err != nil { + return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Output.TaskStatus != "SUCCEEDED" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: aliResponse.Output.Message, + Type: "ali_error", + Param: "", + Code: aliResponse.Output.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, nil +} + +func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) { + url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID) + + var aliResponse TaskResponse + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return &aliResponse, err, nil + } + + req.Header.Set("Authorization", "Bearer "+key) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + logger.SysError("aliAsyncTask client.Do err: " + err.Error()) + return &aliResponse, err, nil + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + + var response TaskResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + logger.SysError("aliAsyncTask NewDecoder err: " + err.Error()) + return &aliResponse, err, nil + } + + return &response, nil, responseBody +} + +func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) { + waitSeconds := 2 + step := 0 + maxStep := 20 + + var taskResponse TaskResponse + var responseBody []byte + + for { + step++ + rsp, err, body := asyncTask(taskID, key) + responseBody = body + if err != nil { + return &taskResponse, responseBody, err + } + + if rsp.Output.TaskStatus == "" { + return &taskResponse, responseBody, nil + } + + switch rsp.Output.TaskStatus { + case "FAILED": + fallthrough + case "CANCELED": + fallthrough + case "SUCCEEDED": + fallthrough + case "UNKNOWN": + return rsp, responseBody, nil + } + if step >= maxStep { + break + } + time.Sleep(time.Duration(waitSeconds) * time.Second) + } + + return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout") +} + +func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse { + imageResponse := openai.ImageResponse{ + Created: helper.GetTimestamp(), + } + + for _, data := range response.Output.Results { + var b64Json string + if responseFormat == "b64_json" { + // 读取 data.Url 的图片数据并转存到 b64Json + imageData, err := getImageData(data.Url) + if err != nil { + // 处理获取图片数据失败的情况 + logger.SysError("getImageData Error getting image data: " + err.Error()) + continue + } + + // 将图片数据转为 Base64 编码的字符串 + b64Json = Base64Encode(imageData) + } else { + // 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image + b64Json = data.B64Image + } + + imageResponse.Data = append(imageResponse.Data, openai.ImageData{ + Url: data.Url, + B64Json: b64Json, + RevisedPrompt: "", + }) + } + return &imageResponse +} + +func getImageData(url string) ([]byte, error) { + response, err := http.Get(url) + if err != nil { + return nil, err + } + defer response.Body.Close() + + imageData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + return imageData, nil +} + +func Base64Encode(data []byte) string { + b64Json := base64.StdEncoding.EncodeToString(data) + return b64Json +} diff --git a/relay/adaptor/ali/main.go b/relay/adaptor/ali/main.go new file mode 100644 index 0000000..6a73c70 --- /dev/null +++ b/relay/adaptor/ali/main.go @@ -0,0 +1,267 @@ +package ali + +import ( + "bufio" + "encoding/json" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r + +const EnableSearchModelSuffix = "-internet" + +func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { + messages := make([]Message, 0, len(request.Messages)) + for i := 0; i < len(request.Messages); i++ { + message := request.Messages[i] + messages = append(messages, Message{ + Content: message.StringContent(), + Role: strings.ToLower(message.Role), + }) + } + enableSearch := false + aliModel := request.Model + if strings.HasSuffix(aliModel, EnableSearchModelSuffix) { + enableSearch = true + aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) + } + request.TopP = helper.Float64PtrMax(request.TopP, 0.9999) + return &ChatRequest{ + Model: aliModel, + Input: Input{ + Messages: messages, + }, + Parameters: Parameters{ + EnableSearch: enableSearch, + IncrementalOutput: request.Stream, + Seed: uint64(request.Seed), + MaxTokens: request.MaxTokens, + Temperature: request.Temperature, + TopP: request.TopP, + TopK: request.TopK, + ResultFormat: "message", + Tools: request.Tools, + }, + } +} + +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + Model: request.Model, + Input: struct { + Texts []string `json:"texts"` + }{ + Texts: request.ParseInput(), + }, + } +} + +func ConvertImageRequest(request model.ImageRequest) *ImageRequest { + var imageRequest ImageRequest + imageRequest.Input.Prompt = request.Prompt + imageRequest.Model = request.Model + imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) + imageRequest.Parameters.N = request.N + imageRequest.ResponseFormat = request.ResponseFormat + + return &imageRequest +} + +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var aliResponse EmbeddingResponse + err := json.NewDecoder(resp.Body).Decode(&aliResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Code != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: aliResponse.Message, + Type: aliResponse.Code, + Param: aliResponse.RequestId, + Code: aliResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + requestModel := c.GetString(ctxkey.RequestModel) + fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) + fullTextResponse.Model = requestModel + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)), + Model: "text-embedding-v1", + Usage: model.Usage{TotalTokens: response.Usage.TotalTokens}, + } + + for _, item := range response.Output.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: item.TextIndex, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + +func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { + fullTextResponse := openai.TextResponse{ + Id: response.RequestId, + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: response.Output.Choices, + Usage: model.Usage{ + PromptTokens: response.Usage.InputTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, + }, + } + return &fullTextResponse +} + +func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + if len(aliResponse.Output.Choices) == 0 { + return nil + } + aliChoice := aliResponse.Output.Choices[0] + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta = aliChoice.Message + if aliChoice.FinishReason != "null" { + finishReason := aliChoice.FinishReason + choice.FinishReason = &finishReason + } + response := openai.ChatCompletionsStreamResponse{ + Id: aliResponse.RequestId, + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: "qwen", + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage model.Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + + common.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || data[:5] != "data:" { + continue + } + data = data[5:] + + var aliResponse ChatResponse + err := json.Unmarshal([]byte(data), &aliResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + if aliResponse.Usage.OutputTokens != 0 { + usage.PromptTokens = aliResponse.Usage.InputTokens + usage.CompletionTokens = aliResponse.Usage.OutputTokens + usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens + } + response := streamResponseAli2OpenAI(&aliResponse) + if response == nil { + continue + } + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + ctx := c.Request.Context() + var aliResponse ChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + logger.Debugf(ctx, "response body: %s\n", responseBody) + err = json.Unmarshal(responseBody, &aliResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if aliResponse.Code != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: aliResponse.Message, + Type: aliResponse.Code, + Param: aliResponse.RequestId, + Code: aliResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseAli2OpenAI(&aliResponse) + fullTextResponse.Model = "qwen" + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/relay/adaptor/ali/model.go b/relay/adaptor/ali/model.go new file mode 100644 index 0000000..a680c7e --- /dev/null +++ b/relay/adaptor/ali/model.go @@ -0,0 +1,154 @@ +package ali + +import ( + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +type Message struct { + Content string `json:"content"` + Role string `json:"role"` +} + +type Input struct { + //Prompt string `json:"prompt"` + Messages []Message `json:"messages"` +} + +type Parameters struct { + TopP *float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` + IncrementalOutput bool `json:"incremental_output,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + ResultFormat string `json:"result_format,omitempty"` + Tools []model.Tool `json:"tools,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model"` + Input Input `json:"input"` + Parameters Parameters `json:"parameters,omitempty"` +} + +type ImageRequest struct { + Model string `json:"model"` + Input struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + } `json:"input"` + Parameters struct { + Size string `json:"size,omitempty"` + N int `json:"n,omitempty"` + Steps string `json:"steps,omitempty"` + Scale string `json:"scale,omitempty"` + } `json:"parameters,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` +} + +type TaskResponse struct { + StatusCode int `json:"status_code,omitempty"` + RequestId string `json:"request_id,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Output struct { + TaskId string `json:"task_id,omitempty"` + TaskStatus string `json:"task_status,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Results []struct { + B64Image string `json:"b64_image,omitempty"` + Url string `json:"url,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + } `json:"results,omitempty"` + TaskMetrics struct { + Total int `json:"TOTAL,omitempty"` + Succeeded int `json:"SUCCEEDED,omitempty"` + Failed int `json:"FAILED,omitempty"` + } `json:"task_metrics,omitempty"` + } `json:"output,omitempty"` + Usage Usage `json:"usage"` +} + +type Header struct { + Action string `json:"action,omitempty"` + Streaming string `json:"streaming,omitempty"` + TaskID string `json:"task_id,omitempty"` + Event string `json:"event,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` + Attributes any `json:"attributes,omitempty"` +} + +type Payload struct { + Model string `json:"model,omitempty"` + Task string `json:"task,omitempty"` + TaskGroup string `json:"task_group,omitempty"` + Function string `json:"function,omitempty"` + Parameters struct { + SampleRate int `json:"sample_rate,omitempty"` + Rate float64 `json:"rate,omitempty"` + Format string `json:"format,omitempty"` + } `json:"parameters,omitempty"` + Input struct { + Text string `json:"text,omitempty"` + } `json:"input,omitempty"` + Usage struct { + Characters int `json:"characters,omitempty"` + } `json:"usage,omitempty"` +} + +type WSSMessage struct { + Header Header `json:"header,omitempty"` + Payload Payload `json:"payload,omitempty"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type Embedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type EmbeddingResponse struct { + Output struct { + Embeddings []Embedding `json:"embeddings"` + } `json:"output"` + Usage Usage `json:"usage"` + Error +} + +type Error struct { + Code string `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Output struct { + //Text string `json:"text"` + //FinishReason string `json:"finish_reason"` + Choices []openai.TextResponseChoice `json:"choices"` +} + +type ChatResponse struct { + Output Output `json:"output"` + Usage Usage `json:"usage"` + Error +} diff --git a/relay/adaptor/alibailian/constants.go b/relay/adaptor/alibailian/constants.go new file mode 100644 index 0000000..947946b --- /dev/null +++ b/relay/adaptor/alibailian/constants.go @@ -0,0 +1,20 @@ +package alibailian + +// https://help.aliyun.com/zh/model-studio/getting-started/models + +var ModelList = []string{ + "qwen-turbo", + "qwen-plus", + "qwen-long", + "qwen-max", + "qwen-coder-plus", + "qwen-coder-plus-latest", + "qwen-coder-turbo", + "qwen-coder-turbo-latest", + "qwen-mt-plus", + "qwen-mt-turbo", + "qwq-32b-preview", + + "deepseek-r1", + "deepseek-v3", +} diff --git a/relay/adaptor/alibailian/main.go b/relay/adaptor/alibailian/main.go new file mode 100644 index 0000000..7839e12 --- /dev/null +++ b/relay/adaptor/alibailian/main.go @@ -0,0 +1,19 @@ +package alibailian + +import ( + "fmt" + + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +func GetRequestURL(meta *meta.Meta) (string, error) { + switch meta.Mode { + case relaymode.ChatCompletions: + return fmt.Sprintf("%s/compatible-mode/v1/chat/completions", meta.BaseURL), nil + case relaymode.Embeddings: + return fmt.Sprintf("%s/compatible-mode/v1/embeddings", meta.BaseURL), nil + default: + } + return "", fmt.Errorf("unsupported relay mode %d for ali bailian", meta.Mode) +} diff --git a/relay/adaptor/anthropic/adaptor.go b/relay/adaptor/anthropic/adaptor.go new file mode 100644 index 0000000..bd0949b --- /dev/null +++ b/relay/adaptor/anthropic/adaptor.go @@ -0,0 +1,79 @@ +package anthropic + +import ( + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("x-api-key", meta.APIKey) + anthropicVersion := c.Request.Header.Get("anthropic-version") + if anthropicVersion == "" { + anthropicVersion = "2023-06-01" + } + req.Header.Set("anthropic-version", anthropicVersion) + req.Header.Set("anthropic-beta", "messages-2023-12-15") + + // https://x.com/alexalbert__/status/1812921642143900036 + // claude-3-5-sonnet can support 8k context + if strings.HasPrefix(meta.ActualModelName, "claude-3-5-sonnet") { + req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15") + } + + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return ConvertRequest(*request), nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "anthropic" +} diff --git a/relay/adaptor/anthropic/constants.go b/relay/adaptor/anthropic/constants.go new file mode 100644 index 0000000..9b515c1 --- /dev/null +++ b/relay/adaptor/anthropic/constants.go @@ -0,0 +1,13 @@ +package anthropic + +var ModelList = []string{ + "claude-instant-1.2", "claude-2.0", "claude-2.1", + "claude-3-haiku-20240307", + "claude-3-5-haiku-20241022", + "claude-3-5-haiku-latest", + "claude-3-sonnet-20240229", + "claude-3-opus-20240229", + "claude-3-5-sonnet-20240620", + "claude-3-5-sonnet-20241022", + "claude-3-5-sonnet-latest", +} diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go new file mode 100644 index 0000000..d3e306c --- /dev/null +++ b/relay/adaptor/anthropic/main.go @@ -0,0 +1,379 @@ +package anthropic + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/image" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +func stopReasonClaude2OpenAI(reason *string) string { + if reason == nil { + return "" + } + switch *reason { + case "end_turn": + return "stop" + case "stop_sequence": + return "stop" + case "max_tokens": + return "length" + case "tool_use": + return "tool_calls" + default: + return *reason + } +} + +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { + claudeTools := make([]Tool, 0, len(textRequest.Tools)) + + for _, tool := range textRequest.Tools { + if params, ok := tool.Function.Parameters.(map[string]any); ok { + claudeTools = append(claudeTools, Tool{ + Name: tool.Function.Name, + Description: tool.Function.Description, + InputSchema: InputSchema{ + Type: params["type"].(string), + Properties: params["properties"], + Required: params["required"], + }, + }) + } + } + + claudeRequest := Request{ + Model: textRequest.Model, + MaxTokens: textRequest.MaxTokens, + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + TopK: textRequest.TopK, + Stream: textRequest.Stream, + Tools: claudeTools, + } + if len(claudeTools) > 0 { + claudeToolChoice := struct { + Type string `json:"type"` + Name string `json:"name,omitempty"` + }{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output + if choice, ok := textRequest.ToolChoice.(map[string]any); ok { + if function, ok := choice["function"].(map[string]any); ok { + claudeToolChoice.Type = "tool" + claudeToolChoice.Name = function["name"].(string) + } + } else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok { + if toolChoiceType == "any" { + claudeToolChoice.Type = toolChoiceType + } + } + claudeRequest.ToolChoice = claudeToolChoice + } + if claudeRequest.MaxTokens == 0 { + claudeRequest.MaxTokens = 4096 + } + // legacy model name mapping + if claudeRequest.Model == "claude-instant-1" { + claudeRequest.Model = "claude-instant-1.1" + } else if claudeRequest.Model == "claude-2" { + claudeRequest.Model = "claude-2.1" + } + for _, message := range textRequest.Messages { + if message.Role == "system" && claudeRequest.System == "" { + claudeRequest.System = message.StringContent() + continue + } + claudeMessage := Message{ + Role: message.Role, + } + var content Content + if message.IsStringContent() { + content.Type = "text" + content.Text = message.StringContent() + if message.Role == "tool" { + claudeMessage.Role = "user" + content.Type = "tool_result" + content.Content = content.Text + content.Text = "" + content.ToolUseId = message.ToolCallId + } + claudeMessage.Content = append(claudeMessage.Content, content) + for i := range message.ToolCalls { + inputParam := make(map[string]any) + _ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam) + claudeMessage.Content = append(claudeMessage.Content, Content{ + Type: "tool_use", + Id: message.ToolCalls[i].Id, + Name: message.ToolCalls[i].Function.Name, + Input: inputParam, + }) + } + claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) + continue + } + var contents []Content + openaiContent := message.ParseContent() + for _, part := range openaiContent { + var content Content + if part.Type == model.ContentTypeText { + content.Type = "text" + content.Text = part.Text + } else if part.Type == model.ContentTypeImageURL { + content.Type = "image" + content.Source = &ImageSource{ + Type: "base64", + } + mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) + content.Source.MediaType = mimeType + content.Source.Data = data + } + contents = append(contents, content) + } + claudeMessage.Content = contents + claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) + } + return &claudeRequest +} + +// https://docs.anthropic.com/claude/reference/messages-streaming +func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { + var response *Response + var responseText string + var stopReason string + tools := make([]model.Tool, 0) + + switch claudeResponse.Type { + case "message_start": + return nil, claudeResponse.Message + case "content_block_start": + if claudeResponse.ContentBlock != nil { + responseText = claudeResponse.ContentBlock.Text + if claudeResponse.ContentBlock.Type == "tool_use" { + tools = append(tools, model.Tool{ + Id: claudeResponse.ContentBlock.Id, + Type: "function", + Function: model.Function{ + Name: claudeResponse.ContentBlock.Name, + Arguments: "", + }, + }) + } + } + case "content_block_delta": + if claudeResponse.Delta != nil { + responseText = claudeResponse.Delta.Text + if claudeResponse.Delta.Type == "input_json_delta" { + tools = append(tools, model.Tool{ + Function: model.Function{ + Arguments: claudeResponse.Delta.PartialJson, + }, + }) + } + } + case "message_delta": + if claudeResponse.Usage != nil { + response = &Response{ + Usage: *claudeResponse.Usage, + } + } + if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil { + stopReason = *claudeResponse.Delta.StopReason + } + } + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = responseText + if len(tools) > 0 { + choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... + choice.Delta.ToolCalls = tools + } + choice.Delta.Role = "assistant" + finishReason := stopReasonClaude2OpenAI(&stopReason) + if finishReason != "null" { + choice.FinishReason = &finishReason + } + var openaiResponse openai.ChatCompletionsStreamResponse + openaiResponse.Object = "chat.completion.chunk" + openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + return &openaiResponse, response +} + +func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { + var responseText string + if len(claudeResponse.Content) > 0 { + responseText = claudeResponse.Content[0].Text + } + tools := make([]model.Tool, 0) + for _, v := range claudeResponse.Content { + if v.Type == "tool_use" { + args, _ := json.Marshal(v.Input) + tools = append(tools, model.Tool{ + Id: v.Id, + Type: "function", // compatible with other OpenAI derivative applications + Function: model.Function{ + Name: v.Name, + Arguments: string(args), + }, + }) + } + } + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: responseText, + Name: nil, + ToolCalls: tools, + }, + FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id), + Model: claudeResponse.Model, + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + createdTime := helper.GetTimestamp() + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + + common.SetEventStreamHeaders(c) + + var usage model.Usage + var modelName string + var id string + var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 || !strings.HasPrefix(data, "data:") { + continue + } + data = strings.TrimPrefix(data, "data:") + data = strings.TrimSpace(data) + + var claudeResponse StreamResponse + err := json.Unmarshal([]byte(data), &claudeResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response, meta := StreamResponseClaude2OpenAI(&claudeResponse) + if meta != nil { + usage.PromptTokens += meta.Usage.InputTokens + usage.CompletionTokens += meta.Usage.OutputTokens + if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. + modelName = meta.Model + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + continue + } else { // finish_reason case + if len(lastToolCallChoice.Delta.ToolCalls) > 0 { + lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function + if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. + lastArgs.Arguments = "{}" + response.Choices[len(response.Choices)-1].Delta.Content = nil + response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls + } + } + } + } + if response == nil { + continue + } + + response.Id = id + response.Model = modelName + response.Created = createdTime + + for _, choice := range response.Choices { + if len(choice.Delta.ToolCalls) > 0 { + lastToolCallChoice = choice + } + } + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var claudeResponse Response + err = json.Unmarshal(responseBody, &claudeResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if claudeResponse.Error.Type != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: claudeResponse.Error.Message, + Type: claudeResponse.Error.Type, + Param: "", + Code: claudeResponse.Error.Type, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := ResponseClaude2OpenAI(&claudeResponse) + fullTextResponse.Model = modelName + usage := model.Usage{ + PromptTokens: claudeResponse.Usage.InputTokens, + CompletionTokens: claudeResponse.Usage.OutputTokens, + TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/relay/adaptor/anthropic/model.go b/relay/adaptor/anthropic/model.go new file mode 100644 index 0000000..47f193f --- /dev/null +++ b/relay/adaptor/anthropic/model.go @@ -0,0 +1,96 @@ +package anthropic + +// https://docs.anthropic.com/claude/reference/messages_post + +type Metadata struct { + UserId string `json:"user_id"` +} + +type ImageSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +type Content struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Source *ImageSource `json:"source,omitempty"` + // tool_calls + Id string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` + Content string `json:"content,omitempty"` + ToolUseId string `json:"tool_use_id,omitempty"` +} + +type Message struct { + Role string `json:"role"` + Content []Content `json:"content"` +} + +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema InputSchema `json:"input_schema"` +} + +type InputSchema struct { + Type string `json:"type"` + Properties any `json:"properties,omitempty"` + Required any `json:"required,omitempty"` +} + +type Request struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + System string `json:"system,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + //Metadata `json:"metadata,omitempty"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type Error struct { + Type string `json:"type"` + Message string `json:"message"` +} + +type Response struct { + Id string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []Content `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` + Usage Usage `json:"usage"` + Error Error `json:"error"` +} + +type Delta struct { + Type string `json:"type"` + Text string `json:"text"` + PartialJson string `json:"partial_json,omitempty"` + StopReason *string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` +} + +type StreamResponse struct { + Type string `json:"type"` + Message *Response `json:"message"` + Index int `json:"index"` + ContentBlock *Content `json:"content_block"` + Delta *Delta `json:"delta"` + Usage *Usage `json:"usage"` +} diff --git a/relay/adaptor/aws/adaptor.go b/relay/adaptor/aws/adaptor.go new file mode 100644 index 0000000..6222134 --- /dev/null +++ b/relay/adaptor/aws/adaptor.go @@ -0,0 +1,84 @@ +package aws + +import ( + "errors" + "io" + "net/http" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var _ adaptor.Adaptor = new(Adaptor) + +type Adaptor struct { + awsAdapter utils.AwsAdapter + + Meta *meta.Meta + AwsClient *bedrockruntime.Client +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.Meta = meta + a.AwsClient = bedrockruntime.New(bedrockruntime.Options{ + Region: meta.Config.Region, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), + }) +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + adaptor := GetAdaptor(request.Model) + if adaptor == nil { + return nil, errors.New("adaptor not found") + } + + a.awsAdapter = adaptor + return adaptor.ConvertRequest(c, relayMode, request) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if a.awsAdapter == nil { + return nil, utils.WrapErr(errors.New("awsAdapter is nil")) + } + return a.awsAdapter.DoResponse(c, a.AwsClient, meta) +} + +func (a *Adaptor) GetModelList() (models []string) { + for model := range adaptors { + models = append(models, model) + } + return +} + +func (a *Adaptor) GetChannelName() string { + return "aws" +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return "", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + return nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return nil, nil +} diff --git a/relay/adaptor/aws/claude/adapter.go b/relay/adaptor/aws/claude/adapter.go new file mode 100644 index 0000000..eb3c9fb --- /dev/null +++ b/relay/adaptor/aws/claude/adapter.go @@ -0,0 +1,37 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var _ utils.AwsAdapter = new(Adaptor) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + claudeReq := anthropic.ConvertRequest(*request) + c.Set(ctxkey.RequestModel, request.Model) + c.Set(ctxkey.ConvertedRequest, claudeReq) + return claudeReq, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, awsCli) + } else { + err, usage = Handler(c, awsCli, meta.ActualModelName) + } + return +} diff --git a/relay/adaptor/aws/claude/main.go b/relay/adaptor/aws/claude/main.go new file mode 100644 index 0000000..3fe3dfd --- /dev/null +++ b/relay/adaptor/aws/claude/main.go @@ -0,0 +1,207 @@ +// Package aws provides the AWS adaptor for the relay service. +package aws + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/gin-gonic/gin" + "github.com/jinzhu/copier" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html +var AwsModelIDMap = map[string]string{ + "claude-instant-1.2": "anthropic.claude-instant-v1", + "claude-2.0": "anthropic.claude-v2", + "claude-2.1": "anthropic.claude-v2:1", + "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", + "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", + "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0", + "claude-3-5-sonnet-latest": "anthropic.claude-3-5-sonnet-20241022-v2:0", + "claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0", +} + +func awsModelID(requestModel string) (string, error) { + if awsModelID, ok := AwsModelIDMap[requestModel]; ok { + return awsModelID, nil + } + + return "", errors.Errorf("model %s not found", requestModel) +} + +func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) + if !ok { + return utils.WrapErr(errors.New("request not found")), nil + } + claudeReq := claudeReq_.(*anthropic.Request) + awsClaudeReq := &Request{ + AnthropicVersion: "bedrock-2023-05-31", + } + if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { + return utils.WrapErr(errors.Wrap(err, "copy request")), nil + } + + awsReq.Body, err = json.Marshal(awsClaudeReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil + } + + claudeResponse := new(anthropic.Response) + err = json.Unmarshal(awsResp.Body, claudeResponse) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil + } + + openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) + openaiResp.Model = modelName + usage := relaymodel.Usage{ + PromptTokens: claudeResponse.Usage.InputTokens, + CompletionTokens: claudeResponse.Usage.OutputTokens, + TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, + } + openaiResp.Usage = usage + + c.JSON(http.StatusOK, openaiResp) + return nil, &usage +} + +func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + createdTime := helper.GetTimestamp() + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) + if !ok { + return utils.WrapErr(errors.New("request not found")), nil + } + claudeReq := claudeReq_.(*anthropic.Request) + + awsClaudeReq := &Request{ + AnthropicVersion: "bedrock-2023-05-31", + } + if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { + return utils.WrapErr(errors.Wrap(err, "copy request")), nil + } + awsReq.Body, err = json.Marshal(awsClaudeReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil + } + stream := awsResp.GetStream() + defer stream.Close() + + c.Writer.Header().Set("Content-Type", "text/event-stream") + var usage relaymodel.Usage + var id string + var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice + + c.Stream(func(w io.Writer) bool { + event, ok := <-stream.Events() + if !ok { + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + + switch v := event.(type) { + case *types.ResponseStreamMemberChunk: + claudeResp := new(anthropic.StreamResponse) + err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return false + } + + response, meta := anthropic.StreamResponseClaude2OpenAI(claudeResp) + if meta != nil { + usage.PromptTokens += meta.Usage.InputTokens + usage.CompletionTokens += meta.Usage.OutputTokens + if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + return true + } else { // finish_reason case + if len(lastToolCallChoice.Delta.ToolCalls) > 0 { + lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function + if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. + lastArgs.Arguments = "{}" + response.Choices[len(response.Choices)-1].Delta.Content = nil + response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls + } + } + } + } + if response == nil { + return true + } + response.Id = id + response.Model = c.GetString(ctxkey.OriginalModel) + response.Created = createdTime + + for _, choice := range response.Choices { + if len(choice.Delta.ToolCalls) > 0 { + lastToolCallChoice = choice + } + } + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case *types.UnknownUnionMember: + fmt.Println("unknown tag:", v.Tag) + return false + default: + fmt.Println("union is nil or unknown type") + return false + } + }) + + return nil, &usage +} diff --git a/relay/adaptor/aws/claude/model.go b/relay/adaptor/aws/claude/model.go new file mode 100644 index 0000000..1062288 --- /dev/null +++ b/relay/adaptor/aws/claude/model.go @@ -0,0 +1,20 @@ +package aws + +import "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + +// Request is the request to AWS Claude +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +type Request struct { + // AnthropicVersion should be "bedrock-2023-05-31" + AnthropicVersion string `json:"anthropic_version"` + Messages []anthropic.Message `json:"messages"` + System string `json:"system,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Tools []anthropic.Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` +} diff --git a/relay/adaptor/aws/llama3/adapter.go b/relay/adaptor/aws/llama3/adapter.go new file mode 100644 index 0000000..83edbc9 --- /dev/null +++ b/relay/adaptor/aws/llama3/adapter.go @@ -0,0 +1,37 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/songquanpeng/one-api/common/ctxkey" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var _ utils.AwsAdapter = new(Adaptor) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + llamaReq := ConvertRequest(*request) + c.Set(ctxkey.RequestModel, request.Model) + c.Set(ctxkey.ConvertedRequest, llamaReq) + return llamaReq, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, awsCli) + } else { + err, usage = Handler(c, awsCli, meta.ActualModelName) + } + return +} diff --git a/relay/adaptor/aws/llama3/main.go b/relay/adaptor/aws/llama3/main.go new file mode 100644 index 0000000..e5fcd89 --- /dev/null +++ b/relay/adaptor/aws/llama3/main.go @@ -0,0 +1,231 @@ +// Package aws provides the AWS adaptor for the relay service. +package aws + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "text/template" + + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/random" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +// Only support llama-3-8b and llama-3-70b instruction models +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html +var AwsModelIDMap = map[string]string{ + "llama3-8b-8192": "meta.llama3-8b-instruct-v1:0", + "llama3-70b-8192": "meta.llama3-70b-instruct-v1:0", +} + +func awsModelID(requestModel string) (string, error) { + if awsModelID, ok := AwsModelIDMap[requestModel]; ok { + return awsModelID, nil + } + + return "", errors.Errorf("model %s not found", requestModel) +} + +// promptTemplate with range +const promptTemplate = `<|begin_of_text|>{{range .Messages}}<|start_header_id|>{{.Role}}<|end_header_id|>{{.StringContent}}<|eot_id|>{{end}}<|start_header_id|>assistant<|end_header_id|> +` + +var promptTpl = template.Must(template.New("llama3-chat").Parse(promptTemplate)) + +func RenderPrompt(messages []relaymodel.Message) string { + var buf bytes.Buffer + err := promptTpl.Execute(&buf, struct{ Messages []relaymodel.Message }{messages}) + if err != nil { + logger.SysError("error rendering prompt messages: " + err.Error()) + } + return buf.String() +} + +func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request { + llamaRequest := Request{ + MaxGenLen: textRequest.MaxTokens, + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + } + if llamaRequest.MaxGenLen == 0 { + llamaRequest.MaxGenLen = 2048 + } + prompt := RenderPrompt(textRequest.Messages) + llamaRequest.Prompt = prompt + return &llamaRequest +} + +func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + llamaReq, ok := c.Get(ctxkey.ConvertedRequest) + if !ok { + return utils.WrapErr(errors.New("request not found")), nil + } + + awsReq.Body, err = json.Marshal(llamaReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil + } + + var llamaResponse Response + err = json.Unmarshal(awsResp.Body, &llamaResponse) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil + } + + openaiResp := ResponseLlama2OpenAI(&llamaResponse) + openaiResp.Model = modelName + usage := relaymodel.Usage{ + PromptTokens: llamaResponse.PromptTokenCount, + CompletionTokens: llamaResponse.GenerationTokenCount, + TotalTokens: llamaResponse.PromptTokenCount + llamaResponse.GenerationTokenCount, + } + openaiResp.Usage = usage + + c.JSON(http.StatusOK, openaiResp) + return nil, &usage +} + +func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse { + var responseText string + if len(llamaResponse.Generation) > 0 { + responseText = llamaResponse.Generation + } + choice := openai.TextResponseChoice{ + Index: 0, + Message: relaymodel.Message{ + Role: "assistant", + Content: responseText, + Name: nil, + }, + FinishReason: llamaResponse.StopReason, + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + createdTime := helper.GetTimestamp() + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + llamaReq, ok := c.Get(ctxkey.ConvertedRequest) + if !ok { + return utils.WrapErr(errors.New("request not found")), nil + } + + awsReq.Body, err = json.Marshal(llamaReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil + } + stream := awsResp.GetStream() + defer stream.Close() + + c.Writer.Header().Set("Content-Type", "text/event-stream") + var usage relaymodel.Usage + c.Stream(func(w io.Writer) bool { + event, ok := <-stream.Events() + if !ok { + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + + switch v := event.(type) { + case *types.ResponseStreamMemberChunk: + var llamaResp StreamResponse + err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&llamaResp) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return false + } + + if llamaResp.PromptTokenCount > 0 { + usage.PromptTokens = llamaResp.PromptTokenCount + } + if llamaResp.StopReason == "stop" { + usage.CompletionTokens = llamaResp.GenerationTokenCount + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + response := StreamResponseLlama2OpenAI(&llamaResp) + response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID()) + response.Model = c.GetString(ctxkey.OriginalModel) + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case *types.UnknownUnionMember: + fmt.Println("unknown tag:", v.Tag) + return false + default: + fmt.Println("union is nil or unknown type") + return false + } + }) + + return nil, &usage +} + +func StreamResponseLlama2OpenAI(llamaResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = llamaResponse.Generation + choice.Delta.Role = "assistant" + finishReason := llamaResponse.StopReason + if finishReason != "null" { + choice.FinishReason = &finishReason + } + var openaiResponse openai.ChatCompletionsStreamResponse + openaiResponse.Object = "chat.completion.chunk" + openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + return &openaiResponse +} diff --git a/relay/adaptor/aws/llama3/main_test.go b/relay/adaptor/aws/llama3/main_test.go new file mode 100644 index 0000000..d539eee --- /dev/null +++ b/relay/adaptor/aws/llama3/main_test.go @@ -0,0 +1,45 @@ +package aws_test + +import ( + "testing" + + aws "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3" + relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/stretchr/testify/assert" +) + +func TestRenderPrompt(t *testing.T) { + messages := []relaymodel.Message{ + { + Role: "user", + Content: "What's your name?", + }, + } + prompt := aws.RenderPrompt(messages) + expected := `<|begin_of_text|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|> +` + assert.Equal(t, expected, prompt) + + messages = []relaymodel.Message{ + { + Role: "system", + Content: "Your name is Kat. You are a detective.", + }, + { + Role: "user", + Content: "What's your name?", + }, + { + Role: "assistant", + Content: "Kat", + }, + { + Role: "user", + Content: "What's your job?", + }, + } + prompt = aws.RenderPrompt(messages) + expected = `<|begin_of_text|><|start_header_id|>system<|end_header_id|>Your name is Kat. You are a detective.<|eot_id|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>Kat<|eot_id|><|start_header_id|>user<|end_header_id|>What's your job?<|eot_id|><|start_header_id|>assistant<|end_header_id|> +` + assert.Equal(t, expected, prompt) +} diff --git a/relay/adaptor/aws/llama3/model.go b/relay/adaptor/aws/llama3/model.go new file mode 100644 index 0000000..6cb64cd --- /dev/null +++ b/relay/adaptor/aws/llama3/model.go @@ -0,0 +1,29 @@ +package aws + +// Request is the request to AWS Llama3 +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html +type Request struct { + Prompt string `json:"prompt"` + MaxGenLen int `json:"max_gen_len,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` +} + +// Response is the response from AWS Llama3 +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html +type Response struct { + Generation string `json:"generation"` + PromptTokenCount int `json:"prompt_token_count"` + GenerationTokenCount int `json:"generation_token_count"` + StopReason string `json:"stop_reason"` +} + +// {'generation': 'Hi', 'prompt_token_count': 15, 'generation_token_count': 1, 'stop_reason': None} +type StreamResponse struct { + Generation string `json:"generation"` + PromptTokenCount int `json:"prompt_token_count"` + GenerationTokenCount int `json:"generation_token_count"` + StopReason string `json:"stop_reason"` +} diff --git a/relay/adaptor/aws/registry.go b/relay/adaptor/aws/registry.go new file mode 100644 index 0000000..5f65548 --- /dev/null +++ b/relay/adaptor/aws/registry.go @@ -0,0 +1,39 @@ +package aws + +import ( + claude "github.com/songquanpeng/one-api/relay/adaptor/aws/claude" + llama3 "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" +) + +type AwsModelType int + +const ( + AwsClaude AwsModelType = iota + 1 + AwsLlama3 +) + +var ( + adaptors = map[string]AwsModelType{} +) + +func init() { + for model := range claude.AwsModelIDMap { + adaptors[model] = AwsClaude + } + for model := range llama3.AwsModelIDMap { + adaptors[model] = AwsLlama3 + } +} + +func GetAdaptor(model string) utils.AwsAdapter { + adaptorType := adaptors[model] + switch adaptorType { + case AwsClaude: + return &claude.Adaptor{} + case AwsLlama3: + return &llama3.Adaptor{} + default: + return nil + } +} diff --git a/relay/adaptor/aws/utils/adaptor.go b/relay/adaptor/aws/utils/adaptor.go new file mode 100644 index 0000000..4cb880f --- /dev/null +++ b/relay/adaptor/aws/utils/adaptor.go @@ -0,0 +1,51 @@ +package utils + +import ( + "errors" + "io" + "net/http" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +type AwsAdapter interface { + ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) +} + +type Adaptor struct { + Meta *meta.Meta + AwsClient *bedrockruntime.Client +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.Meta = meta + a.AwsClient = bedrockruntime.New(bedrockruntime.Options{ + Region: meta.Config.Region, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), + }) +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return "", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + return nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return nil, nil +} diff --git a/relay/adaptor/aws/utils/utils.go b/relay/adaptor/aws/utils/utils.go new file mode 100644 index 0000000..669dc62 --- /dev/null +++ b/relay/adaptor/aws/utils/utils.go @@ -0,0 +1,16 @@ +package utils + +import ( + "net/http" + + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +func WrapErr(err error) *relaymodel.ErrorWithStatusCode { + return &relaymodel.ErrorWithStatusCode{ + StatusCode: http.StatusInternalServerError, + Error: relaymodel.Error{ + Message: err.Error(), + }, + } +} diff --git a/relay/adaptor/baichuan/constants.go b/relay/adaptor/baichuan/constants.go new file mode 100644 index 0000000..cb20a1f --- /dev/null +++ b/relay/adaptor/baichuan/constants.go @@ -0,0 +1,7 @@ +package baichuan + +var ModelList = []string{ + "Baichuan2-Turbo", + "Baichuan2-Turbo-192k", + "Baichuan-Text-Embedding", +} diff --git a/relay/adaptor/baidu/adaptor.go b/relay/adaptor/baidu/adaptor.go new file mode 100644 index 0000000..15306b9 --- /dev/null +++ b/relay/adaptor/baidu/adaptor.go @@ -0,0 +1,143 @@ +package baidu + +import ( + "errors" + "fmt" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/model" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t + suffix := "chat/" + if strings.HasPrefix(meta.ActualModelName, "Embedding") { + suffix = "embeddings/" + } + if strings.HasPrefix(meta.ActualModelName, "bge-large") { + suffix = "embeddings/" + } + if strings.HasPrefix(meta.ActualModelName, "tao-8k") { + suffix = "embeddings/" + } + switch meta.ActualModelName { + case "ERNIE-4.0": + suffix += "completions_pro" + case "ERNIE-Bot-4": + suffix += "completions_pro" + case "ERNIE-Bot": + suffix += "completions" + case "ERNIE-Bot-turbo": + suffix += "eb-instant" + case "ERNIE-Speed": + suffix += "ernie_speed" + case "ERNIE-4.0-8K": + suffix += "completions_pro" + case "ERNIE-3.5-8K": + suffix += "completions" + case "ERNIE-3.5-8K-0205": + suffix += "ernie-3.5-8k-0205" + case "ERNIE-3.5-8K-1222": + suffix += "ernie-3.5-8k-1222" + case "ERNIE-Bot-8K": + suffix += "ernie_bot_8k" + case "ERNIE-3.5-4K-0205": + suffix += "ernie-3.5-4k-0205" + case "ERNIE-Speed-8K": + suffix += "ernie_speed" + case "ERNIE-Speed-128K": + suffix += "ernie-speed-128k" + case "ERNIE-Lite-8K-0922": + suffix += "eb-instant" + case "ERNIE-Lite-8K-0308": + suffix += "ernie-lite-8k" + case "ERNIE-Tiny-8K": + suffix += "ernie-tiny-8k" + case "BLOOMZ-7B": + suffix += "bloomz_7b1" + case "Embedding-V1": + suffix += "embedding-v1" + case "bge-large-zh": + suffix += "bge_large_zh" + case "bge-large-en": + suffix += "bge_large_en" + case "tao-8k": + suffix += "tao_8k" + default: + suffix += strings.ToLower(meta.ActualModelName) + } + fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) + var accessToken string + var err error + if accessToken, err = GetAccessToken(meta.APIKey); err != nil { + return "", err + } + fullRequestURL += "?access_token=" + accessToken + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case relaymode.Embeddings: + baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) + return baiduEmbeddingRequest, nil + default: + baiduRequest := ConvertRequest(*request) + return baiduRequest, nil + } +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "baidu" +} diff --git a/relay/adaptor/baidu/constants.go b/relay/adaptor/baidu/constants.go new file mode 100644 index 0000000..f952adc --- /dev/null +++ b/relay/adaptor/baidu/constants.go @@ -0,0 +1,20 @@ +package baidu + +var ModelList = []string{ + "ERNIE-4.0-8K", + "ERNIE-3.5-8K", + "ERNIE-3.5-8K-0205", + "ERNIE-3.5-8K-1222", + "ERNIE-Bot-8K", + "ERNIE-3.5-4K-0205", + "ERNIE-Speed-8K", + "ERNIE-Speed-128K", + "ERNIE-Lite-8K-0922", + "ERNIE-Lite-8K-0308", + "ERNIE-Tiny-8K", + "BLOOMZ-7B", + "Embedding-V1", + "bge-large-zh", + "bge-large-en", + "tao-8k", +} diff --git a/relay/adaptor/baidu/main.go b/relay/adaptor/baidu/main.go new file mode 100644 index 0000000..ac8a562 --- /dev/null +++ b/relay/adaptor/baidu/main.go @@ -0,0 +1,312 @@ +package baidu + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/client" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" +) + +// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 + +type TokenResponse struct { + ExpiresIn int `json:"expires_in"` + AccessToken string `json:"access_token"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ChatRequest struct { + Messages []Message `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + PenaltyScore *float64 `json:"penalty_score,omitempty"` + Stream bool `json:"stream,omitempty"` + System string `json:"system,omitempty"` + DisableSearch bool `json:"disable_search,omitempty"` + EnableCitation bool `json:"enable_citation,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + UserId string `json:"user_id,omitempty"` +} + +type Error struct { + ErrorCode int `json:"error_code"` + ErrorMsg string `json:"error_msg"` +} + +var baiduTokenStore sync.Map + +func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { + baiduRequest := ChatRequest{ + Messages: make([]Message, 0, len(request.Messages)), + Temperature: request.Temperature, + TopP: request.TopP, + PenaltyScore: request.FrequencyPenalty, + Stream: request.Stream, + DisableSearch: false, + EnableCitation: false, + MaxOutputTokens: request.MaxTokens, + UserId: request.User, + } + for _, message := range request.Messages { + if message.Role == "system" { + baiduRequest.System = message.StringContent() + } else { + baiduRequest.Messages = append(baiduRequest.Messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) + } + } + return &baiduRequest +} + +func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: response.Result, + }, + FinishReason: "stop", + } + fullTextResponse := openai.TextResponse{ + Id: response.Id, + Object: "chat.completion", + Created: response.Created, + Choices: []openai.TextResponseChoice{choice}, + Usage: response.Usage, + } + return &fullTextResponse +} + +func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = baiduResponse.Result + if baiduResponse.IsEnd { + choice.FinishReason = &constant.StopFinishReason + } + response := openai.ChatCompletionsStreamResponse{ + Id: baiduResponse.Id, + Object: "chat.completion.chunk", + Created: baiduResponse.Created, + Model: "ernie-bot", + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + Input: request.ParseInput(), + } +} + +func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)), + Model: "baidu-embedding", + Usage: response.Usage, + } + for _, item := range response.Data { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: item.Object, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage model.Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + common.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { + continue + } + data = data[6:] + + var baiduResponse ChatStreamResponse + err := json.Unmarshal([]byte(data), &baiduResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + if baiduResponse.Usage.TotalTokens != 0 { + usage.TotalTokens = baiduResponse.Usage.TotalTokens + usage.PromptTokens = baiduResponse.Usage.PromptTokens + usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens + } + response := streamResponseBaidu2OpenAI(&baiduResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var baiduResponse ChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &baiduResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if baiduResponse.ErrorMsg != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: baiduResponse.ErrorMsg, + Type: "baidu_error", + Param: "", + Code: baiduResponse.ErrorCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseBaidu2OpenAI(&baiduResponse) + fullTextResponse.Model = "ernie-bot" + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var baiduResponse EmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &baiduResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if baiduResponse.ErrorMsg != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: baiduResponse.ErrorMsg, + Type: "baidu_error", + Param: "", + Code: baiduResponse.ErrorCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func GetAccessToken(apiKey string) (string, error) { + if val, ok := baiduTokenStore.Load(apiKey); ok { + var accessToken AccessToken + if accessToken, ok = val.(AccessToken); ok { + // soon this will expire + if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { + go func() { + _, _ = getBaiduAccessTokenHelper(apiKey) + }() + } + return accessToken.AccessToken, nil + } + } + accessToken, err := getBaiduAccessTokenHelper(apiKey) + if err != nil { + return "", err + } + if accessToken == nil { + return "", errors.New("GetAccessToken return a nil token") + } + return (*accessToken).AccessToken, nil +} + +func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) { + parts := strings.Split(apiKey, "|") + if len(parts) != 2 { + return nil, errors.New("invalid baidu apikey") + } + req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", + parts[0], parts[1]), nil) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + res, err := client.ImpatientHTTPClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + var accessToken AccessToken + err = json.NewDecoder(res.Body).Decode(&accessToken) + if err != nil { + return nil, err + } + if accessToken.Error != "" { + return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) + } + if accessToken.AccessToken == "" { + return nil, errors.New("getBaiduAccessTokenHelper get empty access token") + } + accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) + baiduTokenStore.Store(apiKey, accessToken) + return &accessToken, nil +} diff --git a/relay/adaptor/baidu/model.go b/relay/adaptor/baidu/model.go new file mode 100644 index 0000000..cc1feb2 --- /dev/null +++ b/relay/adaptor/baidu/model.go @@ -0,0 +1,50 @@ +package baidu + +import ( + "github.com/songquanpeng/one-api/relay/model" + "time" +) + +type ChatResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Result string `json:"result"` + IsTruncated bool `json:"is_truncated"` + NeedClearHistory bool `json:"need_clear_history"` + Usage model.Usage `json:"usage"` + Error +} + +type ChatStreamResponse struct { + ChatResponse + SentenceId int `json:"sentence_id"` + IsEnd bool `json:"is_end"` +} + +type EmbeddingRequest struct { + Input []string `json:"input"` +} + +type EmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +type EmbeddingResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Data []EmbeddingData `json:"data"` + Usage model.Usage `json:"usage"` + Error +} + +type AccessToken struct { + AccessToken string `json:"access_token"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + ExpiresAt time.Time `json:"-"` +} diff --git a/relay/adaptor/baiduv2/constants.go b/relay/adaptor/baiduv2/constants.go new file mode 100644 index 0000000..aad9e58 --- /dev/null +++ b/relay/adaptor/baiduv2/constants.go @@ -0,0 +1,30 @@ +package baiduv2 + +// https://console.bce.baidu.com/support/?_=1692863460488×tamp=1739074632076#/api?product=QIANFAN&project=%E5%8D%83%E5%B8%86ModelBuilder&parent=%E5%AF%B9%E8%AF%9DChat%20V2&api=v2%2Fchat%2Fcompletions&method=post +// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Fm2vrveyu#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8 + +var ModelList = []string{ + "ernie-4.0-8k-latest", + "ernie-4.0-8k-preview", + "ernie-4.0-8k", + "ernie-4.0-turbo-8k-latest", + "ernie-4.0-turbo-8k-preview", + "ernie-4.0-turbo-8k", + "ernie-4.0-turbo-128k", + "ernie-3.5-8k-preview", + "ernie-3.5-8k", + "ernie-3.5-128k", + "ernie-speed-8k", + "ernie-speed-128k", + "ernie-speed-pro-128k", + "ernie-lite-8k", + "ernie-lite-pro-128k", + "ernie-tiny-8k", + "ernie-char-8k", + "ernie-char-fiction-8k", + "ernie-novel-8k", + "deepseek-v3", + "deepseek-r1", + "deepseek-r1-distill-qwen-32b", + "deepseek-r1-distill-qwen-14b", +} diff --git a/relay/adaptor/baiduv2/main.go b/relay/adaptor/baiduv2/main.go new file mode 100644 index 0000000..d305e1d --- /dev/null +++ b/relay/adaptor/baiduv2/main.go @@ -0,0 +1,17 @@ +package baiduv2 + +import ( + "fmt" + + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +func GetRequestURL(meta *meta.Meta) (string, error) { + switch meta.Mode { + case relaymode.ChatCompletions: + return fmt.Sprintf("%s/v2/chat/completions", meta.BaseURL), nil + default: + } + return "", fmt.Errorf("unsupported relay mode %d for baidu v2", meta.Mode) +} diff --git a/relay/adaptor/cloudflare/adaptor.go b/relay/adaptor/cloudflare/adaptor.go new file mode 100644 index 0000000..97e3dbb --- /dev/null +++ b/relay/adaptor/cloudflare/adaptor.go @@ -0,0 +1,100 @@ +package cloudflare + +import ( + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +type Adaptor struct { + meta *meta.Meta +} + +// ConvertImageRequest implements adaptor.Adaptor. +func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertImageRequest implements adaptor.Adaptor. + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +// WorkerAI cannot be used across accounts with AIGateWay +// https://developers.cloudflare.com/ai-gateway/providers/workersai/#openai-compatible-endpoints +// https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/workers-ai +func (a *Adaptor) isAIGateWay(baseURL string) bool { + return strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai") +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + isAIGateWay := a.isAIGateWay(meta.BaseURL) + var urlPrefix string + if isAIGateWay { + urlPrefix = meta.BaseURL + } else { + urlPrefix = fmt.Sprintf("%s/client/v4/accounts/%s/ai", meta.BaseURL, meta.Config.UserID) + } + + switch meta.Mode { + case relaymode.ChatCompletions: + return fmt.Sprintf("%s/v1/chat/completions", urlPrefix), nil + case relaymode.Embeddings: + return fmt.Sprintf("%s/v1/embeddings", urlPrefix), nil + default: + if isAIGateWay { + return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModelName), nil + } + return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModelName), nil + } +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case relaymode.Completions: + return ConvertCompletionsRequest(*request), nil + case relaymode.ChatCompletions, relaymode.Embeddings: + return request, nil + default: + return nil, errors.New("not implemented") + } +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp, meta.PromptTokens, meta.ActualModelName) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "cloudflare" +} diff --git a/relay/adaptor/cloudflare/constant.go b/relay/adaptor/cloudflare/constant.go new file mode 100644 index 0000000..54052aa --- /dev/null +++ b/relay/adaptor/cloudflare/constant.go @@ -0,0 +1,37 @@ +package cloudflare + +var ModelList = []string{ + "@cf/meta/llama-3.1-8b-instruct", + "@cf/meta/llama-2-7b-chat-fp16", + "@cf/meta/llama-2-7b-chat-int8", + "@cf/mistral/mistral-7b-instruct-v0.1", + "@hf/thebloke/deepseek-coder-6.7b-base-awq", + "@hf/thebloke/deepseek-coder-6.7b-instruct-awq", + "@cf/deepseek-ai/deepseek-math-7b-base", + "@cf/deepseek-ai/deepseek-math-7b-instruct", + "@cf/thebloke/discolm-german-7b-v1-awq", + "@cf/tiiuae/falcon-7b-instruct", + "@cf/google/gemma-2b-it-lora", + "@hf/google/gemma-7b-it", + "@cf/google/gemma-7b-it-lora", + "@hf/nousresearch/hermes-2-pro-mistral-7b", + "@hf/thebloke/llama-2-13b-chat-awq", + "@cf/meta-llama/llama-2-7b-chat-hf-lora", + "@cf/meta/llama-3-8b-instruct", + "@hf/thebloke/llamaguard-7b-awq", + "@hf/thebloke/mistral-7b-instruct-v0.1-awq", + "@hf/mistralai/mistral-7b-instruct-v0.2", + "@cf/mistral/mistral-7b-instruct-v0.2-lora", + "@hf/thebloke/neural-chat-7b-v3-1-awq", + "@cf/openchat/openchat-3.5-0106", + "@hf/thebloke/openhermes-2.5-mistral-7b-awq", + "@cf/microsoft/phi-2", + "@cf/qwen/qwen1.5-0.5b-chat", + "@cf/qwen/qwen1.5-1.8b-chat", + "@cf/qwen/qwen1.5-14b-chat-awq", + "@cf/qwen/qwen1.5-7b-chat-awq", + "@cf/defog/sqlcoder-7b-2", + "@hf/nexusflow/starling-lm-7b-beta", + "@cf/tinyllama/tinyllama-1.1b-chat-v1.0", + "@hf/thebloke/zephyr-7b-beta-awq", +} diff --git a/relay/adaptor/cloudflare/main.go b/relay/adaptor/cloudflare/main.go new file mode 100644 index 0000000..980a289 --- /dev/null +++ b/relay/adaptor/cloudflare/main.go @@ -0,0 +1,115 @@ +package cloudflare + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "strings" + + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/render" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request { + p, _ := textRequest.Prompt.(string) + return &Request{ + Prompt: p, + MaxTokens: textRequest.MaxTokens, + Stream: textRequest.Stream, + Temperature: textRequest.Temperature, + } +} + +func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + common.SetEventStreamHeaders(c) + id := helper.GetResponseID(c) + responseModel := c.GetString(ctxkey.OriginalModel) + var responseText string + + for scanner.Scan() { + data := scanner.Text() + if len(data) < len("data: ") { + continue + } + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\r") + + if data == "[DONE]" { + break + } + + var response openai.ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data), &response) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + for _, v := range response.Choices { + v.Delta.Role = "assistant" + responseText += v.Delta.StringContent() + } + response.Id = id + response.Model = modelName + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens) + return nil, usage +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var response openai.TextResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + response.Model = modelName + var responseText string + for _, v := range response.Choices { + responseText += v.Message.Content.(string) + } + usage := openai.ResponseText2Usage(responseText, modelName, promptTokens) + response.Usage = *usage + response.Id = helper.GetResponseID(c) + jsonResponse, err := json.Marshal(response) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + return nil, usage +} diff --git a/relay/adaptor/cloudflare/model.go b/relay/adaptor/cloudflare/model.go new file mode 100644 index 0000000..8e382ba --- /dev/null +++ b/relay/adaptor/cloudflare/model.go @@ -0,0 +1,13 @@ +package cloudflare + +import "github.com/songquanpeng/one-api/relay/model" + +type Request struct { + Messages []model.Message `json:"messages,omitempty"` + Lora string `json:"lora,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Prompt string `json:"prompt,omitempty"` + Raw bool `json:"raw,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` +} diff --git a/relay/adaptor/cohere/adaptor.go b/relay/adaptor/cohere/adaptor.go new file mode 100644 index 0000000..6fdb1b0 --- /dev/null +++ b/relay/adaptor/cohere/adaptor.go @@ -0,0 +1,64 @@ +package cohere + +import ( + "errors" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +type Adaptor struct{} + +// ConvertImageRequest implements adaptor.Adaptor. +func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertImageRequest implements adaptor.Adaptor. + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("%s/v1/chat", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return ConvertRequest(*request), nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "Cohere" +} diff --git a/relay/adaptor/cohere/constant.go b/relay/adaptor/cohere/constant.go new file mode 100644 index 0000000..9e70652 --- /dev/null +++ b/relay/adaptor/cohere/constant.go @@ -0,0 +1,14 @@ +package cohere + +var ModelList = []string{ + "command", "command-nightly", + "command-light", "command-light-nightly", + "command-r", "command-r-plus", +} + +func init() { + num := len(ModelList) + for i := 0; i < num; i++ { + ModelList = append(ModelList, ModelList[i]+"-internet") + } +} diff --git a/relay/adaptor/cohere/main.go b/relay/adaptor/cohere/main.go new file mode 100644 index 0000000..736c5a8 --- /dev/null +++ b/relay/adaptor/cohere/main.go @@ -0,0 +1,228 @@ +package cohere + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +var ( + WebSearchConnector = Connector{ID: "web-search"} +) + +func stopReasonCohere2OpenAI(reason *string) string { + if reason == nil { + return "" + } + switch *reason { + case "COMPLETE": + return "stop" + default: + return *reason + } +} + +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { + cohereRequest := Request{ + Model: textRequest.Model, + Message: "", + MaxTokens: textRequest.MaxTokens, + Temperature: textRequest.Temperature, + P: textRequest.TopP, + K: textRequest.TopK, + Stream: textRequest.Stream, + FrequencyPenalty: textRequest.FrequencyPenalty, + PresencePenalty: textRequest.PresencePenalty, + Seed: int(textRequest.Seed), + } + if cohereRequest.Model == "" { + cohereRequest.Model = "command-r" + } + if strings.HasSuffix(cohereRequest.Model, "-internet") { + cohereRequest.Model = strings.TrimSuffix(cohereRequest.Model, "-internet") + cohereRequest.Connectors = append(cohereRequest.Connectors, WebSearchConnector) + } + for _, message := range textRequest.Messages { + if message.Role == "user" { + cohereRequest.Message = message.Content.(string) + } else { + var role string + if message.Role == "assistant" { + role = "CHATBOT" + } else if message.Role == "system" { + role = "SYSTEM" + } else { + role = "USER" + } + cohereRequest.ChatHistory = append(cohereRequest.ChatHistory, ChatMessage{ + Role: role, + Message: message.Content.(string), + }) + } + } + return &cohereRequest +} + +func StreamResponseCohere2OpenAI(cohereResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { + var response *Response + var responseText string + var finishReason string + + switch cohereResponse.EventType { + case "stream-start": + return nil, nil + case "text-generation": + responseText += cohereResponse.Text + case "stream-end": + usage := cohereResponse.Response.Meta.Tokens + response = &Response{ + Meta: Meta{ + Tokens: Usage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + }, + }, + } + finishReason = *cohereResponse.Response.FinishReason + default: + return nil, nil + } + + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = responseText + choice.Delta.Role = "assistant" + if finishReason != "" { + choice.FinishReason = &finishReason + } + var openaiResponse openai.ChatCompletionsStreamResponse + openaiResponse.Object = "chat.completion.chunk" + openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + return &openaiResponse, response +} + +func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse { + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: cohereResponse.Text, + Name: nil, + }, + FinishReason: stopReasonCohere2OpenAI(cohereResponse.FinishReason), + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", cohereResponse.ResponseID), + Model: "model", + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + createdTime := helper.GetTimestamp() + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + common.SetEventStreamHeaders(c) + var usage model.Usage + + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSuffix(data, "\r") + + var cohereResponse StreamResponse + err := json.Unmarshal([]byte(data), &cohereResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response, meta := StreamResponseCohere2OpenAI(&cohereResponse) + if meta != nil { + usage.PromptTokens += meta.Meta.Tokens.InputTokens + usage.CompletionTokens += meta.Meta.Tokens.OutputTokens + continue + } + if response == nil { + continue + } + + response.Id = fmt.Sprintf("chatcmpl-%d", createdTime) + response.Model = c.GetString("original_model") + response.Created = createdTime + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + return nil, &usage +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var cohereResponse Response + err = json.Unmarshal(responseBody, &cohereResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if cohereResponse.ResponseID == "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: cohereResponse.Message, + Type: cohereResponse.Message, + Param: "", + Code: resp.StatusCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := ResponseCohere2OpenAI(&cohereResponse) + fullTextResponse.Model = modelName + usage := model.Usage{ + PromptTokens: cohereResponse.Meta.Tokens.InputTokens, + CompletionTokens: cohereResponse.Meta.Tokens.OutputTokens, + TotalTokens: cohereResponse.Meta.Tokens.InputTokens + cohereResponse.Meta.Tokens.OutputTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/relay/adaptor/cohere/model.go b/relay/adaptor/cohere/model.go new file mode 100644 index 0000000..3a8bc99 --- /dev/null +++ b/relay/adaptor/cohere/model.go @@ -0,0 +1,147 @@ +package cohere + +type Request struct { + Message string `json:"message" required:"true"` + Model string `json:"model,omitempty"` // 默认值为"command-r" + Stream bool `json:"stream,omitempty"` // 默认值为false + Preamble string `json:"preamble,omitempty"` + ChatHistory []ChatMessage `json:"chat_history,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO" + Connectors []Connector `json:"connectors,omitempty"` + Documents []Document `json:"documents,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` // 默认值为0.3 + MaxTokens int `json:"max_tokens,omitempty"` + MaxInputTokens int `json:"max_input_tokens,omitempty"` + K int `json:"k,omitempty"` // 默认值为0 + P *float64 `json:"p,omitempty"` // 默认值为0.75 + Seed int `json:"seed,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0 + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 默认值为0.0 + Tools []Tool `json:"tools,omitempty"` + ToolResults []ToolResult `json:"tool_results,omitempty"` +} + +type ChatMessage struct { + Role string `json:"role" required:"true"` + Message string `json:"message" required:"true"` +} + +type Tool struct { + Name string `json:"name" required:"true"` + Description string `json:"description" required:"true"` + ParameterDefinitions map[string]ParameterSpec `json:"parameter_definitions"` +} + +type ParameterSpec struct { + Description string `json:"description"` + Type string `json:"type" required:"true"` + Required bool `json:"required"` +} + +type ToolResult struct { + Call ToolCall `json:"call"` + Outputs []map[string]interface{} `json:"outputs"` +} + +type ToolCall struct { + Name string `json:"name" required:"true"` + Parameters map[string]interface{} `json:"parameters" required:"true"` +} + +type StreamResponse struct { + IsFinished bool `json:"is_finished"` + EventType string `json:"event_type"` + GenerationID string `json:"generation_id,omitempty"` + SearchQueries []*SearchQuery `json:"search_queries,omitempty"` + SearchResults []*SearchResult `json:"search_results,omitempty"` + Documents []*Document `json:"documents,omitempty"` + Text string `json:"text,omitempty"` + Citations []*Citation `json:"citations,omitempty"` + Response *Response `json:"response,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` +} + +type SearchQuery struct { + Text string `json:"text"` + GenerationID string `json:"generation_id"` +} + +type SearchResult struct { + SearchQuery *SearchQuery `json:"search_query"` + DocumentIDs []string `json:"document_ids"` + Connector *Connector `json:"connector"` +} + +type Connector struct { + ID string `json:"id"` +} + +type Document struct { + ID string `json:"id"` + Snippet string `json:"snippet"` + Timestamp string `json:"timestamp"` + Title string `json:"title"` + URL string `json:"url"` +} + +type Citation struct { + Start int `json:"start"` + End int `json:"end"` + Text string `json:"text"` + DocumentIDs []string `json:"document_ids"` +} + +type Response struct { + ResponseID string `json:"response_id"` + Text string `json:"text"` + GenerationID string `json:"generation_id"` + ChatHistory []*Message `json:"chat_history"` + FinishReason *string `json:"finish_reason"` + Meta Meta `json:"meta"` + Citations []*Citation `json:"citations"` + Documents []*Document `json:"documents"` + SearchResults []*SearchResult `json:"search_results"` + SearchQueries []*SearchQuery `json:"search_queries"` + Message string `json:"message"` +} + +type Message struct { + Role string `json:"role"` + Message string `json:"message"` +} + +type Version struct { + Version string `json:"version"` +} + +type Units struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type ChatEntry struct { + Role string `json:"role"` + Message string `json:"message"` +} + +type Meta struct { + APIVersion APIVersion `json:"api_version"` + BilledUnits BilledUnits `json:"billed_units"` + Tokens Usage `json:"tokens"` +} + +type APIVersion struct { + Version string `json:"version"` +} + +type BilledUnits struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} diff --git a/relay/adaptor/common.go b/relay/adaptor/common.go new file mode 100644 index 0000000..8953d7a --- /dev/null +++ b/relay/adaptor/common.go @@ -0,0 +1,52 @@ +package adaptor + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/client" + "github.com/songquanpeng/one-api/relay/meta" + "io" + "net/http" +) + +func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) { + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + if meta.IsStream && c.Request.Header.Get("Accept") == "" { + req.Header.Set("Accept", "text/event-stream") + } +} + +func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + fullRequestURL, err := a.GetRequestURL(meta) + if err != nil { + return nil, fmt.Errorf("get request url failed: %w", err) + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + err = a.SetupRequestHeader(c, req, meta) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := DoRequest(c, req) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} + +func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) { + resp, err := client.HTTPClient.Do(req) + if err != nil { + return nil, err + } + if resp == nil { + return nil, errors.New("resp is nil") + } + _ = req.Body.Close() + _ = c.Request.Body.Close() + return resp, nil +} diff --git a/relay/adaptor/coze/adaptor.go b/relay/adaptor/coze/adaptor.go new file mode 100644 index 0000000..44f560e --- /dev/null +++ b/relay/adaptor/coze/adaptor.go @@ -0,0 +1,75 @@ +package coze + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +type Adaptor struct { + meta *meta.Meta +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("%s/open_api/v2/chat", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + request.User = a.meta.Config.UserID + return ConvertRequest(*request), nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + var responseText *string + if meta.IsStream { + err, responseText = StreamHandler(c, resp) + } else { + err, responseText = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + if responseText != nil { + usage = openai.ResponseText2Usage(*responseText, meta.ActualModelName, meta.PromptTokens) + } else { + usage = &model.Usage{} + } + usage.PromptTokens = meta.PromptTokens + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "coze" +} diff --git a/relay/adaptor/coze/constant/contenttype/define.go b/relay/adaptor/coze/constant/contenttype/define.go new file mode 100644 index 0000000..69c876b --- /dev/null +++ b/relay/adaptor/coze/constant/contenttype/define.go @@ -0,0 +1,5 @@ +package contenttype + +const ( + Text = "text" +) diff --git a/relay/adaptor/coze/constant/event/define.go b/relay/adaptor/coze/constant/event/define.go new file mode 100644 index 0000000..c03e8c1 --- /dev/null +++ b/relay/adaptor/coze/constant/event/define.go @@ -0,0 +1,7 @@ +package event + +const ( + Message = "message" + Done = "done" + Error = "error" +) diff --git a/relay/adaptor/coze/constant/messagetype/define.go b/relay/adaptor/coze/constant/messagetype/define.go new file mode 100644 index 0000000..6c1c25d --- /dev/null +++ b/relay/adaptor/coze/constant/messagetype/define.go @@ -0,0 +1,6 @@ +package messagetype + +const ( + Answer = "answer" + FollowUp = "follow_up" +) diff --git a/relay/adaptor/coze/constants.go b/relay/adaptor/coze/constants.go new file mode 100644 index 0000000..d20fd87 --- /dev/null +++ b/relay/adaptor/coze/constants.go @@ -0,0 +1,3 @@ +package coze + +var ModelList = []string{} diff --git a/relay/adaptor/coze/helper.go b/relay/adaptor/coze/helper.go new file mode 100644 index 0000000..0396afc --- /dev/null +++ b/relay/adaptor/coze/helper.go @@ -0,0 +1,10 @@ +package coze + +import "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/event" + +func event2StopReason(e *string) string { + if e == nil || *e == event.Message { + return "" + } + return "stop" +} diff --git a/relay/adaptor/coze/main.go b/relay/adaptor/coze/main.go new file mode 100644 index 0000000..d0402a7 --- /dev/null +++ b/relay/adaptor/coze/main.go @@ -0,0 +1,202 @@ +package coze + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +// https://www.coze.com/open + +func stopReasonCoze2OpenAI(reason *string) string { + if reason == nil { + return "" + } + switch *reason { + case "end_turn": + return "stop" + case "stop_sequence": + return "stop" + case "max_tokens": + return "length" + default: + return *reason + } +} + +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { + cozeRequest := Request{ + Stream: textRequest.Stream, + User: textRequest.User, + BotId: strings.TrimPrefix(textRequest.Model, "bot-"), + } + for i, message := range textRequest.Messages { + if i == len(textRequest.Messages)-1 { + cozeRequest.Query = message.StringContent() + continue + } + cozeMessage := Message{ + Role: message.Role, + Content: message.StringContent(), + } + cozeRequest.ChatHistory = append(cozeRequest.ChatHistory, cozeMessage) + } + return &cozeRequest +} + +func StreamResponseCoze2OpenAI(cozeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { + var response *Response + var stopReason string + var choice openai.ChatCompletionsStreamResponseChoice + + if cozeResponse.Message != nil { + if cozeResponse.Message.Type != messagetype.Answer { + return nil, nil + } + choice.Delta.Content = cozeResponse.Message.Content + } + choice.Delta.Role = "assistant" + finishReason := stopReasonCoze2OpenAI(&stopReason) + if finishReason != "null" { + choice.FinishReason = &finishReason + } + var openaiResponse openai.ChatCompletionsStreamResponse + openaiResponse.Object = "chat.completion.chunk" + openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + openaiResponse.Id = cozeResponse.ConversationId + return &openaiResponse, response +} + +func ResponseCoze2OpenAI(cozeResponse *Response) *openai.TextResponse { + var responseText string + for _, message := range cozeResponse.Messages { + if message.Type == messagetype.Answer { + responseText = message.Content + break + } + } + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: responseText, + Name: nil, + }, + FinishReason: "stop", + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", cozeResponse.ConversationId), + Model: "coze-bot", + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *string) { + var responseText string + createdTime := helper.GetTimestamp() + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + common.SetEventStreamHeaders(c) + var modelName string + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || !strings.HasPrefix(data, "data:") { + continue + } + data = strings.TrimPrefix(data, "data:") + data = strings.TrimSuffix(data, "\r") + + var cozeResponse StreamResponse + err := json.Unmarshal([]byte(data), &cozeResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response, _ := StreamResponseCoze2OpenAI(&cozeResponse) + if response == nil { + continue + } + + for _, choice := range response.Choices { + responseText += conv.AsString(choice.Delta.Content) + } + response.Model = modelName + response.Created = createdTime + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + return nil, &responseText +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *string) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var cozeResponse Response + err = json.Unmarshal(responseBody, &cozeResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if cozeResponse.Code != 0 { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: cozeResponse.Msg, + Code: cozeResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := ResponseCoze2OpenAI(&cozeResponse) + fullTextResponse.Model = modelName + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + var responseText string + if len(fullTextResponse.Choices) > 0 { + responseText = fullTextResponse.Choices[0].Message.StringContent() + } + return nil, &responseText +} diff --git a/relay/adaptor/coze/model.go b/relay/adaptor/coze/model.go new file mode 100644 index 0000000..d0afecf --- /dev/null +++ b/relay/adaptor/coze/model.go @@ -0,0 +1,38 @@ +package coze + +type Message struct { + Role string `json:"role"` + Type string `json:"type"` + Content string `json:"content"` + ContentType string `json:"content_type"` +} + +type ErrorInformation struct { + Code int `json:"code"` + Msg string `json:"msg"` +} + +type Request struct { + ConversationId string `json:"conversation_id,omitempty"` + BotId string `json:"bot_id"` + User string `json:"user"` + Query string `json:"query"` + ChatHistory []Message `json:"chat_history,omitempty"` + Stream bool `json:"stream"` +} + +type Response struct { + ConversationId string `json:"conversation_id,omitempty"` + Messages []Message `json:"messages,omitempty"` + Code int `json:"code,omitempty"` + Msg string `json:"msg,omitempty"` +} + +type StreamResponse struct { + Event string `json:"event,omitempty"` + Message *Message `json:"message,omitempty"` + IsFinish bool `json:"is_finish,omitempty"` + Index int `json:"index,omitempty"` + ConversationId string `json:"conversation_id,omitempty"` + ErrorInformation *ErrorInformation `json:"error_information,omitempty"` +} diff --git a/relay/adaptor/deepl/adaptor.go b/relay/adaptor/deepl/adaptor.go new file mode 100644 index 0000000..d018a09 --- /dev/null +++ b/relay/adaptor/deepl/adaptor.go @@ -0,0 +1,73 @@ +package deepl + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +type Adaptor struct { + meta *meta.Meta + promptText string +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("%s/v2/translate", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "DeepL-Auth-Key "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + convertedRequest, text := ConvertRequest(*request) + a.promptText = text + return convertedRequest, nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err = StreamHandler(c, resp, meta.ActualModelName) + } else { + err = Handler(c, resp, meta.ActualModelName) + } + promptTokens := len(a.promptText) + usage = &model.Usage{ + PromptTokens: promptTokens, + TotalTokens: promptTokens, + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "deepl" +} diff --git a/relay/adaptor/deepl/constants.go b/relay/adaptor/deepl/constants.go new file mode 100644 index 0000000..6a4f254 --- /dev/null +++ b/relay/adaptor/deepl/constants.go @@ -0,0 +1,9 @@ +package deepl + +// https://developers.deepl.com/docs/api-reference/glossaries + +var ModelList = []string{ + "deepl-zh", + "deepl-en", + "deepl-ja", +} diff --git a/relay/adaptor/deepl/helper.go b/relay/adaptor/deepl/helper.go new file mode 100644 index 0000000..6d3a914 --- /dev/null +++ b/relay/adaptor/deepl/helper.go @@ -0,0 +1,11 @@ +package deepl + +import "strings" + +func parseLangFromModelName(modelName string) string { + parts := strings.Split(modelName, "-") + if len(parts) == 1 { + return "ZH" + } + return parts[1] +} diff --git a/relay/adaptor/deepl/main.go b/relay/adaptor/deepl/main.go new file mode 100644 index 0000000..f8bbae1 --- /dev/null +++ b/relay/adaptor/deepl/main.go @@ -0,0 +1,137 @@ +package deepl + +import ( + "encoding/json" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/constant/finishreason" + "github.com/songquanpeng/one-api/relay/constant/role" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +// https://developers.deepl.com/docs/getting-started/your-first-api-request + +func ConvertRequest(textRequest model.GeneralOpenAIRequest) (*Request, string) { + var text string + if len(textRequest.Messages) != 0 { + text = textRequest.Messages[len(textRequest.Messages)-1].StringContent() + } + deeplRequest := Request{ + TargetLang: parseLangFromModelName(textRequest.Model), + Text: []string{text}, + } + return &deeplRequest, text +} + +func StreamResponseDeepL2OpenAI(deeplResponse *Response) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + if len(deeplResponse.Translations) != 0 { + choice.Delta.Content = deeplResponse.Translations[0].Text + } + choice.Delta.Role = role.Assistant + choice.FinishReason = &constant.StopFinishReason + openaiResponse := openai.ChatCompletionsStreamResponse{ + Object: constant.StreamObject, + Created: helper.GetTimestamp(), + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &openaiResponse +} + +func ResponseDeepL2OpenAI(deeplResponse *Response) *openai.TextResponse { + var responseText string + if len(deeplResponse.Translations) != 0 { + responseText = deeplResponse.Translations[0].Text + } + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: role.Assistant, + Content: responseText, + Name: nil, + }, + FinishReason: finishreason.Stop, + } + fullTextResponse := openai.TextResponse{ + Object: constant.NonStreamObject, + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response, modelName string) *model.ErrorWithStatusCode { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + var deeplResponse Response + err = json.Unmarshal(responseBody, &deeplResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + fullTextResponse := StreamResponseDeepL2OpenAI(&deeplResponse) + fullTextResponse.Model = modelName + fullTextResponse.Id = helper.GetResponseID(c) + jsonData, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) + } + common.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + if jsonData != nil { + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)}) + jsonData = nil + return true + } + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + }) + _ = resp.Body.Close() + return nil +} + +func Handler(c *gin.Context, resp *http.Response, modelName string) *model.ErrorWithStatusCode { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + var deeplResponse Response + err = json.Unmarshal(responseBody, &deeplResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + if deeplResponse.Message != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: deeplResponse.Message, + Code: "deepl_error", + }, + StatusCode: resp.StatusCode, + } + } + fullTextResponse := ResponseDeepL2OpenAI(&deeplResponse) + fullTextResponse.Model = modelName + fullTextResponse.Id = helper.GetResponseID(c) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil +} diff --git a/relay/adaptor/deepl/model.go b/relay/adaptor/deepl/model.go new file mode 100644 index 0000000..3f823d2 --- /dev/null +++ b/relay/adaptor/deepl/model.go @@ -0,0 +1,16 @@ +package deepl + +type Request struct { + Text []string `json:"text"` + TargetLang string `json:"target_lang"` +} + +type Translation struct { + DetectedSourceLanguage string `json:"detected_source_language,omitempty"` + Text string `json:"text,omitempty"` +} + +type Response struct { + Translations []Translation `json:"translations,omitempty"` + Message string `json:"message,omitempty"` +} diff --git a/relay/adaptor/deepseek/constants.go b/relay/adaptor/deepseek/constants.go new file mode 100644 index 0000000..dc1a512 --- /dev/null +++ b/relay/adaptor/deepseek/constants.go @@ -0,0 +1,6 @@ +package deepseek + +var ModelList = []string{ + "deepseek-chat", + "deepseek-reasoner", +} diff --git a/relay/adaptor/doubao/constants.go b/relay/adaptor/doubao/constants.go new file mode 100644 index 0000000..dbe819d --- /dev/null +++ b/relay/adaptor/doubao/constants.go @@ -0,0 +1,13 @@ +package doubao + +// https://console.volcengine.com/ark/region:ark+cn-beijing/model + +var ModelList = []string{ + "Doubao-pro-128k", + "Doubao-pro-32k", + "Doubao-pro-4k", + "Doubao-lite-128k", + "Doubao-lite-32k", + "Doubao-lite-4k", + "Doubao-embedding", +} diff --git a/relay/adaptor/doubao/main.go b/relay/adaptor/doubao/main.go new file mode 100644 index 0000000..dd43d06 --- /dev/null +++ b/relay/adaptor/doubao/main.go @@ -0,0 +1,18 @@ +package doubao + +import ( + "fmt" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +func GetRequestURL(meta *meta.Meta) (string, error) { + switch meta.Mode { + case relaymode.ChatCompletions: + return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil + case relaymode.Embeddings: + return fmt.Sprintf("%s/api/v3/embeddings", meta.BaseURL), nil + default: + } + return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode) +} diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go new file mode 100644 index 0000000..84083f6 --- /dev/null +++ b/relay/adaptor/gemini/adaptor.go @@ -0,0 +1,102 @@ +package gemini + +import ( + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + channelhelper "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *meta.Meta) { +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + defaultVersion := config.GeminiVersion + if strings.Contains(meta.ActualModelName, "gemini-2.0") || + strings.Contains(meta.ActualModelName, "gemini-1.5") { + defaultVersion = "v1beta" + } + + version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion) + action := "" + switch meta.Mode { + case relaymode.Embeddings: + action = "batchEmbedContents" + default: + action = "generateContent" + } + + if meta.IsStream { + action = "streamGenerateContent?alt=sse" + } + + return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + channelhelper.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("x-goog-api-key", meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case relaymode.Embeddings: + geminiEmbeddingRequest := ConvertEmbeddingRequest(*request) + return geminiEmbeddingRequest, nil + default: + geminiRequest := ConvertRequest(*request) + return geminiRequest, nil + } +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return channelhelper.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + var responseText string + err, responseText = StreamHandler(c, resp) + usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } else { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "google gemini" +} diff --git a/relay/adaptor/gemini/constants.go b/relay/adaptor/gemini/constants.go new file mode 100644 index 0000000..d220b25 --- /dev/null +++ b/relay/adaptor/gemini/constants.go @@ -0,0 +1,35 @@ +package gemini + +import ( + "github.com/songquanpeng/one-api/relay/adaptor/geminiv2" +) + +// https://ai.google.dev/models/gemini + +var ModelList = geminiv2.ModelList + +// ModelsSupportSystemInstruction is the list of models that support system instruction. +// +// https://cloud.google.com/vertex-ai/generative-ai/docs/learn/prompts/system-instructions +var ModelsSupportSystemInstruction = []string{ + // "gemini-1.0-pro-002", + // "gemini-1.5-flash", "gemini-1.5-flash-001", "gemini-1.5-flash-002", + // "gemini-1.5-flash-8b", + // "gemini-1.5-pro", "gemini-1.5-pro-001", "gemini-1.5-pro-002", + // "gemini-1.5-pro-experimental", + "gemini-2.0-flash", "gemini-2.0-flash-exp", + "gemini-2.0-flash-thinking-exp-01-21", +} + +// IsModelSupportSystemInstruction check if the model support system instruction. +// +// Because the main version of Go is 1.20, slice.Contains cannot be used +func IsModelSupportSystemInstruction(model string) bool { + for _, m := range ModelsSupportSystemInstruction { + if m == model { + return true + } + } + + return false +} diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go new file mode 100644 index 0000000..2963729 --- /dev/null +++ b/relay/adaptor/gemini/main.go @@ -0,0 +1,437 @@ +package gemini + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/songquanpeng/one-api/common/render" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/image" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" + + "github.com/gin-gonic/gin" +) + +// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn + +const ( + VisionMaxImageNum = 16 +) + +var mimeTypeMap = map[string]string{ + "json_object": "application/json", + "text": "text/plain", +} + +// Setting safety to the lowest possible values since Gemini is already powerless enough +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { + geminiRequest := ChatRequest{ + Contents: make([]ChatContent, 0, len(textRequest.Messages)), + SafetySettings: []ChatSafetySettings{ + { + Category: "HARM_CATEGORY_HARASSMENT", + Threshold: config.GeminiSafetySetting, + }, + { + Category: "HARM_CATEGORY_HATE_SPEECH", + Threshold: config.GeminiSafetySetting, + }, + { + Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + Threshold: config.GeminiSafetySetting, + }, + { + Category: "HARM_CATEGORY_DANGEROUS_CONTENT", + Threshold: config.GeminiSafetySetting, + }, + { + Category: "HARM_CATEGORY_CIVIC_INTEGRITY", + Threshold: config.GeminiSafetySetting, + }, + }, + GenerationConfig: ChatGenerationConfig{ + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + MaxOutputTokens: textRequest.MaxTokens, + }, + } + if textRequest.ResponseFormat != nil { + if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok { + geminiRequest.GenerationConfig.ResponseMimeType = mimeType + } + if textRequest.ResponseFormat.JsonSchema != nil { + geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema + geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"] + } + } + if textRequest.Tools != nil { + functions := make([]model.Function, 0, len(textRequest.Tools)) + for _, tool := range textRequest.Tools { + functions = append(functions, tool.Function) + } + geminiRequest.Tools = []ChatTools{ + { + FunctionDeclarations: functions, + }, + } + } else if textRequest.Functions != nil { + geminiRequest.Tools = []ChatTools{ + { + FunctionDeclarations: textRequest.Functions, + }, + } + } + shouldAddDummyModelMessage := false + for _, message := range textRequest.Messages { + content := ChatContent{ + Role: message.Role, + Parts: []Part{ + { + Text: message.StringContent(), + }, + }, + } + openaiContent := message.ParseContent() + var parts []Part + imageNum := 0 + for _, part := range openaiContent { + if part.Type == model.ContentTypeText { + parts = append(parts, Part{ + Text: part.Text, + }) + } else if part.Type == model.ContentTypeImageURL { + imageNum += 1 + if imageNum > VisionMaxImageNum { + continue + } + mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) + parts = append(parts, Part{ + InlineData: &InlineData{ + MimeType: mimeType, + Data: data, + }, + }) + } + } + content.Parts = parts + + // there's no assistant role in gemini and API shall vomit if Role is not user or model + if content.Role == "assistant" { + content.Role = "model" + } + // Converting system prompt to prompt from user for the same reason + if content.Role == "system" { + shouldAddDummyModelMessage = true + if IsModelSupportSystemInstruction(textRequest.Model) { + geminiRequest.SystemInstruction = &content + geminiRequest.SystemInstruction.Role = "" + continue + } else { + content.Role = "user" + } + } + + geminiRequest.Contents = append(geminiRequest.Contents, content) + + // If a system message is the last message, we need to add a dummy model message to make gemini happy + if shouldAddDummyModelMessage { + geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{ + Role: "model", + Parts: []Part{ + { + Text: "Okay", + }, + }, + }) + shouldAddDummyModelMessage = false + } + } + + return &geminiRequest +} + +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest { + inputs := request.ParseInput() + requests := make([]EmbeddingRequest, len(inputs)) + model := fmt.Sprintf("models/%s", request.Model) + + for i, input := range inputs { + requests[i] = EmbeddingRequest{ + Model: model, + Content: ChatContent{ + Parts: []Part{ + { + Text: input, + }, + }, + }, + } + } + + return &BatchEmbeddingRequest{ + Requests: requests, + } +} + +type ChatResponse struct { + Candidates []ChatCandidate `json:"candidates"` + PromptFeedback ChatPromptFeedback `json:"promptFeedback"` +} + +func (g *ChatResponse) GetResponseText() string { + if g == nil { + return "" + } + if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { + return g.Candidates[0].Content.Parts[0].Text + } + return "" +} + +type ChatCandidate struct { + Content ChatContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int64 `json:"index"` + SafetyRatings []ChatSafetyRating `json:"safetyRatings"` +} + +type ChatSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +type ChatPromptFeedback struct { + SafetyRatings []ChatSafetyRating `json:"safetyRatings"` +} + +func getToolCalls(candidate *ChatCandidate) []model.Tool { + var toolCalls []model.Tool + + item := candidate.Content.Parts[0] + if item.FunctionCall == nil { + return toolCalls + } + argsBytes, err := json.Marshal(item.FunctionCall.Arguments) + if err != nil { + logger.FatalLog("getToolCalls failed: " + err.Error()) + return toolCalls + } + toolCall := model.Tool{ + Id: fmt.Sprintf("call_%s", random.GetUUID()), + Type: "function", + Function: model.Function{ + Arguments: string(argsBytes), + Name: item.FunctionCall.FunctionName, + }, + } + toolCalls = append(toolCalls, toolCall) + return toolCalls +} + +func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), + } + for i, candidate := range response.Candidates { + choice := openai.TextResponseChoice{ + Index: i, + Message: model.Message{ + Role: "assistant", + }, + FinishReason: constant.StopFinishReason, + } + if len(candidate.Content.Parts) > 0 { + if candidate.Content.Parts[0].FunctionCall != nil { + choice.Message.ToolCalls = getToolCalls(&candidate) + } else { + var builder strings.Builder + for _, part := range candidate.Content.Parts { + if i > 0 { + builder.WriteString("\n") + } + builder.WriteString(part.Text) + } + choice.Message.Content = builder.String() + } + } else { + choice.Message.Content = "" + choice.FinishReason = candidate.FinishReason + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = geminiResponse.GetResponseText() + //choice.FinishReason = &constant.StopFinishReason + var response openai.ChatCompletionsStreamResponse + response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID()) + response.Created = helper.GetTimestamp() + response.Object = "chat.completion.chunk" + response.Model = "gemini" + response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + return &response +} + +func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)), + Model: "gemini-embedding", + Usage: model.Usage{TotalTokens: 0}, + } + for _, item := range response.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: 0, + Embedding: item.Values, + }) + } + return &openAIEmbeddingResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { + responseText := "" + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + common.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "data: ") { + continue + } + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\"") + + var geminiResponse ChatResponse + err := json.Unmarshal([]byte(data), &geminiResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response := streamResponseGeminiChat2OpenAI(&geminiResponse) + if response == nil { + continue + } + + responseText += response.Choices[0].Delta.StringContent() + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + + return nil, responseText +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var geminiResponse ChatResponse + err = json.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if len(geminiResponse.Candidates) == 0 { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: "No candidates returned", + Type: "server_error", + Param: "", + Code: 500, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) + fullTextResponse.Model = modelName + completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), modelName) + usage := model.Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} + +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var geminiEmbeddingResponse EmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &geminiEmbeddingResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if geminiEmbeddingResponse.Error != nil { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: geminiEmbeddingResponse.Error.Message, + Type: "gemini_error", + Param: "", + Code: geminiEmbeddingResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := embeddingResponseGemini2OpenAI(&geminiEmbeddingResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go new file mode 100644 index 0000000..c3acae6 --- /dev/null +++ b/relay/adaptor/gemini/model.go @@ -0,0 +1,77 @@ +package gemini + +type ChatRequest struct { + Contents []ChatContent `json:"contents"` + SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"` + Tools []ChatTools `json:"tools,omitempty"` + SystemInstruction *ChatContent `json:"system_instruction,omitempty"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Content ChatContent `json:"content"` + TaskType string `json:"taskType,omitempty"` + Title string `json:"title,omitempty"` + OutputDimensionality int `json:"outputDimensionality,omitempty"` +} + +type BatchEmbeddingRequest struct { + Requests []EmbeddingRequest `json:"requests"` +} + +type EmbeddingData struct { + Values []float64 `json:"values"` +} + +type EmbeddingResponse struct { + Embeddings []EmbeddingData `json:"embeddings"` + Error *Error `json:"error,omitempty"` +} + +type Error struct { + Code int `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Status string `json:"status,omitempty"` +} + +type InlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type FunctionCall struct { + FunctionName string `json:"name"` + Arguments any `json:"args"` +} + +type Part struct { + Text string `json:"text,omitempty"` + InlineData *InlineData `json:"inlineData,omitempty"` + FunctionCall *FunctionCall `json:"functionCall,omitempty"` +} + +type ChatContent struct { + Role string `json:"role,omitempty"` + Parts []Part `json:"parts"` +} + +type ChatSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type ChatTools struct { + FunctionDeclarations any `json:"function_declarations,omitempty"` +} + +type ChatGenerationConfig struct { + ResponseMimeType string `json:"responseMimeType,omitempty"` + ResponseSchema any `json:"responseSchema,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} diff --git a/relay/adaptor/geminiv2/constants.go b/relay/adaptor/geminiv2/constants.go new file mode 100644 index 0000000..73e7ad7 --- /dev/null +++ b/relay/adaptor/geminiv2/constants.go @@ -0,0 +1,15 @@ +package geminiv2 + +// https://ai.google.dev/models/gemini + +var ModelList = []string{ + "gemini-pro", "gemini-1.0-pro", + // "gemma-2-2b-it", "gemma-2-9b-it", "gemma-2-27b-it", + "gemini-1.5-flash", "gemini-1.5-flash-8b", + "gemini-1.5-pro", "gemini-1.5-pro-experimental", + "text-embedding-004", "aqa", + "gemini-2.0-flash", "gemini-2.0-flash-exp", + "gemini-2.0-flash-lite-preview-02-05", + "gemini-2.0-flash-thinking-exp-01-21", + "gemini-2.0-pro-exp-02-05", +} diff --git a/relay/adaptor/geminiv2/main.go b/relay/adaptor/geminiv2/main.go new file mode 100644 index 0000000..fed7f3b --- /dev/null +++ b/relay/adaptor/geminiv2/main.go @@ -0,0 +1,14 @@ +package geminiv2 + +import ( + "fmt" + "strings" + + "github.com/songquanpeng/one-api/relay/meta" +) + +func GetRequestURL(meta *meta.Meta) (string, error) { + baseURL := strings.TrimSuffix(meta.BaseURL, "/") + requestPath := strings.TrimPrefix(meta.RequestURLPath, "/v1") + return fmt.Sprintf("%s%s", baseURL, requestPath), nil +} diff --git a/relay/adaptor/groq/constants.go b/relay/adaptor/groq/constants.go new file mode 100644 index 0000000..2a26b28 --- /dev/null +++ b/relay/adaptor/groq/constants.go @@ -0,0 +1,27 @@ +package groq + +// https://console.groq.com/docs/models + +var ModelList = []string{ + "gemma2-9b-it", + "llama-3.1-70b-versatile", + "llama-3.1-8b-instant", + "llama-3.2-11b-text-preview", + "llama-3.2-11b-vision-preview", + "llama-3.2-1b-preview", + "llama-3.2-3b-preview", + "llama-3.2-90b-text-preview", + "llama-3.2-90b-vision-preview", + "llama-guard-3-8b", + "llama3-70b-8192", + "llama3-8b-8192", + "llama3-groq-70b-8192-tool-use-preview", + "llama3-groq-8b-8192-tool-use-preview", + "llava-v1.5-7b-4096-preview", + "mixtral-8x7b-32768", + "distil-whisper-large-v3-en", + "whisper-large-v3", + "whisper-large-v3-turbo", + "deepseek-r1-distill-llama-70b-specdec", + "deepseek-r1-distill-llama-70b", +} diff --git a/relay/adaptor/interface.go b/relay/adaptor/interface.go new file mode 100644 index 0000000..01b2e2c --- /dev/null +++ b/relay/adaptor/interface.go @@ -0,0 +1,21 @@ +package adaptor + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +type Adaptor interface { + Init(meta *meta.Meta) + GetRequestURL(meta *meta.Meta) (string, error) + SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error + ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + ConvertImageRequest(request *model.ImageRequest) (any, error) + DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) + DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) + GetModelList() []string + GetChannelName() string +} diff --git a/relay/adaptor/lingyiwanwu/constants.go b/relay/adaptor/lingyiwanwu/constants.go new file mode 100644 index 0000000..30000e9 --- /dev/null +++ b/relay/adaptor/lingyiwanwu/constants.go @@ -0,0 +1,9 @@ +package lingyiwanwu + +// https://platform.lingyiwanwu.com/docs + +var ModelList = []string{ + "yi-34b-chat-0205", + "yi-34b-chat-200k", + "yi-vl-plus", +} diff --git a/relay/adaptor/minimax/constants.go b/relay/adaptor/minimax/constants.go new file mode 100644 index 0000000..65165dd --- /dev/null +++ b/relay/adaptor/minimax/constants.go @@ -0,0 +1,13 @@ +package minimax + +// https://www.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd + +var ModelList = []string{ + "abab6.5-chat", + "abab6.5s-chat", + "abab6-chat", + "abab5.5-chat", + "abab5.5s-chat", + "MiniMax-VL-01", + "MiniMax-Text-01", +} diff --git a/relay/adaptor/minimax/main.go b/relay/adaptor/minimax/main.go new file mode 100644 index 0000000..fc9b5d2 --- /dev/null +++ b/relay/adaptor/minimax/main.go @@ -0,0 +1,14 @@ +package minimax + +import ( + "fmt" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +func GetRequestURL(meta *meta.Meta) (string, error) { + if meta.Mode == relaymode.ChatCompletions { + return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil + } + return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode) +} diff --git a/relay/adaptor/mistral/constants.go b/relay/adaptor/mistral/constants.go new file mode 100644 index 0000000..cdb157f --- /dev/null +++ b/relay/adaptor/mistral/constants.go @@ -0,0 +1,10 @@ +package mistral + +var ModelList = []string{ + "open-mistral-7b", + "open-mixtral-8x7b", + "mistral-small-latest", + "mistral-medium-latest", + "mistral-large-latest", + "mistral-embed", +} diff --git a/relay/adaptor/moonshot/constants.go b/relay/adaptor/moonshot/constants.go new file mode 100644 index 0000000..1b86f0f --- /dev/null +++ b/relay/adaptor/moonshot/constants.go @@ -0,0 +1,7 @@ +package moonshot + +var ModelList = []string{ + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k", +} diff --git a/relay/adaptor/novita/constants.go b/relay/adaptor/novita/constants.go new file mode 100644 index 0000000..c661830 --- /dev/null +++ b/relay/adaptor/novita/constants.go @@ -0,0 +1,19 @@ +package novita + +// https://novita.ai/llm-api + +var ModelList = []string{ + "meta-llama/llama-3-8b-instruct", + "meta-llama/llama-3-70b-instruct", + "nousresearch/hermes-2-pro-llama-3-8b", + "nousresearch/nous-hermes-llama2-13b", + "mistralai/mistral-7b-instruct", + "cognitivecomputations/dolphin-mixtral-8x22b", + "sao10k/l3-70b-euryale-v2.1", + "sophosympatheia/midnight-rose-70b", + "gryphe/mythomax-l2-13b", + "Nous-Hermes-2-Mixtral-8x7B-DPO", + "lzlv_70b", + "teknium/openhermes-2.5-mistral-7b", + "microsoft/wizardlm-2-8x22b", +} diff --git a/relay/adaptor/novita/main.go b/relay/adaptor/novita/main.go new file mode 100644 index 0000000..80efa41 --- /dev/null +++ b/relay/adaptor/novita/main.go @@ -0,0 +1,15 @@ +package novita + +import ( + "fmt" + + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +func GetRequestURL(meta *meta.Meta) (string, error) { + if meta.Mode == relaymode.ChatCompletions { + return fmt.Sprintf("%s/chat/completions", meta.BaseURL), nil + } + return "", fmt.Errorf("unsupported relay mode %d for novita", meta.Mode) +} diff --git a/relay/adaptor/ollama/adaptor.go b/relay/adaptor/ollama/adaptor.go new file mode 100644 index 0000000..ad1f898 --- /dev/null +++ b/relay/adaptor/ollama/adaptor.go @@ -0,0 +1,82 @@ +package ollama + +import ( + "errors" + "fmt" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/model" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + // https://github.com/ollama/ollama/blob/main/docs/api.md + fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) + if meta.Mode == relaymode.Embeddings { + fullRequestURL = fmt.Sprintf("%s/api/embed", meta.BaseURL) + } + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case relaymode.Embeddings: + ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request) + return ollamaEmbeddingRequest, nil + default: + return ConvertRequest(*request), nil + } +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "ollama" +} diff --git a/relay/adaptor/ollama/constants.go b/relay/adaptor/ollama/constants.go new file mode 100644 index 0000000..d9dc72a --- /dev/null +++ b/relay/adaptor/ollama/constants.go @@ -0,0 +1,11 @@ +package ollama + +var ModelList = []string{ + "codellama:7b-instruct", + "llama2:7b", + "llama2:latest", + "llama3:latest", + "phi3:latest", + "qwen:0.5b-chat", + "qwen:7b", +} diff --git a/relay/adaptor/ollama/main.go b/relay/adaptor/ollama/main.go new file mode 100644 index 0000000..fa1b05f --- /dev/null +++ b/relay/adaptor/ollama/main.go @@ -0,0 +1,265 @@ +package ollama + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/random" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/image" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" +) + +func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { + ollamaRequest := ChatRequest{ + Model: request.Model, + Options: &Options{ + Seed: int(request.Seed), + Temperature: request.Temperature, + TopP: request.TopP, + FrequencyPenalty: request.FrequencyPenalty, + PresencePenalty: request.PresencePenalty, + NumPredict: request.MaxTokens, + NumCtx: request.NumCtx, + }, + Stream: request.Stream, + } + for _, message := range request.Messages { + openaiContent := message.ParseContent() + var imageUrls []string + var contentText string + for _, part := range openaiContent { + switch part.Type { + case model.ContentTypeText: + contentText = part.Text + case model.ContentTypeImageURL: + _, data, _ := image.GetImageFromUrl(part.ImageURL.Url) + imageUrls = append(imageUrls, data) + } + } + ollamaRequest.Messages = append(ollamaRequest.Messages, Message{ + Role: message.Role, + Content: contentText, + Images: imageUrls, + }) + } + return &ollamaRequest +} + +func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse { + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: response.Message.Role, + Content: response.Message.Content, + }, + } + if response.Done { + choice.FinishReason = "stop" + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Model: response.Model, + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + Usage: model.Usage{ + PromptTokens: response.PromptEvalCount, + CompletionTokens: response.EvalCount, + TotalTokens: response.PromptEvalCount + response.EvalCount, + }, + } + return &fullTextResponse +} + +func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Role = ollamaResponse.Message.Role + choice.Delta.Content = ollamaResponse.Message.Content + if ollamaResponse.Done { + choice.FinishReason = &constant.StopFinishReason + } + response := openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: ollamaResponse.Model, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage model.Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "}\n"); i >= 0 { + return i + 2, data[0 : i+1], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + + common.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + if strings.HasPrefix(data, "}") { + data = strings.TrimPrefix(data, "}") + "}" + } + + var ollamaResponse ChatResponse + err := json.Unmarshal([]byte(data), &ollamaResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + if ollamaResponse.EvalCount != 0 { + usage.PromptTokens = ollamaResponse.PromptEvalCount + usage.CompletionTokens = ollamaResponse.EvalCount + usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount + } + + response := streamResponseOllama2OpenAI(&ollamaResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + return nil, &usage +} + +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + Model: request.Model, + Input: request.ParseInput(), + Options: &Options{ + Seed: int(request.Seed), + Temperature: request.Temperature, + TopP: request.TopP, + FrequencyPenalty: request.FrequencyPenalty, + PresencePenalty: request.PresencePenalty, + }, + } +} + +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var ollamaResponse EmbeddingResponse + err := json.NewDecoder(resp.Body).Decode(&ollamaResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if ollamaResponse.Error != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: ollamaResponse.Error, + Type: "ollama_error", + Param: "", + Code: "ollama_error", + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, 1), + Model: response.Model, + Usage: model.Usage{TotalTokens: 0}, + } + + for i, embedding := range response.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: i, + Embedding: embedding, + }) + } + return &openAIEmbeddingResponse +} + +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + ctx := context.TODO() + var ollamaResponse ChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + logger.Debugf(ctx, "ollama response: %s", string(responseBody)) + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &ollamaResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if ollamaResponse.Error != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: ollamaResponse.Error, + Type: "ollama_error", + Param: "", + Code: "ollama_error", + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/relay/adaptor/ollama/model.go b/relay/adaptor/ollama/model.go new file mode 100644 index 0000000..94f2ab7 --- /dev/null +++ b/relay/adaptor/ollama/model.go @@ -0,0 +1,53 @@ +package ollama + +type Options struct { + Seed int `json:"seed,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` +} + +type Message struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Stream bool `json:"stream"` + Options *Options `json:"options,omitempty"` +} + +type ChatResponse struct { + Model string `json:"model,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + Message Message `json:"message,omitempty"` + Response string `json:"response,omitempty"` // for stream response + Done bool `json:"done,omitempty"` + TotalDuration int `json:"total_duration,omitempty"` + LoadDuration int `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration int `json:"eval_duration,omitempty"` + Error string `json:"error,omitempty"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` + // Truncate bool `json:"truncate,omitempty"` + Options *Options `json:"options,omitempty"` + // KeepAlive string `json:"keep_alive,omitempty"` +} + +type EmbeddingResponse struct { + Error string `json:"error,omitempty"` + Model string `json:"model"` + Embeddings [][]float64 `json:"embeddings"` +} diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go new file mode 100644 index 0000000..8faf90a --- /dev/null +++ b/relay/adaptor/openai/adaptor.go @@ -0,0 +1,139 @@ +package openai + +import ( + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/alibailian" + "github.com/songquanpeng/one-api/relay/adaptor/baiduv2" + "github.com/songquanpeng/one-api/relay/adaptor/doubao" + "github.com/songquanpeng/one-api/relay/adaptor/geminiv2" + "github.com/songquanpeng/one-api/relay/adaptor/minimax" + "github.com/songquanpeng/one-api/relay/adaptor/novita" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +type Adaptor struct { + ChannelType int +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.ChannelType = meta.ChannelType +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + switch meta.ChannelType { + case channeltype.Azure: + if meta.Mode == relaymode.ImagesGenerations { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api + // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview + fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion) + return fullRequestURL, nil + } + + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api + requestURL := strings.Split(meta.RequestURLPath, "?")[0] + requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.Config.APIVersion) + task := strings.TrimPrefix(requestURL, "/v1/") + model_ := meta.ActualModelName + model_ = strings.Replace(model_, ".", "", -1) + //https://github.com/songquanpeng/one-api/issues/1191 + // {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version} + requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) + return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil + case channeltype.Minimax: + return minimax.GetRequestURL(meta) + case channeltype.Doubao: + return doubao.GetRequestURL(meta) + case channeltype.Novita: + return novita.GetRequestURL(meta) + case channeltype.BaiduV2: + return baiduv2.GetRequestURL(meta) + case channeltype.AliBailian: + return alibailian.GetRequestURL(meta) + case channeltype.GeminiOpenAICompatible: + return geminiv2.GetRequestURL(meta) + default: + return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil + } +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + if meta.ChannelType == channeltype.Azure { + req.Header.Set("api-key", meta.APIKey) + return nil + } + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + if meta.ChannelType == channeltype.OpenRouter { + req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") + req.Header.Set("X-Title", "One API") + } + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + if request.Stream { + // always return usage in stream mode + if request.StreamOptions == nil { + request.StreamOptions = &model.StreamOptions{} + } + request.StreamOptions.IncludeUsage = true + } + return request, nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + var responseText string + err, responseText, usage = StreamHandler(c, resp, meta.Mode) + if usage == nil || usage.TotalTokens == 0 { + usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } + if usage.TotalTokens != 0 && usage.PromptTokens == 0 { // some channels don't return prompt tokens & completion tokens + usage.PromptTokens = meta.PromptTokens + usage.CompletionTokens = usage.TotalTokens - meta.PromptTokens + } + } else { + switch meta.Mode { + case relaymode.ImagesGenerations: + err, _ = ImageHandler(c, resp) + default: + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + _, modelList := GetCompatibleChannelMeta(a.ChannelType) + return modelList +} + +func (a *Adaptor) GetChannelName() string { + channelName, _ := GetCompatibleChannelMeta(a.ChannelType) + return channelName +} diff --git a/relay/adaptor/openai/compatible.go b/relay/adaptor/openai/compatible.go new file mode 100644 index 0000000..955cb09 --- /dev/null +++ b/relay/adaptor/openai/compatible.go @@ -0,0 +1,91 @@ +package openai + +import ( + "github.com/songquanpeng/one-api/relay/adaptor/ai360" + "github.com/songquanpeng/one-api/relay/adaptor/alibailian" + "github.com/songquanpeng/one-api/relay/adaptor/baichuan" + "github.com/songquanpeng/one-api/relay/adaptor/baiduv2" + "github.com/songquanpeng/one-api/relay/adaptor/deepseek" + "github.com/songquanpeng/one-api/relay/adaptor/doubao" + "github.com/songquanpeng/one-api/relay/adaptor/geminiv2" + "github.com/songquanpeng/one-api/relay/adaptor/groq" + "github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu" + "github.com/songquanpeng/one-api/relay/adaptor/minimax" + "github.com/songquanpeng/one-api/relay/adaptor/mistral" + "github.com/songquanpeng/one-api/relay/adaptor/moonshot" + "github.com/songquanpeng/one-api/relay/adaptor/novita" + "github.com/songquanpeng/one-api/relay/adaptor/openrouter" + "github.com/songquanpeng/one-api/relay/adaptor/siliconflow" + "github.com/songquanpeng/one-api/relay/adaptor/stepfun" + "github.com/songquanpeng/one-api/relay/adaptor/togetherai" + "github.com/songquanpeng/one-api/relay/adaptor/xai" + "github.com/songquanpeng/one-api/relay/adaptor/xunfeiv2" + "github.com/songquanpeng/one-api/relay/channeltype" +) + +var CompatibleChannels = []int{ + channeltype.Azure, + channeltype.AI360, + channeltype.Moonshot, + channeltype.Baichuan, + channeltype.Minimax, + channeltype.Doubao, + channeltype.Mistral, + channeltype.Groq, + channeltype.LingYiWanWu, + channeltype.StepFun, + channeltype.DeepSeek, + channeltype.TogetherAI, + channeltype.Novita, + channeltype.SiliconFlow, + channeltype.XAI, + channeltype.BaiduV2, + channeltype.XunfeiV2, +} + +func GetCompatibleChannelMeta(channelType int) (string, []string) { + switch channelType { + case channeltype.Azure: + return "azure", ModelList + case channeltype.AI360: + return "360", ai360.ModelList + case channeltype.Moonshot: + return "moonshot", moonshot.ModelList + case channeltype.Baichuan: + return "baichuan", baichuan.ModelList + case channeltype.Minimax: + return "minimax", minimax.ModelList + case channeltype.Mistral: + return "mistralai", mistral.ModelList + case channeltype.Groq: + return "groq", groq.ModelList + case channeltype.LingYiWanWu: + return "lingyiwanwu", lingyiwanwu.ModelList + case channeltype.StepFun: + return "stepfun", stepfun.ModelList + case channeltype.DeepSeek: + return "deepseek", deepseek.ModelList + case channeltype.TogetherAI: + return "together.ai", togetherai.ModelList + case channeltype.Doubao: + return "doubao", doubao.ModelList + case channeltype.Novita: + return "novita", novita.ModelList + case channeltype.SiliconFlow: + return "siliconflow", siliconflow.ModelList + case channeltype.XAI: + return "xai", xai.ModelList + case channeltype.BaiduV2: + return "baiduv2", baiduv2.ModelList + case channeltype.XunfeiV2: + return "xunfeiv2", xunfeiv2.ModelList + case channeltype.OpenRouter: + return "openrouter", openrouter.ModelList + case channeltype.AliBailian: + return "alibailian", alibailian.ModelList + case channeltype.GeminiOpenAICompatible: + return "geminiv2", geminiv2.ModelList + default: + return "openai", ModelList + } +} diff --git a/relay/adaptor/openai/constants.go b/relay/adaptor/openai/constants.go new file mode 100644 index 0000000..8a643bc --- /dev/null +++ b/relay/adaptor/openai/constants.go @@ -0,0 +1,27 @@ +package openai + +var ModelList = []string{ + "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", + "gpt-3.5-turbo-instruct", + "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", + "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", + "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", + "gpt-4o", "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-2024-11-20", + "chatgpt-4o-latest", + "gpt-4o-mini", "gpt-4o-mini-2024-07-18", + "gpt-4-vision-preview", + "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", + "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", + "text-moderation-latest", "text-moderation-stable", + "text-davinci-edit-001", + "davinci-002", "babbage-002", + "dall-e-2", "dall-e-3", + "whisper-1", + "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", + "o1", "o1-2024-12-17", + "o1-preview", "o1-preview-2024-09-12", + "o1-mini", "o1-mini-2024-09-12", +} diff --git a/relay/adaptor/openai/helper.go b/relay/adaptor/openai/helper.go new file mode 100644 index 0000000..c6d5bd7 --- /dev/null +++ b/relay/adaptor/openai/helper.go @@ -0,0 +1,34 @@ +package openai + +import ( + "fmt" + "strings" + + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/model" +) + +func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage { + usage := &model.Usage{} + usage.PromptTokens = promptTokens + usage.CompletionTokens = CountTokenText(responseText, modelName) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + return usage +} + +func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { + if channelType == channeltype.OpenAICompatible { + return fmt.Sprintf("%s%s", strings.TrimSuffix(baseURL, "/"), strings.TrimPrefix(requestURL, "/v1")) + } + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + switch channelType { + case channeltype.OpenAI: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + case channeltype.Azure: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) + } + } + return fullRequestURL +} diff --git a/relay/adaptor/openai/image.go b/relay/adaptor/openai/image.go new file mode 100644 index 0000000..0f89618 --- /dev/null +++ b/relay/adaptor/openai/image.go @@ -0,0 +1,44 @@ +package openai + +import ( + "bytes" + "encoding/json" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var imageResponse ImageResponse + responseBody, err := io.ReadAll(resp.Body) + + if err != nil { + return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &imageResponse) + if err != nil { + return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, nil +} diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go new file mode 100644 index 0000000..9708073 --- /dev/null +++ b/relay/adaptor/openai/main.go @@ -0,0 +1,151 @@ +package openai + +import ( + "bufio" + "bytes" + "encoding/json" + "io" + "net/http" + "strings" + + "github.com/songquanpeng/one-api/common/render" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +const ( + dataPrefix = "data: " + done = "[DONE]" + dataPrefixLength = len(dataPrefix) +) + +func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { + responseText := "" + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + var usage *model.Usage + + common.SetEventStreamHeaders(c) + + doneRendered := false + for scanner.Scan() { + data := scanner.Text() + if len(data) < dataPrefixLength { // ignore blank line or wrong format + continue + } + if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done { + continue + } + if strings.HasPrefix(data[dataPrefixLength:], done) { + render.StringData(c, data) + doneRendered = true + continue + } + switch relayMode { + case relaymode.ChatCompletions: + var streamResponse ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + render.StringData(c, data) // if error happened, pass the data to client + continue // just ignore the error + } + if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil { + // but for empty choice and no usage, we should not pass it to client, this is for azure + continue // just ignore empty choice + } + render.StringData(c, data) + for _, choice := range streamResponse.Choices { + responseText += conv.AsString(choice.Delta.Content) + } + if streamResponse.Usage != nil { + usage = streamResponse.Usage + } + case relaymode.Completions: + render.StringData(c, data) + var streamResponse CompletionsStreamResponse + err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + for _, choice := range streamResponse.Choices { + responseText += choice.Text + } + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + if !doneRendered { + render.Done(c) + } + + err := resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil + } + + return nil, responseText, usage +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + var textResponse SlimTextResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if textResponse.Error.Type != "" { + return &model.ErrorWithStatusCode{ + Error: textResponse.Error, + StatusCode: resp.StatusCode, + }, nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the HTTPClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) { + completionTokens := 0 + for _, choice := range textResponse.Choices { + completionTokens += CountTokenText(choice.Message.StringContent(), modelName) + } + textResponse.Usage = model.Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + } + return nil, &textResponse.Usage +} diff --git a/relay/adaptor/openai/model.go b/relay/adaptor/openai/model.go new file mode 100644 index 0000000..4c974de --- /dev/null +++ b/relay/adaptor/openai/model.go @@ -0,0 +1,145 @@ +package openai + +import "github.com/songquanpeng/one-api/relay/model" + +type TextContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text,omitempty"` +} + +type ImageContent struct { + Type string `json:"type,omitempty"` + ImageURL *model.ImageURL `json:"image_url,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model"` + Messages []model.Message `json:"messages"` + MaxTokens int `json:"max_tokens"` +} + +type TextRequest struct { + Model string `json:"model"` + Messages []model.Message `json:"messages"` + Prompt string `json:"prompt"` + MaxTokens int `json:"max_tokens"` + //Stream bool `json:"stream"` +} + +// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + User string `json:"user,omitempty"` +} + +type WhisperJSONResponse struct { + Text string `json:"text,omitempty"` +} + +type WhisperVerboseJSONResponse struct { + Task string `json:"task,omitempty"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` + Text string `json:"text,omitempty"` + Segments []Segment `json:"segments,omitempty"` +} + +type Segment struct { + Id int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} + +type TextToSpeechRequest struct { + Model string `json:"model" binding:"required"` + Input string `json:"input" binding:"required"` + Voice string `json:"voice" binding:"required"` + Speed float64 `json:"speed"` + ResponseFormat string `json:"response_format"` +} + +type UsageOrResponseText struct { + *model.Usage + ResponseText string +} + +type SlimTextResponse struct { + Choices []TextResponseChoice `json:"choices"` + model.Usage `json:"usage"` + Error model.Error `json:"error"` +} + +type TextResponseChoice struct { + Index int `json:"index"` + model.Message `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type TextResponse struct { + Id string `json:"id"` + Model string `json:"model,omitempty"` + Object string `json:"object"` + Created int64 `json:"created"` + Choices []TextResponseChoice `json:"choices"` + model.Usage `json:"usage"` +} + +type EmbeddingResponseItem struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + +type EmbeddingResponse struct { + Object string `json:"object"` + Data []EmbeddingResponseItem `json:"data"` + Model string `json:"model"` + model.Usage `json:"usage"` +} + +type ImageData struct { + Url string `json:"url,omitempty"` + B64Json string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` +} + +type ImageResponse struct { + Created int64 `json:"created"` + Data []ImageData `json:"data"` + //model.Usage `json:"usage"` +} + +type ChatCompletionsStreamResponseChoice struct { + Index int `json:"index"` + Delta model.Message `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` +} + +type ChatCompletionsStreamResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionsStreamResponseChoice `json:"choices"` + Usage *model.Usage `json:"usage,omitempty"` +} + +type CompletionsStreamResponse struct { + Choices []struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} diff --git a/relay/adaptor/openai/token.go b/relay/adaptor/openai/token.go new file mode 100644 index 0000000..b50220e --- /dev/null +++ b/relay/adaptor/openai/token.go @@ -0,0 +1,235 @@ +package openai + +import ( + "errors" + "fmt" + "math" + "strings" + + "github.com/pkoukk/tiktoken-go" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/image" + "github.com/songquanpeng/one-api/common/logger" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/model" +) + +// tokenEncoderMap won't grow after initialization +var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} +var defaultTokenEncoder *tiktoken.Tiktoken + +func InitTokenEncoders() { + logger.SysLog("initializing token encoders") + gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") + if err != nil { + logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s, "+ + "if you are using in offline environment, please set TIKTOKEN_CACHE_DIR to use exsited files, check this link for more information: https://stackoverflow.com/questions/76106366/how-to-use-tiktoken-in-offline-mode-computer ", err.Error())) + } + defaultTokenEncoder = gpt35TokenEncoder + gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o") + if err != nil { + logger.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) + } + gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") + if err != nil { + logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) + } + for model := range billingratio.ModelRatio { + if strings.HasPrefix(model, "gpt-3.5") { + tokenEncoderMap[model] = gpt35TokenEncoder + } else if strings.HasPrefix(model, "gpt-4o") { + tokenEncoderMap[model] = gpt4oTokenEncoder + } else if strings.HasPrefix(model, "gpt-4") { + tokenEncoderMap[model] = gpt4TokenEncoder + } else { + tokenEncoderMap[model] = nil + } + } + logger.SysLog("token encoders initialized") +} + +func getTokenEncoder(model string) *tiktoken.Tiktoken { + tokenEncoder, ok := tokenEncoderMap[model] + if ok && tokenEncoder != nil { + return tokenEncoder + } + if ok { + tokenEncoder, err := tiktoken.EncodingForModel(model) + if err != nil { + logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) + tokenEncoder = defaultTokenEncoder + } + tokenEncoderMap[model] = tokenEncoder + return tokenEncoder + } + return defaultTokenEncoder +} + +func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { + if config.ApproximateTokenEnabled { + return int(float64(len(text)) * 0.38) + } + return len(tokenEncoder.Encode(text, nil, nil)) +} + +func CountTokenMessages(messages []model.Message, model string) int { + tokenEncoder := getTokenEncoder(model) + // Reference: + // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + // https://github.com/pkoukk/tiktoken-go/issues/6 + // + // Every message follows <|start|>{role/name}\n{content}<|end|>\n + var tokensPerMessage int + var tokensPerName int + if model == "gpt-3.5-turbo-0301" { + tokensPerMessage = 4 + tokensPerName = -1 // If there's a name, the role is omitted + } else { + tokensPerMessage = 3 + tokensPerName = 1 + } + tokenNum := 0 + for _, message := range messages { + tokenNum += tokensPerMessage + switch v := message.Content.(type) { + case string: + tokenNum += getTokenNum(tokenEncoder, v) + case []any: + for _, it := range v { + m := it.(map[string]any) + switch m["type"] { + case "text": + if textValue, ok := m["text"]; ok { + if textString, ok := textValue.(string); ok { + tokenNum += getTokenNum(tokenEncoder, textString) + } + } + case "image_url": + imageUrl, ok := m["image_url"].(map[string]any) + if ok { + url := imageUrl["url"].(string) + detail := "" + if imageUrl["detail"] != nil { + detail = imageUrl["detail"].(string) + } + imageTokens, err := countImageTokens(url, detail, model) + if err != nil { + logger.SysError("error counting image tokens: " + err.Error()) + } else { + tokenNum += imageTokens + } + } + } + } + } + tokenNum += getTokenNum(tokenEncoder, message.Role) + if message.Name != nil { + tokenNum += tokensPerName + tokenNum += getTokenNum(tokenEncoder, *message.Name) + } + } + tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> + return tokenNum +} + +const ( + lowDetailCost = 85 + highDetailCostPerTile = 170 + additionalCost = 85 + // gpt-4o-mini cost higher than other model + gpt4oMiniLowDetailCost = 2833 + gpt4oMiniHighDetailCost = 5667 + gpt4oMiniAdditionalCost = 2833 +) + +// https://platform.openai.com/docs/guides/vision/calculating-costs +// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb +func countImageTokens(url string, detail string, model string) (_ int, err error) { + var fetchSize = true + var width, height int + // Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding + // detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting. + // According to the official guide, "low" disable the high-res model, + // and only receive low-res 512px x 512px version of the image, indicating + // that image is treated as low-res when size is smaller than 512px x 512px, + // then we can assume that image size larger than 512px x 512px is treated + // as high-res. Then we have the following logic: + // if detail == "" || detail == "auto" { + // width, height, err = image.GetImageSize(url) + // if err != nil { + // return 0, err + // } + // fetchSize = false + // // not sure if this is correct + // if width > 512 || height > 512 { + // detail = "high" + // } else { + // detail = "low" + // } + // } + + // However, in my test, it seems to be always the same as "high". + // The following image, which is 125x50, is still treated as high-res, taken + // 255 tokens in the response of non-stream chat completion api. + // https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg + if detail == "" || detail == "auto" { + // assume by test, not sure if this is correct + detail = "high" + } + switch detail { + case "low": + if strings.HasPrefix(model, "gpt-4o-mini") { + return gpt4oMiniLowDetailCost, nil + } + return lowDetailCost, nil + case "high": + if fetchSize { + width, height, err = image.GetImageSize(url) + if err != nil { + return 0, err + } + } + if width > 2048 || height > 2048 { // max(width, height) > 2048 + ratio := float64(2048) / math.Max(float64(width), float64(height)) + width = int(float64(width) * ratio) + height = int(float64(height) * ratio) + } + if width > 768 && height > 768 { // min(width, height) > 768 + ratio := float64(768) / math.Min(float64(width), float64(height)) + width = int(float64(width) * ratio) + height = int(float64(height) * ratio) + } + numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512)) + if strings.HasPrefix(model, "gpt-4o-mini") { + return numSquares*gpt4oMiniHighDetailCost + gpt4oMiniAdditionalCost, nil + } + result := numSquares*highDetailCostPerTile + additionalCost + return result, nil + default: + return 0, errors.New("invalid detail option") + } +} + +func CountTokenInput(input any, model string) int { + switch v := input.(type) { + case string: + return CountTokenText(v, model) + case []string: + text := "" + for _, s := range v { + text += s + } + return CountTokenText(text, model) + } + return 0 +} + +func CountTokenText(text string, model string) int { + tokenEncoder := getTokenEncoder(model) + return getTokenNum(tokenEncoder, text) +} + +func CountToken(text string) int { + return CountTokenInput(text, "gpt-3.5-turbo") +} diff --git a/relay/adaptor/openai/util.go b/relay/adaptor/openai/util.go new file mode 100644 index 0000000..83beadb --- /dev/null +++ b/relay/adaptor/openai/util.go @@ -0,0 +1,23 @@ +package openai + +import ( + "context" + "fmt" + + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/model" +) + +func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode { + logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err)) + + Error := model.Error{ + Message: err.Error(), + Type: "one_api_error", + Code: code, + } + return &model.ErrorWithStatusCode{ + Error: Error, + StatusCode: statusCode, + } +} diff --git a/relay/adaptor/openrouter/constants.go b/relay/adaptor/openrouter/constants.go new file mode 100644 index 0000000..b12fa65 --- /dev/null +++ b/relay/adaptor/openrouter/constants.go @@ -0,0 +1,235 @@ +package openrouter + +var ModelList = []string{ + "01-ai/yi-large", + "aetherwiing/mn-starcannon-12b", + "ai21/jamba-1-5-large", + "ai21/jamba-1-5-mini", + "ai21/jamba-instruct", + "aion-labs/aion-1.0", + "aion-labs/aion-1.0-mini", + "aion-labs/aion-rp-llama-3.1-8b", + "allenai/llama-3.1-tulu-3-405b", + "alpindale/goliath-120b", + "alpindale/magnum-72b", + "amazon/nova-lite-v1", + "amazon/nova-micro-v1", + "amazon/nova-pro-v1", + "anthracite-org/magnum-v2-72b", + "anthracite-org/magnum-v4-72b", + "anthropic/claude-2", + "anthropic/claude-2.0", + "anthropic/claude-2.0:beta", + "anthropic/claude-2.1", + "anthropic/claude-2.1:beta", + "anthropic/claude-2:beta", + "anthropic/claude-3-haiku", + "anthropic/claude-3-haiku:beta", + "anthropic/claude-3-opus", + "anthropic/claude-3-opus:beta", + "anthropic/claude-3-sonnet", + "anthropic/claude-3-sonnet:beta", + "anthropic/claude-3.5-haiku", + "anthropic/claude-3.5-haiku-20241022", + "anthropic/claude-3.5-haiku-20241022:beta", + "anthropic/claude-3.5-haiku:beta", + "anthropic/claude-3.5-sonnet", + "anthropic/claude-3.5-sonnet-20240620", + "anthropic/claude-3.5-sonnet-20240620:beta", + "anthropic/claude-3.5-sonnet:beta", + "cognitivecomputations/dolphin-mixtral-8x22b", + "cognitivecomputations/dolphin-mixtral-8x7b", + "cohere/command", + "cohere/command-r", + "cohere/command-r-03-2024", + "cohere/command-r-08-2024", + "cohere/command-r-plus", + "cohere/command-r-plus-04-2024", + "cohere/command-r-plus-08-2024", + "cohere/command-r7b-12-2024", + "databricks/dbrx-instruct", + "deepseek/deepseek-chat", + "deepseek/deepseek-chat-v2.5", + "deepseek/deepseek-chat:free", + "deepseek/deepseek-r1", + "deepseek/deepseek-r1-distill-llama-70b", + "deepseek/deepseek-r1-distill-llama-70b:free", + "deepseek/deepseek-r1-distill-llama-8b", + "deepseek/deepseek-r1-distill-qwen-1.5b", + "deepseek/deepseek-r1-distill-qwen-14b", + "deepseek/deepseek-r1-distill-qwen-32b", + "deepseek/deepseek-r1:free", + "eva-unit-01/eva-llama-3.33-70b", + "eva-unit-01/eva-qwen-2.5-32b", + "eva-unit-01/eva-qwen-2.5-72b", + "google/gemini-2.0-flash-001", + "google/gemini-2.0-flash-exp:free", + "google/gemini-2.0-flash-lite-preview-02-05:free", + "google/gemini-2.0-flash-thinking-exp-1219:free", + "google/gemini-2.0-flash-thinking-exp:free", + "google/gemini-2.0-pro-exp-02-05:free", + "google/gemini-exp-1206:free", + "google/gemini-flash-1.5", + "google/gemini-flash-1.5-8b", + "google/gemini-flash-1.5-8b-exp", + "google/gemini-pro", + "google/gemini-pro-1.5", + "google/gemini-pro-vision", + "google/gemma-2-27b-it", + "google/gemma-2-9b-it", + "google/gemma-2-9b-it:free", + "google/gemma-7b-it", + "google/learnlm-1.5-pro-experimental:free", + "google/palm-2-chat-bison", + "google/palm-2-chat-bison-32k", + "google/palm-2-codechat-bison", + "google/palm-2-codechat-bison-32k", + "gryphe/mythomax-l2-13b", + "gryphe/mythomax-l2-13b:free", + "huggingfaceh4/zephyr-7b-beta:free", + "infermatic/mn-inferor-12b", + "inflection/inflection-3-pi", + "inflection/inflection-3-productivity", + "jondurbin/airoboros-l2-70b", + "liquid/lfm-3b", + "liquid/lfm-40b", + "liquid/lfm-7b", + "mancer/weaver", + "meta-llama/llama-2-13b-chat", + "meta-llama/llama-2-70b-chat", + "meta-llama/llama-3-70b-instruct", + "meta-llama/llama-3-8b-instruct", + "meta-llama/llama-3-8b-instruct:free", + "meta-llama/llama-3.1-405b", + "meta-llama/llama-3.1-405b-instruct", + "meta-llama/llama-3.1-70b-instruct", + "meta-llama/llama-3.1-8b-instruct", + "meta-llama/llama-3.2-11b-vision-instruct", + "meta-llama/llama-3.2-11b-vision-instruct:free", + "meta-llama/llama-3.2-1b-instruct", + "meta-llama/llama-3.2-3b-instruct", + "meta-llama/llama-3.2-90b-vision-instruct", + "meta-llama/llama-3.3-70b-instruct", + "meta-llama/llama-3.3-70b-instruct:free", + "meta-llama/llama-guard-2-8b", + "microsoft/phi-3-medium-128k-instruct", + "microsoft/phi-3-medium-128k-instruct:free", + "microsoft/phi-3-mini-128k-instruct", + "microsoft/phi-3-mini-128k-instruct:free", + "microsoft/phi-3.5-mini-128k-instruct", + "microsoft/phi-4", + "microsoft/wizardlm-2-7b", + "microsoft/wizardlm-2-8x22b", + "minimax/minimax-01", + "mistralai/codestral-2501", + "mistralai/codestral-mamba", + "mistralai/ministral-3b", + "mistralai/ministral-8b", + "mistralai/mistral-7b-instruct", + "mistralai/mistral-7b-instruct-v0.1", + "mistralai/mistral-7b-instruct-v0.3", + "mistralai/mistral-7b-instruct:free", + "mistralai/mistral-large", + "mistralai/mistral-large-2407", + "mistralai/mistral-large-2411", + "mistralai/mistral-medium", + "mistralai/mistral-nemo", + "mistralai/mistral-nemo:free", + "mistralai/mistral-small", + "mistralai/mistral-small-24b-instruct-2501", + "mistralai/mistral-small-24b-instruct-2501:free", + "mistralai/mistral-tiny", + "mistralai/mixtral-8x22b-instruct", + "mistralai/mixtral-8x7b", + "mistralai/mixtral-8x7b-instruct", + "mistralai/pixtral-12b", + "mistralai/pixtral-large-2411", + "neversleep/llama-3-lumimaid-70b", + "neversleep/llama-3-lumimaid-8b", + "neversleep/llama-3-lumimaid-8b:extended", + "neversleep/llama-3.1-lumimaid-70b", + "neversleep/llama-3.1-lumimaid-8b", + "neversleep/noromaid-20b", + "nothingiisreal/mn-celeste-12b", + "nousresearch/hermes-2-pro-llama-3-8b", + "nousresearch/hermes-3-llama-3.1-405b", + "nousresearch/hermes-3-llama-3.1-70b", + "nousresearch/nous-hermes-2-mixtral-8x7b-dpo", + "nousresearch/nous-hermes-llama2-13b", + "nvidia/llama-3.1-nemotron-70b-instruct", + "nvidia/llama-3.1-nemotron-70b-instruct:free", + "openai/chatgpt-4o-latest", + "openai/gpt-3.5-turbo", + "openai/gpt-3.5-turbo-0125", + "openai/gpt-3.5-turbo-0613", + "openai/gpt-3.5-turbo-1106", + "openai/gpt-3.5-turbo-16k", + "openai/gpt-3.5-turbo-instruct", + "openai/gpt-4", + "openai/gpt-4-0314", + "openai/gpt-4-1106-preview", + "openai/gpt-4-32k", + "openai/gpt-4-32k-0314", + "openai/gpt-4-turbo", + "openai/gpt-4-turbo-preview", + "openai/gpt-4o", + "openai/gpt-4o-2024-05-13", + "openai/gpt-4o-2024-08-06", + "openai/gpt-4o-2024-11-20", + "openai/gpt-4o-mini", + "openai/gpt-4o-mini-2024-07-18", + "openai/gpt-4o:extended", + "openai/o1", + "openai/o1-mini", + "openai/o1-mini-2024-09-12", + "openai/o1-preview", + "openai/o1-preview-2024-09-12", + "openai/o3-mini", + "openai/o3-mini-high", + "openchat/openchat-7b", + "openchat/openchat-7b:free", + "openrouter/auto", + "perplexity/llama-3.1-sonar-huge-128k-online", + "perplexity/llama-3.1-sonar-large-128k-chat", + "perplexity/llama-3.1-sonar-large-128k-online", + "perplexity/llama-3.1-sonar-small-128k-chat", + "perplexity/llama-3.1-sonar-small-128k-online", + "perplexity/sonar", + "perplexity/sonar-reasoning", + "pygmalionai/mythalion-13b", + "qwen/qvq-72b-preview", + "qwen/qwen-2-72b-instruct", + "qwen/qwen-2-7b-instruct", + "qwen/qwen-2-7b-instruct:free", + "qwen/qwen-2-vl-72b-instruct", + "qwen/qwen-2-vl-7b-instruct", + "qwen/qwen-2.5-72b-instruct", + "qwen/qwen-2.5-7b-instruct", + "qwen/qwen-2.5-coder-32b-instruct", + "qwen/qwen-max", + "qwen/qwen-plus", + "qwen/qwen-turbo", + "qwen/qwen-vl-plus:free", + "qwen/qwen2.5-vl-72b-instruct:free", + "qwen/qwq-32b-preview", + "raifle/sorcererlm-8x22b", + "sao10k/fimbulvetr-11b-v2", + "sao10k/l3-euryale-70b", + "sao10k/l3-lunaris-8b", + "sao10k/l3.1-70b-hanami-x1", + "sao10k/l3.1-euryale-70b", + "sao10k/l3.3-euryale-70b", + "sophosympatheia/midnight-rose-70b", + "sophosympatheia/rogue-rose-103b-v0.2:free", + "teknium/openhermes-2.5-mistral-7b", + "thedrummer/rocinante-12b", + "thedrummer/unslopnemo-12b", + "undi95/remm-slerp-l2-13b", + "undi95/toppy-m-7b", + "undi95/toppy-m-7b:free", + "x-ai/grok-2-1212", + "x-ai/grok-2-vision-1212", + "x-ai/grok-beta", + "x-ai/grok-vision-beta", + "xwin-lm/xwin-lm-70b", +} diff --git a/relay/adaptor/palm/adaptor.go b/relay/adaptor/palm/adaptor.go new file mode 100644 index 0000000..98aa3e1 --- /dev/null +++ b/relay/adaptor/palm/adaptor.go @@ -0,0 +1,67 @@ +package palm + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("x-goog-api-key", meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return ConvertRequest(*request), nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + var responseText string + err, responseText = StreamHandler(c, resp) + usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "google palm" +} diff --git a/relay/adaptor/palm/constants.go b/relay/adaptor/palm/constants.go new file mode 100644 index 0000000..a834936 --- /dev/null +++ b/relay/adaptor/palm/constants.go @@ -0,0 +1,5 @@ +package palm + +var ModelList = []string{ + "PaLM-2", +} diff --git a/relay/adaptor/palm/model.go b/relay/adaptor/palm/model.go new file mode 100644 index 0000000..2bdd8f2 --- /dev/null +++ b/relay/adaptor/palm/model.go @@ -0,0 +1,40 @@ +package palm + +import ( + "github.com/songquanpeng/one-api/relay/model" +) + +type ChatMessage struct { + Author string `json:"author"` + Content string `json:"content"` +} + +type Filter struct { + Reason string `json:"reason"` + Message string `json:"message"` +} + +type Prompt struct { + Messages []ChatMessage `json:"messages"` +} + +type ChatRequest struct { + Prompt Prompt `json:"prompt"` + Temperature *float64 `json:"temperature,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + TopP *float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` +} + +type Error struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` +} + +type ChatResponse struct { + Candidates []ChatMessage `json:"candidates"` + Messages []model.Message `json:"messages"` + Filters []Filter `json:"filters"` + Error Error `json:"error"` +} diff --git a/relay/adaptor/palm/palm.go b/relay/adaptor/palm/palm.go new file mode 100644 index 0000000..d31784e --- /dev/null +++ b/relay/adaptor/palm/palm.go @@ -0,0 +1,172 @@ +package palm + +import ( + "encoding/json" + "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" +) + +// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body +// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body + +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { + palmRequest := ChatRequest{ + Prompt: Prompt{ + Messages: make([]ChatMessage, 0, len(textRequest.Messages)), + }, + Temperature: textRequest.Temperature, + CandidateCount: textRequest.N, + TopP: textRequest.TopP, + TopK: textRequest.MaxTokens, + } + for _, message := range textRequest.Messages { + palmMessage := ChatMessage{ + Content: message.StringContent(), + } + if message.Role == "user" { + palmMessage.Author = "0" + } else { + palmMessage.Author = "1" + } + palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) + } + return &palmRequest +} + +func responsePaLM2OpenAI(response *ChatResponse) *openai.TextResponse { + fullTextResponse := openai.TextResponse{ + Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), + } + for i, candidate := range response.Candidates { + choice := openai.TextResponseChoice{ + Index: i, + Message: model.Message{ + Role: "assistant", + Content: candidate.Content, + }, + FinishReason: "stop", + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + if len(palmResponse.Candidates) > 0 { + choice.Delta.Content = palmResponse.Candidates[0].Content + } + choice.FinishReason = &constant.StopFinishReason + var response openai.ChatCompletionsStreamResponse + response.Object = "chat.completion.chunk" + response.Model = "palm2" + response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + return &response +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { + responseText := "" + responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID()) + createdTime := helper.GetTimestamp() + + common.SetEventStreamHeaders(c) + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + logger.SysError("error reading stream response: " + err.Error()) + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), "" + } + + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + + var palmResponse ChatResponse + err = json.Unmarshal(responseBody, &palmResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), "" + } + + fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) + fullTextResponse.Id = responseId + fullTextResponse.Created = createdTime + if len(palmResponse.Candidates) > 0 { + responseText = palmResponse.Candidates[0].Content + } + + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), "" + } + + err = render.ObjectData(c, string(jsonResponse)) + if err != nil { + logger.SysError(err.Error()) + } + + render.Done(c) + + return nil, responseText +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var palmResponse ChatResponse + err = json.Unmarshal(responseBody, &palmResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: palmResponse.Error.Message, + Type: palmResponse.Error.Status, + Param: "", + Code: palmResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responsePaLM2OpenAI(&palmResponse) + fullTextResponse.Model = modelName + completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, modelName) + usage := model.Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/relay/adaptor/proxy/adaptor.go b/relay/adaptor/proxy/adaptor.go new file mode 100644 index 0000000..670c762 --- /dev/null +++ b/relay/adaptor/proxy/adaptor.go @@ -0,0 +1,89 @@ +package proxy + +import ( + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/relay/adaptor" + channelhelper "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +var _ adaptor.Adaptor = new(Adaptor) + +const channelName = "proxy" + +type Adaptor struct{} + +func (a *Adaptor) Init(meta *meta.Meta) { +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + return nil, errors.New("notimplement") +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + for k, v := range resp.Header { + for _, vv := range v { + c.Writer.Header().Set(k, vv) + } + } + + c.Writer.WriteHeader(resp.StatusCode) + if _, gerr := io.Copy(c.Writer, resp.Body); gerr != nil { + return nil, &relaymodel.ErrorWithStatusCode{ + StatusCode: http.StatusInternalServerError, + Error: relaymodel.Error{ + Message: gerr.Error(), + }, + } + } + + return nil, nil +} + +func (a *Adaptor) GetModelList() (models []string) { + return nil +} + +func (a *Adaptor) GetChannelName() string { + return channelName +} + +// GetRequestURL remove static prefix, and return the real request url to the upstream service +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + prefix := fmt.Sprintf("/v1/oneapi/proxy/%d", meta.ChannelId) + return meta.BaseURL + strings.TrimPrefix(meta.RequestURLPath, prefix), nil + +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + for k, v := range c.Request.Header { + req.Header.Set(k, v[0]) + } + + // remove unnecessary headers + req.Header.Del("Host") + req.Header.Del("Content-Length") + req.Header.Del("Accept-Encoding") + req.Header.Del("Connection") + + // set authorization header + req.Header.Set("Authorization", meta.APIKey) + + return nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return nil, errors.Errorf("not implement") +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return channelhelper.DoRequestHelper(a, c, meta, requestBody) +} diff --git a/relay/adaptor/replicate/adaptor.go b/relay/adaptor/replicate/adaptor.go new file mode 100644 index 0000000..a60a7de --- /dev/null +++ b/relay/adaptor/replicate/adaptor.go @@ -0,0 +1,136 @@ +package replicate + +import ( + "fmt" + "io" + "net/http" + "slices" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +type Adaptor struct { + meta *meta.Meta +} + +// ConvertImageRequest implements adaptor.Adaptor. +func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return DrawImageRequest{ + Input: ImageInput{ + Steps: 25, + Prompt: request.Prompt, + Guidance: 3, + Seed: int(time.Now().UnixNano()), + SafetyTolerance: 5, + NImages: 1, // replicate will always return 1 image + Width: 1440, + Height: 1440, + AspectRatio: "1:1", + }, + }, nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if !request.Stream { + // TODO: support non-stream mode + return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true") + } + + // Build the prompt from OpenAI messages + var promptBuilder strings.Builder + for _, message := range request.Messages { + switch msgCnt := message.Content.(type) { + case string: + promptBuilder.WriteString(message.Role) + promptBuilder.WriteString(": ") + promptBuilder.WriteString(msgCnt) + promptBuilder.WriteString("\n") + default: + } + } + + replicateRequest := ReplicateChatRequest{ + Input: ChatInput{ + Prompt: promptBuilder.String(), + MaxTokens: request.MaxTokens, + Temperature: 1.0, + TopP: 1.0, + PresencePenalty: 0.0, + FrequencyPenalty: 0.0, + }, + } + + // Map optional fields + if request.Temperature != nil { + replicateRequest.Input.Temperature = *request.Temperature + } + if request.TopP != nil { + replicateRequest.Input.TopP = *request.TopP + } + if request.PresencePenalty != nil { + replicateRequest.Input.PresencePenalty = *request.PresencePenalty + } + if request.FrequencyPenalty != nil { + replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty + } + if request.MaxTokens > 0 { + replicateRequest.Input.MaxTokens = request.MaxTokens + } else if request.MaxTokens == 0 { + replicateRequest.Input.MaxTokens = 500 + } + + return replicateRequest, nil +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + if !slices.Contains(ModelList, meta.OriginModelName) { + return "", errors.Errorf("model %s not supported", meta.OriginModelName) + } + + return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + logger.Info(c, "send request to replicate") + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + switch meta.Mode { + case relaymode.ImagesGenerations: + err, usage = ImageHandler(c, resp) + case relaymode.ChatCompletions: + err, usage = ChatHandler(c, resp) + default: + err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError) + } + + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "replicate" +} diff --git a/relay/adaptor/replicate/chat.go b/relay/adaptor/replicate/chat.go new file mode 100644 index 0000000..4051f85 --- /dev/null +++ b/relay/adaptor/replicate/chat.go @@ -0,0 +1,191 @@ +package replicate + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/render" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +func ChatHandler(c *gin.Context, resp *http.Response) ( + srvErr *model.ErrorWithStatusCode, usage *model.Usage) { + if resp.StatusCode != http.StatusCreated { + payload, _ := io.ReadAll(resp.Body) + return openai.ErrorWrapper( + errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), + "bad_status_code", http.StatusInternalServerError), + nil + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + + respData := new(ChatResponse) + if err = json.Unmarshal(respBody, respData); err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + for { + err = func() error { + // get task + taskReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, respData.URLs.Get, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + taskResp, err := http.DefaultClient.Do(taskReq) + if err != nil { + return errors.Wrap(err, "get task") + } + defer taskResp.Body.Close() + + if taskResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(taskResp.Body) + return errors.Errorf("bad status code [%d]%s", + taskResp.StatusCode, string(payload)) + } + + taskBody, err := io.ReadAll(taskResp.Body) + if err != nil { + return errors.Wrap(err, "read task response") + } + + taskData := new(ChatResponse) + if err = json.Unmarshal(taskBody, taskData); err != nil { + return errors.Wrap(err, "decode task response") + } + + switch taskData.Status { + case "succeeded": + case "failed", "canceled": + return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error) + default: + time.Sleep(time.Second * 3) + return errNextLoop + } + + if taskData.URLs.Stream == "" { + return errors.New("stream url is empty") + } + + // request stream url + responseText, err := chatStreamHandler(c, taskData.URLs.Stream) + if err != nil { + return errors.Wrap(err, "chat stream handler") + } + + ctxMeta := meta.GetByContext(c) + usage = openai.ResponseText2Usage(responseText, + ctxMeta.ActualModelName, ctxMeta.PromptTokens) + return nil + }() + if err != nil { + if errors.Is(err, errNextLoop) { + continue + } + + return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil + } + + break + } + + return nil, usage +} + +const ( + eventPrefix = "event: " + dataPrefix = "data: " + done = "[DONE]" +) + +func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) { + // request stream endpoint + streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil) + if err != nil { + return "", errors.Wrap(err, "new request to stream") + } + + streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + streamReq.Header.Set("Accept", "text/event-stream") + streamReq.Header.Set("Cache-Control", "no-store") + + resp, err := http.DefaultClient.Do(streamReq) + if err != nil { + return "", errors.Wrap(err, "do request to stream") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(resp.Body) + return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + common.SetEventStreamHeaders(c) + doneRendered := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + // Handle comments starting with ':' + if strings.HasPrefix(line, ":") { + continue + } + + // Parse SSE fields + if strings.HasPrefix(line, eventPrefix) { + event := strings.TrimSpace(line[len(eventPrefix):]) + var data string + // Read the following lines to get data and id + for scanner.Scan() { + nextLine := scanner.Text() + if nextLine == "" { + break + } + if strings.HasPrefix(nextLine, dataPrefix) { + data = nextLine[len(dataPrefix):] + } else if strings.HasPrefix(nextLine, "id:") { + // id = strings.TrimSpace(nextLine[len("id:"):]) + } + } + + if event == "output" { + render.StringData(c, data) + responseText += data + } else if event == "done" { + render.Done(c) + doneRendered = true + break + } + } + } + + if err := scanner.Err(); err != nil { + return "", errors.Wrap(err, "scan stream") + } + + if !doneRendered { + render.Done(c) + } + + return responseText, nil +} diff --git a/relay/adaptor/replicate/constant.go b/relay/adaptor/replicate/constant.go new file mode 100644 index 0000000..989142c --- /dev/null +++ b/relay/adaptor/replicate/constant.go @@ -0,0 +1,58 @@ +package replicate + +// ModelList is a list of models that can be used with Replicate. +// +// https://replicate.com/pricing +var ModelList = []string{ + // ------------------------------------- + // image model + // ------------------------------------- + "black-forest-labs/flux-1.1-pro", + "black-forest-labs/flux-1.1-pro-ultra", + "black-forest-labs/flux-canny-dev", + "black-forest-labs/flux-canny-pro", + "black-forest-labs/flux-depth-dev", + "black-forest-labs/flux-depth-pro", + "black-forest-labs/flux-dev", + "black-forest-labs/flux-dev-lora", + "black-forest-labs/flux-fill-dev", + "black-forest-labs/flux-fill-pro", + "black-forest-labs/flux-pro", + "black-forest-labs/flux-redux-dev", + "black-forest-labs/flux-redux-schnell", + "black-forest-labs/flux-schnell", + "black-forest-labs/flux-schnell-lora", + "ideogram-ai/ideogram-v2", + "ideogram-ai/ideogram-v2-turbo", + "recraft-ai/recraft-v3", + "recraft-ai/recraft-v3-svg", + "stability-ai/stable-diffusion-3", + "stability-ai/stable-diffusion-3.5-large", + "stability-ai/stable-diffusion-3.5-large-turbo", + "stability-ai/stable-diffusion-3.5-medium", + // ------------------------------------- + // language model + // ------------------------------------- + "ibm-granite/granite-20b-code-instruct-8k", + "ibm-granite/granite-3.0-2b-instruct", + "ibm-granite/granite-3.0-8b-instruct", + "ibm-granite/granite-8b-code-instruct-128k", + "meta/llama-2-13b", + "meta/llama-2-13b-chat", + "meta/llama-2-70b", + "meta/llama-2-70b-chat", + "meta/llama-2-7b", + "meta/llama-2-7b-chat", + "meta/meta-llama-3.1-405b-instruct", + "meta/meta-llama-3-70b", + "meta/meta-llama-3-70b-instruct", + "meta/meta-llama-3-8b", + "meta/meta-llama-3-8b-instruct", + "mistralai/mistral-7b-instruct-v0.2", + "mistralai/mistral-7b-v0.1", + "mistralai/mixtral-8x7b-instruct-v0.1", + // ------------------------------------- + // video model + // ------------------------------------- + // "minimax/video-01", // TODO: implement the adaptor +} diff --git a/relay/adaptor/replicate/image.go b/relay/adaptor/replicate/image.go new file mode 100644 index 0000000..3687249 --- /dev/null +++ b/relay/adaptor/replicate/image.go @@ -0,0 +1,222 @@ +package replicate + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "image" + "image/png" + "io" + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "golang.org/x/image/webp" + "golang.org/x/sync/errgroup" +) + +// ImagesEditsHandler just copy response body to client +// +// https://replicate.com/black-forest-labs/flux-fill-pro +// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { +// c.Writer.WriteHeader(resp.StatusCode) +// for k, v := range resp.Header { +// c.Writer.Header().Set(k, v[0]) +// } + +// if _, err := io.Copy(c.Writer, resp.Body); err != nil { +// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil +// } +// defer resp.Body.Close() + +// return nil, nil +// } + +var errNextLoop = errors.New("next_loop") + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + if resp.StatusCode != http.StatusCreated { + payload, _ := io.ReadAll(resp.Body) + return openai.ErrorWrapper( + errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), + "bad_status_code", http.StatusInternalServerError), + nil + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + + respData := new(ImageResponse) + if err = json.Unmarshal(respBody, respData); err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + for { + err = func() error { + // get task + taskReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, respData.URLs.Get, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + taskResp, err := http.DefaultClient.Do(taskReq) + if err != nil { + return errors.Wrap(err, "get task") + } + defer taskResp.Body.Close() + + if taskResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(taskResp.Body) + return errors.Errorf("bad status code [%d]%s", + taskResp.StatusCode, string(payload)) + } + + taskBody, err := io.ReadAll(taskResp.Body) + if err != nil { + return errors.Wrap(err, "read task response") + } + + taskData := new(ImageResponse) + if err = json.Unmarshal(taskBody, taskData); err != nil { + return errors.Wrap(err, "decode task response") + } + + switch taskData.Status { + case "succeeded": + case "failed", "canceled": + return errors.Errorf("task failed: %s", taskData.Status) + default: + time.Sleep(time.Second * 3) + return errNextLoop + } + + output, err := taskData.GetOutput() + if err != nil { + return errors.Wrap(err, "get output") + } + if len(output) == 0 { + return errors.New("response output is empty") + } + + var mu sync.Mutex + var pool errgroup.Group + respBody := &openai.ImageResponse{ + Created: taskData.CompletedAt.Unix(), + Data: []openai.ImageData{}, + } + + for _, imgOut := range output { + imgOut := imgOut + pool.Go(func() error { + // download image + downloadReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, imgOut, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + imgResp, err := http.DefaultClient.Do(downloadReq) + if err != nil { + return errors.Wrap(err, "download image") + } + defer imgResp.Body.Close() + + if imgResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(imgResp.Body) + return errors.Errorf("bad status code [%d]%s", + imgResp.StatusCode, string(payload)) + } + + imgData, err := io.ReadAll(imgResp.Body) + if err != nil { + return errors.Wrap(err, "read image") + } + + imgData, err = ConvertImageToPNG(imgData) + if err != nil { + return errors.Wrap(err, "convert image") + } + + mu.Lock() + respBody.Data = append(respBody.Data, openai.ImageData{ + B64Json: fmt.Sprintf("data:image/png;base64,%s", + base64.StdEncoding.EncodeToString(imgData)), + }) + mu.Unlock() + + return nil + }) + } + + if err := pool.Wait(); err != nil { + if len(respBody.Data) == 0 { + return errors.WithStack(err) + } + + logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err)) + } + + c.JSON(http.StatusOK, respBody) + return nil + }() + if err != nil { + if errors.Is(err, errNextLoop) { + continue + } + + return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil + } + + break + } + + return nil, nil +} + +// ConvertImageToPNG converts a WebP image to PNG format +func ConvertImageToPNG(webpData []byte) ([]byte, error) { + // bypass if it's already a PNG image + if bytes.HasPrefix(webpData, []byte("\x89PNG")) { + return webpData, nil + } + + // check if is jpeg, convert to png + if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) { + img, _, err := image.Decode(bytes.NewReader(webpData)) + if err != nil { + return nil, errors.Wrap(err, "decode jpeg") + } + + var pngBuffer bytes.Buffer + if err := png.Encode(&pngBuffer, img); err != nil { + return nil, errors.Wrap(err, "encode png") + } + + return pngBuffer.Bytes(), nil + } + + // Decode the WebP image + img, err := webp.Decode(bytes.NewReader(webpData)) + if err != nil { + return nil, errors.Wrap(err, "decode webp") + } + + // Encode the image as PNG + var pngBuffer bytes.Buffer + if err := png.Encode(&pngBuffer, img); err != nil { + return nil, errors.Wrap(err, "encode png") + } + + return pngBuffer.Bytes(), nil +} diff --git a/relay/adaptor/replicate/model.go b/relay/adaptor/replicate/model.go new file mode 100644 index 0000000..dba277e --- /dev/null +++ b/relay/adaptor/replicate/model.go @@ -0,0 +1,159 @@ +package replicate + +import ( + "time" + + "github.com/pkg/errors" +) + +// DrawImageRequest draw image by fluxpro +// +// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json +type DrawImageRequest struct { + Input ImageInput `json:"input"` +} + +// ImageInput is input of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema +type ImageInput struct { + Steps int `json:"steps" binding:"required,min=1"` + Prompt string `json:"prompt" binding:"required,min=5"` + ImagePrompt string `json:"image_prompt"` + Guidance int `json:"guidance" binding:"required,min=2,max=5"` + Interval int `json:"interval" binding:"required,min=1,max=4"` + AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"` + SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"` + Seed int `json:"seed"` + NImages int `json:"n_images" binding:"required,min=1,max=8"` + Width int `json:"width" binding:"required,min=256,max=1440"` + Height int `json:"height" binding:"required,min=256,max=1440"` +} + +// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro +// +// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema +type InpaintingImageByFlusReplicateRequest struct { + Input FluxInpaintingInput `json:"input"` +} + +// FluxInpaintingInput is input of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema +type FluxInpaintingInput struct { + Mask string `json:"mask" binding:"required"` + Image string `json:"image" binding:"required"` + Seed int `json:"seed"` + Steps int `json:"steps" binding:"required,min=1"` + Prompt string `json:"prompt" binding:"required,min=5"` + Guidance int `json:"guidance" binding:"required,min=2,max=5"` + OutputFormat string `json:"output_format"` + SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"` + PromptUnsampling bool `json:"prompt_unsampling"` +} + +// ImageResponse is response of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json +type ImageResponse struct { + CompletedAt time.Time `json:"completed_at"` + CreatedAt time.Time `json:"created_at"` + DataRemoved bool `json:"data_removed"` + Error string `json:"error"` + ID string `json:"id"` + Input DrawImageRequest `json:"input"` + Logs string `json:"logs"` + Metrics FluxMetrics `json:"metrics"` + // Output could be `string` or `[]string` + Output any `json:"output"` + StartedAt time.Time `json:"started_at"` + Status string `json:"status"` + URLs FluxURLs `json:"urls"` + Version string `json:"version"` +} + +func (r *ImageResponse) GetOutput() ([]string, error) { + switch v := r.Output.(type) { + case string: + return []string{v}, nil + case []string: + return v, nil + case nil: + return nil, nil + case []interface{}: + // convert []interface{} to []string + ret := make([]string, len(v)) + for idx, vv := range v { + if vvv, ok := vv.(string); ok { + ret[idx] = vvv + } else { + return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv) + } + } + + return ret, nil + default: + return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output) + } +} + +// FluxMetrics is metrics of ImageResponse +type FluxMetrics struct { + ImageCount int `json:"image_count"` + PredictTime float64 `json:"predict_time"` + TotalTime float64 `json:"total_time"` +} + +// FluxURLs is urls of ImageResponse +type FluxURLs struct { + Get string `json:"get"` + Cancel string `json:"cancel"` +} + +type ReplicateChatRequest struct { + Input ChatInput `json:"input" form:"input" binding:"required"` +} + +// ChatInput is input of ChatByReplicateRequest +// +// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema +type ChatInput struct { + TopK int `json:"top_k"` + TopP float64 `json:"top_p"` + Prompt string `json:"prompt"` + MaxTokens int `json:"max_tokens"` + MinTokens int `json:"min_tokens"` + Temperature float64 `json:"temperature"` + SystemPrompt string `json:"system_prompt"` + StopSequences string `json:"stop_sequences"` + PromptTemplate string `json:"prompt_template"` + PresencePenalty float64 `json:"presence_penalty"` + FrequencyPenalty float64 `json:"frequency_penalty"` +} + +// ChatResponse is response of ChatByReplicateRequest +// +// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json +type ChatResponse struct { + CompletedAt time.Time `json:"completed_at"` + CreatedAt time.Time `json:"created_at"` + DataRemoved bool `json:"data_removed"` + Error string `json:"error"` + ID string `json:"id"` + Input ChatInput `json:"input"` + Logs string `json:"logs"` + Metrics FluxMetrics `json:"metrics"` + // Output could be `string` or `[]string` + Output []string `json:"output"` + StartedAt time.Time `json:"started_at"` + Status string `json:"status"` + URLs ChatResponseUrl `json:"urls"` + Version string `json:"version"` +} + +// ChatResponseUrl is task urls of ChatResponse +type ChatResponseUrl struct { + Stream string `json:"stream"` + Get string `json:"get"` + Cancel string `json:"cancel"` +} diff --git a/relay/adaptor/siliconflow/constants.go b/relay/adaptor/siliconflow/constants.go new file mode 100644 index 0000000..0bf5476 --- /dev/null +++ b/relay/adaptor/siliconflow/constants.go @@ -0,0 +1,36 @@ +package siliconflow + +// https://docs.siliconflow.cn/docs/getting-started + +var ModelList = []string{ + "deepseek-ai/deepseek-llm-67b-chat", + "Qwen/Qwen1.5-14B-Chat", + "Qwen/Qwen1.5-7B-Chat", + "Qwen/Qwen1.5-110B-Chat", + "Qwen/Qwen1.5-32B-Chat", + "01-ai/Yi-1.5-6B-Chat", + "01-ai/Yi-1.5-9B-Chat-16K", + "01-ai/Yi-1.5-34B-Chat-16K", + "THUDM/chatglm3-6b", + "deepseek-ai/DeepSeek-V2-Chat", + "THUDM/glm-4-9b-chat", + "Qwen/Qwen2-72B-Instruct", + "Qwen/Qwen2-7B-Instruct", + "Qwen/Qwen2-57B-A14B-Instruct", + "deepseek-ai/DeepSeek-Coder-V2-Instruct", + "Qwen/Qwen2-1.5B-Instruct", + "internlm/internlm2_5-7b-chat", + "BAAI/bge-large-en-v1.5", + "BAAI/bge-large-zh-v1.5", + "Pro/Qwen/Qwen2-7B-Instruct", + "Pro/Qwen/Qwen2-1.5B-Instruct", + "Pro/Qwen/Qwen1.5-7B-Chat", + "Pro/THUDM/glm-4-9b-chat", + "Pro/THUDM/chatglm3-6b", + "Pro/01-ai/Yi-1.5-9B-Chat-16K", + "Pro/01-ai/Yi-1.5-6B-Chat", + "Pro/google/gemma-2-9b-it", + "Pro/internlm/internlm2_5-7b-chat", + "Pro/meta-llama/Meta-Llama-3-8B-Instruct", + "Pro/mistralai/Mistral-7B-Instruct-v0.2", +} diff --git a/relay/adaptor/stepfun/constants.go b/relay/adaptor/stepfun/constants.go new file mode 100644 index 0000000..6a2346c --- /dev/null +++ b/relay/adaptor/stepfun/constants.go @@ -0,0 +1,13 @@ +package stepfun + +var ModelList = []string{ + "step-1-8k", + "step-1-32k", + "step-1-128k", + "step-1-256k", + "step-1-flash", + "step-2-16k", + "step-1v-8k", + "step-1v-32k", + "step-1x-medium", +} diff --git a/relay/adaptor/tencent/adaptor.go b/relay/adaptor/tencent/adaptor.go new file mode 100644 index 0000000..b20d427 --- /dev/null +++ b/relay/adaptor/tencent/adaptor.go @@ -0,0 +1,105 @@ +package tencent + +import ( + "errors" + "io" + "net/http" + "strconv" + "strings" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +// https://cloud.tencent.com/document/api/1729/101837 + +type Adaptor struct { + Sign string + Action string + Version string + Timestamp int64 +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.Action = "ChatCompletions" + a.Version = "2023-09-01" + a.Timestamp = helper.GetTimestamp() +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return meta.BaseURL + "/", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", a.Sign) + req.Header.Set("X-TC-Action", a.Action) + req.Header.Set("X-TC-Version", a.Version) + req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10)) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + _, secretId, secretKey, err := ParseConfig(apiKey) + if err != nil { + return nil, err + } + var convertedRequest any + switch relayMode { + case relaymode.Embeddings: + a.Action = "GetEmbedding" + convertedRequest = ConvertEmbeddingRequest(*request) + default: + a.Action = "ChatCompletions" + convertedRequest = ConvertRequest(*request) + } + // we have to calculate the sign here + a.Sign = GetSign(convertedRequest, a, secretId, secretKey) + return convertedRequest, nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + var responseText string + err, responseText = StreamHandler(c, resp) + usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } else { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "tencent" +} diff --git a/relay/adaptor/tencent/constants.go b/relay/adaptor/tencent/constants.go new file mode 100644 index 0000000..7997bfd --- /dev/null +++ b/relay/adaptor/tencent/constants.go @@ -0,0 +1,10 @@ +package tencent + +var ModelList = []string{ + "hunyuan-lite", + "hunyuan-standard", + "hunyuan-standard-256K", + "hunyuan-pro", + "hunyuan-vision", + "hunyuan-embedding", +} diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go new file mode 100644 index 0000000..8bf8e46 --- /dev/null +++ b/relay/adaptor/tencent/main.go @@ -0,0 +1,307 @@ +package tencent + +import ( + "bufio" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/common/render" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" +) + +func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { + messages := make([]*Message, 0, len(request.Messages)) + for i := 0; i < len(request.Messages); i++ { + message := request.Messages[i] + messages = append(messages, &Message{ + Content: message.StringContent(), + Role: message.Role, + }) + } + return &ChatRequest{ + Model: &request.Model, + Stream: &request.Stream, + Messages: messages, + TopP: request.TopP, + Temperature: request.Temperature, + } +} + +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + InputList: request.ParseInput(), + } +} + +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var tencentResponseP EmbeddingResponseP + err := json.NewDecoder(resp.Body).Decode(&tencentResponseP) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + tencentResponse := tencentResponseP.Response + if tencentResponse.Error.Code != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: tencentResponse.Error.Message, + Code: tencentResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + requestModel := c.GetString(ctxkey.RequestModel) + fullTextResponse := embeddingResponseTencent2OpenAI(&tencentResponse) + fullTextResponse.Model = requestModel + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseTencent2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)), + Model: "hunyuan-embedding", + Usage: model.Usage{TotalTokens: response.EmbeddingUsage.TotalTokens}, + } + + for _, item := range response.Data { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: item.Object, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + +func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { + fullTextResponse := openai.TextResponse{ + Id: response.ReqID, + Object: "chat.completion", + Created: helper.GetTimestamp(), + Usage: model.Usage{ + PromptTokens: response.Usage.PromptTokens, + CompletionTokens: response.Usage.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + }, + } + if len(response.Choices) > 0 { + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: response.Choices[0].Messages.Content, + }, + FinishReason: response.Choices[0].FinishReason, + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + response := openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: "tencent-hunyuan", + } + if len(TencentResponse.Choices) > 0 { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = TencentResponse.Choices[0].Delta.Content + if TencentResponse.Choices[0].FinishReason == "stop" { + choice.FinishReason = &constant.StopFinishReason + } + response.Choices = append(response.Choices, choice) + } + return &response +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { + var responseText string + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + common.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || !strings.HasPrefix(data, "data:") { + continue + } + data = strings.TrimPrefix(data, "data:") + + var tencentResponse ChatResponse + err := json.Unmarshal([]byte(data), &tencentResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response := streamResponseTencent2OpenAI(&tencentResponse) + if len(response.Choices) != 0 { + responseText += conv.AsString(response.Choices[0].Delta.Content) + } + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + + return nil, responseText +} + +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var TencentResponse ChatResponse + var responseP ChatResponseP + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &responseP) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + TencentResponse = responseP.Response + if TencentResponse.Error.Code != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: TencentResponse.Error.Message, + Code: TencentResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseTencent2OpenAI(&TencentResponse) + fullTextResponse.Model = "hunyuan" + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + if err != nil { + return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &fullTextResponse.Usage +} + +func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) { + parts := strings.Split(config, "|") + if len(parts) != 3 { + err = errors.New("invalid tencent config") + return + } + appId, err = strconv.ParseInt(parts[0], 10, 64) + secretId = parts[1] + secretKey = parts[2] + return +} + +func sha256hex(s string) string { + b := sha256.Sum256([]byte(s)) + return hex.EncodeToString(b[:]) +} + +func hmacSha256(s, key string) string { + hashed := hmac.New(sha256.New, []byte(key)) + hashed.Write([]byte(s)) + return string(hashed.Sum(nil)) +} + +func GetSign(req any, adaptor *Adaptor, secId, secKey string) string { + // build canonical request string + host := "hunyuan.tencentcloudapi.com" + httpRequestMethod := "POST" + canonicalURI := "/" + canonicalQueryString := "" + canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n", + "application/json", host, strings.ToLower(adaptor.Action)) + signedHeaders := "content-type;host;x-tc-action" + payload, _ := json.Marshal(req) + hashedRequestPayload := sha256hex(string(payload)) + canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", + httpRequestMethod, + canonicalURI, + canonicalQueryString, + canonicalHeaders, + signedHeaders, + hashedRequestPayload) + // build string to sign + algorithm := "TC3-HMAC-SHA256" + requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10) + timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64) + t := time.Unix(timestamp, 0).UTC() + // must be the format 2006-01-02, ref to package time for more info + date := t.Format("2006-01-02") + credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan") + hashedCanonicalRequest := sha256hex(canonicalRequest) + string2sign := fmt.Sprintf("%s\n%s\n%s\n%s", + algorithm, + requestTimestamp, + credentialScope, + hashedCanonicalRequest) + + // sign string + secretDate := hmacSha256(date, "TC3"+secKey) + secretService := hmacSha256("hunyuan", secretDate) + secretKey := hmacSha256("tc3_request", secretService) + signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey))) + + // build authorization + authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", + algorithm, + secId, + credentialScope, + signedHeaders, + signature) + return authorization +} diff --git a/relay/adaptor/tencent/model.go b/relay/adaptor/tencent/model.go new file mode 100644 index 0000000..fda6c6c --- /dev/null +++ b/relay/adaptor/tencent/model.go @@ -0,0 +1,101 @@ +package tencent + +type Message struct { + Role string `json:"Role"` + Content string `json:"Content"` +} + +type ChatRequest struct { + // 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。 + // 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。 + // + // 注意: + // 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。 + Model *string `json:"Model"` + // 聊天上下文信息。 + // 说明: + // 1. 长度最多为 40,按对话时间从旧到新在数组中排列。 + // 2. Message.Role 可选值:system、user、assistant。 + // 其中,system 角色可选,如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system(可选) user assistant user assistant user ...]。 + // 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。 + Messages []*Message `json:"Messages"` + // 流式调用开关。 + // 说明: + // 1. 未传值时默认为非流式调用(false)。 + // 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。 + // 3. 非流式调用时: + // 调用方式与普通 HTTP 请求无异。 + // 接口响应耗时较长,**如需更低时延建议设置为 true**。 + // 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。 + // + // 注意: + // 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。 + Stream *bool `json:"Stream"` + // 说明: + // 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。 + // 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。 + // 3. 非必要不建议使用,不合理的取值会影响效果。 + TopP *float64 `json:"TopP,omitempty"` + // 说明: + // 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。 + // 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。 + // 3. 非必要不建议使用,不合理的取值会影响效果。 + Temperature *float64 `json:"Temperature,omitempty"` +} + +type Error struct { + Code string `json:"Code"` + Message string `json:"Message"` +} + +type Usage struct { + PromptTokens int `json:"PromptTokens"` + CompletionTokens int `json:"CompletionTokens"` + TotalTokens int `json:"TotalTokens"` +} + +type ResponseChoices struct { + FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 + Messages Message `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 + Delta Message `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 +} + +type ChatResponse struct { + Choices []ResponseChoices `json:"Choices,omitempty"` // 结果 + Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串 + Id string `json:"Id,omitempty"` // 会话 id + Usage Usage `json:"Usage,omitempty"` // token 数量 + Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 + Note string `json:"Note,omitempty"` // 注释 + ReqID string `json:"RequestId,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 +} + +type ChatResponseP struct { + Response ChatResponse `json:"Response,omitempty"` +} + +type EmbeddingRequest struct { + InputList []string `json:"InputList"` +} + +type EmbeddingData struct { + Embedding []float64 `json:"Embedding"` + Index int `json:"Index"` + Object string `json:"Object"` +} + +type EmbeddingUsage struct { + PromptTokens int `json:"PromptTokens"` + TotalTokens int `json:"TotalTokens"` +} + +type EmbeddingResponse struct { + Data []EmbeddingData `json:"Data"` + EmbeddingUsage EmbeddingUsage `json:"Usage,omitempty"` + RequestId string `json:"RequestId,omitempty"` + Error Error `json:"Error,omitempty"` +} + +type EmbeddingResponseP struct { + Response EmbeddingResponse `json:"Response,omitempty"` +} diff --git a/relay/adaptor/togetherai/constants.go b/relay/adaptor/togetherai/constants.go new file mode 100644 index 0000000..0a79fbd --- /dev/null +++ b/relay/adaptor/togetherai/constants.go @@ -0,0 +1,10 @@ +package togetherai + +// https://docs.together.ai/docs/inference-models + +var ModelList = []string{ + "meta-llama/Llama-3-70b-chat-hf", + "deepseek-ai/deepseek-coder-33b-instruct", + "mistralai/Mixtral-8x22B-Instruct-v0.1", + "Qwen/Qwen1.5-72B-Chat", +} diff --git a/relay/adaptor/vertexai/adaptor.go b/relay/adaptor/vertexai/adaptor.go new file mode 100644 index 0000000..3fab4a4 --- /dev/null +++ b/relay/adaptor/vertexai/adaptor.go @@ -0,0 +1,117 @@ +package vertexai + +import ( + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + channelhelper "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +var _ adaptor.Adaptor = new(Adaptor) + +const channelName = "vertexai" + +type Adaptor struct{} + +func (a *Adaptor) Init(meta *meta.Meta) { +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + adaptor := GetAdaptor(request.Model) + if adaptor == nil { + return nil, errors.New("adaptor not found") + } + + return adaptor.ConvertRequest(c, relayMode, request) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + adaptor := GetAdaptor(meta.ActualModelName) + if adaptor == nil { + return nil, &relaymodel.ErrorWithStatusCode{ + StatusCode: http.StatusInternalServerError, + Error: relaymodel.Error{ + Message: "adaptor not found", + }, + } + } + return adaptor.DoResponse(c, resp, meta) +} + +func (a *Adaptor) GetModelList() (models []string) { + models = modelList + return +} + +func (a *Adaptor) GetChannelName() string { + return channelName +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + suffix := "" + if strings.HasPrefix(meta.ActualModelName, "gemini") { + if meta.IsStream { + suffix = "streamGenerateContent?alt=sse" + } else { + suffix = "generateContent" + } + } else { + if meta.IsStream { + suffix = "streamRawPredict?alt=sse" + } else { + suffix = "rawPredict" + } + } + + if meta.BaseURL != "" { + return fmt.Sprintf( + "%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + meta.BaseURL, + meta.Config.VertexAIProjectID, + meta.Config.Region, + meta.ActualModelName, + suffix, + ), nil + } + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + meta.Config.Region, + meta.Config.VertexAIProjectID, + meta.Config.Region, + meta.ActualModelName, + suffix, + ), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + token, err := getToken(c, meta.ChannelId, meta.Config.VertexAIADC) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+token) + return nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return channelhelper.DoRequestHelper(a, c, meta, requestBody) +} diff --git a/relay/adaptor/vertexai/claude/adapter.go b/relay/adaptor/vertexai/claude/adapter.go new file mode 100644 index 0000000..cb911cf --- /dev/null +++ b/relay/adaptor/vertexai/claude/adapter.go @@ -0,0 +1,60 @@ +package vertexai + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var ModelList = []string{ + "claude-3-haiku@20240307", + "claude-3-sonnet@20240229", + "claude-3-opus@20240229", + "claude-3-5-sonnet@20240620", + "claude-3-5-sonnet-v2@20241022", + "claude-3-5-haiku@20241022", +} + +const anthropicVersion = "vertex-2023-10-16" + +type Adaptor struct { +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + claudeReq := anthropic.ConvertRequest(*request) + req := Request{ + AnthropicVersion: anthropicVersion, + // Model: claudeReq.Model, + Messages: claudeReq.Messages, + System: claudeReq.System, + MaxTokens: claudeReq.MaxTokens, + Temperature: claudeReq.Temperature, + TopP: claudeReq.TopP, + TopK: claudeReq.TopK, + Stream: claudeReq.Stream, + Tools: claudeReq.Tools, + } + + c.Set(ctxkey.RequestModel, request.Model) + c.Set(ctxkey.ConvertedRequest, req) + return req, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = anthropic.StreamHandler(c, resp) + } else { + err, usage = anthropic.Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} diff --git a/relay/adaptor/vertexai/claude/model.go b/relay/adaptor/vertexai/claude/model.go new file mode 100644 index 0000000..c08ba46 --- /dev/null +++ b/relay/adaptor/vertexai/claude/model.go @@ -0,0 +1,19 @@ +package vertexai + +import "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + +type Request struct { + // AnthropicVersion must be "vertex-2023-10-16" + AnthropicVersion string `json:"anthropic_version"` + // Model string `json:"model"` + Messages []anthropic.Message `json:"messages"` + System string `json:"system,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []anthropic.Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` +} diff --git a/relay/adaptor/vertexai/gemini/adapter.go b/relay/adaptor/vertexai/gemini/adapter.go new file mode 100644 index 0000000..f5b245d --- /dev/null +++ b/relay/adaptor/vertexai/gemini/adapter.go @@ -0,0 +1,55 @@ +package vertexai + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/relay/adaptor/gemini" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/relaymode" + + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var ModelList = []string{ + "gemini-pro", "gemini-pro-vision", + "gemini-exp-1206", + "gemini-1.5-pro-001", "gemini-1.5-pro-002", + "gemini-1.5-flash-001", "gemini-1.5-flash-002", + "gemini-2.0-flash-exp", "gemini-2.0-flash-001", + "gemini-2.0-flash-lite-preview-02-05", + "gemini-2.0-flash-thinking-exp-01-21", +} + +type Adaptor struct { +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + geminiRequest := gemini.ConvertRequest(*request) + c.Set(ctxkey.RequestModel, request.Model) + c.Set(ctxkey.ConvertedRequest, geminiRequest) + return geminiRequest, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + var responseText string + err, responseText = gemini.StreamHandler(c, resp) + usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } else { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = gemini.EmbeddingHandler(c, resp) + default: + err, usage = gemini.Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + } + return +} diff --git a/relay/adaptor/vertexai/registry.go b/relay/adaptor/vertexai/registry.go new file mode 100644 index 0000000..41099f0 --- /dev/null +++ b/relay/adaptor/vertexai/registry.go @@ -0,0 +1,50 @@ +package vertexai + +import ( + "net/http" + + "github.com/gin-gonic/gin" + claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude" + gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +type VertexAIModelType int + +const ( + VerterAIClaude VertexAIModelType = iota + 1 + VerterAIGemini +) + +var modelMapping = map[string]VertexAIModelType{} +var modelList = []string{} + +func init() { + modelList = append(modelList, claude.ModelList...) + for _, model := range claude.ModelList { + modelMapping[model] = VerterAIClaude + } + + modelList = append(modelList, gemini.ModelList...) + for _, model := range gemini.ModelList { + modelMapping[model] = VerterAIGemini + } +} + +type innerAIAdapter interface { + ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) +} + +func GetAdaptor(model string) innerAIAdapter { + adaptorType := modelMapping[model] + switch adaptorType { + case VerterAIClaude: + return &claude.Adaptor{} + case VerterAIGemini: + return &gemini.Adaptor{} + default: + return nil + } +} diff --git a/relay/adaptor/vertexai/token.go b/relay/adaptor/vertexai/token.go new file mode 100644 index 0000000..0a5e0aa --- /dev/null +++ b/relay/adaptor/vertexai/token.go @@ -0,0 +1,62 @@ +package vertexai + +import ( + "context" + "encoding/json" + "fmt" + "time" + + credentials "cloud.google.com/go/iam/credentials/apiv1" + "cloud.google.com/go/iam/credentials/apiv1/credentialspb" + "github.com/patrickmn/go-cache" + "google.golang.org/api/option" +) + +type ApplicationDefaultCredentials struct { + Type string `json:"type"` + ProjectID string `json:"project_id"` + PrivateKeyID string `json:"private_key_id"` + PrivateKey string `json:"private_key"` + ClientEmail string `json:"client_email"` + ClientID string `json:"client_id"` + AuthURI string `json:"auth_uri"` + TokenURI string `json:"token_uri"` + AuthProviderX509CertURL string `json:"auth_provider_x509_cert_url"` + ClientX509CertURL string `json:"client_x509_cert_url"` + UniverseDomain string `json:"universe_domain"` +} + +var Cache = cache.New(50*time.Minute, 55*time.Minute) + +const defaultScope = "https://www.googleapis.com/auth/cloud-platform" + +func getToken(ctx context.Context, channelId int, adcJson string) (string, error) { + cacheKey := fmt.Sprintf("vertexai-token-%d", channelId) + if token, found := Cache.Get(cacheKey); found { + return token.(string), nil + } + adc := &ApplicationDefaultCredentials{} + if err := json.Unmarshal([]byte(adcJson), adc); err != nil { + return "", fmt.Errorf("Failed to decode credentials file: %w", err) + } + + c, err := credentials.NewIamCredentialsClient(ctx, option.WithCredentialsJSON([]byte(adcJson))) + if err != nil { + return "", fmt.Errorf("Failed to create client: %w", err) + } + defer c.Close() + + req := &credentialspb.GenerateAccessTokenRequest{ + // See https://pkg.go.dev/cloud.google.com/go/iam/credentials/apiv1/credentialspb#GenerateAccessTokenRequest. + Name: fmt.Sprintf("projects/-/serviceAccounts/%s", adc.ClientEmail), + Scope: []string{defaultScope}, + } + resp, err := c.GenerateAccessToken(ctx, req) + if err != nil { + return "", fmt.Errorf("Failed to generate access token: %w", err) + } + _ = resp + + Cache.Set(cacheKey, resp.AccessToken, cache.DefaultExpiration) + return resp.AccessToken, nil +} diff --git a/relay/adaptor/xai/constants.go b/relay/adaptor/xai/constants.go new file mode 100644 index 0000000..3955b62 --- /dev/null +++ b/relay/adaptor/xai/constants.go @@ -0,0 +1,14 @@ +package xai + +//https://console.x.ai/ + +var ModelList = []string{ + "grok-2", + "grok-vision-beta", + "grok-2-vision-1212", + "grok-2-vision", + "grok-2-vision-latest", + "grok-2-1212", + "grok-2-latest", + "grok-beta", +} diff --git a/relay/adaptor/xunfei/adaptor.go b/relay/adaptor/xunfei/adaptor.go new file mode 100644 index 0000000..b5967f2 --- /dev/null +++ b/relay/adaptor/xunfei/adaptor.go @@ -0,0 +1,86 @@ +package xunfei + +import ( + "errors" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" +) + +type Adaptor struct { + request *model.GeneralOpenAIRequest + meta *meta.Meta +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return "", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + // check DoResponse for auth part + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + a.request = request + return nil, nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + // xunfei's request is not http request, so we don't need to do anything here + dummyResp := &http.Response{} + dummyResp.StatusCode = http.StatusOK + return dummyResp, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + splits := strings.Split(meta.APIKey, "|") + if len(splits) != 3 { + return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) + } + if a.request == nil { + return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) + } + version := parseAPIVersionByModelName(meta.ActualModelName) + if version == "" { + version = a.meta.Config.APIVersion + } + if version == "" { + version = "v1.1" + } + a.meta.Config.APIVersion = version + if meta.IsStream { + err, usage = StreamHandler(c, meta, *a.request, splits[0], splits[1], splits[2]) + } else { + err, usage = Handler(c, meta, *a.request, splits[0], splits[1], splits[2]) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "xunfei" +} diff --git a/relay/adaptor/xunfei/constants.go b/relay/adaptor/xunfei/constants.go new file mode 100644 index 0000000..e19c184 --- /dev/null +++ b/relay/adaptor/xunfei/constants.go @@ -0,0 +1,10 @@ +package xunfei + +var ModelList = []string{ + "Spark-Lite", + "Spark-Pro", + "Spark-Pro-128K", + "Spark-Max", + "Spark-Max-32K", + "Spark-4.0-Ultra", +} diff --git a/relay/adaptor/xunfei/domain.go b/relay/adaptor/xunfei/domain.go new file mode 100644 index 0000000..fd961ba --- /dev/null +++ b/relay/adaptor/xunfei/domain.go @@ -0,0 +1,97 @@ +package xunfei + +import ( + "fmt" + "strings" +) + +// https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E + +//Spark4.0 Ultra 请求地址,对应的domain参数为4.0Ultra: +// +//wss://spark-api.xf-yun.com/v4.0/chat +//Spark Max-32K请求地址,对应的domain参数为max-32k +// +//wss://spark-api.xf-yun.com/chat/max-32k +//Spark Max请求地址,对应的domain参数为generalv3.5 +// +//wss://spark-api.xf-yun.com/v3.5/chat +//Spark Pro-128K请求地址,对应的domain参数为pro-128k: +// +// wss://spark-api.xf-yun.com/chat/pro-128k +//Spark Pro请求地址,对应的domain参数为generalv3: +// +//wss://spark-api.xf-yun.com/v3.1/chat +//Spark Lite请求地址,对应的domain参数为lite: +// +//wss://spark-api.xf-yun.com/v1.1/chat + +// Lite、Pro、Pro-128K、Max、Max-32K和4.0 Ultra + +func parseAPIVersionByModelName(modelName string) string { + apiVersion := modelName2APIVersion(modelName) + if apiVersion != "" { + return apiVersion + } + + index := strings.IndexAny(modelName, "-") + if index != -1 { + return modelName[index+1:] + } + return "" +} + +func modelName2APIVersion(modelName string) string { + switch modelName { + case "Spark-Lite": + return "v1.1" + case "Spark-Pro": + return "v3.1" + case "Spark-Pro-128K": + return "v3.1-128K" + case "Spark-Max": + return "v3.5" + case "Spark-Max-32K": + return "v3.5-32K" + case "Spark-4.0-Ultra": + return "v4.0" + } + return "" +} + +// https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E +func apiVersion2domain(apiVersion string) string { + switch apiVersion { + case "v1.1": + return "lite" + case "v2.1": + return "generalv2" + case "v3.1": + return "generalv3" + case "v3.1-128K": + return "pro-128k" + case "v3.5": + return "generalv3.5" + case "v3.5-32K": + return "max-32k" + case "v4.0": + return "4.0Ultra" + } + return "general" + apiVersion +} + +func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) { + var authUrl string + domain := apiVersion2domain(apiVersion) + switch apiVersion { + case "v3.1-128K": + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/pro-128k"), apiKey, apiSecret) + break + case "v3.5-32K": + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret) + break + default: + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) + } + return domain, authUrl +} diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go new file mode 100644 index 0000000..9a8aef1 --- /dev/null +++ b/relay/adaptor/xunfei/main.go @@ -0,0 +1,273 @@ +package xunfei + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +// https://console.xfyun.cn/services/cbm +// https://www.xfyun.cn/doc/spark/Web.html + +func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { + messages := make([]Message, 0, len(request.Messages)) + for _, message := range request.Messages { + messages = append(messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) + } + xunfeiRequest := ChatRequest{} + xunfeiRequest.Header.AppId = xunfeiAppId + xunfeiRequest.Parameter.Chat.Domain = domain + xunfeiRequest.Parameter.Chat.Temperature = request.Temperature + xunfeiRequest.Parameter.Chat.TopK = request.N + xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens + xunfeiRequest.Payload.Message.Text = messages + + if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" { + functions := make([]model.Function, len(request.Tools)) + for i, tool := range request.Tools { + functions[i] = tool.Function + } + xunfeiRequest.Payload.Functions = &Functions{ + Text: functions, + } + } + + return &xunfeiRequest +} + +func getToolCalls(response *ChatResponse) []model.Tool { + var toolCalls []model.Tool + if len(response.Payload.Choices.Text) == 0 { + return toolCalls + } + item := response.Payload.Choices.Text[0] + if item.FunctionCall == nil { + return toolCalls + } + toolCall := model.Tool{ + Id: fmt.Sprintf("call_%s", random.GetUUID()), + Type: "function", + Function: *item.FunctionCall, + } + toolCalls = append(toolCalls, toolCall) + return toolCalls +} + +func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { + if len(response.Payload.Choices.Text) == 0 { + response.Payload.Choices.Text = []ChatResponseTextItem{ + { + Content: "", + }, + } + } + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: response.Payload.Choices.Text[0].Content, + ToolCalls: getToolCalls(response), + }, + FinishReason: constant.StopFinishReason, + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + Usage: response.Payload.Usage.Text, + } + return &fullTextResponse +} + +func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ + { + Content: "", + }, + } + } + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content + choice.Delta.ToolCalls = getToolCalls(xunfeiResponse) + if xunfeiResponse.Payload.Choices.Status == 2 { + choice.FinishReason = &constant.StopFinishReason + } + response := openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: "SparkDesk", + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { + HmacWithShaToBase64 := func(algorithm, data, key string) string { + mac := hmac.New(sha256.New, []byte(key)) + mac.Write([]byte(data)) + encodeData := mac.Sum(nil) + return base64.StdEncoding.EncodeToString(encodeData) + } + ul, err := url.Parse(hostUrl) + if err != nil { + fmt.Println(err) + } + date := time.Now().UTC().Format(time.RFC1123) + signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} + sign := strings.Join(signString, "\n") + sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) + authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, + "hmac-sha256", "host date request-line", sha) + authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) + v := url.Values{} + v.Add("host", ul.Host) + v.Add("date", date) + v.Add("authorization", authorization) + callUrl := hostUrl + "?" + v.Encode() + return callUrl +} + +func StreamHandler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { + domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + if err != nil { + return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil + } + common.SetEventStreamHeaders(c) + var usage model.Usage + c.Stream(func(w io.Writer) bool { + select { + case xunfeiResponse := <-dataChan: + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + response := streamResponseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + return nil, &usage +} + +func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { + domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + if err != nil { + return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil + } + var usage model.Usage + var content string + var xunfeiResponse ChatResponse + stop := false + for !stop { + select { + case xunfeiResponse = <-dataChan: + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + continue + } + content += xunfeiResponse.Payload.Choices.Text[0].Content + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + case stop = <-stopChan: + } + } + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + return openai.ErrorWrapper(errors.New("xunfei empty response detected"), "xunfei_empty_response_detected", http.StatusInternalServerError), nil + } + xunfeiResponse.Payload.Choices.Text[0].Content = content + + response := responseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + _, _ = c.Writer.Write(jsonResponse) + return nil, &usage +} + +func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) { + d := websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + } + conn, resp, err := d.Dial(authUrl, nil) + if err != nil || resp.StatusCode != 101 { + return nil, nil, err + } + data := requestOpenAI2Xunfei(textRequest, appId, domain) + err = conn.WriteJSON(data) + if err != nil { + return nil, nil, err + } + _, msg, err := conn.ReadMessage() + if err != nil { + return nil, nil, err + } + + dataChan := make(chan ChatResponse) + stopChan := make(chan bool) + go func() { + for { + if msg == nil { + _, msg, err = conn.ReadMessage() + if err != nil { + logger.SysError("error reading stream response: " + err.Error()) + break + } + } + var response ChatResponse + err = json.Unmarshal(msg, &response) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + break + } + msg = nil + dataChan <- response + if response.Payload.Choices.Status == 2 { + err := conn.Close() + if err != nil { + logger.SysError("error closing websocket connection: " + err.Error()) + } + break + } + } + stopChan <- true + }() + + return dataChan, stopChan, nil +} diff --git a/relay/adaptor/xunfei/model.go b/relay/adaptor/xunfei/model.go new file mode 100644 index 0000000..c9fb1bb --- /dev/null +++ b/relay/adaptor/xunfei/model.go @@ -0,0 +1,68 @@ +package xunfei + +import ( + "github.com/songquanpeng/one-api/relay/model" +) + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type Functions struct { + Text []model.Function `json:"text,omitempty"` +} + +type ChatRequest struct { + Header struct { + AppId string `json:"app_id"` + } `json:"header"` + Parameter struct { + Chat struct { + Domain string `json:"domain,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Auditing bool `json:"auditing,omitempty"` + } `json:"chat"` + } `json:"parameter"` + Payload struct { + Message struct { + Text []Message `json:"text"` + } `json:"message"` + Functions *Functions `json:"functions,omitempty"` + } `json:"payload"` +} + +type ChatResponseTextItem struct { + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` + ContentType string `json:"content_type"` + FunctionCall *model.Function `json:"function_call"` +} + +type ChatResponse struct { + Header struct { + Code int `json:"code"` + Message string `json:"message"` + Sid string `json:"sid"` + Status int `json:"status"` + } `json:"header"` + Payload struct { + Choices struct { + Status int `json:"status"` + Seq int `json:"seq"` + Text []ChatResponseTextItem `json:"text"` + } `json:"choices"` + Usage struct { + //Text struct { + // QuestionTokens string `json:"question_tokens"` + // PromptTokens string `json:"prompt_tokens"` + // CompletionTokens string `json:"completion_tokens"` + // TotalTokens string `json:"total_tokens"` + //} `json:"text"` + Text model.Usage `json:"text"` + } `json:"usage"` + } `json:"payload"` +} diff --git a/relay/adaptor/xunfeiv2/constants.go b/relay/adaptor/xunfeiv2/constants.go new file mode 100644 index 0000000..08c3264 --- /dev/null +++ b/relay/adaptor/xunfeiv2/constants.go @@ -0,0 +1,12 @@ +package xunfeiv2 + +// https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E + +var ModelList = []string{ + "lite", + "generalv3", + "pro-128k", + "generalv3.5", + "max-32k", + "4.0Ultra", +} diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go new file mode 100644 index 0000000..660bd37 --- /dev/null +++ b/relay/adaptor/zhipu/adaptor.go @@ -0,0 +1,149 @@ +package zhipu + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" + "strings" +) + +type Adaptor struct { + APIVersion string +} + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) SetVersionByModeName(modelName string) { + if strings.HasPrefix(modelName, "glm-") { + a.APIVersion = "v4" + } else { + a.APIVersion = "v3" + } +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + switch meta.Mode { + case relaymode.ImagesGenerations: + return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil + case relaymode.Embeddings: + return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil + } + a.SetVersionByModeName(meta.ActualModelName) + if a.APIVersion == "v4" { + return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil + } + method := "invoke" + if meta.IsStream { + method = "sse-invoke" + } + return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + token := GetToken(meta.APIKey) + req.Header.Set("Authorization", token) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case relaymode.Embeddings: + baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request) + return baiduEmbeddingRequest, err + default: + // TopP [0.0, 1.0] + request.TopP = helper.Float64PtrMax(request.TopP, 1) + request.TopP = helper.Float64PtrMin(request.TopP, 0) + + // Temperature [0.0, 1.0] + request.Temperature = helper.Float64PtrMax(request.Temperature, 1) + request.Temperature = helper.Float64PtrMin(request.Temperature, 0) + a.SetVersionByModeName(request.Model) + if a.APIVersion == "v4" { + return request, nil + } + return ConvertRequest(*request), nil + } +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + newRequest := ImageRequest{ + Model: request.Model, + Prompt: request.Prompt, + UserId: request.User, + } + return newRequest, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, _, usage = openai.StreamHandler(c, resp, meta.Mode) + } else { + err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingsHandler(c, resp) + return + case relaymode.ImagesGenerations: + err, usage = openai.ImageHandler(c, resp) + return + } + if a.APIVersion == "v4" { + return a.DoResponseV4(c, resp, meta) + } + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + if meta.Mode == relaymode.Embeddings { + err, usage = EmbeddingsHandler(c, resp) + } else { + err, usage = Handler(c, resp) + } + } + return +} + +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) (*EmbeddingRequest, error) { + inputs := request.ParseInput() + if len(inputs) != 1 { + return nil, errors.New("invalid input length, zhipu only support one input") + } + return &EmbeddingRequest{ + Model: request.Model, + Input: inputs[0], + }, nil +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "zhipu" +} diff --git a/relay/adaptor/zhipu/constants.go b/relay/adaptor/zhipu/constants.go new file mode 100644 index 0000000..a86ffc4 --- /dev/null +++ b/relay/adaptor/zhipu/constants.go @@ -0,0 +1,14 @@ +package zhipu + +// https://open.bigmodel.cn/pricing + +var ModelList = []string{ + "glm-zero-preview", "glm-4-plus", "glm-4-0520", "glm-4-airx", + "glm-4-air", "glm-4-long", "glm-4-flashx", "glm-4-flash", + "glm-4", "glm-3-turbo", + "glm-4v-plus", "glm-4v", "glm-4v-flash", + "cogview-3-plus", "cogview-3", "cogview-3-flash", + "cogviewx", "cogviewx-flash", + "charglm-4", "emohaa", "codegeex-4", + "embedding-2", "embedding-3", +} diff --git a/relay/adaptor/zhipu/main.go b/relay/adaptor/zhipu/main.go new file mode 100644 index 0000000..ab3a567 --- /dev/null +++ b/relay/adaptor/zhipu/main.go @@ -0,0 +1,294 @@ +package zhipu + +import ( + "bufio" + "encoding/json" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" +) + +// https://open.bigmodel.cn/doc/api#chatglm_std +// chatglm_std, chatglm_lite +// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke +// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke + +var zhipuTokens sync.Map +var expSeconds int64 = 24 * 3600 + +func GetToken(apikey string) string { + data, ok := zhipuTokens.Load(apikey) + if ok { + tokenData := data.(tokenData) + if time.Now().Before(tokenData.ExpiryTime) { + return tokenData.Token + } + } + + split := strings.Split(apikey, ".") + if len(split) != 2 { + logger.SysError("invalid zhipu key: " + apikey) + return "" + } + + id := split[0] + secret := split[1] + + expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6 + expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second) + + timestamp := time.Now().UnixNano() / 1e6 + + payload := jwt.MapClaims{ + "api_key": id, + "exp": expMillis, + "timestamp": timestamp, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) + + token.Header["alg"] = "HS256" + token.Header["sign_type"] = "SIGN" + + tokenString, err := token.SignedString([]byte(secret)) + if err != nil { + return "" + } + + zhipuTokens.Store(apikey, tokenData{ + Token: tokenString, + ExpiryTime: expiryTime, + }) + + return tokenString +} + +func ConvertRequest(request model.GeneralOpenAIRequest) *Request { + messages := make([]Message, 0, len(request.Messages)) + for _, message := range request.Messages { + messages = append(messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) + } + return &Request{ + Prompt: messages, + Temperature: request.Temperature, + TopP: request.TopP, + Incremental: false, + } +} + +func responseZhipu2OpenAI(response *Response) *openai.TextResponse { + fullTextResponse := openai.TextResponse{ + Id: response.Data.TaskId, + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)), + Usage: response.Data.Usage, + } + for i, choice := range response.Data.Choices { + openaiChoice := openai.TextResponseChoice{ + Index: i, + Message: model.Message{ + Role: choice.Role, + Content: strings.Trim(choice.Content, "\""), + }, + FinishReason: "", + } + if i == len(response.Data.Choices)-1 { + openaiChoice.FinishReason = "stop" + } + fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice) + } + return &fullTextResponse +} + +func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = zhipuResponse + response := openai.ChatCompletionsStreamResponse{ + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: "chatglm", + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *model.Usage) { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = "" + choice.FinishReason = &constant.StopFinishReason + response := openai.ChatCompletionsStreamResponse{ + Id: zhipuResponse.RequestId, + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: "chatglm", + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &response, &zhipuResponse.Usage +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage *model.Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { + return i + 2, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + + common.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + lines := strings.Split(data, "\n") + for i, line := range lines { + if len(line) < 5 { + continue + } + if strings.HasPrefix(line, "data:") { + dataSegment := line[5:] + if i != len(lines)-1 { + dataSegment += "\n" + } + response := streamResponseZhipu2OpenAI(dataSegment) + err := render.ObjectData(c, response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + } + } else if strings.HasPrefix(line, "meta:") { + metaSegment := line[5:] + var zhipuResponse StreamMetaResponse + err := json.Unmarshal([]byte(metaSegment), &zhipuResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + } + usage = zhipuUsage + } + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + return nil, usage +} + +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var zhipuResponse Response + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &zhipuResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if !zhipuResponse.Success { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: zhipuResponse.Msg, + Type: "zhipu_error", + Param: "", + Code: zhipuResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) + fullTextResponse.Model = "chatglm" + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var zhipuResponse EmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &zhipuResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseZhipu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)), + Model: response.Model, + Usage: model.Usage{ + PromptTokens: response.PromptTokens, + CompletionTokens: response.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + }, + } + + for _, item := range response.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} diff --git a/relay/adaptor/zhipu/model.go b/relay/adaptor/zhipu/model.go new file mode 100644 index 0000000..06e22dc --- /dev/null +++ b/relay/adaptor/zhipu/model.go @@ -0,0 +1,70 @@ +package zhipu + +import ( + "github.com/songquanpeng/one-api/relay/model" + "time" +) + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type Request struct { + Prompt []Message `json:"prompt"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + RequestId string `json:"request_id,omitempty"` + Incremental bool `json:"incremental,omitempty"` +} + +type ResponseData struct { + TaskId string `json:"task_id"` + RequestId string `json:"request_id"` + TaskStatus string `json:"task_status"` + Choices []Message `json:"choices"` + model.Usage `json:"usage"` +} + +type Response struct { + Code int `json:"code"` + Msg string `json:"msg"` + Success bool `json:"success"` + Data ResponseData `json:"data"` +} + +type StreamMetaResponse struct { + RequestId string `json:"request_id"` + TaskId string `json:"task_id"` + TaskStatus string `json:"task_status"` + model.Usage `json:"usage"` +} + +type tokenData struct { + Token string + ExpiryTime time.Time +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input string `json:"input"` +} + +type EmbeddingResponse struct { + Model string `json:"model"` + Object string `json:"object"` + Embeddings []EmbeddingData `json:"data"` + model.Usage `json:"usage"` +} + +type EmbeddingData struct { + Index int `json:"index"` + Object string `json:"object"` + Embedding []float64 `json:"embedding"` +} + +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + UserId string `json:"user_id,omitempty"` +} diff --git a/relay/adaptor_test.go b/relay/adaptor_test.go new file mode 100644 index 0000000..884c8e9 --- /dev/null +++ b/relay/adaptor_test.go @@ -0,0 +1,16 @@ +package relay + +import ( + . "github.com/smartystreets/goconvey/convey" + "github.com/songquanpeng/one-api/relay/apitype" + "testing" +) + +func TestGetAdaptor(t *testing.T) { + Convey("get adaptor", t, func() { + for i := 0; i < apitype.Dummy; i++ { + a := GetAdaptor(i) + So(a, ShouldNotBeNil) + } + }) +} diff --git a/relay/apitype/define.go b/relay/apitype/define.go new file mode 100644 index 0000000..0c6a5ff --- /dev/null +++ b/relay/apitype/define.go @@ -0,0 +1,25 @@ +package apitype + +const ( + OpenAI = iota + Anthropic + PaLM + Baidu + Zhipu + Ali + Xunfei + AIProxyLibrary + Tencent + Gemini + Ollama + AwsClaude + Coze + Cohere + Cloudflare + DeepL + VertexAI + Proxy + Replicate + + Dummy // this one is only for count, do not add any channel after this +) diff --git a/relay/billing/billing.go b/relay/billing/billing.go new file mode 100644 index 0000000..2f87dfa --- /dev/null +++ b/relay/billing/billing.go @@ -0,0 +1,52 @@ +package billing + +import ( + "context" + "fmt" + + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" +) + +func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { + if preConsumedQuota != 0 { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, "error return pre-consumed quota: "+err.Error()) + } + }(ctx) + } +} + +func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + // quotaDelta is remaining quota to be consumed + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + logger.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(ctx, userId) + if err != nil { + logger.SysError("error update user quota cache: " + err.Error()) + } + // totalQuota is total quota consumed + if totalQuota != 0 { + logContent := fmt.Sprintf("倍率:%.2f × %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, &model.Log{ + UserId: userId, + ChannelId: channelId, + PromptTokens: int(totalQuota), + CompletionTokens: 0, + ModelName: modelName, + TokenName: tokenName, + Quota: int(totalQuota), + Content: logContent, + }) + model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) + model.UpdateChannelUsedQuota(channelId, totalQuota) + } + if totalQuota <= 0 { + logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) + } +} diff --git a/relay/billing/ratio/group.go b/relay/billing/ratio/group.go new file mode 100644 index 0000000..b7f62e6 --- /dev/null +++ b/relay/billing/ratio/group.go @@ -0,0 +1,40 @@ +package ratio + +import ( + "encoding/json" + "github.com/songquanpeng/one-api/common/logger" + "sync" +) + +var groupRatioLock sync.RWMutex +var GroupRatio = map[string]float64{ + "default": 1, + "vip": 1, + "svip": 1, +} + +func GroupRatio2JSONString() string { + jsonBytes, err := json.Marshal(GroupRatio) + if err != nil { + logger.SysError("error marshalling model ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateGroupRatioByJSONString(jsonStr string) error { + groupRatioLock.Lock() + defer groupRatioLock.Unlock() + GroupRatio = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &GroupRatio) +} + +func GetGroupRatio(name string) float64 { + groupRatioLock.RLock() + defer groupRatioLock.RUnlock() + ratio, ok := GroupRatio[name] + if !ok { + logger.SysError("group ratio not found: " + name) + return 1 + } + return ratio +} diff --git a/relay/billing/ratio/image.go b/relay/billing/ratio/image.go new file mode 100644 index 0000000..c8c42a1 --- /dev/null +++ b/relay/billing/ratio/image.go @@ -0,0 +1,66 @@ +package ratio + +var ImageSizeRatios = map[string]map[string]float64{ + "dall-e-2": { + "256x256": 1, + "512x512": 1.125, + "1024x1024": 1.25, + }, + "dall-e-3": { + "1024x1024": 1, + "1024x1792": 2, + "1792x1024": 2, + }, + "ali-stable-diffusion-xl": { + "512x1024": 1, + "1024x768": 1, + "1024x1024": 1, + "576x1024": 1, + "1024x576": 1, + }, + "ali-stable-diffusion-v1.5": { + "512x1024": 1, + "1024x768": 1, + "1024x1024": 1, + "576x1024": 1, + "1024x576": 1, + }, + "wanx-v1": { + "1024x1024": 1, + "720x1280": 1, + "1280x720": 1, + }, + "step-1x-medium": { + "256x256": 1, + "512x512": 1, + "768x768": 1, + "1024x1024": 1, + "1280x800": 1, + "800x1280": 1, + }, +} + +var ImageGenerationAmounts = map[string][2]int{ + "dall-e-2": {1, 10}, + "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. + "ali-stable-diffusion-xl": {1, 4}, // Ali + "ali-stable-diffusion-v1.5": {1, 4}, // Ali + "wanx-v1": {1, 4}, // Ali + "cogview-3": {1, 1}, + "step-1x-medium": {1, 1}, +} + +var ImagePromptLengthLimitations = map[string]int{ + "dall-e-2": 1000, + "dall-e-3": 4000, + "ali-stable-diffusion-xl": 4000, + "ali-stable-diffusion-v1.5": 4000, + "wanx-v1": 4000, + "cogview-3": 833, + "step-1x-medium": 4000, +} + +var ImageOriginModelName = map[string]string{ + "ali-stable-diffusion-xl": "stable-diffusion-xl", + "ali-stable-diffusion-v1.5": "stable-diffusion-v1.5", +} diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go new file mode 100644 index 0000000..e8b3b61 --- /dev/null +++ b/relay/billing/ratio/model.go @@ -0,0 +1,835 @@ +package ratio + +import ( + "encoding/json" + "fmt" + "strings" + "sync" + + "github.com/songquanpeng/one-api/common/logger" +) + +const ( + USD2RMB = 7 + USD = 500 // $0.002 = 1 -> $1 = 500 + MILLI_USD = 1.0 / 1000 * USD + RMB = USD / USD2RMB +) + +var modelRatioLock sync.RWMutex + +// ModelRatio +// https://platform.openai.com/docs/models/model-endpoint-compatibility +// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf +// https://openai.com/pricing +// 1 === $0.002 / 1K tokens +// 1 === ¥0.014 / 1k tokens +var ModelRatio = map[string]float64{ + // https://openai.com/pricing + "gpt-4": 15, + "gpt-4-0314": 15, + "gpt-4-0613": 15, + "gpt-4-32k": 30, + "gpt-4-32k-0314": 30, + "gpt-4-32k-0613": 30, + "gpt-4-1106-preview": 5, // $0.01 / 1K tokens + "gpt-4-0125-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo": 5, // $0.01 / 1K tokens + "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens + "gpt-4o": 2.5, // $0.005 / 1K tokens + "chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens + "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens + "gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens + "gpt-4o-2024-11-20": 1.25, // $0.0025 / 1K tokens + "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens + "gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens + "gpt-4-vision-preview": 5, // $0.01 / 1K tokens + "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens + "gpt-3.5-turbo-0301": 0.75, + "gpt-3.5-turbo-0613": 0.75, + "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens + "gpt-3.5-turbo-16k-0613": 1.5, + "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens + "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens + "o1": 7.5, // $15.00 / 1M input tokens + "o1-2024-12-17": 7.5, + "o1-preview": 7.5, // $15.00 / 1M input tokens + "o1-preview-2024-09-12": 7.5, + "o1-mini": 1.5, // $3.00 / 1M input tokens + "o1-mini-2024-09-12": 1.5, + "o3-mini": 1.5, // $3.00 / 1M input tokens + "o3-mini-2025-01-31": 1.5, + "davinci-002": 1, // $0.002 / 1K tokens + "babbage-002": 0.2, // $0.0004 / 1K tokens + "text-ada-001": 0.2, + "text-babbage-001": 0.25, + "text-curie-001": 1, + "text-davinci-002": 10, + "text-davinci-003": 10, + "text-davinci-edit-001": 10, + "code-davinci-edit-001": 10, + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": 7.5, // $0.015 / 1K characters + "tts-1-1106": 7.5, + "tts-1-hd": 15, // $0.030 / 1K characters + "tts-1-hd-1106": 15, + "davinci": 10, + "curie": 10, + "babbage": 10, + "ada": 10, + "text-embedding-ada-002": 0.05, + "text-embedding-3-small": 0.01, + "text-embedding-3-large": 0.065, + "text-search-ada-doc-001": 10, + "text-moderation-stable": 0.1, + "text-moderation-latest": 0.1, + "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image + "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image + // https://docs.anthropic.com/en/docs/about-claude/models + "claude-instant-1.2": 0.8 / 1000 * USD, + "claude-2.0": 8.0 / 1000 * USD, + "claude-2.1": 8.0 / 1000 * USD, + "claude-3-haiku-20240307": 0.25 / 1000 * USD, + "claude-3-5-haiku-20241022": 1.0 / 1000 * USD, + "claude-3-5-haiku-latest": 1.0 / 1000 * USD, + "claude-3-sonnet-20240229": 3.0 / 1000 * USD, + "claude-3-5-sonnet-20240620": 3.0 / 1000 * USD, + "claude-3-5-sonnet-20241022": 3.0 / 1000 * USD, + "claude-3-5-sonnet-latest": 3.0 / 1000 * USD, + "claude-3-opus-20240229": 15.0 / 1000 * USD, + // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 + "ERNIE-4.0-8K": 0.120 * RMB, + "ERNIE-3.5-8K": 0.012 * RMB, + "ERNIE-3.5-8K-0205": 0.024 * RMB, + "ERNIE-3.5-8K-1222": 0.012 * RMB, + "ERNIE-Bot-8K": 0.024 * RMB, + "ERNIE-3.5-4K-0205": 0.012 * RMB, + "ERNIE-Speed-8K": 0.004 * RMB, + "ERNIE-Speed-128K": 0.004 * RMB, + "ERNIE-Lite-8K-0922": 0.008 * RMB, + "ERNIE-Lite-8K-0308": 0.003 * RMB, + "ERNIE-Tiny-8K": 0.001 * RMB, + "BLOOMZ-7B": 0.004 * RMB, + "Embedding-V1": 0.002 * RMB, + "bge-large-zh": 0.002 * RMB, + "bge-large-en": 0.002 * RMB, + "tao-8k": 0.002 * RMB, + // https://ai.google.dev/pricing + // https://cloud.google.com/vertex-ai/generative-ai/pricing + // "gemma-2-2b-it": 0, + // "gemma-2-9b-it": 0, + // "gemma-2-27b-it": 0, + "gemini-pro": 0.25 * MILLI_USD, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-1.0-pro": 0.125 * MILLI_USD, + "gemini-1.5-pro": 1.25 * MILLI_USD, + "gemini-1.5-pro-001": 1.25 * MILLI_USD, + "gemini-1.5-pro-experimental": 1.25 * MILLI_USD, + "gemini-1.5-flash": 0.075 * MILLI_USD, + "gemini-1.5-flash-001": 0.075 * MILLI_USD, + "gemini-1.5-flash-8b": 0.0375 * MILLI_USD, + "gemini-2.0-flash-exp": 0.075 * MILLI_USD, + "gemini-2.0-flash": 0.15 * MILLI_USD, + "gemini-2.0-flash-001": 0.15 * MILLI_USD, + "gemini-2.0-flash-lite-preview-02-05": 0.075 * MILLI_USD, + "gemini-2.0-flash-thinking-exp-01-21": 0.075 * MILLI_USD, + "gemini-2.0-pro-exp-02-05": 1.25 * MILLI_USD, + "aqa": 1, + // https://open.bigmodel.cn/pricing + "glm-zero-preview": 0.01 * RMB, + "glm-4-plus": 0.05 * RMB, + "glm-4-0520": 0.1 * RMB, + "glm-4-airx": 0.01 * RMB, + "glm-4-air": 0.0005 * RMB, + "glm-4-long": 0.001 * RMB, + "glm-4-flashx": 0.0001 * RMB, + "glm-4-flash": 0, + "glm-4": 0.1 * RMB, // deprecated model, available until 2025/06 + "glm-3-turbo": 0.001 * RMB, // deprecated model, available until 2025/06 + "glm-4v-plus": 0.004 * RMB, + "glm-4v": 0.05 * RMB, + "glm-4v-flash": 0, + "cogview-3-plus": 0.06 * RMB, + "cogview-3": 0.1 * RMB, + "cogview-3-flash": 0, + "cogviewx": 0.5 * RMB, + "cogviewx-flash": 0, + "charglm-4": 0.001 * RMB, + "emohaa": 0.015 * RMB, + "codegeex-4": 0.0001 * RMB, + "embedding-2": 0.0005 * RMB, + "embedding-3": 0.0005 * RMB, + // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing + "qwen-turbo": 0.0003 * RMB, + "qwen-turbo-latest": 0.0003 * RMB, + "qwen-plus": 0.0008 * RMB, + "qwen-plus-latest": 0.0008 * RMB, + "qwen-max": 0.0024 * RMB, + "qwen-max-latest": 0.0024 * RMB, + "qwen-max-longcontext": 0.0005 * RMB, + "qwen-vl-max": 0.003 * RMB, + "qwen-vl-max-latest": 0.003 * RMB, + "qwen-vl-plus": 0.0015 * RMB, + "qwen-vl-plus-latest": 0.0015 * RMB, + "qwen-vl-ocr": 0.005 * RMB, + "qwen-vl-ocr-latest": 0.005 * RMB, + "qwen-audio-turbo": 1.4286, + "qwen-math-plus": 0.004 * RMB, + "qwen-math-plus-latest": 0.004 * RMB, + "qwen-math-turbo": 0.002 * RMB, + "qwen-math-turbo-latest": 0.002 * RMB, + "qwen-coder-plus": 0.0035 * RMB, + "qwen-coder-plus-latest": 0.0035 * RMB, + "qwen-coder-turbo": 0.002 * RMB, + "qwen-coder-turbo-latest": 0.002 * RMB, + "qwen-mt-plus": 0.015 * RMB, + "qwen-mt-turbo": 0.001 * RMB, + "qwq-32b-preview": 0.002 * RMB, + "qwen2.5-72b-instruct": 0.004 * RMB, + "qwen2.5-32b-instruct": 0.03 * RMB, + "qwen2.5-14b-instruct": 0.001 * RMB, + "qwen2.5-7b-instruct": 0.0005 * RMB, + "qwen2.5-3b-instruct": 0.006 * RMB, + "qwen2.5-1.5b-instruct": 0.0003 * RMB, + "qwen2.5-0.5b-instruct": 0.0003 * RMB, + "qwen2-72b-instruct": 0.004 * RMB, + "qwen2-57b-a14b-instruct": 0.0035 * RMB, + "qwen2-7b-instruct": 0.001 * RMB, + "qwen2-1.5b-instruct": 0.001 * RMB, + "qwen2-0.5b-instruct": 0.001 * RMB, + "qwen1.5-110b-chat": 0.007 * RMB, + "qwen1.5-72b-chat": 0.005 * RMB, + "qwen1.5-32b-chat": 0.0035 * RMB, + "qwen1.5-14b-chat": 0.002 * RMB, + "qwen1.5-7b-chat": 0.001 * RMB, + "qwen1.5-1.8b-chat": 0.001 * RMB, + "qwen1.5-0.5b-chat": 0.001 * RMB, + "qwen-72b-chat": 0.02 * RMB, + "qwen-14b-chat": 0.008 * RMB, + "qwen-7b-chat": 0.006 * RMB, + "qwen-1.8b-chat": 0.006 * RMB, + "qwen-1.8b-longcontext-chat": 0.006 * RMB, + "qvq-72b-preview": 0.012 * RMB, + "qwen2.5-vl-72b-instruct": 0.016 * RMB, + "qwen2.5-vl-7b-instruct": 0.002 * RMB, + "qwen2.5-vl-3b-instruct": 0.0012 * RMB, + "qwen2-vl-7b-instruct": 0.016 * RMB, + "qwen2-vl-2b-instruct": 0.002 * RMB, + "qwen-vl-v1": 0.002 * RMB, + "qwen-vl-chat-v1": 0.002 * RMB, + "qwen2-audio-instruct": 0.002 * RMB, + "qwen-audio-chat": 0.002 * RMB, + "qwen2.5-math-72b-instruct": 0.004 * RMB, + "qwen2.5-math-7b-instruct": 0.001 * RMB, + "qwen2.5-math-1.5b-instruct": 0.001 * RMB, + "qwen2-math-72b-instruct": 0.004 * RMB, + "qwen2-math-7b-instruct": 0.001 * RMB, + "qwen2-math-1.5b-instruct": 0.001 * RMB, + "qwen2.5-coder-32b-instruct": 0.002 * RMB, + "qwen2.5-coder-14b-instruct": 0.002 * RMB, + "qwen2.5-coder-7b-instruct": 0.001 * RMB, + "qwen2.5-coder-3b-instruct": 0.001 * RMB, + "qwen2.5-coder-1.5b-instruct": 0.001 * RMB, + "qwen2.5-coder-0.5b-instruct": 0.001 * RMB, + "text-embedding-v1": 0.0007 * RMB, // ¥0.0007 / 1k tokens + "text-embedding-v3": 0.0007 * RMB, + "text-embedding-v2": 0.0007 * RMB, + "text-embedding-async-v2": 0.0007 * RMB, + "text-embedding-async-v1": 0.0007 * RMB, + "ali-stable-diffusion-xl": 8.00, + "ali-stable-diffusion-v1.5": 8.00, + "wanx-v1": 8.00, + "deepseek-r1": 0.002 * RMB, + "deepseek-v3": 0.001 * RMB, + "deepseek-r1-distill-qwen-1.5b": 0.001 * RMB, + "deepseek-r1-distill-qwen-7b": 0.0005 * RMB, + "deepseek-r1-distill-qwen-14b": 0.001 * RMB, + "deepseek-r1-distill-qwen-32b": 0.002 * RMB, + "deepseek-r1-distill-llama-8b": 0.0005 * RMB, + "deepseek-r1-distill-llama-70b": 0.004 * RMB, + "SparkDesk": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.5-32K": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens + "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens + "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens + "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens + "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens + // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 + "hunyuan-turbo": 0.015 * RMB, + "hunyuan-large": 0.004 * RMB, + "hunyuan-large-longcontext": 0.006 * RMB, + "hunyuan-standard": 0.0008 * RMB, + "hunyuan-standard-256K": 0.0005 * RMB, + "hunyuan-translation-lite": 0.005 * RMB, + "hunyuan-role": 0.004 * RMB, + "hunyuan-functioncall": 0.004 * RMB, + "hunyuan-code": 0.004 * RMB, + "hunyuan-turbo-vision": 0.08 * RMB, + "hunyuan-vision": 0.018 * RMB, + "hunyuan-embedding": 0.0007 * RMB, + // https://platform.moonshot.cn/pricing + "moonshot-v1-8k": 0.012 * RMB, + "moonshot-v1-32k": 0.024 * RMB, + "moonshot-v1-128k": 0.06 * RMB, + // https://platform.baichuan-ai.com/price + "Baichuan2-Turbo": 0.008 * RMB, + "Baichuan2-Turbo-192k": 0.016 * RMB, + "Baichuan2-53B": 0.02 * RMB, + // https://api.minimax.chat/document/price + "abab6.5-chat": 0.03 * RMB, + "abab6.5s-chat": 0.01 * RMB, + "abab6-chat": 0.1 * RMB, + "abab5.5-chat": 0.015 * RMB, + "abab5.5s-chat": 0.005 * RMB, + // https://docs.mistral.ai/platform/pricing/ + "open-mistral-7b": 0.25 / 1000 * USD, + "open-mixtral-8x7b": 0.7 / 1000 * USD, + "mistral-small-latest": 2.0 / 1000 * USD, + "mistral-medium-latest": 2.7 / 1000 * USD, + "mistral-large-latest": 8.0 / 1000 * USD, + "mistral-embed": 0.1 / 1000 * USD, + // https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed + "gemma-7b-it": 0.07 / 1000000 * USD, + "gemma2-9b-it": 0.20 / 1000000 * USD, + "llama-3.1-70b-versatile": 0.59 / 1000000 * USD, + "llama-3.1-8b-instant": 0.05 / 1000000 * USD, + "llama-3.2-11b-text-preview": 0.05 / 1000000 * USD, + "llama-3.2-11b-vision-preview": 0.05 / 1000000 * USD, + "llama-3.2-1b-preview": 0.05 / 1000000 * USD, + "llama-3.2-3b-preview": 0.05 / 1000000 * USD, + "llama-3.2-90b-text-preview": 0.59 / 1000000 * USD, + "llama-guard-3-8b": 0.05 / 1000000 * USD, + "llama3-70b-8192": 0.59 / 1000000 * USD, + "llama3-8b-8192": 0.05 / 1000000 * USD, + "llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD, + "llama3-groq-8b-8192-tool-use-preview": 0.19 / 1000000 * USD, + "mixtral-8x7b-32768": 0.24 / 1000000 * USD, + + // https://platform.lingyiwanwu.com/docs#-计费单元 + "yi-34b-chat-0205": 2.5 / 1000 * RMB, + "yi-34b-chat-200k": 12.0 / 1000 * RMB, + "yi-vl-plus": 6.0 / 1000 * RMB, + // https://platform.stepfun.com/docs/pricing/details + "step-1-8k": 0.005 / 1000 * RMB, + "step-1-32k": 0.015 / 1000 * RMB, + "step-1-128k": 0.040 / 1000 * RMB, + "step-1-256k": 0.095 / 1000 * RMB, + "step-1-flash": 0.001 / 1000 * RMB, + "step-2-16k": 0.038 / 1000 * RMB, + "step-1v-8k": 0.005 / 1000 * RMB, + "step-1v-32k": 0.015 / 1000 * RMB, + // aws llama3 https://aws.amazon.com/cn/bedrock/pricing/ + "llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens + "llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens + // https://cohere.com/pricing + "command": 0.5, + "command-nightly": 0.5, + "command-light": 0.5, + "command-light-nightly": 0.5, + "command-r": 0.5 / 1000 * USD, + "command-r-plus": 3.0 / 1000 * USD, + // https://platform.deepseek.com/api-docs/pricing/ + "deepseek-chat": 0.14 * MILLI_USD, + "deepseek-reasoner": 0.55 * MILLI_USD, + // https://www.deepl.com/pro?cta=header-prices + "deepl-zh": 25.0 / 1000 * USD, + "deepl-en": 25.0 / 1000 * USD, + "deepl-ja": 25.0 / 1000 * USD, + // https://console.x.ai/ + "grok-beta": 5.0 / 1000 * USD, + // replicate charges based on the number of generated images + // https://replicate.com/pricing + "black-forest-labs/flux-1.1-pro": 0.04 * USD, + "black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD, + "black-forest-labs/flux-canny-dev": 0.025 * USD, + "black-forest-labs/flux-canny-pro": 0.05 * USD, + "black-forest-labs/flux-depth-dev": 0.025 * USD, + "black-forest-labs/flux-depth-pro": 0.05 * USD, + "black-forest-labs/flux-dev": 0.025 * USD, + "black-forest-labs/flux-dev-lora": 0.032 * USD, + "black-forest-labs/flux-fill-dev": 0.04 * USD, + "black-forest-labs/flux-fill-pro": 0.05 * USD, + "black-forest-labs/flux-pro": 0.055 * USD, + "black-forest-labs/flux-redux-dev": 0.025 * USD, + "black-forest-labs/flux-redux-schnell": 0.003 * USD, + "black-forest-labs/flux-schnell": 0.003 * USD, + "black-forest-labs/flux-schnell-lora": 0.02 * USD, + "ideogram-ai/ideogram-v2": 0.08 * USD, + "ideogram-ai/ideogram-v2-turbo": 0.05 * USD, + "recraft-ai/recraft-v3": 0.04 * USD, + "recraft-ai/recraft-v3-svg": 0.08 * USD, + "stability-ai/stable-diffusion-3": 0.035 * USD, + "stability-ai/stable-diffusion-3.5-large": 0.065 * USD, + "stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD, + "stability-ai/stable-diffusion-3.5-medium": 0.035 * USD, + // replicate chat models + "ibm-granite/granite-20b-code-instruct-8k": 0.100 * USD, + "ibm-granite/granite-3.0-2b-instruct": 0.030 * USD, + "ibm-granite/granite-3.0-8b-instruct": 0.050 * USD, + "ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD, + "meta/llama-2-13b": 0.100 * USD, + "meta/llama-2-13b-chat": 0.100 * USD, + "meta/llama-2-70b": 0.650 * USD, + "meta/llama-2-70b-chat": 0.650 * USD, + "meta/llama-2-7b": 0.050 * USD, + "meta/llama-2-7b-chat": 0.050 * USD, + "meta/meta-llama-3.1-405b-instruct": 9.500 * USD, + "meta/meta-llama-3-70b": 0.650 * USD, + "meta/meta-llama-3-70b-instruct": 0.650 * USD, + "meta/meta-llama-3-8b": 0.050 * USD, + "meta/meta-llama-3-8b-instruct": 0.050 * USD, + "mistralai/mistral-7b-instruct-v0.2": 0.050 * USD, + "mistralai/mistral-7b-v0.1": 0.050 * USD, + "mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD, + //https://openrouter.ai/models + "01-ai/yi-large": 1.5, + "aetherwiing/mn-starcannon-12b": 0.6, + "ai21/jamba-1-5-large": 4.0, + "ai21/jamba-1-5-mini": 0.2, + "ai21/jamba-instruct": 0.35, + "aion-labs/aion-1.0": 6.0, + "aion-labs/aion-1.0-mini": 1.2, + "aion-labs/aion-rp-llama-3.1-8b": 0.1, + "allenai/llama-3.1-tulu-3-405b": 5.0, + "alpindale/goliath-120b": 4.6875, + "alpindale/magnum-72b": 1.125, + "amazon/nova-lite-v1": 0.12, + "amazon/nova-micro-v1": 0.07, + "amazon/nova-pro-v1": 1.6, + "anthracite-org/magnum-v2-72b": 1.5, + "anthracite-org/magnum-v4-72b": 1.125, + "anthropic/claude-2": 12.0, + "anthropic/claude-2.0": 12.0, + "anthropic/claude-2.0:beta": 12.0, + "anthropic/claude-2.1": 12.0, + "anthropic/claude-2.1:beta": 12.0, + "anthropic/claude-2:beta": 12.0, + "anthropic/claude-3-haiku": 0.625, + "anthropic/claude-3-haiku:beta": 0.625, + "anthropic/claude-3-opus": 37.5, + "anthropic/claude-3-opus:beta": 37.5, + "anthropic/claude-3-sonnet": 7.5, + "anthropic/claude-3-sonnet:beta": 7.5, + "anthropic/claude-3.5-haiku": 2.0, + "anthropic/claude-3.5-haiku-20241022": 2.0, + "anthropic/claude-3.5-haiku-20241022:beta": 2.0, + "anthropic/claude-3.5-haiku:beta": 2.0, + "anthropic/claude-3.5-sonnet": 7.5, + "anthropic/claude-3.5-sonnet-20240620": 7.5, + "anthropic/claude-3.5-sonnet-20240620:beta": 7.5, + "anthropic/claude-3.5-sonnet:beta": 7.5, + "cognitivecomputations/dolphin-mixtral-8x22b": 0.45, + "cognitivecomputations/dolphin-mixtral-8x7b": 0.25, + "cohere/command": 0.95, + "cohere/command-r": 0.7125, + "cohere/command-r-03-2024": 0.7125, + "cohere/command-r-08-2024": 0.285, + "cohere/command-r-plus": 7.125, + "cohere/command-r-plus-04-2024": 7.125, + "cohere/command-r-plus-08-2024": 4.75, + "cohere/command-r7b-12-2024": 0.075, + "databricks/dbrx-instruct": 0.6, + "deepseek/deepseek-chat": 0.445, + "deepseek/deepseek-chat-v2.5": 1.0, + "deepseek/deepseek-chat:free": 0.0, + "deepseek/deepseek-r1": 1.2, + "deepseek/deepseek-r1-distill-llama-70b": 0.345, + "deepseek/deepseek-r1-distill-llama-70b:free": 0.0, + "deepseek/deepseek-r1-distill-llama-8b": 0.02, + "deepseek/deepseek-r1-distill-qwen-1.5b": 0.09, + "deepseek/deepseek-r1-distill-qwen-14b": 0.075, + "deepseek/deepseek-r1-distill-qwen-32b": 0.09, + "deepseek/deepseek-r1:free": 0.0, + "eva-unit-01/eva-llama-3.33-70b": 3.0, + "eva-unit-01/eva-qwen-2.5-32b": 1.7, + "eva-unit-01/eva-qwen-2.5-72b": 3.0, + "google/gemini-2.0-flash-001": 0.2, + "google/gemini-2.0-flash-exp:free": 0.0, + "google/gemini-2.0-flash-lite-preview-02-05:free": 0.0, + "google/gemini-2.0-flash-thinking-exp-1219:free": 0.0, + "google/gemini-2.0-flash-thinking-exp:free": 0.0, + "google/gemini-2.0-pro-exp-02-05:free": 0.0, + "google/gemini-exp-1206:free": 0.0, + "google/gemini-flash-1.5": 0.15, + "google/gemini-flash-1.5-8b": 0.075, + "google/gemini-flash-1.5-8b-exp": 0.0, + "google/gemini-pro": 0.75, + "google/gemini-pro-1.5": 2.5, + "google/gemini-pro-vision": 0.75, + "google/gemma-2-27b-it": 0.135, + "google/gemma-2-9b-it": 0.03, + "google/gemma-2-9b-it:free": 0.0, + "google/gemma-7b-it": 0.075, + "google/learnlm-1.5-pro-experimental:free": 0.0, + "google/palm-2-chat-bison": 1.0, + "google/palm-2-chat-bison-32k": 1.0, + "google/palm-2-codechat-bison": 1.0, + "google/palm-2-codechat-bison-32k": 1.0, + "gryphe/mythomax-l2-13b": 0.0325, + "gryphe/mythomax-l2-13b:free": 0.0, + "huggingfaceh4/zephyr-7b-beta:free": 0.0, + "infermatic/mn-inferor-12b": 0.6, + "inflection/inflection-3-pi": 5.0, + "inflection/inflection-3-productivity": 5.0, + "jondurbin/airoboros-l2-70b": 0.25, + "liquid/lfm-3b": 0.01, + "liquid/lfm-40b": 0.075, + "liquid/lfm-7b": 0.005, + "mancer/weaver": 1.125, + "meta-llama/llama-2-13b-chat": 0.11, + "meta-llama/llama-2-70b-chat": 0.45, + "meta-llama/llama-3-70b-instruct": 0.2, + "meta-llama/llama-3-8b-instruct": 0.03, + "meta-llama/llama-3-8b-instruct:free": 0.0, + "meta-llama/llama-3.1-405b": 1.0, + "meta-llama/llama-3.1-405b-instruct": 0.4, + "meta-llama/llama-3.1-70b-instruct": 0.15, + "meta-llama/llama-3.1-8b-instruct": 0.025, + "meta-llama/llama-3.2-11b-vision-instruct": 0.0275, + "meta-llama/llama-3.2-11b-vision-instruct:free": 0.0, + "meta-llama/llama-3.2-1b-instruct": 0.005, + "meta-llama/llama-3.2-3b-instruct": 0.0125, + "meta-llama/llama-3.2-90b-vision-instruct": 0.8, + "meta-llama/llama-3.3-70b-instruct": 0.15, + "meta-llama/llama-3.3-70b-instruct:free": 0.0, + "meta-llama/llama-guard-2-8b": 0.1, + "microsoft/phi-3-medium-128k-instruct": 0.5, + "microsoft/phi-3-medium-128k-instruct:free": 0.0, + "microsoft/phi-3-mini-128k-instruct": 0.05, + "microsoft/phi-3-mini-128k-instruct:free": 0.0, + "microsoft/phi-3.5-mini-128k-instruct": 0.05, + "microsoft/phi-4": 0.07, + "microsoft/wizardlm-2-7b": 0.035, + "microsoft/wizardlm-2-8x22b": 0.25, + "minimax/minimax-01": 0.55, + "mistralai/codestral-2501": 0.45, + "mistralai/codestral-mamba": 0.125, + "mistralai/ministral-3b": 0.02, + "mistralai/ministral-8b": 0.05, + "mistralai/mistral-7b-instruct": 0.0275, + "mistralai/mistral-7b-instruct-v0.1": 0.1, + "mistralai/mistral-7b-instruct-v0.3": 0.0275, + "mistralai/mistral-7b-instruct:free": 0.0, + "mistralai/mistral-large": 3.0, + "mistralai/mistral-large-2407": 3.0, + "mistralai/mistral-large-2411": 3.0, + "mistralai/mistral-medium": 4.05, + "mistralai/mistral-nemo": 0.04, + "mistralai/mistral-nemo:free": 0.0, + "mistralai/mistral-small": 0.3, + "mistralai/mistral-small-24b-instruct-2501": 0.07, + "mistralai/mistral-small-24b-instruct-2501:free": 0.0, + "mistralai/mistral-tiny": 0.125, + "mistralai/mixtral-8x22b-instruct": 0.45, + "mistralai/mixtral-8x7b": 0.3, + "mistralai/mixtral-8x7b-instruct": 0.12, + "mistralai/pixtral-12b": 0.05, + "mistralai/pixtral-large-2411": 3.0, + "neversleep/llama-3-lumimaid-70b": 2.25, + "neversleep/llama-3-lumimaid-8b": 0.5625, + "neversleep/llama-3-lumimaid-8b:extended": 0.5625, + "neversleep/llama-3.1-lumimaid-70b": 2.25, + "neversleep/llama-3.1-lumimaid-8b": 0.5625, + "neversleep/noromaid-20b": 1.125, + "nothingiisreal/mn-celeste-12b": 0.6, + "nousresearch/hermes-2-pro-llama-3-8b": 0.02, + "nousresearch/hermes-3-llama-3.1-405b": 0.4, + "nousresearch/hermes-3-llama-3.1-70b": 0.15, + "nousresearch/nous-hermes-2-mixtral-8x7b-dpo": 0.3, + "nousresearch/nous-hermes-llama2-13b": 0.085, + "nvidia/llama-3.1-nemotron-70b-instruct": 0.15, + "nvidia/llama-3.1-nemotron-70b-instruct:free": 0.0, + "openai/chatgpt-4o-latest": 7.5, + "openai/gpt-3.5-turbo": 0.75, + "openai/gpt-3.5-turbo-0125": 0.75, + "openai/gpt-3.5-turbo-0613": 1.0, + "openai/gpt-3.5-turbo-1106": 1.0, + "openai/gpt-3.5-turbo-16k": 2.0, + "openai/gpt-3.5-turbo-instruct": 1.0, + "openai/gpt-4": 30.0, + "openai/gpt-4-0314": 30.0, + "openai/gpt-4-1106-preview": 15.0, + "openai/gpt-4-32k": 60.0, + "openai/gpt-4-32k-0314": 60.0, + "openai/gpt-4-turbo": 15.0, + "openai/gpt-4-turbo-preview": 15.0, + "openai/gpt-4o": 5.0, + "openai/gpt-4o-2024-05-13": 7.5, + "openai/gpt-4o-2024-08-06": 5.0, + "openai/gpt-4o-2024-11-20": 5.0, + "openai/gpt-4o-mini": 0.3, + "openai/gpt-4o-mini-2024-07-18": 0.3, + "openai/gpt-4o:extended": 9.0, + "openai/o1": 30.0, + "openai/o1-mini": 2.2, + "openai/o1-mini-2024-09-12": 2.2, + "openai/o1-preview": 30.0, + "openai/o1-preview-2024-09-12": 30.0, + "openai/o3-mini": 2.2, + "openai/o3-mini-high": 2.2, + "openchat/openchat-7b": 0.0275, + "openchat/openchat-7b:free": 0.0, + "openrouter/auto": -500000.0, + "perplexity/llama-3.1-sonar-huge-128k-online": 2.5, + "perplexity/llama-3.1-sonar-large-128k-chat": 0.5, + "perplexity/llama-3.1-sonar-large-128k-online": 0.5, + "perplexity/llama-3.1-sonar-small-128k-chat": 0.1, + "perplexity/llama-3.1-sonar-small-128k-online": 0.1, + "perplexity/sonar": 0.5, + "perplexity/sonar-reasoning": 2.5, + "pygmalionai/mythalion-13b": 0.6, + "qwen/qvq-72b-preview": 0.25, + "qwen/qwen-2-72b-instruct": 0.45, + "qwen/qwen-2-7b-instruct": 0.027, + "qwen/qwen-2-7b-instruct:free": 0.0, + "qwen/qwen-2-vl-72b-instruct": 0.2, + "qwen/qwen-2-vl-7b-instruct": 0.05, + "qwen/qwen-2.5-72b-instruct": 0.2, + "qwen/qwen-2.5-7b-instruct": 0.025, + "qwen/qwen-2.5-coder-32b-instruct": 0.08, + "qwen/qwen-max": 3.2, + "qwen/qwen-plus": 0.6, + "qwen/qwen-turbo": 0.1, + "qwen/qwen-vl-plus:free": 0.0, + "qwen/qwen2.5-vl-72b-instruct:free": 0.0, + "qwen/qwq-32b-preview": 0.09, + "raifle/sorcererlm-8x22b": 2.25, + "sao10k/fimbulvetr-11b-v2": 0.6, + "sao10k/l3-euryale-70b": 0.4, + "sao10k/l3-lunaris-8b": 0.03, + "sao10k/l3.1-70b-hanami-x1": 1.5, + "sao10k/l3.1-euryale-70b": 0.4, + "sao10k/l3.3-euryale-70b": 0.4, + "sophosympatheia/midnight-rose-70b": 0.4, + "sophosympatheia/rogue-rose-103b-v0.2:free": 0.0, + "teknium/openhermes-2.5-mistral-7b": 0.085, + "thedrummer/rocinante-12b": 0.25, + "thedrummer/unslopnemo-12b": 0.25, + "undi95/remm-slerp-l2-13b": 0.6, + "undi95/toppy-m-7b": 0.035, + "undi95/toppy-m-7b:free": 0.0, + "x-ai/grok-2-1212": 5.0, + "x-ai/grok-2-vision-1212": 5.0, + "x-ai/grok-beta": 7.5, + "x-ai/grok-vision-beta": 7.5, + "xwin-lm/xwin-lm-70b": 1.875, +} + +var CompletionRatio = map[string]float64{ + // aws llama3 + "llama3-8b-8192(33)": 0.0006 / 0.0003, + "llama3-70b-8192(33)": 0.0035 / 0.00265, + // whisper + "whisper-1": 0, // only count input tokens + // deepseek + "deepseek-chat": 0.28 / 0.14, + "deepseek-reasoner": 2.19 / 0.55, +} + +var ( + DefaultModelRatio map[string]float64 + DefaultCompletionRatio map[string]float64 +) + +func init() { + DefaultModelRatio = make(map[string]float64) + for k, v := range ModelRatio { + DefaultModelRatio[k] = v + } + DefaultCompletionRatio = make(map[string]float64) + for k, v := range CompletionRatio { + DefaultCompletionRatio[k] = v + } +} + +func AddNewMissingRatio(oldRatio string) string { + newRatio := make(map[string]float64) + err := json.Unmarshal([]byte(oldRatio), &newRatio) + if err != nil { + logger.SysError("error unmarshalling old ratio: " + err.Error()) + return oldRatio + } + for k, v := range DefaultModelRatio { + if _, ok := newRatio[k]; !ok { + newRatio[k] = v + } + } + jsonBytes, err := json.Marshal(newRatio) + if err != nil { + logger.SysError("error marshalling new ratio: " + err.Error()) + return oldRatio + } + return string(jsonBytes) +} + +func ModelRatio2JSONString() string { + jsonBytes, err := json.Marshal(ModelRatio) + if err != nil { + logger.SysError("error marshalling model ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateModelRatioByJSONString(jsonStr string) error { + modelRatioLock.Lock() + defer modelRatioLock.Unlock() + ModelRatio = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &ModelRatio) +} + +func GetModelRatio(name string, channelType int) float64 { + modelRatioLock.RLock() + defer modelRatioLock.RUnlock() + if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { + name = strings.TrimSuffix(name, "-internet") + } + if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { + name = strings.TrimSuffix(name, "-internet") + } + model := fmt.Sprintf("%s(%d)", name, channelType) + if ratio, ok := ModelRatio[model]; ok { + return ratio + } + if ratio, ok := DefaultModelRatio[model]; ok { + return ratio + } + if ratio, ok := ModelRatio[name]; ok { + return ratio + } + if ratio, ok := DefaultModelRatio[name]; ok { + return ratio + } + logger.SysError("model ratio not found: " + name) + return 30 +} + +func CompletionRatio2JSONString() string { + jsonBytes, err := json.Marshal(CompletionRatio) + if err != nil { + logger.SysError("error marshalling completion ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateCompletionRatioByJSONString(jsonStr string) error { + CompletionRatio = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &CompletionRatio) +} + +func GetCompletionRatio(name string, channelType int) float64 { + if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { + name = strings.TrimSuffix(name, "-internet") + } + model := fmt.Sprintf("%s(%d)", name, channelType) + if ratio, ok := CompletionRatio[model]; ok { + return ratio + } + if ratio, ok := DefaultCompletionRatio[model]; ok { + return ratio + } + if ratio, ok := CompletionRatio[name]; ok { + return ratio + } + if ratio, ok := DefaultCompletionRatio[name]; ok { + return ratio + } + if strings.HasPrefix(name, "gpt-3.5") { + if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { + // https://openai.com/blog/new-embedding-models-and-api-updates + // Updated GPT-3.5 Turbo model and lower pricing + return 3 + } + if strings.HasSuffix(name, "1106") { + return 2 + } + return 4.0 / 3.0 + } + if strings.HasPrefix(name, "gpt-4") { + if strings.HasPrefix(name, "gpt-4o") { + if name == "gpt-4o-2024-05-13" { + return 3 + } + return 4 + } + if strings.HasPrefix(name, "gpt-4-turbo") || + strings.HasSuffix(name, "preview") { + return 3 + } + return 2 + } + // including o1, o1-preview, o1-mini + if strings.HasPrefix(name, "o1") { + return 4 + } + if name == "chatgpt-4o-latest" { + return 3 + } + if strings.HasPrefix(name, "claude-3") { + return 5 + } + if strings.HasPrefix(name, "claude-") { + return 3 + } + if strings.HasPrefix(name, "mistral-") { + return 3 + } + if strings.HasPrefix(name, "gemini-") { + return 3 + } + if strings.HasPrefix(name, "deepseek-") { + return 2 + } + + switch name { + case "llama2-70b-4096": + return 0.8 / 0.64 + case "llama3-8b-8192": + return 2 + case "llama3-70b-8192": + return 0.79 / 0.59 + case "command", "command-light", "command-nightly", "command-light-nightly": + return 2 + case "command-r": + return 3 + case "command-r-plus": + return 5 + case "grok-beta": + return 3 + // Replicate Models + // https://replicate.com/pricing + case "ibm-granite/granite-20b-code-instruct-8k": + return 5 + case "ibm-granite/granite-3.0-2b-instruct": + return 8.333333333333334 + case "ibm-granite/granite-3.0-8b-instruct", + "ibm-granite/granite-8b-code-instruct-128k": + return 5 + case "meta/llama-2-13b", + "meta/llama-2-13b-chat", + "meta/llama-2-7b", + "meta/llama-2-7b-chat", + "meta/meta-llama-3-8b", + "meta/meta-llama-3-8b-instruct": + return 5 + case "meta/llama-2-70b", + "meta/llama-2-70b-chat", + "meta/meta-llama-3-70b", + "meta/meta-llama-3-70b-instruct": + return 2.750 / 0.650 // ≈4.230769 + case "meta/meta-llama-3.1-405b-instruct": + return 1 + case "mistralai/mistral-7b-instruct-v0.2", + "mistralai/mistral-7b-v0.1": + return 5 + case "mistralai/mixtral-8x7b-instruct-v0.1": + return 1.000 / 0.300 // ≈3.333333 + } + + return 1 +} diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go new file mode 100644 index 0000000..f557d6c --- /dev/null +++ b/relay/channeltype/define.go @@ -0,0 +1,57 @@ +package channeltype + +const ( + Unknown = iota + OpenAI + API2D + Azure + CloseAI + OpenAISB + OpenAIMax + OhMyGPT + Custom + Ails + AIProxy + PaLM + API2GPT + AIGC2D + Anthropic + Baidu + Zhipu + Ali + Xunfei + AI360 + OpenRouter + AIProxyLibrary + FastGPT + Tencent + Gemini + Moonshot + Baichuan + Minimax + Mistral + Groq + Ollama + LingYiWanWu + StepFun + AwsClaude + Coze + Cohere + DeepSeek + Cloudflare + DeepL + TogetherAI + Doubao + Novita + VertextAI + Proxy + SiliconFlow + XAI + Replicate + BaiduV2 + XunfeiV2 + AliBailian + OpenAICompatible + GeminiOpenAICompatible + Dummy +) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go new file mode 100644 index 0000000..8839b30 --- /dev/null +++ b/relay/channeltype/helper.go @@ -0,0 +1,47 @@ +package channeltype + +import "github.com/songquanpeng/one-api/relay/apitype" + +func ToAPIType(channelType int) int { + apiType := apitype.OpenAI + switch channelType { + case Anthropic: + apiType = apitype.Anthropic + case Baidu: + apiType = apitype.Baidu + case PaLM: + apiType = apitype.PaLM + case Zhipu: + apiType = apitype.Zhipu + case Ali: + apiType = apitype.Ali + case Xunfei: + apiType = apitype.Xunfei + case AIProxyLibrary: + apiType = apitype.AIProxyLibrary + case Tencent: + apiType = apitype.Tencent + case Gemini: + apiType = apitype.Gemini + case Ollama: + apiType = apitype.Ollama + case AwsClaude: + apiType = apitype.AwsClaude + case Coze: + apiType = apitype.Coze + case Cohere: + apiType = apitype.Cohere + case Cloudflare: + apiType = apitype.Cloudflare + case DeepL: + apiType = apitype.DeepL + case VertextAI: + apiType = apitype.VertexAI + case Replicate: + apiType = apitype.Replicate + case Proxy: + apiType = apitype.Proxy + } + + return apiType +} diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go new file mode 100644 index 0000000..5a47a64 --- /dev/null +++ b/relay/channeltype/url.go @@ -0,0 +1,63 @@ +package channeltype + +var ChannelBaseURLs = []string{ + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "https://generativelanguage.googleapis.com", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 + "https://hunyuan.tencentcloudapi.com", // 23 + "https://generativelanguage.googleapis.com", // 24 + "https://api.moonshot.cn", // 25 + "https://api.baichuan-ai.com", // 26 + "https://api.minimax.chat", // 27 + "https://api.mistral.ai", // 28 + "https://api.groq.com/openai", // 29 + "http://localhost:11434", // 30 + "https://api.lingyiwanwu.com", // 31 + "https://api.stepfun.com", // 32 + "", // 33 + "https://api.coze.com", // 34 + "https://api.cohere.ai", // 35 + "https://api.deepseek.com", // 36 + "https://api.cloudflare.com", // 37 + "https://api-free.deepl.com", // 38 + "https://api.together.xyz", // 39 + "https://ark.cn-beijing.volces.com", // 40 + "https://api.novita.ai/v3/openai", // 41 + "", // 42 + "", // 43 + "https://api.siliconflow.cn", // 44 + "https://api.x.ai", // 45 + "https://api.replicate.com/v1/models/", // 46 + "https://qianfan.baidubce.com", // 47 + "https://spark-api-open.xf-yun.com", // 48 + "https://dashscope.aliyuncs.com", // 49 + "", // 50 + + "https://generativelanguage.googleapis.com/v1beta/openai/", // 51 +} + +func init() { + if len(ChannelBaseURLs) != Dummy { + panic("channel base urls length not match") + } +} diff --git a/relay/channeltype/url_test.go b/relay/channeltype/url_test.go new file mode 100644 index 0000000..2aada27 --- /dev/null +++ b/relay/channeltype/url_test.go @@ -0,0 +1,12 @@ +package channeltype + +import ( + . "github.com/smartystreets/goconvey/convey" + "testing" +) + +func TestChannelBaseURLs(t *testing.T) { + Convey("channel base urls", t, func() { + So(len(ChannelBaseURLs), ShouldEqual, Dummy) + }) +} diff --git a/relay/constant/common.go b/relay/constant/common.go new file mode 100644 index 0000000..f31477c --- /dev/null +++ b/relay/constant/common.go @@ -0,0 +1,5 @@ +package constant + +var StopFinishReason = "stop" +var StreamObject = "chat.completion.chunk" +var NonStreamObject = "chat.completion" diff --git a/relay/constant/finishreason/define.go b/relay/constant/finishreason/define.go new file mode 100644 index 0000000..1ed9c42 --- /dev/null +++ b/relay/constant/finishreason/define.go @@ -0,0 +1,5 @@ +package finishreason + +const ( + Stop = "stop" +) diff --git a/relay/constant/role/define.go b/relay/constant/role/define.go new file mode 100644 index 0000000..5097c97 --- /dev/null +++ b/relay/constant/role/define.go @@ -0,0 +1,6 @@ +package role + +const ( + System = "system" + Assistant = "assistant" +) diff --git a/relay/controller/audio.go b/relay/controller/audio.go new file mode 100644 index 0000000..e3d57b1 --- /dev/null +++ b/relay/controller/audio.go @@ -0,0 +1,281 @@ +package controller + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/client" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" + relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { + ctx := c.Request.Context() + meta := meta.GetByContext(c) + audioModel := "whisper-1" + + tokenId := c.GetInt(ctxkey.TokenId) + channelType := c.GetInt(ctxkey.Channel) + channelId := c.GetInt(ctxkey.ChannelId) + userId := c.GetInt(ctxkey.Id) + group := c.GetString(ctxkey.Group) + tokenName := c.GetString(ctxkey.TokenName) + + var ttsRequest openai.TextToSpeechRequest + if relayMode == relaymode.AudioSpeech { + // Read JSON + err := common.UnmarshalBodyReusable(c, &ttsRequest) + // Check if JSON is valid + if err != nil { + return openai.ErrorWrapper(err, "invalid_json", http.StatusBadRequest) + } + audioModel = ttsRequest.Model + // Check if text is too long 4096 + if len(ttsRequest.Input) > 4096 { + return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) + } + } + + modelRatio := billingratio.GetModelRatio(audioModel, channelType) + groupRatio := billingratio.GetGroupRatio(group) + ratio := modelRatio * groupRatio + var quota int64 + var preConsumedQuota int64 + switch relayMode { + case relaymode.AudioSpeech: + preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) + quota = preConsumedQuota + default: + preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio) + } + userQuota, err := model.CacheGetUserQuota(ctx, userId) + if err != nil { + return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + + // Check if user quota is enough + if userQuota-preConsumedQuota < 0 { + return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if err != nil { + return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + succeed := false + defer func() { + if succeed { + return + } + if preConsumedQuota > 0 { + // we need to roll back the pre-consumed quota + defer func(ctx context.Context) { + go func() { + // negative means add quota back for token & user + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) + } + }() + }(c.Request.Context()) + } + }() + + // map model name + modelMapping := c.GetStringMapString(ctxkey.ModelMapping) + if modelMapping != nil && modelMapping[audioModel] != "" { + audioModel = modelMapping[audioModel] + } + + baseURL := channeltype.ChannelBaseURLs[channelType] + requestURL := c.Request.URL.String() + if c.GetString(ctxkey.BaseURL) != "" { + baseURL = c.GetString(ctxkey.BaseURL) + } + + fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType) + if channelType == channeltype.Azure { + apiVersion := meta.Config.APIVersion + if relayMode == relaymode.AudioTranscription { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) + } else if relayMode == relaymode.AudioSpeech { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion) + } + } + + requestBody := &bytes.Buffer{} + _, err = io.Copy(requestBody, c.Request.Body) + if err != nil { + return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) + responseFormat := c.DefaultPostForm("response_format", "json") + + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + req.Header.Set("api-key", apiKey) + req.ContentLength = c.Request.ContentLength + } else { + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + } + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + + resp, err := client.HTTPClient.Do(req) + if err != nil { + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + err = req.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + err = c.Request.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + + if relayMode != relaymode.AudioSpeech { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + + var openAIErr openai.SlimTextResponse + if err = json.Unmarshal(responseBody, &openAIErr); err == nil { + if openAIErr.Error.Message != "" { + return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) + } + } + + var text string + switch responseFormat { + case "json": + text, err = getTextFromJSON(responseBody) + case "text": + text, err = getTextFromText(responseBody) + case "srt": + text, err = getTextFromSRT(responseBody) + case "verbose_json": + text, err = getTextFromVerboseJSON(responseBody) + case "vtt": + text, err = getTextFromVTT(responseBody) + default: + return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) + } + if err != nil { + return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) + } + quota = int64(openai.CountTokenText(text, audioModel)) + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + } + if resp.StatusCode != http.StatusOK { + return RelayErrorHandler(resp) + } + succeed = true + quotaDelta := quota - preConsumedQuota + defer func(ctx context.Context) { + go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + }(c.Request.Context()) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + return nil +} + +func getTextFromVTT(body []byte) (string, error) { + return getTextFromSRT(body) +} + +func getTextFromVerboseJSON(body []byte) (string, error) { + var whisperResponse openai.WhisperVerboseJSONResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} + +func getTextFromSRT(body []byte) (string, error) { + scanner := bufio.NewScanner(strings.NewReader(string(body))) + var builder strings.Builder + var textLine bool + for scanner.Scan() { + line := scanner.Text() + if textLine { + builder.WriteString(line) + textLine = false + continue + } else if strings.Contains(line, "-->") { + textLine = true + continue + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return builder.String(), nil +} + +func getTextFromText(body []byte) (string, error) { + return strings.TrimSuffix(string(body), "\n"), nil +} + +func getTextFromJSON(body []byte) (string, error) { + var whisperResponse openai.WhisperJSONResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} diff --git a/relay/controller/error.go b/relay/controller/error.go new file mode 100644 index 0000000..29d4f12 --- /dev/null +++ b/relay/controller/error.go @@ -0,0 +1,101 @@ +package controller + +import ( + "encoding/json" + "fmt" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strconv" +) + +type GeneralErrorResponse struct { + Error model.Error `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` + Header struct { + Message string `json:"message"` + } `json:"header"` + Response struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } `json:"response"` +} + +func (e GeneralErrorResponse) ToMessage() string { + if e.Error.Message != "" { + return e.Error.Message + } + if e.Message != "" { + return e.Message + } + if e.Msg != "" { + return e.Msg + } + if e.Err != "" { + return e.Err + } + if e.ErrorMsg != "" { + return e.ErrorMsg + } + if e.Header.Message != "" { + return e.Header.Message + } + if e.Response.Error.Message != "" { + return e.Response.Error.Message + } + return "" +} + +func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *model.ErrorWithStatusCode) { + if resp == nil { + return &model.ErrorWithStatusCode{ + StatusCode: 500, + Error: model.Error{ + Message: "resp is nil", + Type: "upstream_error", + Code: "bad_response", + }, + } + } + ErrorWithStatusCode = &model.ErrorWithStatusCode{ + StatusCode: resp.StatusCode, + Error: model.Error{ + Message: "", + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + if config.DebugEnabled { + logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody))) + } + err = resp.Body.Close() + if err != nil { + return + } + var errResponse GeneralErrorResponse + err = json.Unmarshal(responseBody, &errResponse) + if err != nil { + return + } + if errResponse.Error.Message != "" { + // OpenAI format error, so we override the default one + ErrorWithStatusCode.Error = errResponse.Error + } else { + ErrorWithStatusCode.Error.Message = errResponse.ToMessage() + } + if ErrorWithStatusCode.Error.Message == "" { + ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } + return +} diff --git a/relay/controller/helper.go b/relay/controller/helper.go new file mode 100644 index 0000000..5b6f023 --- /dev/null +++ b/relay/controller/helper.go @@ -0,0 +1,198 @@ +package controller + +import ( + "context" + "errors" + "fmt" + "math" + "net/http" + "strings" + + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/relay/constant/role" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/controller/validator" + "github.com/songquanpeng/one-api/relay/meta" + relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) { + textRequest := &relaymodel.GeneralOpenAIRequest{} + err := common.UnmarshalBodyReusable(c, textRequest) + if err != nil { + return nil, err + } + if relayMode == relaymode.Moderations && textRequest.Model == "" { + textRequest.Model = "text-moderation-latest" + } + if relayMode == relaymode.Embeddings && textRequest.Model == "" { + textRequest.Model = c.Param("model") + } + err = validator.ValidateTextRequest(textRequest, relayMode) + if err != nil { + return nil, err + } + return textRequest, nil +} + +func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { + switch relayMode { + case relaymode.ChatCompletions: + return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) + case relaymode.Completions: + return openai.CountTokenInput(textRequest.Prompt, textRequest.Model) + case relaymode.Moderations: + return openai.CountTokenInput(textRequest.Input, textRequest.Model) + } + return 0 +} + +func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 { + preConsumedTokens := config.PreConsumedQuota + int64(promptTokens) + if textRequest.MaxTokens != 0 { + preConsumedTokens += int64(textRequest.MaxTokens) + } + return int64(float64(preConsumedTokens) * ratio) +} + +func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *meta.Meta) (int64, *relaymodel.ErrorWithStatusCode) { + preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) + + userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) + if err != nil { + return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + if userQuota-preConsumedQuota < 0 { + return preConsumedQuota, openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota) + if err != nil { + return preConsumedQuota, openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + logger.Info(ctx, fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota)) + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota) + if err != nil { + return preConsumedQuota, openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + return preConsumedQuota, nil +} + +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) { + if usage == nil { + logger.Error(ctx, "usage is nil, which is unexpected") + return + } + var quota int64 + completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType) + promptTokens := usage.PromptTokens + completionTokens := usage.CompletionTokens + quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) + if ratio != 0 && quota <= 0 { + quota = 1 + } + totalTokens := promptTokens + completionTokens + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + } + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta) + if err != nil { + logger.Error(ctx, "error consuming token remain quota: "+err.Error()) + } + err = model.CacheUpdateUserQuota(ctx, meta.UserId) + if err != nil { + logger.Error(ctx, "error update user quota cache: "+err.Error()) + } + logContent := fmt.Sprintf("倍率:%.2f × %.2f × %.2f", modelRatio, groupRatio, completionRatio) + model.RecordConsumeLog(ctx, &model.Log{ + UserId: meta.UserId, + ChannelId: meta.ChannelId, + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + ModelName: textRequest.Model, + TokenName: meta.TokenName, + Quota: int(quota), + Content: logContent, + IsStream: meta.IsStream, + ElapsedTime: helper.CalcElapsedTime(meta.StartTime), + SystemPromptReset: systemPromptReset, + }) + model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) + model.UpdateChannelUsedQuota(meta.ChannelId, quota) +} + +func getMappedModelName(modelName string, mapping map[string]string) (string, bool) { + if mapping == nil { + return modelName, false + } + mappedModelName := mapping[modelName] + if mappedModelName != "" { + return mappedModelName, true + } + return modelName, false +} + +func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { + if resp == nil { + if meta.ChannelType == channeltype.AwsClaude { + return false + } + return true + } + if resp.StatusCode != http.StatusOK && + // replicate return 201 to create a task + resp.StatusCode != http.StatusCreated { + return true + } + if meta.ChannelType == channeltype.DeepL { + // skip stream check for deepl + return false + } + + if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") && + // Even if stream mode is enabled, replicate will first return a task info in JSON format, + // requiring the client to request the stream endpoint in the task info + meta.ChannelType != channeltype.Replicate { + return true + } + return false +} + +func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) { + if prompt == "" { + return false + } + if len(request.Messages) == 0 { + return false + } + if request.Messages[0].Role == role.System { + request.Messages[0].Content = prompt + logger.Infof(ctx, "rewrite system prompt") + return true + } + request.Messages = append([]relaymodel.Message{{ + Role: role.System, + Content: prompt, + }}, request.Messages...) + logger.Infof(ctx, "add system prompt") + return true +} diff --git a/relay/controller/image.go b/relay/controller/image.go new file mode 100644 index 0000000..9a980a1 --- /dev/null +++ b/relay/controller/image.go @@ -0,0 +1,238 @@ +package controller + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) { + imageRequest := &relaymodel.ImageRequest{} + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err + } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" + } + return imageRequest, nil +} + +func isValidImageSize(model string, size string) bool { + if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil { + return true + } + _, ok := billingratio.ImageSizeRatios[model][size] + return ok +} + +func isValidImagePromptLength(model string, promptLength int) bool { + maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model] + return !ok || promptLength <= maxPromptLength +} + +func isWithinRange(element string, value int) bool { + amounts, ok := billingratio.ImageGenerationAmounts[element] + return !ok || (value >= amounts[0] && value <= amounts[1]) +} + +func getImageSizeRatio(model string, size string) float64 { + if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok { + return ratio + } + return 1 +} + +func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode { + // check prompt length + if imageRequest.Prompt == "" { + return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) + } + + // model validation + if !isValidImageSize(imageRequest.Model, imageRequest.Size) { + return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + } + + if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) { + return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) + } + + // Number of generated images validation + if !isWithinRange(imageRequest.Model, imageRequest.N) { + return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + } + return nil +} + +func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { + if imageRequest == nil { + return 0, errors.New("imageRequest is nil") + } + imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) + if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { + if imageRequest.Size == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + return imageCostRatio, nil +} + +func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { + ctx := c.Request.Context() + meta := meta.GetByContext(c) + imageRequest, err := getImageRequest(c, meta.Mode) + if err != nil { + logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) + } + + // map model name + var isModelMapped bool + meta.OriginModelName = imageRequest.Model + imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping) + meta.ActualModelName = imageRequest.Model + + // model validation + bizErr := validateImageRequest(imageRequest, meta) + if bizErr != nil { + return bizErr + } + + imageCostRatio, err := getImageCostRatio(imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) + } + + imageModel := imageRequest.Model + // Convert the original image model + imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName) + c.Set("response_format", imageRequest.ResponseFormat) + + var requestBody io.Reader + if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body + jsonStr, err := json.Marshal(imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + } else { + requestBody = c.Request.Body + } + + adaptor := relay.GetAdaptor(meta.APIType) + if adaptor == nil { + return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(meta) + + // these adaptors need to convert the request + switch meta.ChannelType { + case channeltype.Zhipu, + channeltype.Ali, + channeltype.Replicate, + channeltype.Baidu: + finalRequest, err := adaptor.ConvertImageRequest(imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) + } + jsonStr, err := json.Marshal(finalRequest) + if err != nil { + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + } + + modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType) + groupRatio := billingratio.GetGroupRatio(meta.Group) + ratio := modelRatio * groupRatio + userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) + + var quota int64 + switch meta.ChannelType { + case channeltype.Replicate: + // replicate always return 1 image + quota = int64(ratio * imageCostRatio * 1000) + default: + quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) + } + + if userQuota-quota < 0 { + return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + + // do request + resp, err := adaptor.DoRequest(c, meta, requestBody) + if err != nil { + logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + defer func(ctx context.Context) { + if resp != nil && + resp.StatusCode != http.StatusCreated && // replicate returns 201 + resp.StatusCode != http.StatusOK { + return + } + + err := model.PostConsumeTokenQuota(meta.TokenId, quota) + if err != nil { + logger.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(ctx, meta.UserId) + if err != nil { + logger.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString(ctxkey.TokenName) + logContent := fmt.Sprintf("倍率:%.2f × %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, &model.Log{ + UserId: meta.UserId, + ChannelId: meta.ChannelId, + PromptTokens: 0, + CompletionTokens: 0, + ModelName: imageRequest.Model, + TokenName: tokenName, + Quota: int(quota), + Content: logContent, + }) + model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) + channelId := c.GetInt(ctxkey.ChannelId) + model.UpdateChannelUsedQuota(channelId, quota) + } + }(c.Request.Context()) + + // do response + _, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + return respErr + } + + return nil +} diff --git a/relay/controller/proxy.go b/relay/controller/proxy.go new file mode 100644 index 0000000..dcaf15a --- /dev/null +++ b/relay/controller/proxy.go @@ -0,0 +1,41 @@ +// Package controller is a package for handling the relay controller +package controller + +import ( + "fmt" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +// RelayProxyHelper is a helper function to proxy the request to the upstream service +func RelayProxyHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { + ctx := c.Request.Context() + meta := meta.GetByContext(c) + + adaptor := relay.GetAdaptor(meta.APIType) + if adaptor == nil { + return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(meta) + + resp, err := adaptor.DoRequest(c, meta, c.Request.Body) + if err != nil { + logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + // do response + _, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + return respErr + } + + return nil +} diff --git a/relay/controller/text.go b/relay/controller/text.go new file mode 100644 index 0000000..f912498 --- /dev/null +++ b/relay/controller/text.go @@ -0,0 +1,115 @@ +package controller + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/apitype" + "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { + ctx := c.Request.Context() + meta := meta.GetByContext(c) + // get & validate textRequest + textRequest, err := getAndValidateTextRequest(c, meta.Mode) + if err != nil { + logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) + } + meta.IsStream = textRequest.Stream + + // map model name + meta.OriginModelName = textRequest.Model + textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) + meta.ActualModelName = textRequest.Model + // set system prompt if not empty + systemPromptReset := setSystemPrompt(ctx, textRequest, meta.ForcedSystemPrompt) + // get model ratio & group ratio + modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) + groupRatio := billingratio.GetGroupRatio(meta.Group) + ratio := modelRatio * groupRatio + // pre-consume quota + promptTokens := getPromptTokens(textRequest, meta.Mode) + meta.PromptTokens = promptTokens + preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta) + if bizErr != nil { + logger.Warnf(ctx, "preConsumeQuota failed: %+v", *bizErr) + return bizErr + } + + adaptor := relay.GetAdaptor(meta.APIType) + if adaptor == nil { + return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(meta) + + // get request body + requestBody, err := getRequestBody(c, meta, textRequest, adaptor) + if err != nil { + return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError) + } + + // do request + resp, err := adaptor.DoRequest(c, meta, requestBody) + if err != nil { + logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + if isErrorHappened(meta, resp) { + billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + return RelayErrorHandler(resp) + } + + // do response + usage, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + return respErr + } + // post-consume quota + go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset) + return nil +} + +func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { + if !config.EnforceIncludeUsage && + meta.APIType == apitype.OpenAI && + meta.OriginModelName == meta.ActualModelName && + meta.ChannelType != channeltype.Baichuan && + meta.ForcedSystemPrompt == "" { + // no need to convert request for openai + return c.Request.Body, nil + } + + // get request body + var requestBody io.Reader + convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest) + if err != nil { + logger.Debugf(c.Request.Context(), "converted request failed: %s\n", err.Error()) + return nil, err + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error()) + return nil, err + } + logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData)) + requestBody = bytes.NewBuffer(jsonData) + return requestBody, nil +} diff --git a/relay/controller/validator/validation.go b/relay/controller/validator/validation.go new file mode 100644 index 0000000..8ff520b --- /dev/null +++ b/relay/controller/validator/validation.go @@ -0,0 +1,37 @@ +package validator + +import ( + "errors" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" + "math" +) + +func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) error { + if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { + return errors.New("max_tokens is invalid") + } + if textRequest.Model == "" { + return errors.New("model is required") + } + switch relayMode { + case relaymode.Completions: + if textRequest.Prompt == "" { + return errors.New("field prompt is required") + } + case relaymode.ChatCompletions: + if textRequest.Messages == nil || len(textRequest.Messages) == 0 { + return errors.New("field messages is required") + } + case relaymode.Embeddings: + case relaymode.Moderations: + if textRequest.Input == "" { + return errors.New("field input is required") + } + case relaymode.Edits: + if textRequest.Instruction == "" { + return errors.New("field instruction is required") + } + } + return nil +} diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go new file mode 100644 index 0000000..8c74ef8 --- /dev/null +++ b/relay/meta/relay_meta.go @@ -0,0 +1,66 @@ +package meta + +import ( + "strings" + "time" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +type Meta struct { + Mode int + ChannelType int + ChannelId int + TokenId int + TokenName string + UserId int + Group string + ModelMapping map[string]string + // BaseURL is the proxy url set in the channel config + BaseURL string + APIKey string + APIType int + Config model.ChannelConfig + IsStream bool + // OriginModelName is the model name from the raw user request + OriginModelName string + // ActualModelName is the model name after mapping + ActualModelName string + RequestURLPath string + PromptTokens int // only for DoResponse + ForcedSystemPrompt string + StartTime time.Time +} + +func GetByContext(c *gin.Context) *Meta { + meta := Meta{ + Mode: relaymode.GetByPath(c.Request.URL.Path), + ChannelType: c.GetInt(ctxkey.Channel), + ChannelId: c.GetInt(ctxkey.ChannelId), + TokenId: c.GetInt(ctxkey.TokenId), + TokenName: c.GetString(ctxkey.TokenName), + UserId: c.GetInt(ctxkey.Id), + Group: c.GetString(ctxkey.Group), + ModelMapping: c.GetStringMapString(ctxkey.ModelMapping), + OriginModelName: c.GetString(ctxkey.RequestModel), + BaseURL: c.GetString(ctxkey.BaseURL), + APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + RequestURLPath: c.Request.URL.String(), + ForcedSystemPrompt: c.GetString(ctxkey.SystemPrompt), + StartTime: time.Now(), + } + cfg, ok := c.Get(ctxkey.Config) + if ok { + meta.Config = cfg.(model.ChannelConfig) + } + if meta.BaseURL == "" { + meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType] + } + meta.APIType = channeltype.ToAPIType(meta.ChannelType) + return &meta +} diff --git a/relay/model/constant.go b/relay/model/constant.go new file mode 100644 index 0000000..c9d6d64 --- /dev/null +++ b/relay/model/constant.go @@ -0,0 +1,7 @@ +package model + +const ( + ContentTypeText = "text" + ContentTypeImageURL = "image_url" + ContentTypeInputAudio = "input_audio" +) diff --git a/relay/model/general.go b/relay/model/general.go new file mode 100644 index 0000000..5f5968c --- /dev/null +++ b/relay/model/general.go @@ -0,0 +1,88 @@ +package model + +type ResponseFormat struct { + Type string `json:"type,omitempty"` + JsonSchema *JSONSchema `json:"json_schema,omitempty"` +} + +type JSONSchema struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Schema map[string]interface{} `json:"schema,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +type Audio struct { + Voice string `json:"voice,omitempty"` + Format string `json:"format,omitempty"` +} + +type StreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` +} + +type GeneralOpenAIRequest struct { + // https://platform.openai.com/docs/api-reference/chat/create + Messages []Message `json:"messages,omitempty"` + Model string `json:"model,omitempty"` + Store *bool `json:"store,omitempty"` + ReasoningEffort *string `json:"reasoning_effort,omitempty"` + Metadata any `json:"metadata,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias any `json:"logit_bias,omitempty"` + Logprobs *bool `json:"logprobs,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + N int `json:"n,omitempty"` + Modalities []string `json:"modalities,omitempty"` + Prediction any `json:"prediction,omitempty"` + Audio *Audio `json:"audio,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + ServiceTier *string `json:"service_tier,omitempty"` + Stop any `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` + User string `json:"user,omitempty"` + FunctionCall any `json:"function_call,omitempty"` + Functions any `json:"functions,omitempty"` + // https://platform.openai.com/docs/api-reference/embeddings/create + Input any `json:"input,omitempty"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + // https://platform.openai.com/docs/api-reference/images/create + Prompt any `json:"prompt,omitempty"` + Quality *string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + Style *string `json:"style,omitempty"` + // Others + Instruction string `json:"instruction,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` +} + +func (r GeneralOpenAIRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input +} diff --git a/relay/model/image.go b/relay/model/image.go new file mode 100644 index 0000000..bab8425 --- /dev/null +++ b/relay/model/image.go @@ -0,0 +1,12 @@ +package model + +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + User string `json:"user,omitempty"` +} diff --git a/relay/model/message.go b/relay/model/message.go new file mode 100644 index 0000000..5ff7b7a --- /dev/null +++ b/relay/model/message.go @@ -0,0 +1,91 @@ +package model + +type Message struct { + Role string `json:"role,omitempty"` + Content any `json:"content,omitempty"` + ReasoningContent any `json:"reasoning_content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []Tool `json:"tool_calls,omitempty"` + ToolCallId string `json:"tool_call_id,omitempty"` +} + +func (m Message) IsStringContent() bool { + _, ok := m.Content.(string) + return ok +} + +func (m Message) StringContent() string { + content, ok := m.Content.(string) + if ok { + return content + } + contentList, ok := m.Content.([]any) + if ok { + var contentStr string + for _, contentItem := range contentList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + if contentMap["type"] == ContentTypeText { + if subStr, ok := contentMap["text"].(string); ok { + contentStr += subStr + } + } + } + return contentStr + } + return "" +} + +func (m Message) ParseContent() []MessageContent { + var contentList []MessageContent + content, ok := m.Content.(string) + if ok { + contentList = append(contentList, MessageContent{ + Type: ContentTypeText, + Text: content, + }) + return contentList + } + anyList, ok := m.Content.([]any) + if ok { + for _, contentItem := range anyList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + switch contentMap["type"] { + case ContentTypeText: + if subStr, ok := contentMap["text"].(string); ok { + contentList = append(contentList, MessageContent{ + Type: ContentTypeText, + Text: subStr, + }) + } + case ContentTypeImageURL: + if subObj, ok := contentMap["image_url"].(map[string]any); ok { + contentList = append(contentList, MessageContent{ + Type: ContentTypeImageURL, + ImageURL: &ImageURL{ + Url: subObj["url"].(string), + }, + }) + } + } + } + return contentList + } + return nil +} + +type ImageURL struct { + Url string `json:"url,omitempty"` + Detail string `json:"detail,omitempty"` +} + +type MessageContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} diff --git a/relay/model/misc.go b/relay/model/misc.go new file mode 100644 index 0000000..fdba01e --- /dev/null +++ b/relay/model/misc.go @@ -0,0 +1,27 @@ +package model + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + + CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` +} + +type CompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` +} + +type Error struct { + Message string `json:"message"` + Type string `json:"type"` + Param string `json:"param"` + Code any `json:"code"` +} + +type ErrorWithStatusCode struct { + Error + StatusCode int `json:"status_code"` +} diff --git a/relay/model/tool.go b/relay/model/tool.go new file mode 100644 index 0000000..75dbb8f --- /dev/null +++ b/relay/model/tool.go @@ -0,0 +1,14 @@ +package model + +type Tool struct { + Id string `json:"id,omitempty"` + Type string `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty + Function Function `json:"function"` +} + +type Function struct { + Description string `json:"description,omitempty"` + Name string `json:"name,omitempty"` // when splicing claude tools stream messages, it is empty + Parameters any `json:"parameters,omitempty"` // request + Arguments any `json:"arguments,omitempty"` // response +} diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go new file mode 100644 index 0000000..aa77120 --- /dev/null +++ b/relay/relaymode/define.go @@ -0,0 +1,16 @@ +package relaymode + +const ( + Unknown = iota + ChatCompletions + Completions + Embeddings + Moderations + ImagesGenerations + Edits + AudioSpeech + AudioTranscription + AudioTranslation + // Proxy is a special relay mode for proxying requests to custom upstream + Proxy +) diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go new file mode 100644 index 0000000..2cde5b8 --- /dev/null +++ b/relay/relaymode/helper.go @@ -0,0 +1,31 @@ +package relaymode + +import "strings" + +func GetByPath(path string) int { + relayMode := Unknown + if strings.HasPrefix(path, "/v1/chat/completions") { + relayMode = ChatCompletions + } else if strings.HasPrefix(path, "/v1/completions") { + relayMode = Completions + } else if strings.HasPrefix(path, "/v1/embeddings") { + relayMode = Embeddings + } else if strings.HasSuffix(path, "embeddings") { + relayMode = Embeddings + } else if strings.HasPrefix(path, "/v1/moderations") { + relayMode = Moderations + } else if strings.HasPrefix(path, "/v1/images/generations") { + relayMode = ImagesGenerations + } else if strings.HasPrefix(path, "/v1/edits") { + relayMode = Edits + } else if strings.HasPrefix(path, "/v1/audio/speech") { + relayMode = AudioSpeech + } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { + relayMode = AudioTranscription + } else if strings.HasPrefix(path, "/v1/audio/translations") { + relayMode = AudioTranslation + } else if strings.HasPrefix(path, "/v1/oneapi/proxy") { + relayMode = Proxy + } + return relayMode +} diff --git a/router/api.go b/router/api.go new file mode 100644 index 0000000..6d00c6e --- /dev/null +++ b/router/api.go @@ -0,0 +1,121 @@ +package router + +import ( + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/controller/auth" + "github.com/songquanpeng/one-api/middleware" + + "github.com/gin-contrib/gzip" + "github.com/gin-gonic/gin" +) + +func SetApiRouter(router *gin.Engine) { + apiRouter := router.Group("/api") + apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) + apiRouter.Use(middleware.GlobalAPIRateLimit()) + { + apiRouter.GET("/status", controller.GetStatus) + apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels) + apiRouter.GET("/notice", controller.GetNotice) + apiRouter.GET("/about", controller.GetAbout) + apiRouter.GET("/home_page_content", controller.GetHomePageContent) + apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) + apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) + apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) + apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) + apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth) + apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) + apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) + apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) + apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), auth.WeChatBind) + apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) + apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp) + + userRoute := apiRouter.Group("/user") + { + userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register) + userRoute.POST("/login", middleware.CriticalRateLimit(), controller.Login) + userRoute.GET("/logout", controller.Logout) + + selfRoute := userRoute.Group("/") + selfRoute.Use(middleware.UserAuth()) + { + selfRoute.GET("/dashboard", controller.GetUserDashboard) + selfRoute.GET("/self", controller.GetSelf) + selfRoute.PUT("/self", controller.UpdateSelf) + selfRoute.DELETE("/self", controller.DeleteSelf) + selfRoute.GET("/token", controller.GenerateAccessToken) + selfRoute.GET("/aff", controller.GetAffCode) + selfRoute.POST("/topup", controller.TopUp) + selfRoute.GET("/available_models", controller.GetUserAvailableModels) + } + + adminRoute := userRoute.Group("/") + adminRoute.Use(middleware.AdminAuth()) + { + adminRoute.GET("/", controller.GetAllUsers) + adminRoute.GET("/search", controller.SearchUsers) + adminRoute.GET("/:id", controller.GetUser) + adminRoute.POST("/", controller.CreateUser) + adminRoute.POST("/manage", controller.ManageUser) + adminRoute.PUT("/", controller.UpdateUser) + adminRoute.DELETE("/:id", controller.DeleteUser) + } + } + optionRoute := apiRouter.Group("/option") + optionRoute.Use(middleware.RootAuth()) + { + optionRoute.GET("/", controller.GetOptions) + optionRoute.PUT("/", controller.UpdateOption) + } + channelRoute := apiRouter.Group("/channel") + channelRoute.Use(middleware.AdminAuth()) + { + channelRoute.GET("/", controller.GetAllChannels) + channelRoute.GET("/search", controller.SearchChannels) + channelRoute.GET("/models", controller.ListAllModels) + channelRoute.GET("/:id", controller.GetChannel) + channelRoute.GET("/test", controller.TestChannels) + channelRoute.GET("/test/:id", controller.TestChannel) + channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) + channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) + channelRoute.POST("/", controller.AddChannel) + channelRoute.PUT("/", controller.UpdateChannel) + channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel) + channelRoute.DELETE("/:id", controller.DeleteChannel) + } + tokenRoute := apiRouter.Group("/token") + tokenRoute.Use(middleware.UserAuth()) + { + tokenRoute.GET("/", controller.GetAllTokens) + tokenRoute.GET("/search", controller.SearchTokens) + tokenRoute.GET("/:id", controller.GetToken) + tokenRoute.POST("/", controller.AddToken) + tokenRoute.PUT("/", controller.UpdateToken) + tokenRoute.DELETE("/:id", controller.DeleteToken) + } + redemptionRoute := apiRouter.Group("/redemption") + redemptionRoute.Use(middleware.AdminAuth()) + { + redemptionRoute.GET("/", controller.GetAllRedemptions) + redemptionRoute.GET("/search", controller.SearchRedemptions) + redemptionRoute.GET("/:id", controller.GetRedemption) + redemptionRoute.POST("/", controller.AddRedemption) + redemptionRoute.PUT("/", controller.UpdateRedemption) + redemptionRoute.DELETE("/:id", controller.DeleteRedemption) + } + logRoute := apiRouter.Group("/log") + logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs) + logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs) + logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat) + logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat) + logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) + logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs) + logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs) + groupRoute := apiRouter.Group("/group") + groupRoute.Use(middleware.AdminAuth()) + { + groupRoute.GET("/", controller.GetGroups) + } + } +} diff --git a/router/dashboard.go b/router/dashboard.go new file mode 100644 index 0000000..5952d69 --- /dev/null +++ b/router/dashboard.go @@ -0,0 +1,22 @@ +package router + +import ( + "github.com/gin-contrib/gzip" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/middleware" +) + +func SetDashboardRouter(router *gin.Engine) { + apiRouter := router.Group("/") + apiRouter.Use(middleware.CORS()) + apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) + apiRouter.Use(middleware.GlobalAPIRateLimit()) + apiRouter.Use(middleware.TokenAuth()) + { + apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription) + apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription) + apiRouter.GET("/dashboard/billing/usage", controller.GetUsage) + apiRouter.GET("/v1/dashboard/billing/usage", controller.GetUsage) + } +} diff --git a/router/main.go b/router/main.go new file mode 100644 index 0000000..39d8c04 --- /dev/null +++ b/router/main.go @@ -0,0 +1,31 @@ +package router + +import ( + "embed" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "net/http" + "os" + "strings" +) + +func SetRouter(router *gin.Engine, buildFS embed.FS) { + SetApiRouter(router) + SetDashboardRouter(router) + SetRelayRouter(router) + frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") + if config.IsMasterNode && frontendBaseUrl != "" { + frontendBaseUrl = "" + logger.SysLog("FRONTEND_BASE_URL is ignored on master node") + } + if frontendBaseUrl == "" { + SetWebRouter(router, buildFS) + } else { + frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/") + router.NoRoute(func(c *gin.Context) { + c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI)) + }) + } +} diff --git a/router/relay.go b/router/relay.go new file mode 100644 index 0000000..8f3c730 --- /dev/null +++ b/router/relay.go @@ -0,0 +1,74 @@ +package router + +import ( + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/middleware" + + "github.com/gin-gonic/gin" +) + +func SetRelayRouter(router *gin.Engine) { + router.Use(middleware.CORS()) + router.Use(middleware.GzipDecodeMiddleware()) + // https://platform.openai.com/docs/api-reference/introduction + modelsRouter := router.Group("/v1/models") + modelsRouter.Use(middleware.TokenAuth()) + { + modelsRouter.GET("", controller.ListModels) + modelsRouter.GET("/:model", controller.RetrieveModel) + } + relayV1Router := router.Group("/v1") + relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute()) + { + relayV1Router.Any("/oneapi/proxy/:channelid/*target", controller.Relay) + relayV1Router.POST("/completions", controller.Relay) + relayV1Router.POST("/chat/completions", controller.Relay) + relayV1Router.POST("/edits", controller.Relay) + relayV1Router.POST("/images/generations", controller.Relay) + relayV1Router.POST("/images/edits", controller.RelayNotImplemented) + relayV1Router.POST("/images/variations", controller.RelayNotImplemented) + relayV1Router.POST("/embeddings", controller.Relay) + relayV1Router.POST("/engines/:model/embeddings", controller.Relay) + relayV1Router.POST("/audio/transcriptions", controller.Relay) + relayV1Router.POST("/audio/translations", controller.Relay) + relayV1Router.POST("/audio/speech", controller.Relay) + relayV1Router.GET("/files", controller.RelayNotImplemented) + relayV1Router.POST("/files", controller.RelayNotImplemented) + relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) + relayV1Router.GET("/files/:id", controller.RelayNotImplemented) + relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented) + relayV1Router.POST("/fine_tuning/jobs", controller.RelayNotImplemented) + relayV1Router.GET("/fine_tuning/jobs", controller.RelayNotImplemented) + relayV1Router.GET("/fine_tuning/jobs/:id", controller.RelayNotImplemented) + relayV1Router.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented) + relayV1Router.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented) + relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) + relayV1Router.POST("/moderations", controller.Relay) + relayV1Router.POST("/assistants", controller.RelayNotImplemented) + relayV1Router.GET("/assistants/:id", controller.RelayNotImplemented) + relayV1Router.POST("/assistants/:id", controller.RelayNotImplemented) + relayV1Router.DELETE("/assistants/:id", controller.RelayNotImplemented) + relayV1Router.GET("/assistants", controller.RelayNotImplemented) + relayV1Router.POST("/assistants/:id/files", controller.RelayNotImplemented) + relayV1Router.GET("/assistants/:id/files/:fileId", controller.RelayNotImplemented) + relayV1Router.DELETE("/assistants/:id/files/:fileId", controller.RelayNotImplemented) + relayV1Router.GET("/assistants/:id/files", controller.RelayNotImplemented) + relayV1Router.POST("/threads", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id", controller.RelayNotImplemented) + relayV1Router.DELETE("/threads/:id", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/messages", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/messages/:messageId", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/messages/:messageId", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/messages/:messageId/files/:filesId", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/messages/:messageId/files", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/runs", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/runs/:runsId", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/runs/:runsId", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/runs", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/runs/:runsId/submit_tool_outputs", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/runs/:runsId/cancel", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/runs/:runsId/steps/:stepId", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented) + } +} diff --git a/router/web.go b/router/web.go new file mode 100644 index 0000000..3c9b464 --- /dev/null +++ b/router/web.go @@ -0,0 +1,31 @@ +package router + +import ( + "embed" + "fmt" + "github.com/gin-contrib/gzip" + "github.com/gin-contrib/static" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/middleware" + "net/http" + "strings" +) + +func SetWebRouter(router *gin.Engine, buildFS embed.FS) { + indexPageData, _ := buildFS.ReadFile(fmt.Sprintf("web/build/%s/index.html", config.Theme)) + router.Use(gzip.Gzip(gzip.DefaultCompression)) + router.Use(middleware.GlobalWebRateLimit()) + router.Use(middleware.Cache()) + router.Use(static.Serve("/", common.EmbedFolder(buildFS, fmt.Sprintf("web/build/%s", config.Theme)))) + router.NoRoute(func(c *gin.Context) { + if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") { + controller.RelayNotFound(c) + return + } + c.Header("Cache-Control", "no-cache") + c.Data(http.StatusOK, "text/html; charset=utf-8", indexPageData) + }) +} diff --git a/web/README.md b/web/README.md new file mode 100644 index 0000000..829271e --- /dev/null +++ b/web/README.md @@ -0,0 +1,47 @@ +# One API 的前端界面 + +> 每个文件夹代表一个主题,欢迎提交你的主题 + +> [!WARNING] +> 不是每一个主题都及时同步了所有功能,由于精力有限,优先更新默认主题,其他主题欢迎 & 期待 PR + +## 提交新的主题 + +> 欢迎在页面底部保留你和 One API 的版权信息以及指向链接 + +1. 在 `web` 文件夹下新建一个文件夹,文件夹名为主题名。 +2. 把你的主题文件放到这个文件夹下。 +3. 修改你的 `package.json` 文件,把 `build` 命令改为:`"build": "react-scripts build && mv -f build ../build/default"`,其中 `default` 为你的主题名。 +4. 修改 `common/config/config.go` 中的 `ValidThemes`,把你的主题名称注册进去。 +5. 修改 `web/THEMES` 文件,这里也需要同步修改。 + +## 主题列表 + +### 主题:default + +默认主题,由 [JustSong](https://github.com/songquanpeng) 开发。 + +预览: +|![image](https://github.com/songquanpeng/one-api/assets/39998050/ccfbc668-3a7f-4bc1-87da-7eacfd7bf371)|![image](https://github.com/songquanpeng/one-api/assets/39998050/a63ed547-44b9-45db-b43a-ecea07d60840)| +|:---:|:---:| + +### 主题:berry + +由 [MartialBE](https://github.com/MartialBE) 开发。 + +预览: +||| +|:---:|:---:| +|![image](https://github.com/songquanpeng/one-api/assets/42402987/36aff5c6-c5ff-4a90-8e3d-33d5cff34cbf)|![image](https://github.com/songquanpeng/one-api/assets/42402987/9ac63b36-5140-4064-8fad-fc9d25821509)| +|![image](https://github.com/songquanpeng/one-api/assets/42402987/fb2b1c64-ef24-4027-9b80-0cd9d945a47f)|![image](https://github.com/songquanpeng/one-api/assets/42402987/b6b649ec-2888-4324-8b2d-d5e11554eed6)| +|![image](https://github.com/songquanpeng/one-api/assets/42402987/6d3b22e0-436b-4e26-8911-bcc993c6a2bd)|![image](https://github.com/songquanpeng/one-api/assets/42402987/eef1e224-7245-44d7-804e-9d1c8fa3f29c)| + +### 主题:air +由 [Calon](https://github.com/Calcium-Ion) 开发。 +|![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/1ddb274b-a715-4e81-858b-857d520b6ff4)|![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/163b0b8e-1f73-49cb-b632-3dcb986b56d5)| +|:---:|:---:| + + +#### 开发说明 + +请查看 [web/berry/README.md](https://github.com/songquanpeng/one-api/tree/main/web/berry/README.md) diff --git a/web/THEMES b/web/THEMES new file mode 100644 index 0000000..149e869 --- /dev/null +++ b/web/THEMES @@ -0,0 +1,3 @@ +default +berry +air diff --git a/web/air/.gitignore b/web/air/.gitignore new file mode 100644 index 0000000..2b5bba7 --- /dev/null +++ b/web/air/.gitignore @@ -0,0 +1,26 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.js + +# testing +/coverage + +# production +/build + +# misc +.DS_Store +.env.local +.env.development.local +.env.test.local +.env.production.local + +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.idea +package-lock.json +yarn.lock \ No newline at end of file diff --git a/web/air/README.md b/web/air/README.md new file mode 100644 index 0000000..1b1031a --- /dev/null +++ b/web/air/README.md @@ -0,0 +1,21 @@ +# React Template + +## Basic Usages + +```shell +# Runs the app in the development mode +npm start + +# Builds the app for production to the `build` folder +npm run build +``` + +If you want to change the default server, please set `REACT_APP_SERVER` environment variables before build, +for example: `REACT_APP_SERVER=http://your.domain.com`. + +Before you start editing, make sure your `Actions on Save` options have `Optimize imports` & `Run Prettier` enabled. + +## Reference + +1. https://github.com/OIerDb-ng/OIerDb +2. https://github.com/cornflourblue/react-hooks-redux-registration-login-example \ No newline at end of file diff --git a/web/air/package.json b/web/air/package.json new file mode 100644 index 0000000..3bdf395 --- /dev/null +++ b/web/air/package.json @@ -0,0 +1,60 @@ +{ + "name": "react-template", + "version": "0.1.0", + "private": true, + "dependencies": { + "@douyinfe/semi-icons": "^2.46.1", + "@douyinfe/semi-ui": "^2.46.1", + "@visactor/react-vchart": "~1.8.8", + "@visactor/vchart": "~1.8.8", + "@visactor/vchart-semi-theme": "~1.8.8", + "axios": "^0.27.2", + "history": "^5.3.0", + "marked": "^4.1.1", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "react-dropzone": "^14.2.3", + "react-fireworks": "^1.0.4", + "react-router-dom": "^6.3.0", + "react-scripts": "5.0.1", + "react-telegram-login": "^1.1.2", + "react-toastify": "^9.0.8", + "react-turnstile": "^1.0.5", + "semantic-ui-css": "^2.5.0", + "semantic-ui-react": "^2.1.3", + "usehooks-ts": "^2.9.1" + }, + "scripts": { + "start": "react-scripts start", + "build": "react-scripts build && mv -f build ../build/air", + "test": "react-scripts test", + "eject": "react-scripts eject" + }, + "eslintConfig": { + "extends": [ + "react-app", + "react-app/jest" + ] + }, + "browserslist": { + "production": [ + ">0.2%", + "not dead", + "not op_mini all" + ], + "development": [ + "last 1 chrome version", + "last 1 firefox version", + "last 1 safari version" + ] + }, + "devDependencies": { + "prettier": "2.8.8", + "typescript": "4.4.2" + }, + "prettier": { + "singleQuote": true, + "jsxSingleQuote": true + }, + "proxy": "http://localhost:3000" +} diff --git a/web/air/public/favicon.ico b/web/air/public/favicon.ico new file mode 100644 index 0000000..c2c8de0 Binary files /dev/null and b/web/air/public/favicon.ico differ diff --git a/web/air/public/index.html b/web/air/public/index.html new file mode 100644 index 0000000..e0de002 --- /dev/null +++ b/web/air/public/index.html @@ -0,0 +1,18 @@ + + + + + + + + + One API + + + +
+ + diff --git a/web/air/public/logo.png b/web/air/public/logo.png new file mode 100644 index 0000000..0f237a2 Binary files /dev/null and b/web/air/public/logo.png differ diff --git a/web/air/public/robots.txt b/web/air/public/robots.txt new file mode 100644 index 0000000..e9e57dc --- /dev/null +++ b/web/air/public/robots.txt @@ -0,0 +1,3 @@ +# https://www.robotstxt.org/robotstxt.html +User-agent: * +Disallow: diff --git a/web/air/src/App.js b/web/air/src/App.js new file mode 100644 index 0000000..5a67318 --- /dev/null +++ b/web/air/src/App.js @@ -0,0 +1,242 @@ +import React, { lazy, Suspense, useContext, useEffect } from 'react'; +import { Route, Routes } from 'react-router-dom'; +import Loading from './components/Loading'; +import User from './pages/User'; +import { PrivateRoute } from './components/PrivateRoute'; +import RegisterForm from './components/RegisterForm'; +import LoginForm from './components/LoginForm'; +import NotFound from './pages/NotFound'; +import Setting from './pages/Setting'; +import EditUser from './pages/User/EditUser'; +import { getLogo, getSystemName } from './helpers'; +import PasswordResetForm from './components/PasswordResetForm'; +import GitHubOAuth from './components/GitHubOAuth'; +import PasswordResetConfirm from './components/PasswordResetConfirm'; +import { UserContext } from './context/User'; +import Channel from './pages/Channel'; +import Token from './pages/Token'; +import EditChannel from './pages/Channel/EditChannel'; +import Redemption from './pages/Redemption'; +import TopUp from './pages/TopUp'; +import Log from './pages/Log'; +import Chat from './pages/Chat'; +import { Layout } from '@douyinfe/semi-ui'; +import Midjourney from './pages/Midjourney'; +import Detail from './pages/Detail'; + +const Home = lazy(() => import('./pages/Home')); +const About = lazy(() => import('./pages/About')); + +function App() { + const [userState, userDispatch] = useContext(UserContext); + // const [statusState, statusDispatch] = useContext(StatusContext); + + const loadUser = () => { + let user = localStorage.getItem('user'); + if (user) { + let data = JSON.parse(user); + userDispatch({ type: 'login', payload: data }); + } + }; + + useEffect(() => { + loadUser(); + let systemName = getSystemName(); + if (systemName) { + document.title = systemName; + } + let logo = getLogo(); + if (logo) { + let linkElement = document.querySelector('link[rel~=\'icon\']'); + if (linkElement) { + linkElement.href = logo; + } + } + }, []); + + return ( + + + + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + + } + /> + + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + + }> + + + + } + /> + + }> + + + + } + /> + + + + } + /> + + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + } /> + + + + ); +} + +export default App; diff --git a/web/air/src/components/ChannelsTable.js b/web/air/src/components/ChannelsTable.js new file mode 100644 index 0000000..c384d50 --- /dev/null +++ b/web/air/src/components/ChannelsTable.js @@ -0,0 +1,738 @@ +import React, { useEffect, useState } from 'react'; +import { API, isMobile, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; + +import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; +import { renderGroup, renderNumberWithPoint, renderQuota } from '../helpers/render'; +import { + Button, + Dropdown, + Form, + InputNumber, + Popconfirm, + Space, + SplitButtonGroup, + Switch, + Table, + Tag, + Tooltip, + Typography +} from '@douyinfe/semi-ui'; +import EditChannel from '../pages/Channel/EditChannel'; +import { IconTreeTriangleDown } from '@douyinfe/semi-icons'; + +function renderTimestamp(timestamp) { + return ( + <> + {timestamp2string(timestamp)} + + ); +} + +let type2label = undefined; + +function renderType(type) { + if (!type2label) { + type2label = new Map(); + for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { + type2label[CHANNEL_OPTIONS[i].value] = CHANNEL_OPTIONS[i]; + } + type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; + } + return {type2label[type]?.text}; +} + +const ChannelsTable = () => { + const columns = [ + // { + // title: '', + // dataIndex: 'checkbox', + // className: 'checkbox', + // }, + { + title: 'ID', + dataIndex: 'id' + }, + { + title: '名称', + dataIndex: 'name' + }, + // { + // title: '分组', + // dataIndex: 'group', + // render: (text, record, index) => { + // return ( + //
+ // + // { + // text.split(',').map((item, index) => { + // return (renderGroup(item)); + // }) + // } + // + //
+ // ); + // } + // }, + { + title: '类型', + dataIndex: 'type', + render: (text, record, index) => { + return ( +
+ {renderType(text)} +
+ ); + } + }, + { + title: '状态', + dataIndex: 'status', + render: (text, record, index) => { + return ( +
+ {renderStatus(text)} +
+ ); + } + }, + { + title: '响应时间', + dataIndex: 'response_time', + render: (text, record, index) => { + return ( +
+ {renderResponseTime(text)} +
+ ); + } + }, + { + title: '已用/剩余', + dataIndex: 'expired_time', + render: (text, record, index) => { + return ( +
+ + + {renderQuota(record.used_quota)} + + + { + updateChannelBalance(record); + }}>${renderNumberWithPoint(record.balance)} + + +
+ ); + } + }, + { + title: '优先级', + dataIndex: 'priority', + render: (text, record, index) => { + return ( +
+ { + manageChannel(record.id, 'priority', record, e.target.value); + }} + keepFocus={true} + innerButtons + defaultValue={record.priority} + min={-999} + /> +
+ ); + } + }, + // { + // title: '权重', + // dataIndex: 'weight', + // render: (text, record, index) => { + // return ( + //
+ // { + // manageChannel(record.id, 'weight', record, e.target.value); + // }} + // keepFocus={true} + // innerButtons + // defaultValue={record.weight} + // min={0} + // /> + //
+ // ); + // } + // }, + { + title: '', + dataIndex: 'operate', + render: (text, record, index) => ( +
+ {/* + + + + + */} + + { + manageChannel(record.id, 'delete', record).then( + () => { + removeRecord(record.id); + } + ); + }} + > + + + { + record.status === 1 ? + : + + } + +
+ ) + } + ]; + + const [channels, setChannels] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [idSort, setIdSort] = useState(false); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searchGroup, setSearchGroup] = useState(''); + const [searchModel, setSearchModel] = useState(''); + const [searching, setSearching] = useState(false); + const [updatingBalance, setUpdatingBalance] = useState(false); + const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); + const [showPrompt, setShowPrompt] = useState(shouldShowPrompt('channel-test')); + const [channelCount, setChannelCount] = useState(pageSize); + const [groupOptions, setGroupOptions] = useState([]); + const [showEdit, setShowEdit] = useState(false); + const [enableBatchDelete, setEnableBatchDelete] = useState(false); + const [editingChannel, setEditingChannel] = useState({ + id: undefined + }); + const [selectedChannels, setSelectedChannels] = useState([]); + + const removeRecord = id => { + let newDataSource = [...channels]; + if (id != null) { + let idx = newDataSource.findIndex(data => data.id === id); + + if (idx > -1) { + newDataSource.splice(idx, 1); + setChannels(newDataSource); + } + } + }; + + const setChannelFormat = (channels) => { + for (let i = 0; i < channels.length; i++) { + channels[i].key = '' + channels[i].id; + let test_models = []; + channels[i].models.split(',').forEach((item, index) => { + test_models.push({ + node: 'item', + name: item, + onClick: () => { + testChannel(channels[i], item); + } + }); + }); + channels[i].test_models = test_models; + } + // data.key = '' + data.id + setChannels(channels); + if (channels.length >= pageSize) { + setChannelCount(channels.length + pageSize); + } else { + setChannelCount(channels.length); + } + }; + + const loadChannels = async (startIdx, pageSize, idSort) => { + setLoading(true); + const res = await API.get(`/api/channel/?p=${startIdx}&page_size=${pageSize}&id_sort=${idSort}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setChannelFormat(data); + } else { + let newChannels = [...channels]; + newChannels.splice(startIdx * pageSize, data.length, ...data); + setChannelFormat(newChannels); + } + } else { + showError(message); + } + setLoading(false); + }; + + const refresh = async () => { + await loadChannels(activePage - 1, pageSize, idSort); + }; + + useEffect(() => { + // console.log('default effect') + const localIdSort = localStorage.getItem('id-sort') === 'true'; + const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE; + setIdSort(localIdSort); + setPageSize(localPageSize); + loadChannels(0, localPageSize, localIdSort) + .then() + .catch((reason) => { + showError(reason); + }); + fetchGroups().then(); + }, []); + + const manageChannel = async (id, action, record, value) => { + let data = { id }; + let res; + switch (action) { + case 'delete': + res = await API.delete(`/api/channel/${id}/`); + break; + case 'enable': + data.status = 1; + res = await API.put('/api/channel/', data); + break; + case 'disable': + data.status = 2; + res = await API.put('/api/channel/', data); + break; + case 'priority': + if (value === '') { + return; + } + data.priority = parseInt(value); + res = await API.put('/api/channel/', data); + break; + case 'weight': + if (value === '') { + return; + } + data.weight = parseInt(value); + if (data.weight < 0) { + data.weight = 0; + } + res = await API.put('/api/channel/', data); + break; + } + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let channel = res.data.data; + let newChannels = [...channels]; + if (action === 'delete') { + + } else { + record.status = channel.status; + } + setChannels(newChannels); + } else { + showError(message); + } + }; + + const renderStatus = (status) => { + switch (status) { + case 1: + return 已启用; + case 2: + return ( + + 已禁用 + + ); + case 3: + return ( + + 自动禁用 + + ); + default: + return ( + + 未知状态 + + ); + } + }; + + const renderResponseTime = (responseTime) => { + let time = responseTime / 1000; + time = time.toFixed(2) + ' 秒'; + if (responseTime === 0) { + return 未测试; + } else if (responseTime <= 1000) { + return {time}; + } else if (responseTime <= 3000) { + return {time}; + } else if (responseTime <= 5000) { + return {time}; + } else { + return {time}; + } + }; + + const searchChannels = async (searchKeyword, searchGroup, searchModel) => { + if (searchKeyword === '' && searchGroup === '' && searchModel === '') { + // if keyword is blank, load files instead. + await loadChannels(0, pageSize, idSort); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}`); + const { success, message, data } = res.data; + if (success) { + setChannels(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const testChannel = async (record, model) => { + const res = await API.get(`/api/channel/test/${record.id}?model=${model}`); + const { success, message, time } = res.data; + if (success) { + record.response_time = time * 1000; + record.test_time = Date.now() / 1000; + showInfo(`渠道 ${record.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); + } else { + showError(message); + } + }; + + const testChannels = async (scope) => { + const res = await API.get(`/api/channel/test?scope=${scope}`); + const { success, message } = res.data; + if (success) { + showInfo('已成功开始测试渠道,请刷新页面查看结果。'); + } else { + showError(message); + } + }; + + const deleteAllDisabledChannels = async () => { + const res = await API.delete(`/api/channel/disabled`); + const { success, message, data } = res.data; + if (success) { + showSuccess(`已删除所有禁用渠道,共计 ${data} 个`); + await refresh(); + } else { + showError(message); + } + }; + + const updateChannelBalance = async (record) => { + const res = await API.get(`/api/channel/update_balance/${record.id}/`); + const { success, message, balance } = res.data; + if (success) { + record.balance = balance; + record.balance_updated_time = Date.now() / 1000; + showInfo(`渠道 ${record.name} 余额更新成功!`); + } else { + showError(message); + } + }; + + const updateAllChannelsBalance = async () => { + setUpdatingBalance(true); + const res = await API.get(`/api/channel/update_balance`); + const { success, message } = res.data; + if (success) { + showInfo('已更新完毕所有已启用渠道余额!'); + } else { + showError(message); + } + setUpdatingBalance(false); + }; + + const batchDeleteChannels = async () => { + if (selectedChannels.length === 0) { + showError('请先选择要删除的渠道!'); + return; + } + setLoading(true); + let ids = []; + selectedChannels.forEach((channel) => { + ids.push(channel.id); + }); + const res = await API.post(`/api/channel/batch`, { ids: ids }); + const { success, message, data } = res.data; + if (success) { + showSuccess(`已删除 ${data} 个渠道!`); + await refresh(); + } else { + showError(message); + } + setLoading(false); + }; + + const fixChannelsAbilities = async () => { + const res = await API.post(`/api/channel/fix`); + const { success, message, data } = res.data; + if (success) { + showSuccess(`已修复 ${data} 个渠道!`); + await refresh(); + } else { + showError(message); + } + }; + + let pageData = channels.slice((activePage - 1) * pageSize, activePage * pageSize); + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(channels.length / pageSize) + 1) { + // In this case we have to load more data and then append them. + loadChannels(page - 1, pageSize, idSort).then(r => { + }); + } + }; + + const handlePageSizeChange = async (size) => { + localStorage.setItem('page-size', size + ''); + setPageSize(size); + setActivePage(1); + loadChannels(0, size, idSort) + .then() + .catch((reason) => { + showError(reason); + }); + }; + + const fetchGroups = async () => { + try { + let res = await API.get(`/api/group/`); + // add 'all' option + // res.data.data.unshift('all'); + setGroupOptions(res.data.data.map((group) => ({ + label: group, + value: group + }))); + } catch (error) { + showError(error.message); + } + }; + + const closeEdit = () => { + setShowEdit(false); + }; + + const handleRow = (record, index) => { + if (record.status !== 1) { + return { + style: { + background: 'var(--semi-color-disabled-border)' + } + }; + } else { + return {}; + } + }; + + + return ( + <> + +
+
{ + searchChannels(searchKeyword, searchGroup, searchModel); + }} labelPosition="left"> +
+ + { + setSearchKeyword(v.trim()); + }} + /> + {/* { + setSearchModel(v.trim()); + }} + /> + { + setSearchGroup(v); + searchChannels(searchKeyword, v, searchModel); + }} /> */} + + +
+
+
+ + + { testChannels("all") }} + position={isMobile() ? 'top' : 'left'} + > + + + { testChannels("disabled") }} + position={isMobile() ? 'top' : 'left'} + > + + + {/* + + */} + + + + + + + {/*
*/} + + {/*
*/} +
+ {/*
+ + 开启批量删除 + { + setEnableBatchDelete(v); + }}> + + + + + + + +
+
+ + + 使用ID排序 + { + localStorage.setItem('id-sort', v + ''); + setIdSort(v); + loadChannels(0, pageSize, v) + .then() + .catch((reason) => { + showError(reason); + }); + }}> + + +
*/} +
+ '', + onPageSizeChange: (size) => { + handlePageSizeChange(size).then(); + }, + onPageChange: handlePageChange + }} loading={loading} onRow={handleRow} rowSelection={ + enableBatchDelete ? + { + onChange: (selectedRowKeys, selectedRows) => { + // console.log(`selectedRowKeys: ${selectedRowKeys}`, 'selectedRows: ', selectedRows); + setSelectedChannels(selectedRows); + } + } : null + } /> + + ); +}; + +export default ChannelsTable; diff --git a/web/air/src/components/Footer.js b/web/air/src/components/Footer.js new file mode 100644 index 0000000..6fd0fa5 --- /dev/null +++ b/web/air/src/components/Footer.js @@ -0,0 +1,64 @@ +import React, { useEffect, useState } from 'react'; + +import { Container, Segment } from 'semantic-ui-react'; +import { getFooterHTML, getSystemName } from '../helpers'; + +const Footer = () => { + const systemName = getSystemName(); + const [footer, setFooter] = useState(getFooterHTML()); + let remainCheckTimes = 5; + + const loadFooter = () => { + let footer_html = localStorage.getItem('footer_html'); + if (footer_html) { + setFooter(footer_html); + } + }; + + useEffect(() => { + const timer = setInterval(() => { + if (remainCheckTimes <= 0) { + clearInterval(timer); + return; + } + remainCheckTimes--; + loadFooter(); + }, 200); + return () => clearTimeout(timer); + }, []); + + return ( + + + {footer ? ( +
+ ) : ( +
+ + {systemName} {process.env.REACT_APP_VERSION}{' '} + + 由{' '} + + JustSong + {' '} + 构建,主题 air 来自{' '} + + Calon + {' '},源代码遵循{' '} + + MIT 协议 + +
+ )} +
+
+ ); +}; + +export default Footer; diff --git a/web/air/src/components/GitHubOAuth.js b/web/air/src/components/GitHubOAuth.js new file mode 100644 index 0000000..4e3b93b --- /dev/null +++ b/web/air/src/components/GitHubOAuth.js @@ -0,0 +1,58 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { Dimmer, Loader, Segment } from 'semantic-ui-react'; +import { useNavigate, useSearchParams } from 'react-router-dom'; +import { API, showError, showSuccess } from '../helpers'; +import { UserContext } from '../context/User'; + +const GitHubOAuth = () => { + const [searchParams, setSearchParams] = useSearchParams(); + + const [userState, userDispatch] = useContext(UserContext); + const [prompt, setPrompt] = useState('处理中...'); + const [processing, setProcessing] = useState(true); + + let navigate = useNavigate(); + + const sendCode = async (code, state, count) => { + const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/setting'); + } else { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } + } else { + showError(message); + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + navigate('/setting'); // in case this is failed to bind GitHub + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, count * 2000)); + await sendCode(code, state, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + let state = searchParams.get('state'); + sendCode(code, state, 0).then(); + }, []); + + return ( + + + {prompt} + + + ); +}; + +export default GitHubOAuth; diff --git a/web/air/src/components/HeaderBar.js b/web/air/src/components/HeaderBar.js new file mode 100644 index 0000000..eaf36c4 --- /dev/null +++ b/web/air/src/components/HeaderBar.js @@ -0,0 +1,161 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { Link, useNavigate } from 'react-router-dom'; +import { UserContext } from '../context/User'; + +import { API, getLogo, getSystemName, showSuccess } from '../helpers'; +import '../index.css'; + +import fireworks from 'react-fireworks'; + +import { IconHelpCircle, IconKey, IconUser } from '@douyinfe/semi-icons'; +import { Avatar, Dropdown, Layout, Nav, Switch } from '@douyinfe/semi-ui'; +import { stringToColor } from '../helpers/render'; + +// HeaderBar Buttons +let headerButtons = [ + { + text: '关于', + itemKey: 'about', + to: '/about', + icon: + } +]; + +if (localStorage.getItem('chat_link')) { + headerButtons.splice(1, 0, { + name: '聊天', + to: '/chat', + icon: 'comments' + }); +} + +const HeaderBar = () => { + const [userState, userDispatch] = useContext(UserContext); + let navigate = useNavigate(); + + const [showSidebar, setShowSidebar] = useState(false); + const [dark, setDark] = useState(false); + const systemName = getSystemName(); + const logo = getLogo(); + var themeMode = localStorage.getItem('theme-mode'); + const currentDate = new Date(); + // enable fireworks on new year(1.1 and 2.9-2.24) + const isNewYear = (currentDate.getMonth() === 0 && currentDate.getDate() === 1) || (currentDate.getMonth() === 1 && currentDate.getDate() >= 9 && currentDate.getDate() <= 24); + + async function logout() { + setShowSidebar(false); + await API.get('/api/user/logout'); + showSuccess('注销成功!'); + userDispatch({ type: 'logout' }); + localStorage.removeItem('user'); + navigate('/login'); + } + + const handleNewYearClick = () => { + fireworks.init('root', {}); + fireworks.start(); + setTimeout(() => { + fireworks.stop(); + setTimeout(() => { + window.location.reload(); + }, 10000); + }, 3000); + }; + + useEffect(() => { + if (themeMode === 'dark') { + switchMode(true); + } + if (isNewYear) { + console.log('Happy New Year!'); + } + }, []); + + const switchMode = (model) => { + const body = document.body; + if (!model) { + body.removeAttribute('theme-mode'); + localStorage.setItem('theme-mode', 'light'); + } else { + body.setAttribute('theme-mode', 'dark'); + localStorage.setItem('theme-mode', 'dark'); + } + setDark(model); + }; + return ( + <> + +
+ +
+
+ + ); +}; + +export default HeaderBar; diff --git a/web/air/src/components/Loading.js b/web/air/src/components/Loading.js new file mode 100644 index 0000000..bacb53b --- /dev/null +++ b/web/air/src/components/Loading.js @@ -0,0 +1,14 @@ +import React from 'react'; +import { Dimmer, Loader, Segment } from 'semantic-ui-react'; + +const Loading = ({ prompt: name = 'page' }) => { + return ( + + + 加载{name}中... + + + ); +}; + +export default Loading; diff --git a/web/air/src/components/LoginForm.js b/web/air/src/components/LoginForm.js new file mode 100644 index 0000000..3cbeb52 --- /dev/null +++ b/web/air/src/components/LoginForm.js @@ -0,0 +1,254 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { Link, useNavigate, useSearchParams } from 'react-router-dom'; +import { UserContext } from '../context/User'; +import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; +import { onGitHubOAuthClicked } from './utils'; +import Turnstile from 'react-turnstile'; +import { Button, Card, Divider, Form, Icon, Layout, Modal } from '@douyinfe/semi-ui'; +import Title from '@douyinfe/semi-ui/lib/es/typography/title'; +import Text from '@douyinfe/semi-ui/lib/es/typography/text'; +import TelegramLoginButton from 'react-telegram-login'; + +import { IconGithubLogo } from '@douyinfe/semi-icons'; +import WeChatIcon from './WeChatIcon'; + +const LoginForm = () => { + const [inputs, setInputs] = useState({ + username: '', + password: '', + wechat_verification_code: '' + }); + const [searchParams, setSearchParams] = useSearchParams(); + const [submitted, setSubmitted] = useState(false); + const { username, password } = inputs; + const [userState, userDispatch] = useContext(UserContext); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + let navigate = useNavigate(); + const [status, setStatus] = useState({}); + const logo = getLogo(); + + useEffect(() => { + if (searchParams.get('expired')) { + showError('未登录或登录已过期,请重新登录!'); + } + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + setStatus(status); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + } + }, []); + + const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); + + const onWeChatLoginClicked = () => { + setShowWeChatLoginModal(true); + }; + + const onSubmitWeChatVerificationCode = async () => { + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + const res = await API.get( + `/api/oauth/wechat?code=${inputs.wechat_verification_code}` + ); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + navigate('/'); + showSuccess('登录成功!'); + setShowWeChatLoginModal(false); + } else { + showError(message); + } + }; + + function handleChange(name, value) { + setInputs((inputs) => ({ ...inputs, [name]: value })); + } + + async function handleSubmit(e) { + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setSubmitted(true); + if (username && password) { + const res = await API.post(`/api/user/login?turnstile=${turnstileToken}`, { + username, + password + }); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + if (username === 'root' && password === '123456') { + Modal.error({ title: '您正在使用默认密码!', content: '请立刻修改默认密码!', centered: true }); + } + navigate('/token'); + } else { + showError(message); + } + } else { + showError('请输入用户名和密码!'); + } + } + + // 添加Telegram登录处理函数 + const onTelegramLoginClicked = async (response) => { + const fields = ['id', 'first_name', 'last_name', 'username', 'photo_url', 'auth_date', 'hash', 'lang']; + const params = {}; + fields.forEach((field) => { + if (response[field]) { + params[field] = response[field]; + } + }); + const res = await API.get(`/api/oauth/telegram/login`, { params }); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } else { + showError(message); + } + }; + + return ( +
+ + + + +
+
+ + + 用户登录 + +
+ handleChange('username', value)} + /> + handleChange('password', value)} + /> + + + +
+ + 没有账号请先 注册账号 + + + 忘记密码 点击重置 + +
+ {status.github_oauth || status.wechat_login || status.telegram_oauth ? ( + <> + + 第三方登录 + +
+ {status.github_oauth ? ( +
+ + ) : ( + <> + )} + setShowWeChatLoginModal(false)} + okText={'登录'} + size={'small'} + centered={true} + > +
+ +
+
+

+ 微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) +

+
+
+ handleChange('wechat_verification_code', value)} + /> + +
+
+ {turnstileEnabled ? ( +
+ { + setTurnstileToken(token); + }} + /> +
+ ) : ( + <> + )} +
+
+ +
+
+
+ ); +}; + +export default LoginForm; diff --git a/web/air/src/components/LogsTable.js b/web/air/src/components/LogsTable.js new file mode 100644 index 0000000..7d372d4 --- /dev/null +++ b/web/air/src/components/LogsTable.js @@ -0,0 +1,403 @@ +import React, { useEffect, useState } from 'react'; +import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers'; + +import { Avatar, Button, Form, Layout, Modal, Select, Space, Spin, Table, Tag } from '@douyinfe/semi-ui'; +import { ITEMS_PER_PAGE } from '../constants'; +import { renderNumber, renderQuota, stringToColor } from '../helpers/render'; +import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph'; + +const { Header } = Layout; + +function renderTimestamp(timestamp) { + return (<> + {timestamp2string(timestamp)} + ); +} + +const MODE_OPTIONS = [{ key: 'all', text: '全部用户', value: 'all' }, { key: 'self', text: '当前用户', value: 'self' }]; + +const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo', 'light-blue', 'lime', 'orange', 'pink', 'purple', 'red', 'teal', 'violet', 'yellow']; + +function renderType(type) { + switch (type) { + case 1: + return 充值 ; + case 2: + return 消费 ; + case 3: + return 管理 ; + case 4: + return 系统 ; + case 5: + return 测试 ; + default: + return 未知 ; + } +} + +function renderIsStream(bool) { + if (bool) { + return ; + } else { + return 非流; + } +} + +function renderUseTime(type) { + const time = parseInt(type); + if (time < 101) { + return {time} s ; + } else if (time < 300) { + return {time} s ; + } else { + return {time} s ; + } +} + +const LogsTable = () => { + const columns = [{ + title: '时间', dataIndex: 'timestamp2string' + }, { + title: '渠道', + dataIndex: 'channel', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return (isAdminUser ? record.type === 0 || record.type === 2 ?
+ { {text} } +
: <> : <>); + } + }, { + title: '用户', + dataIndex: 'username', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return (isAdminUser ?
+ showUserInfo(record.user_id)}> + {typeof text === 'string' && text.slice(0, 1)} + + {text} +
: <>); + } + }, { + title: '令牌', dataIndex: 'token_name', render: (text, record, index) => { + return (record.type === 0 || record.type === 2 ?
+ { + copyText(text); + }}> {text} +
: <>); + } + }, { + title: '类型', dataIndex: 'type', render: (text, record, index) => { + return (
+ {renderType(text)} +
); + } + }, { + title: '模型', dataIndex: 'model_name', render: (text, record, index) => { + return (record.type === 0 || record.type === 2 ?
+ { + copyText(text); + }}> {text} +
: <>); + } + }, + // { + // title: '用时', dataIndex: 'use_time', render: (text, record, index) => { + // return (
+ // + // {renderUseTime(text)} + // {renderIsStream(record.is_stream)} + // + //
); + // } + // }, + { + title: '提示', dataIndex: 'prompt_tokens', render: (text, record, index) => { + return (record.type === 0 || record.type === 2 ?
+ { {text} } +
: <>); + } + }, { + title: '补全', dataIndex: 'completion_tokens', render: (text, record, index) => { + return (parseInt(text) > 0 && (record.type === 0 || record.type === 2) ?
+ { {text} } +
: <>); + } + }, { + title: '花费', dataIndex: 'quota', render: (text, record, index) => { + return (record.type === 0 || record.type === 2 ?
+ {renderQuota(text, 6)} +
: <>); + } + }, { + title: '详情', dataIndex: 'content', render: (text, record, index) => { + return + {text} + ; + } + }]; + + const [logs, setLogs] = useState([]); + const [showStat, setShowStat] = useState(false); + const [loading, setLoading] = useState(false); + const [loadingStat, setLoadingStat] = useState(false); + const [activePage, setActivePage] = useState(1); + const [logCount, setLogCount] = useState(ITEMS_PER_PAGE); + const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + const [logType, setLogType] = useState(0); + const isAdminUser = isAdmin(); + let now = new Date(); + // 初始化start_timestamp为前一天 + const [inputs, setInputs] = useState({ + username: '', + token_name: '', + model_name: '', + start_timestamp: timestamp2string(now.getTime() / 1000 - 86400), + end_timestamp: timestamp2string(now.getTime() / 1000 + 3600), + channel: '' + }); + const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs; + + const [stat, setStat] = useState({ + quota: 0, token: 0 + }); + + const handleInputChange = (value, name) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + const getLogSelfStat = async () => { + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); + const { success, message, data } = res.data; + if (success) { + setStat(data); + } else { + showError(message); + } + }; + + const getLogStat = async () => { + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`); + const { success, message, data } = res.data; + if (success) { + setStat(data); + } else { + showError(message); + } + }; + + const handleEyeClick = async () => { + setLoadingStat(true); + if (isAdminUser) { + await getLogStat(); + } else { + await getLogSelfStat(); + } + setShowStat(true); + setLoadingStat(false); + }; + + const showUserInfo = async (userId) => { + if (!isAdminUser) { + return; + } + const res = await API.get(`/api/user/${userId}`); + const { success, message, data } = res.data; + if (success) { + Modal.info({ + title: '用户信息', content:
+

用户名: {data.username}

+

余额: {renderQuota(data.quota)}

+

已用额度:{renderQuota(data.used_quota)}

+

请求次数:{renderNumber(data.request_count)}

+
, centered: true + }); + } else { + showError(message); + } + }; + + const setLogsFormat = (logs) => { + for (let i = 0; i < logs.length; i++) { + logs[i].timestamp2string = timestamp2string(logs[i].created_at); + logs[i].key = '' + logs[i].id; + } + // data.key = '' + data.id + setLogs(logs); + setLogCount(logs.length + ITEMS_PER_PAGE); + // console.log(logCount); + }; + + const loadLogs = async (startIdx, pageSize, logType = 0) => { + setLoading(true); + + let url = ''; + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + if (isAdminUser) { + url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`; + } else { + url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } + const res = await API.get(url); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setLogsFormat(data); + } else { + let newLogs = [...logs]; + newLogs.splice(startIdx * pageSize, data.length, ...data); + setLogsFormat(newLogs); + } + } else { + showError(message); + } + setLoading(false); + }; + + const pageData = logs.slice((activePage - 1) * pageSize, activePage * pageSize); + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(logs.length / pageSize) + 1) { + // In this case we have to load more data and then append them. + loadLogs(page - 1, pageSize).then(r => { + }); + } + }; + + const handlePageSizeChange = async (size) => { + localStorage.setItem('page-size', size + ''); + setPageSize(size); + setActivePage(1); + loadLogs(0, size) + .then() + .catch((reason) => { + showError(reason); + }); + }; + + const refresh = async (localLogType) => { + // setLoading(true); + setActivePage(1); + await loadLogs(0, pageSize, localLogType); + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制:' + text); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + useEffect(() => { + // console.log('default effect') + const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE; + setPageSize(localPageSize); + loadLogs(0, localPageSize) + .then() + .catch((reason) => { + showError(reason); + }); + }, []); + + const searchLogs = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadLogs(0, pageSize); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/log/self/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setLogs(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + return (<> + +
+ +

使用明细(总消耗额度: + {showStat ? renderQuota(stat.quota) : '点击查看'} + ) +

+
+
+
+ <> + handleInputChange(value, 'token_name')} /> + handleInputChange(value, 'model_name')} /> + handleInputChange(value, 'start_timestamp')} /> + handleInputChange(value, 'end_timestamp')} /> + {isAdminUser && <> + handleInputChange(value, 'channel')} /> + handleInputChange(value, 'username')} /> + } + + + + + +
{ + handlePageSizeChange(size).then(); + }, + onPageChange: handlePageChange + }} /> + + + ); +}; + +export default LogsTable; diff --git a/web/air/src/components/MjLogsTable.js b/web/air/src/components/MjLogsTable.js new file mode 100644 index 0000000..6a6fbd9 --- /dev/null +++ b/web/air/src/components/MjLogsTable.js @@ -0,0 +1,454 @@ +import React, { useEffect, useState } from 'react'; +import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers'; + +import { Banner, Button, Form, ImagePreview, Layout, Modal, Progress, Table, Tag, Typography } from '@douyinfe/semi-ui'; +import { ITEMS_PER_PAGE } from '../constants'; + + +const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo', + 'light-blue', 'lime', 'orange', 'pink', + 'purple', 'red', 'teal', 'violet', 'yellow' +]; + +function renderType(type) { + switch (type) { + case 'IMAGINE': + return 绘图; + case 'UPSCALE': + return 放大; + case 'VARIATION': + return 变换; + case 'HIGH_VARIATION': + return 强变换; + case 'LOW_VARIATION': + return 弱变换; + case 'PAN': + return 平移; + case 'DESCRIBE': + return 图生文; + case 'BLEND': + return 图混合; + case 'SHORTEN': + return 缩词; + case 'REROLL': + return 重绘; + case 'INPAINT': + return 局部重绘-提交; + case 'ZOOM': + return 变焦; + case 'CUSTOM_ZOOM': + return 自定义变焦-提交; + case 'MODAL': + return 窗口处理; + case 'SWAP_FACE': + return 换脸; + default: + return 未知; + } +} + + +function renderCode(code) { + switch (code) { + case 1: + return 已提交; + case 21: + return 等待中; + case 22: + return 重复提交; + case 0: + return 未提交; + default: + return 未知; + } +} + + +function renderStatus(type) { + // Ensure all cases are string literals by adding quotes. + switch (type) { + case 'SUCCESS': + return 成功; + case 'NOT_START': + return 未启动; + case 'SUBMITTED': + return 队列中; + case 'IN_PROGRESS': + return 执行中; + case 'FAILURE': + return 失败; + case 'MODAL': + return 窗口等待; + default: + return 未知; + } +} + +const renderTimestamp = (timestampInSeconds) => { + const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒 + + const year = date.getFullYear(); // 获取年份 + const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份,从0开始需要+1,并保证两位数 + const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数 + const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数 + const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数 + const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数 + + return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出 +}; + + +const LogsTable = () => { + const [isModalOpen, setIsModalOpen] = useState(false); + const [modalContent, setModalContent] = useState(''); + const columns = [ + { + title: '提交时间', + dataIndex: 'submit_time', + render: (text, record, index) => { + return ( +
+ {renderTimestamp(text / 1000)} +
+ ); + } + }, + { + title: '渠道', + dataIndex: 'channel_id', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return ( + +
+ { + copyText(text); // 假设copyText是用于文本复制的函数 + }}> {text} +
+ + ); + } + }, + { + title: '类型', + dataIndex: 'action', + render: (text, record, index) => { + return ( +
+ {renderType(text)} +
+ ); + } + }, + { + title: '任务ID', + dataIndex: 'mj_id', + render: (text, record, index) => { + return ( +
+ {text} +
+ ); + } + }, + { + title: '提交结果', + dataIndex: 'code', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return ( +
+ {renderCode(text)} +
+ ); + } + }, + { + title: '任务状态', + dataIndex: 'status', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return ( +
+ {renderStatus(text)} +
+ ); + } + }, + { + title: '进度', + dataIndex: 'progress', + render: (text, record, index) => { + return ( +
+ { + // 转换例如100%为数字100,如果text未定义,返回0 + + } +
+ ); + } + }, + { + title: '结果图片', + dataIndex: 'image_url', + render: (text, record, index) => { + if (!text) { + return '无'; + } + return ( + + ); + } + }, + { + title: 'Prompt', + dataIndex: 'prompt', + render: (text, record, index) => { + // 如果text未定义,返回替代文本,例如空字符串''或其他 + if (!text) { + return '无'; + } + + return ( + { + setModalContent(text); + setIsModalOpen(true); + }} + > + {text} + + ); + } + }, + { + title: 'PromptEn', + dataIndex: 'prompt_en', + render: (text, record, index) => { + // 如果text未定义,返回替代文本,例如空字符串''或其他 + if (!text) { + return '无'; + } + + return ( + { + setModalContent(text); + setIsModalOpen(true); + }} + > + {text} + + ); + } + }, + { + title: '失败原因', + dataIndex: 'fail_reason', + render: (text, record, index) => { + // 如果text未定义,返回替代文本,例如空字符串''或其他 + if (!text) { + return '无'; + } + + return ( + { + setModalContent(text); + setIsModalOpen(true); + }} + > + {text} + + ); + } + } + + ]; + + const [logs, setLogs] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [logCount, setLogCount] = useState(ITEMS_PER_PAGE); + const [logType, setLogType] = useState(0); + const isAdminUser = isAdmin(); + const [isModalOpenurl, setIsModalOpenurl] = useState(false); + const [showBanner, setShowBanner] = useState(false); + + // 定义模态框图片URL的状态和更新函数 + const [modalImageUrl, setModalImageUrl] = useState(''); + let now = new Date(); + // 初始化start_timestamp为前一天 + const [inputs, setInputs] = useState({ + channel_id: '', + mj_id: '', + start_timestamp: timestamp2string(now.getTime() / 1000 - 2592000), + end_timestamp: timestamp2string(now.getTime() / 1000 + 3600) + }); + const { channel_id, mj_id, start_timestamp, end_timestamp } = inputs; + + const [stat, setStat] = useState({ + quota: 0, + token: 0 + }); + + const handleInputChange = (value, name) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + + const setLogsFormat = (logs) => { + for (let i = 0; i < logs.length; i++) { + logs[i].timestamp2string = timestamp2string(logs[i].created_at); + logs[i].key = '' + logs[i].id; + } + // data.key = '' + data.id + setLogs(logs); + setLogCount(logs.length + ITEMS_PER_PAGE); + // console.log(logCount); + }; + + const loadLogs = async (startIdx) => { + setLoading(true); + + let url = ''; + let localStartTimestamp = Date.parse(start_timestamp); + let localEndTimestamp = Date.parse(end_timestamp); + if (isAdminUser) { + url = `/api/mj/?p=${startIdx}&channel_id=${channel_id}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } else { + url = `/api/mj/self/?p=${startIdx}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } + const res = await API.get(url); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setLogsFormat(data); + } else { + let newLogs = [...logs]; + newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); + setLogsFormat(newLogs); + } + } else { + showError(message); + } + setLoading(false); + }; + + const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE); + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + loadLogs(page - 1).then(r => { + }); + } + }; + + const refresh = async () => { + // setLoading(true); + setActivePage(1); + await loadLogs(0); + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制:' + text); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + useEffect(() => { + refresh().then(); + }, [logType]); + + useEffect(() => { + const mjNotifyEnabled = localStorage.getItem('mj_notify_enabled'); + if (mjNotifyEnabled !== 'true') { + setShowBanner(true); + } + }, []); + + return ( + <> + + + {isAdminUser && showBanner ? : <> + } +
+ <> + handleInputChange(value, 'channel_id')} /> + handleInputChange(value, 'mj_id')} /> + handleInputChange(value, 'start_timestamp')} /> + handleInputChange(value, 'end_timestamp')} /> + + + + + + +
+ setIsModalOpen(false)} + onCancel={() => setIsModalOpen(false)} + closable={null} + bodyStyle={{ height: '400px', overflow: 'auto' }} // 设置模态框内容区域样式 + width={800} // 设置模态框宽度 + > +

{modalContent}

+
+ setIsModalOpenurl(visible)} + /> + + + + ); +}; + +export default LogsTable; diff --git a/web/air/src/components/OperationSetting.js b/web/air/src/components/OperationSetting.js new file mode 100644 index 0000000..6356ac6 --- /dev/null +++ b/web/air/src/components/OperationSetting.js @@ -0,0 +1,389 @@ +import React, { useEffect, useState } from 'react'; +import { Divider, Form, Grid, Header } from 'semantic-ui-react'; +import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers'; + +const OperationSetting = () => { + let now = new Date(); + let [inputs, setInputs] = useState({ + QuotaForNewUser: 0, + QuotaForInviter: 0, + QuotaForInvitee: 0, + QuotaRemindThreshold: 0, + PreConsumedQuota: 0, + ModelRatio: '', + CompletionRatio: '', + GroupRatio: '', + TopUpLink: '', + ChatLink: '', + QuotaPerUnit: 0, + AutomaticDisableChannelEnabled: '', + AutomaticEnableChannelEnabled: '', + ChannelDisableThreshold: 0, + LogConsumeEnabled: '', + DisplayInCurrencyEnabled: '', + DisplayTokenStatEnabled: '', + ApproximateTokenEnabled: '', + RetryTimes: 0 + }); + const [originInputs, setOriginInputs] = useState({}); + let [loading, setLoading] = useState(false); + let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key === 'ModelRatio' || item.key === 'GroupRatio' || item.key === 'CompletionRatio') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + if (item.value === '{}') { + item.value = ''; + } + newInputs[item.key] = item.value; + }); + setInputs(newInputs); + setOriginInputs(newInputs); + } else { + showError(message); + } + }; + + useEffect(() => { + getOptions().then(); + }, []); + + const updateOption = async (key, value) => { + setLoading(true); + if (key.endsWith('Enabled')) { + value = inputs[key] === 'true' ? 'false' : 'true'; + } + const res = await API.put('/api/option/', { + key, + value + }); + const { success, message } = res.data; + if (success) { + setInputs((inputs) => ({ ...inputs, [key]: value })); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, { name, value }) => { + if (name.endsWith('Enabled')) { + await updateOption(name, value); + } else { + setInputs((inputs) => ({ ...inputs, [name]: value })); + } + }; + + const submitConfig = async (group) => { + switch (group) { + case 'monitor': + if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) { + await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold); + } + if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) { + await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold); + } + break; + case 'ratio': + if (originInputs['ModelRatio'] !== inputs.ModelRatio) { + if (!verifyJSON(inputs.ModelRatio)) { + showError('模型倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('ModelRatio', inputs.ModelRatio); + } + if (originInputs['GroupRatio'] !== inputs.GroupRatio) { + if (!verifyJSON(inputs.GroupRatio)) { + showError('分组倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('GroupRatio', inputs.GroupRatio); + } + if (originInputs['CompletionRatio'] !== inputs.CompletionRatio) { + if (!verifyJSON(inputs.CompletionRatio)) { + showError('补全倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('CompletionRatio', inputs.CompletionRatio); + } + break; + case 'quota': + if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) { + await updateOption('QuotaForNewUser', inputs.QuotaForNewUser); + } + if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) { + await updateOption('QuotaForInvitee', inputs.QuotaForInvitee); + } + if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) { + await updateOption('QuotaForInviter', inputs.QuotaForInviter); + } + if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) { + await updateOption('PreConsumedQuota', inputs.PreConsumedQuota); + } + break; + case 'general': + if (originInputs['TopUpLink'] !== inputs.TopUpLink) { + await updateOption('TopUpLink', inputs.TopUpLink); + } + if (originInputs['ChatLink'] !== inputs.ChatLink) { + await updateOption('ChatLink', inputs.ChatLink); + } + if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) { + await updateOption('QuotaPerUnit', inputs.QuotaPerUnit); + } + if (originInputs['RetryTimes'] !== inputs.RetryTimes) { + await updateOption('RetryTimes', inputs.RetryTimes); + } + break; + } + }; + + const deleteHistoryLogs = async () => { + console.log(inputs); + const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`); + const { success, message, data } = res.data; + if (success) { + showSuccess(`${data} 条日志已清理!`); + return; + } + showError('日志清理失败:' + message); + }; + + return ( + + +
+
+ 通用设置 +
+ + + + + + + + + + + + { + submitConfig('general').then(); + }}>保存通用设置 + +
+ 日志设置 +
+ + + + + { + setHistoryTimestamp(value); + }} /> + + { + deleteHistoryLogs().then(); + }}>清理历史日志 + +
+ 监控设置 +
+ + + + + + + + + { + submitConfig('monitor').then(); + }}>保存监控设置 + +
+ 额度设置 +
+ + + + + + + { + submitConfig('quota').then(); + }}>保存额度设置 + +
+ 倍率设置 +
+ + + + + + + + + + { + submitConfig('ratio').then(); + }}>保存倍率设置 + +
+
+ ); +}; + +export default OperationSetting; diff --git a/web/air/src/components/OtherSetting.js b/web/air/src/components/OtherSetting.js new file mode 100644 index 0000000..ae924d9 --- /dev/null +++ b/web/air/src/components/OtherSetting.js @@ -0,0 +1,225 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Divider, Form, Grid, Header, Message, Modal } from 'semantic-ui-react'; +import { API, showError, showSuccess } from '../helpers'; +import { marked } from 'marked'; +import { Link } from 'react-router-dom'; + +const OtherSetting = () => { + let [inputs, setInputs] = useState({ + Footer: '', + Notice: '', + About: '', + SystemName: '', + Logo: '', + HomePageContent: '', + Theme: '' + }); + let [loading, setLoading] = useState(false); + const [showUpdateModal, setShowUpdateModal] = useState(false); + const [updateData, setUpdateData] = useState({ + tag_name: '', + content: '' + }); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key in inputs) { + newInputs[item.key] = item.value; + } + }); + setInputs(newInputs); + } else { + showError(message); + } + }; + + useEffect(() => { + getOptions().then(); + }, []); + + const updateOption = async (key, value) => { + setLoading(true); + const res = await API.put('/api/option/', { + key, + value + }); + const { success, message } = res.data; + if (success) { + setInputs((inputs) => ({ ...inputs, [key]: value })); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, { name, value }) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + const submitNotice = async () => { + await updateOption('Notice', inputs.Notice); + }; + + const submitFooter = async () => { + await updateOption('Footer', inputs.Footer); + }; + + const submitSystemName = async () => { + await updateOption('SystemName', inputs.SystemName); + }; + + const submitTheme = async () => { + await updateOption('Theme', inputs.Theme); + }; + + const submitLogo = async () => { + await updateOption('Logo', inputs.Logo); + }; + + const submitAbout = async () => { + await updateOption('About', inputs.About); + }; + + const submitOption = async (key) => { + await updateOption(key, inputs[key]); + }; + + const openGitHubRelease = () => { + window.location = + 'https://github.com/songquanpeng/one-api/releases/latest'; + }; + + const checkUpdate = async () => { + const res = await API.get( + 'https://api.github.com/repos/songquanpeng/one-api/releases/latest' + ); + const { tag_name, body } = res.data; + if (tag_name === process.env.REACT_APP_VERSION) { + showSuccess(`已是最新版本:${tag_name}`); + } else { + setUpdateData({ + tag_name: tag_name, + content: marked.parse(body) + }); + setShowUpdateModal(true); + } + }; + + return ( + + +
+
通用设置
+ 检查更新 + + + + 保存公告 + +
个性化设置
+ + + + 设置系统名称 + + 主题名称(当前可用主题)} + placeholder='请输入主题名称' + value={inputs.Theme} + name='Theme' + onChange={handleInputChange} + /> + + 设置主题(重启生效) + + + + 设置 Logo + + + + submitOption('HomePageContent')}>保存首页内容 + + + + 保存关于 + 移除 One API + 的版权标识必须首先获得授权,项目维护需要花费大量精力,如果本项目对你有意义,请主动支持本项目。 + + + + 设置页脚 + +
+ setShowUpdateModal(false)} + onOpen={() => setShowUpdateModal(true)} + open={showUpdateModal} + > + 新版本:{updateData.tag_name} + + +
+
+
+ + + + + + +
+ ); +}; + +export default PasswordResetConfirm; diff --git a/web/air/src/components/PasswordResetForm.js b/web/air/src/components/PasswordResetForm.js new file mode 100644 index 0000000..ff3eaad --- /dev/null +++ b/web/air/src/components/PasswordResetForm.js @@ -0,0 +1,102 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Form, Grid, Header, Image, Segment } from 'semantic-ui-react'; +import { API, showError, showInfo, showSuccess } from '../helpers'; +import Turnstile from 'react-turnstile'; + +const PasswordResetForm = () => { + const [inputs, setInputs] = useState({ + email: '' + }); + const { email } = inputs; + + const [loading, setLoading] = useState(false); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + const [disableButton, setDisableButton] = useState(false); + const [countdown, setCountdown] = useState(30); + + useEffect(() => { + let countdownInterval = null; + if (disableButton && countdown > 0) { + countdownInterval = setInterval(() => { + setCountdown(countdown - 1); + }, 1000); + } else if (countdown === 0) { + setDisableButton(false); + setCountdown(30); + } + return () => clearInterval(countdownInterval); + }, [disableButton, countdown]); + + function handleChange(e) { + const { name, value } = e.target; + setInputs(inputs => ({ ...inputs, [name]: value })); + } + + async function handleSubmit(e) { + setDisableButton(true); + if (!email) return; + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/reset_password?email=${email}&turnstile=${turnstileToken}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('重置邮件发送成功,请检查邮箱!'); + setInputs({ ...inputs, email: '' }); + } else { + showError(message); + } + setLoading(false); + } + + return ( + + +
+ 密码重置 +
+
+ + + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} + + + +
+
+ ); +}; + +export default PasswordResetForm; diff --git a/web/air/src/components/PersonalSetting.js b/web/air/src/components/PersonalSetting.js new file mode 100644 index 0000000..ef4acf1 --- /dev/null +++ b/web/air/src/components/PersonalSetting.js @@ -0,0 +1,653 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { useNavigate } from 'react-router-dom'; +import { API, copy, isRoot, showError, showInfo, showSuccess } from '../helpers'; +import Turnstile from 'react-turnstile'; +import { UserContext } from '../context/User'; +import { onGitHubOAuthClicked } from './utils'; +import { + Avatar, + Banner, + Button, + Card, + Descriptions, + Image, + Input, + InputNumber, + Layout, + Modal, + Space, + Tag, + Typography +} from '@douyinfe/semi-ui'; +import { getQuotaPerUnit, renderQuota, renderQuotaWithPrompt, stringToColor } from '../helpers/render'; +import TelegramLoginButton from 'react-telegram-login'; + +const PersonalSetting = () => { + const [userState, userDispatch] = useContext(UserContext); + let navigate = useNavigate(); + + const [inputs, setInputs] = useState({ + wechat_verification_code: '', + email_verification_code: '', + email: '', + self_account_deletion_confirmation: '', + set_new_password: '', + set_new_password_confirmation: '' + }); + const [status, setStatus] = useState({}); + const [showChangePasswordModal, setShowChangePasswordModal] = useState(false); + const [showWeChatBindModal, setShowWeChatBindModal] = useState(false); + const [showEmailBindModal, setShowEmailBindModal] = useState(false); + const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + const [loading, setLoading] = useState(false); + const [disableButton, setDisableButton] = useState(false); + const [countdown, setCountdown] = useState(30); + const [affLink, setAffLink] = useState(''); + const [systemToken, setSystemToken] = useState(''); + const [models, setModels] = useState([]); + const [openTransfer, setOpenTransfer] = useState(false); + const [transferAmount, setTransferAmount] = useState(0); + + useEffect(() => { + // let user = localStorage.getItem('user'); + // if (user) { + // userDispatch({ type: 'login', payload: user }); + // } + // console.log(localStorage.getItem('user')) + + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + setStatus(status); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + } + getUserData().then( + (res) => { + console.log(userState); + } + ); + loadModels().then(); + getAffLink().then(); + setTransferAmount(getQuotaPerUnit()); + }, []); + + useEffect(() => { + let countdownInterval = null; + if (disableButton && countdown > 0) { + countdownInterval = setInterval(() => { + setCountdown(countdown - 1); + }, 1000); + } else if (countdown === 0) { + setDisableButton(false); + setCountdown(30); + } + return () => clearInterval(countdownInterval); // Clean up on unmount + }, [disableButton, countdown]); + + const handleInputChange = (name, value) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + const generateAccessToken = async () => { + const res = await API.get('/api/user/token'); + const { success, message, data } = res.data; + if (success) { + setSystemToken(data); + await copy(data); + showSuccess(`令牌已重置并已复制到剪贴板`); + } else { + showError(message); + } + }; + + const getAffLink = async () => { + const res = await API.get('/api/user/aff'); + const { success, message, data } = res.data; + if (success) { + let link = `${window.location.origin}/register?aff=${data}`; + setAffLink(link); + } else { + showError(message); + } + }; + + const getUserData = async () => { + let res = await API.get(`/api/user/self`); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + } else { + showError(message); + } + }; + + const loadModels = async () => { + let res = await API.get(`/api/user/available_models`); + const { success, message, data } = res.data; + if (success) { + setModels(data); + console.log(data); + } else { + showError(message); + } + }; + + const handleAffLinkClick = async (e) => { + e.target.select(); + await copy(e.target.value); + showSuccess(`邀请链接已复制到剪切板`); + }; + + const handleSystemTokenClick = async (e) => { + e.target.select(); + await copy(e.target.value); + showSuccess(`系统令牌已复制到剪切板`); + }; + + const deleteAccount = async () => { + if (inputs.self_account_deletion_confirmation !== userState.user.username) { + showError('请输入你的账户名以确认删除!'); + return; + } + + const res = await API.delete('/api/user/self'); + const { success, message } = res.data; + + if (success) { + showSuccess('账户已删除!'); + await API.get('/api/user/logout'); + userDispatch({ type: 'logout' }); + localStorage.removeItem('user'); + navigate('/login'); + } else { + showError(message); + } + }; + + const bindWeChat = async () => { + if (inputs.wechat_verification_code === '') return; + const res = await API.get( + `/api/oauth/wechat/bind?code=${inputs.wechat_verification_code}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('微信账户绑定成功!'); + setShowWeChatBindModal(false); + } else { + showError(message); + } + }; + + const changePassword = async () => { + if (inputs.set_new_password !== inputs.set_new_password_confirmation) { + showError('两次输入的密码不一致!'); + return; + } + const res = await API.put( + `/api/user/self`, + { + password: inputs.set_new_password + } + ); + const { success, message } = res.data; + if (success) { + showSuccess('密码修改成功!'); + setShowWeChatBindModal(false); + } else { + showError(message); + } + setShowChangePasswordModal(false); + }; + + const transfer = async () => { + if (transferAmount < getQuotaPerUnit()) { + showError('划转金额最低为' + renderQuota(getQuotaPerUnit())); + return; + } + const res = await API.post( + `/api/user/aff_transfer`, + { + quota: transferAmount + } + ); + const { success, message } = res.data; + if (success) { + showSuccess(message); + setOpenTransfer(false); + getUserData().then(); + } else { + showError(message); + } + }; + + const sendVerificationCode = async () => { + if (inputs.email === '') { + showError('请输入邮箱!'); + return; + } + setDisableButton(true); + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('验证码发送成功,请检查邮箱!'); + } else { + showError(message); + } + setLoading(false); + }; + + const bindEmail = async () => { + if (inputs.email_verification_code === '') { + showError('请输入邮箱验证码!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/oauth/email/bind?email=${inputs.email}&code=${inputs.email_verification_code}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('邮箱账户绑定成功!'); + setShowEmailBindModal(false); + userState.user.email = inputs.email; + } else { + showError(message); + } + setLoading(false); + }; + + const getUsername = () => { + if (userState.user) { + return userState.user.username; + } else { + return 'null'; + } + }; + + const handleCancel = () => { + setOpenTransfer(false); + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制:' + text); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + return ( +
+ + + +
+ {`可用额度${renderQuotaWithPrompt(userState?.user?.aff_quota)}`} + +
+
+ {`划转额度${renderQuotaWithPrompt(transferAmount)} 最低` + renderQuota(getQuotaPerUnit())} +
+ setTransferAmount(value)} disabled={false}> +
+
+
+
+ + {typeof getUsername() === 'string' && getUsername().slice(0, 1)} + } + title={{getUsername()}} + description={isRoot() ? 管理员 : 普通用户} + > + } + headerExtraContent={ + <> + + {'ID: ' + userState?.user?.id} + {userState?.user?.group} + + + } + footer={ + + {renderQuota(userState?.user?.quota)} + {renderQuota(userState?.user?.used_quota)} + {userState.user?.request_count} + + } + > + 调用信息 +

可用模型(可点击复制)

+
+ + {models.map((model) => ( + { + copyText(model); + }}> + {model} + + ))} + +
+
+ {/* + 邀请链接 + +
+ } + > + 邀请信息 +
+ + + + { + renderQuota(userState?.user?.aff_quota) + } + + + + {renderQuota(userState?.user?.aff_history_quota)} + {userState?.user?.aff_count} + +
+ */} + + 邀请链接 + + + + 个人信息 +
+ 邮箱 +
+
+ +
+
+ +
+
+
+
+ 微信 +
+
+ +
+
+ +
+
+
+
+ GitHub +
+
+ +
+
+ +
+
+
+ + {/*
+ Telegram +
+
+ +
+
+ {status.telegram_oauth ? + userState.user.telegram_id !== '' ? + : + : + } +
+
+
*/} + +
+ + + + + + + {systemToken && ( + + )} + { + status.wechat_login && ( + + ) + } + setShowWeChatBindModal(false)} + // onOpen={() => setShowWeChatBindModal(true)} + visible={showWeChatBindModal} + size={'mini'} + > + +
+

+ 微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) +

+
+ handleInputChange('wechat_verification_code', v)} + /> + +
+
+
+ setShowEmailBindModal(false)} + // onOpen={() => setShowEmailBindModal(true)} + onOk={bindEmail} + visible={showEmailBindModal} + size={'small'} + centered={true} + maskClosable={false} + > + 绑定邮箱地址 +
+ handleInputChange('email', value)} + name="email" + type="email" + /> + +
+
+ handleInputChange('email_verification_code', value)} + /> +
+ {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} +
+ setShowAccountDeleteModal(false)} + visible={showAccountDeleteModal} + size={'small'} + centered={true} + onOk={deleteAccount} + > +
+ +
+
+ handleInputChange('self_account_deletion_confirmation', value)} + /> + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} +
+
+ setShowChangePasswordModal(false)} + visible={showChangePasswordModal} + size={'small'} + centered={true} + onOk={changePassword} + > +
+ handleInputChange('set_new_password', value)} + /> + handleInputChange('set_new_password_confirmation', value)} + /> + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} +
+
+
+ + + + + ); +}; + +export default PersonalSetting; diff --git a/web/air/src/components/PrivateRoute.js b/web/air/src/components/PrivateRoute.js new file mode 100644 index 0000000..9ef826c --- /dev/null +++ b/web/air/src/components/PrivateRoute.js @@ -0,0 +1,13 @@ +import { Navigate } from 'react-router-dom'; + +import { history } from '../helpers'; + + +function PrivateRoute({ children }) { + if (!localStorage.getItem('user')) { + return ; + } + return children; +} + +export { PrivateRoute }; \ No newline at end of file diff --git a/web/air/src/components/RedemptionsTable.js b/web/air/src/components/RedemptionsTable.js new file mode 100644 index 0000000..89e4ce2 --- /dev/null +++ b/web/air/src/components/RedemptionsTable.js @@ -0,0 +1,406 @@ +import React, { useEffect, useState } from 'react'; +import { API, copy, showError, showSuccess, timestamp2string } from '../helpers'; + +import { ITEMS_PER_PAGE } from '../constants'; +import { renderQuota } from '../helpers/render'; +import { Button, Form, Modal, Popconfirm, Popover, Table, Tag } from '@douyinfe/semi-ui'; +import EditRedemption from '../pages/Redemption/EditRedemption'; + +function renderTimestamp(timestamp) { + return ( + <> + {timestamp2string(timestamp)} + + ); +} + +function renderStatus(status) { + switch (status) { + case 1: + return 未使用; + case 2: + return 已禁用 ; + case 3: + return 已使用 ; + default: + return 未知状态 ; + } +} + +const RedemptionsTable = () => { + const columns = [ + { + title: 'ID', + dataIndex: 'id' + }, + { + title: '名称', + dataIndex: 'name' + }, + { + title: '状态', + dataIndex: 'status', + key: 'status', + render: (text, record, index) => { + return ( +
+ {renderStatus(text)} +
+ ); + } + }, + { + title: '额度', + dataIndex: 'quota', + render: (text, record, index) => { + return ( +
+ {renderQuota(parseInt(text))} +
+ ); + } + }, + { + title: '创建时间', + dataIndex: 'created_time', + render: (text, record, index) => { + return ( +
+ {renderTimestamp(text)} +
+ ); + } + }, + // { + // title: '兑换人ID', + // dataIndex: 'used_user_id', + // render: (text, record, index) => { + // return ( + //
+ // {text === 0 ? '无' : text} + //
+ // ); + // } + // }, + { + title: '', + dataIndex: 'operate', + render: (text, record, index) => ( +
+ + + + + { + manageRedemption(record.id, 'delete', record).then( + () => { + removeRecord(record.key); + } + ); + }} + > + + + { + record.status === 1 ? + : + + } + +
+ ) + } + ]; + + const [redemptions, setRedemptions] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + const [tokenCount, setTokenCount] = useState(ITEMS_PER_PAGE); + const [selectedKeys, setSelectedKeys] = useState([]); + const [editingRedemption, setEditingRedemption] = useState({ + id: undefined + }); + const [showEdit, setShowEdit] = useState(false); + + const closeEdit = () => { + setShowEdit(false); + }; + + // const setCount = (data) => { + // if (data.length >= (activePage) * ITEMS_PER_PAGE) { + // setTokenCount(data.length + 1); + // } else { + // setTokenCount(data.length); + // } + // } + + const setRedemptionFormat = (redeptions) => { + // for (let i = 0; i < redeptions.length; i++) { + // redeptions[i].key = '' + redeptions[i].id; + // } + // data.key = '' + data.id + setRedemptions(redeptions); + if (redeptions.length >= (activePage) * ITEMS_PER_PAGE) { + setTokenCount(redeptions.length + 1); + } else { + setTokenCount(redeptions.length); + } + }; + + const loadRedemptions = async (startIdx) => { + const res = await API.get(`/api/redemption/?p=${startIdx}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setRedemptionFormat(data); + } else { + let newRedemptions = redemptions; + newRedemptions.push(...data); + setRedemptionFormat(newRedemptions); + } + } else { + showError(message); + } + setLoading(false); + }; + + const removeRecord = key => { + let newDataSource = [...redemptions]; + if (key != null) { + let idx = newDataSource.findIndex(data => data.key === key); + + if (idx > -1) { + newDataSource.splice(idx, 1); + setRedemptions(newDataSource); + } + } + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制到剪贴板!'); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(redemptions.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + await loadRedemptions(activePage - 1); + } + setActivePage(activePage); + })(); + }; + + useEffect(() => { + loadRedemptions(0) + .then() + .catch((reason) => { + showError(reason); + }); + }, []); + + const refresh = async () => { + await loadRedemptions(activePage - 1); + }; + + const manageRedemption = async (id, action, record) => { + let data = { id }; + let res; + switch (action) { + case 'delete': + res = await API.delete(`/api/redemption/${id}/`); + break; + case 'enable': + data.status = 1; + res = await API.put('/api/redemption/?status_only=true', data); + break; + case 'disable': + data.status = 2; + res = await API.put('/api/redemption/?status_only=true', data); + break; + } + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let redemption = res.data.data; + let newRedemptions = [...redemptions]; + // let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + if (action === 'delete') { + + } else { + record.status = redemption.status; + } + setRedemptions(newRedemptions); + } else { + showError(message); + } + }; + + const searchRedemptions = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadRedemptions(0); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/redemption/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setRedemptions(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (value) => { + setSearchKeyword(value.trim()); + }; + + const sortRedemption = (key) => { + if (redemptions.length === 0) return; + setLoading(true); + let sortedRedemptions = [...redemptions]; + sortedRedemptions.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + if (sortedRedemptions[0].id === redemptions[0].id) { + sortedRedemptions.reverse(); + } + setRedemptions(sortedRedemptions); + setLoading(false); + }; + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(redemptions.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + loadRedemptions(page - 1).then(r => { + }); + } + }; + + let pageData = redemptions.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE); + const rowSelection = { + onSelect: (record, selected) => { + }, + onSelectAll: (selected, selectedRows) => { + }, + onChange: (selectedRowKeys, selectedRows) => { + setSelectedKeys(selectedRows); + } + }; + + const handleRow = (record, index) => { + if (record.status !== 1) { + return { + style: { + background: 'var(--semi-color-disabled-border)' + } + }; + } else { + return {}; + } + }; + + return ( + <> + +
+ + + +
`第 ${page.currentStart} - ${page.currentEnd} 条,共 ${redemptions.length} 条`, + // onPageSizeChange: (size) => { + // setPageSize(size); + // setActivePage(1); + // }, + onPageChange: handlePageChange + }} loading={loading} rowSelection={rowSelection} onRow={handleRow}> +
+ + + + ); +}; + +export default RedemptionsTable; diff --git a/web/air/src/components/RegisterForm.js b/web/air/src/components/RegisterForm.js new file mode 100644 index 0000000..1f26b63 --- /dev/null +++ b/web/air/src/components/RegisterForm.js @@ -0,0 +1,194 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Form, Grid, Header, Image, Message, Segment } from 'semantic-ui-react'; +import { Link, useNavigate } from 'react-router-dom'; +import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; +import Turnstile from 'react-turnstile'; + +const RegisterForm = () => { + const [inputs, setInputs] = useState({ + username: '', + password: '', + password2: '', + email: '', + verification_code: '' + }); + const { username, password, password2 } = inputs; + const [showEmailVerification, setShowEmailVerification] = useState(false); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + const [loading, setLoading] = useState(false); + const logo = getLogo(); + let affCode = new URLSearchParams(window.location.search).get('aff'); + if (affCode) { + localStorage.setItem('aff', affCode); + } + + useEffect(() => { + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + setShowEmailVerification(status.email_verification); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + } + }); + + let navigate = useNavigate(); + + function handleChange(e) { + const { name, value } = e.target; + console.log(name, value); + setInputs((inputs) => ({ ...inputs, [name]: value })); + } + + async function handleSubmit(e) { + if (password.length < 8) { + showInfo('密码长度不得小于 8 位!'); + return; + } + if (password !== password2) { + showInfo('两次输入的密码不一致'); + return; + } + if (username && password) { + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + if (!affCode) { + affCode = localStorage.getItem('aff'); + } + inputs.aff_code = affCode; + const res = await API.post( + `/api/user/register?turnstile=${turnstileToken}`, + inputs + ); + const { success, message } = res.data; + if (success) { + navigate('/login'); + showSuccess('注册成功!'); + } else { + showError(message); + } + setLoading(false); + } + } + + const sendVerificationCode = async () => { + if (inputs.email === '') return; + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('验证码发送成功,请检查你的邮箱!'); + } else { + showError(message); + } + setLoading(false); + }; + + return ( + + +
+ 新用户注册 +
+
+ + + + + {showEmailVerification ? ( + <> + + 获取验证码 + + } + /> + + + ) : ( + <> + )} + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} + + +
+ + 已有账户? + + 点击登录 + + +
+
+ ); +}; + +export default RegisterForm; diff --git a/web/air/src/components/SiderBar.js b/web/air/src/components/SiderBar.js new file mode 100644 index 0000000..b3da272 --- /dev/null +++ b/web/air/src/components/SiderBar.js @@ -0,0 +1,214 @@ +import React, { useContext, useEffect, useMemo, useState } from 'react'; +import { Link, useNavigate } from 'react-router-dom'; +import { UserContext } from '../context/User'; +import { StatusContext } from '../context/Status'; + +import { API, getLogo, getSystemName, isAdmin, isMobile, showError } from '../helpers'; +import '../index.css'; + +import { + IconCalendarClock, + IconComment, + IconCreditCard, + IconGift, + IconHistogram, + IconHome, + IconImage, + IconKey, + IconLayers, + IconSetting, + IconUser +} from '@douyinfe/semi-icons'; +import { Layout, Nav } from '@douyinfe/semi-ui'; + +// HeaderBar Buttons + +const SiderBar = () => { + const [userState, userDispatch] = useContext(UserContext); + const [statusState, statusDispatch] = useContext(StatusContext); + const defaultIsCollapsed = isMobile() || localStorage.getItem('default_collapse_sidebar') === 'true'; + + let navigate = useNavigate(); + const [selectedKeys, setSelectedKeys] = useState(['home']); + const systemName = getSystemName(); + const logo = getLogo(); + const [isCollapsed, setIsCollapsed] = useState(defaultIsCollapsed); + + const headerButtons = useMemo(() => [ + { + text: '首页', + itemKey: 'home', + to: '/', + icon: + }, + { + text: '渠道', + itemKey: 'channel', + to: '/channel', + icon: , + className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '聊天', + itemKey: 'chat', + to: '/chat', + icon: , + className: localStorage.getItem('chat_link') ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '令牌', + itemKey: 'token', + to: '/token', + icon: + }, + { + text: '兑换', + itemKey: 'redemption', + to: '/redemption', + icon: , + className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '充值', + itemKey: 'topup', + to: '/topup', + icon: + }, + { + text: '用户', + itemKey: 'user', + to: '/user', + icon: , + className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '日志', + itemKey: 'log', + to: '/log', + icon: + }, + { + text: '数据看板', + itemKey: 'detail', + to: '/detail', + icon: , + className: localStorage.getItem('enable_data_export') === 'true' ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '绘图', + itemKey: 'midjourney', + to: '/midjourney', + icon: , + className: localStorage.getItem('enable_drawing') === 'true' ? 'semi-navigation-item-normal' : 'tableHiddle' + }, + { + text: '设置', + itemKey: 'setting', + to: '/setting', + icon: + } + // { + // text: '关于', + // itemKey: 'about', + // to: '/about', + // icon: + // } + ], [localStorage.getItem('enable_data_export'), localStorage.getItem('enable_drawing'), localStorage.getItem('chat_link'), isAdmin()]); + + const loadStatus = async () => { + const res = await API.get('/api/status'); + const { success, data } = res.data; + if (success) { + localStorage.setItem('status', JSON.stringify(data)); + statusDispatch({ type: 'set', payload: data }); + localStorage.setItem('system_name', data.system_name); + localStorage.setItem('logo', data.logo); + localStorage.setItem('footer_html', data.footer_html); + localStorage.setItem('quota_per_unit', data.quota_per_unit); + localStorage.setItem('display_in_currency', data.display_in_currency); + localStorage.setItem('enable_drawing', data.enable_drawing); + localStorage.setItem('enable_data_export', data.enable_data_export); + localStorage.setItem('data_export_default_time', data.data_export_default_time); + localStorage.setItem('default_collapse_sidebar', data.default_collapse_sidebar); + localStorage.setItem('mj_notify_enabled', data.mj_notify_enabled); + if (data.chat_link) { + localStorage.setItem('chat_link', data.chat_link); + } else { + localStorage.removeItem('chat_link'); + } + if (data.chat_link2) { + localStorage.setItem('chat_link2', data.chat_link2); + } else { + localStorage.removeItem('chat_link2'); + } + } else { + showError('无法正常连接至服务器!'); + } + }; + + useEffect(() => { + loadStatus().then(() => { + setIsCollapsed(isMobile() || localStorage.getItem('default_collapse_sidebar') === 'true'); + }); + }, []); + + return ( + <> + +
+ +
+
+ + ); +}; + +export default SiderBar; diff --git a/web/air/src/components/SystemSetting.js b/web/air/src/components/SystemSetting.js new file mode 100644 index 0000000..09b9866 --- /dev/null +++ b/web/air/src/components/SystemSetting.js @@ -0,0 +1,590 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Divider, Form, Grid, Header, Modal, Message } from 'semantic-ui-react'; +import { API, removeTrailingSlash, showError } from '../helpers'; + +const SystemSetting = () => { + let [inputs, setInputs] = useState({ + PasswordLoginEnabled: '', + PasswordRegisterEnabled: '', + EmailVerificationEnabled: '', + GitHubOAuthEnabled: '', + GitHubClientId: '', + GitHubClientSecret: '', + Notice: '', + SMTPServer: '', + SMTPPort: '', + SMTPAccount: '', + SMTPFrom: '', + SMTPToken: '', + ServerAddress: '', + Footer: '', + WeChatAuthEnabled: '', + WeChatServerAddress: '', + WeChatServerToken: '', + WeChatAccountQRCodeImageURL: '', + MessagePusherAddress: '', + MessagePusherToken: '', + TurnstileCheckEnabled: '', + TurnstileSiteKey: '', + TurnstileSecretKey: '', + RegisterEnabled: '', + EmailDomainRestrictionEnabled: '', + EmailDomainWhitelist: '' + }); + const [originInputs, setOriginInputs] = useState({}); + let [loading, setLoading] = useState(false); + const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]); + const [restrictedDomainInput, setRestrictedDomainInput] = useState(''); + const [showPasswordWarningModal, setShowPasswordWarningModal] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + newInputs[item.key] = item.value; + }); + setInputs({ + ...newInputs, + EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(',') + }); + setOriginInputs(newInputs); + + setEmailDomainWhitelist(newInputs.EmailDomainWhitelist.split(',').map((item) => { + return { key: item, text: item, value: item }; + })); + } else { + showError(message); + } + }; + + useEffect(() => { + getOptions().then(); + }, []); + + const updateOption = async (key, value) => { + setLoading(true); + switch (key) { + case 'PasswordLoginEnabled': + case 'PasswordRegisterEnabled': + case 'EmailVerificationEnabled': + case 'GitHubOAuthEnabled': + case 'WeChatAuthEnabled': + case 'TurnstileCheckEnabled': + case 'EmailDomainRestrictionEnabled': + case 'RegisterEnabled': + value = inputs[key] === 'true' ? 'false' : 'true'; + break; + default: + break; + } + const res = await API.put('/api/option/', { + key, + value + }); + const { success, message } = res.data; + if (success) { + if (key === 'EmailDomainWhitelist') { + value = value.split(','); + } + setInputs((inputs) => ({ + ...inputs, [key]: value + })); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, { name, value }) => { + if (name === 'PasswordLoginEnabled' && inputs[name] === 'true') { + // block disabling password login + setShowPasswordWarningModal(true); + return; + } + if ( + name === 'Notice' || + name.startsWith('SMTP') || + name === 'ServerAddress' || + name === 'GitHubClientId' || + name === 'GitHubClientSecret' || + name === 'WeChatServerAddress' || + name === 'WeChatServerToken' || + name === 'WeChatAccountQRCodeImageURL' || + name === 'TurnstileSiteKey' || + name === 'TurnstileSecretKey' || + name === 'EmailDomainWhitelist' + ) { + setInputs((inputs) => ({ ...inputs, [name]: value })); + } else { + await updateOption(name, value); + } + }; + + const submitServerAddress = async () => { + let ServerAddress = removeTrailingSlash(inputs.ServerAddress); + await updateOption('ServerAddress', ServerAddress); + }; + + const submitSMTP = async () => { + if (originInputs['SMTPServer'] !== inputs.SMTPServer) { + await updateOption('SMTPServer', inputs.SMTPServer); + } + if (originInputs['SMTPAccount'] !== inputs.SMTPAccount) { + await updateOption('SMTPAccount', inputs.SMTPAccount); + } + if (originInputs['SMTPFrom'] !== inputs.SMTPFrom) { + await updateOption('SMTPFrom', inputs.SMTPFrom); + } + if ( + originInputs['SMTPPort'] !== inputs.SMTPPort && + inputs.SMTPPort !== '' + ) { + await updateOption('SMTPPort', inputs.SMTPPort); + } + if ( + originInputs['SMTPToken'] !== inputs.SMTPToken && + inputs.SMTPToken !== '' + ) { + await updateOption('SMTPToken', inputs.SMTPToken); + } + }; + + + const submitEmailDomainWhitelist = async () => { + if ( + originInputs['EmailDomainWhitelist'] !== inputs.EmailDomainWhitelist.join(',') && + inputs.SMTPToken !== '' + ) { + await updateOption('EmailDomainWhitelist', inputs.EmailDomainWhitelist.join(',')); + } + }; + + const submitWeChat = async () => { + if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) { + await updateOption( + 'WeChatServerAddress', + removeTrailingSlash(inputs.WeChatServerAddress) + ); + } + if ( + originInputs['WeChatAccountQRCodeImageURL'] !== + inputs.WeChatAccountQRCodeImageURL + ) { + await updateOption( + 'WeChatAccountQRCodeImageURL', + inputs.WeChatAccountQRCodeImageURL + ); + } + if ( + originInputs['WeChatServerToken'] !== inputs.WeChatServerToken && + inputs.WeChatServerToken !== '' + ) { + await updateOption('WeChatServerToken', inputs.WeChatServerToken); + } + }; + + const submitMessagePusher = async () => { + if (originInputs['MessagePusherAddress'] !== inputs.MessagePusherAddress) { + await updateOption( + 'MessagePusherAddress', + removeTrailingSlash(inputs.MessagePusherAddress) + ); + } + if ( + originInputs['MessagePusherToken'] !== inputs.MessagePusherToken && + inputs.MessagePusherToken !== '' + ) { + await updateOption('MessagePusherToken', inputs.MessagePusherToken); + } + }; + + const submitGitHubOAuth = async () => { + if (originInputs['GitHubClientId'] !== inputs.GitHubClientId) { + await updateOption('GitHubClientId', inputs.GitHubClientId); + } + if ( + originInputs['GitHubClientSecret'] !== inputs.GitHubClientSecret && + inputs.GitHubClientSecret !== '' + ) { + await updateOption('GitHubClientSecret', inputs.GitHubClientSecret); + } + }; + + const submitTurnstile = async () => { + if (originInputs['TurnstileSiteKey'] !== inputs.TurnstileSiteKey) { + await updateOption('TurnstileSiteKey', inputs.TurnstileSiteKey); + } + if ( + originInputs['TurnstileSecretKey'] !== inputs.TurnstileSecretKey && + inputs.TurnstileSecretKey !== '' + ) { + await updateOption('TurnstileSecretKey', inputs.TurnstileSecretKey); + } + }; + + const submitNewRestrictedDomain = () => { + const localDomainList = inputs.EmailDomainWhitelist; + if (restrictedDomainInput !== '' && !localDomainList.includes(restrictedDomainInput)) { + setRestrictedDomainInput(''); + setInputs({ + ...inputs, + EmailDomainWhitelist: [...localDomainList, restrictedDomainInput], + }); + setEmailDomainWhitelist([...EmailDomainWhitelist, { + key: restrictedDomainInput, + text: restrictedDomainInput, + value: restrictedDomainInput, + }]); + } + } + + return ( + + +
+
通用设置
+ + + + + 更新服务器地址 + + +
配置登录注册
+ + + { + showPasswordWarningModal && + setShowPasswordWarningModal(false)} + size={'tiny'} + style={{ maxWidth: '450px' }} + > + 警告 + +

取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?

+
+ + + + +
+ } + + + + +
+ + + + + +
+ 配置邮箱域名白名单 + 用以防止恶意用户利用临时邮箱批量注册 +
+ + + + + + { + submitNewRestrictedDomain(); + }}>填入 + } + onKeyDown={(e) => { + if (e.key === 'Enter') { + submitNewRestrictedDomain(); + } + }} + autoComplete='new-password' + placeholder='输入新的允许的邮箱域名' + value={restrictedDomainInput} + onChange={(e, { value }) => { + setRestrictedDomainInput(value); + }} + /> + + 保存邮箱域名白名单设置 + +
+ 配置 SMTP + 用以支持系统的邮件发送 +
+ + + + + + + + + + 保存 SMTP 设置 + +
+ 配置 GitHub OAuth App + + 用以支持通过 GitHub 进行登录注册, + + 点击此处 + + 管理你的 GitHub OAuth App + +
+ + Homepage URL 填 {inputs.ServerAddress} + ,Authorization callback URL 填{' '} + {`${inputs.ServerAddress}/oauth/github`} + + + + + + + 保存 GitHub OAuth 设置 + + +
+ 配置 WeChat Server + + 用以支持通过微信进行登录注册, + + 点击此处 + + 了解 WeChat Server + +
+ + + + + + + 保存 WeChat Server 设置 + + +
+ 配置 Message Pusher + + 用以推送报警信息, + + 点击此处 + + 了解 Message Pusher + +
+ + + + + + 保存 Message Pusher 设置 + + +
+ 配置 Turnstile + + 用以支持用户校验, + + 点击此处 + + 管理你的 Turnstile Sites,推荐选择 Invisible Widget Type + +
+ + + + + + 保存 Turnstile 设置 + + +
+
+ ); +}; + +export default SystemSetting; diff --git a/web/air/src/components/TokensTable.js b/web/air/src/components/TokensTable.js new file mode 100644 index 0000000..48836c8 --- /dev/null +++ b/web/air/src/components/TokensTable.js @@ -0,0 +1,636 @@ +import React, { useEffect, useState } from 'react'; +import { API, copy, showError, showSuccess, timestamp2string } from '../helpers'; + +import { ITEMS_PER_PAGE } from '../constants'; +import { renderQuota } from '../helpers/render'; +import { Button, Dropdown, Form, Modal, Popconfirm, Popover, SplitButtonGroup, Table, Tag } from '@douyinfe/semi-ui'; + +import { IconTreeTriangleDown } from '@douyinfe/semi-icons'; +import EditToken from '../pages/Token/EditToken'; + +const COPY_OPTIONS = [ + { key: 'next', text: 'ChatGPT Next Web', value: 'next' }, + { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, + { key: 'opencat', text: 'OpenCat', value: 'opencat' }, + { key: 'lobechat', text: 'LobeChat', value: 'lobechat' }, +]; + +const OPEN_LINK_OPTIONS = [ + { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, + { key: 'opencat', text: 'OpenCat', value: 'opencat' }, + { key: 'lobechat', text: 'LobeChat', value: 'lobechat' } +]; + +function renderTimestamp(timestamp) { + return ( + <> + {timestamp2string(timestamp)} + + ); +} + +function renderStatus(status, model_limits_enabled = false) { + switch (status) { + case 1: + if (model_limits_enabled) { + return 已启用:限制模型; + } else { + return 已启用; + } + case 2: + return 已禁用 ; + case 3: + return 已过期 ; + case 4: + return 已耗尽 ; + default: + return 未知状态 ; + } +} + +const TokensTable = () => { + + const link_menu = [ + { + node: 'item', key: 'next', name: 'ChatGPT Next Web', onClick: () => { + onOpenLink('next'); + } + }, + { node: 'item', key: 'ama', name: 'AMA 问天', value: 'ama' }, + { + node: 'item', key: 'next-mj', name: 'ChatGPT Web & Midjourney', value: 'next-mj', onClick: () => { + onOpenLink('next-mj'); + } + }, + { node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' }, + { + node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => { + onOpenLink('lobechat'); + } + } + ]; + + const columns = [ + { + title: '名称', + dataIndex: 'name' + }, + { + title: '状态', + dataIndex: 'status', + key: 'status', + render: (text, record, index) => { + return ( +
+ {renderStatus(text, record.model_limits_enabled)} +
+ ); + } + }, + { + title: '已用额度', + dataIndex: 'used_quota', + render: (text, record, index) => { + return ( +
+ {renderQuota(parseInt(text))} +
+ ); + } + }, + { + title: '剩余额度', + dataIndex: 'remain_quota', + render: (text, record, index) => { + return ( +
+ {record.unlimited_quota ? 无限制 : + {renderQuota(parseInt(text))}} +
+ ); + } + }, + { + title: '创建时间', + dataIndex: 'created_time', + render: (text, record, index) => { + return ( +
+ {renderTimestamp(text)} +
+ ); + } + }, + { + title: '过期时间', + dataIndex: 'expired_time', + render: (text, record, index) => { + return ( +
+ {record.expired_time === -1 ? '永不过期' : renderTimestamp(text)} +
+ ); + } + }, + { + title: '', + dataIndex: 'operate', + render: (text, record, index) => ( +
+ + + + + + + { + onOpenLink('next', record.key); + } + }, + { + node: 'item', + key: 'next-mj', + disabled: !localStorage.getItem('chat_link2'), + name: 'ChatGPT Web & Midjourney', + onClick: () => { + onOpenLink('next-mj', record.key); + } + }, + { + node: 'item', key: 'ama', name: 'AMA 问天(BotGem)', onClick: () => { + onOpenLink('ama', record.key); + } + }, + { + node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => { + onOpenLink('opencat', record.key); + } + }, + { + node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => { + onOpenLink('lobechat'); + } + } + ] + } + > + + + + { + manageToken(record.id, 'delete', record).then( + () => { + removeRecord(record.key); + } + ); + }} + > + + + { + record.status === 1 ? + : + + } + +
+ ) + } + ]; + + const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); + const [showEdit, setShowEdit] = useState(false); + const [tokens, setTokens] = useState([]); + const [selectedKeys, setSelectedKeys] = useState([]); + const [tokenCount, setTokenCount] = useState(pageSize); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searchToken, setSearchToken] = useState(''); + const [searching, setSearching] = useState(false); + const [showTopUpModal, setShowTopUpModal] = useState(false); + const [targetTokenIdx, setTargetTokenIdx] = useState(0); + const [editingToken, setEditingToken] = useState({ + id: undefined + }); + const [orderBy, setOrderBy] = useState(''); + const [dropdownVisible, setDropdownVisible] = useState(false); + + const closeEdit = () => { + setShowEdit(false); + setTimeout(() => { + setEditingToken({ + id: undefined + }); + }, 500); + }; + + const setTokensFormat = (tokens) => { + setTokens(tokens); + if (tokens.length >= pageSize) { + setTokenCount(tokens.length + pageSize); + } else { + setTokenCount(tokens.length); + } + }; + + let pageData = tokens.slice((activePage - 1) * pageSize, activePage * pageSize); + const loadTokens = async (startIdx) => { + setLoading(true); + const res = await API.get(`/api/token/?p=${startIdx}&size=${pageSize}&order=${orderBy}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setTokensFormat(data); + } else { + let newTokens = [...tokens]; + newTokens.splice(startIdx * pageSize, data.length, ...data); + setTokensFormat(newTokens); + } + } else { + showError(message); + } + setLoading(false); + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(tokens.length / pageSize) + 1) { + // In this case we have to load more data and then append them. + await loadTokens(activePage - 1, orderBy); + } + setActivePage(activePage); + })(); + }; + + const refresh = async () => { + await loadTokens(activePage - 1); + }; + + const onCopy = async (type, key) => { + let status = localStorage.getItem('status'); + let serverAddress = ''; + if (status) { + status = JSON.parse(status); + serverAddress = status.server_address; + } + if (serverAddress === '') { + serverAddress = window.location.origin; + } + let encodedServerAddress = encodeURIComponent(serverAddress); + const nextLink = localStorage.getItem('chat_link'); + const mjLink = localStorage.getItem('chat_link2'); + let nextUrl; + + if (nextLink) { + nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + } else { + nextUrl = `https://app.nextchat.dev/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + } + + let url; + switch (type) { + case 'ama': + url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + break; + case 'opencat': + url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`; + break; + case 'next': + url = nextUrl; + break; + default: + url = `sk-${key}`; + } + // if (await copy(url)) { + // showSuccess('已复制到剪贴板!'); + // } else { + // showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。'); + // setSearchKeyword(url); + // } + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制到剪贴板!'); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + const onOpenLink = async (type, key) => { + let status = localStorage.getItem('status'); + let serverAddress = ''; + if (status) { + status = JSON.parse(status); + serverAddress = status.server_address; + } + if (serverAddress === '') { + serverAddress = window.location.origin; + } + let encodedServerAddress = encodeURIComponent(serverAddress); + const chatLink = localStorage.getItem('chat_link'); + const mjLink = localStorage.getItem('chat_link2'); + let defaultUrl; + + if (chatLink) { + defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + } + let url; + switch (type) { + case 'ama': + url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`; + break; + case 'opencat': + url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`; + break; + case 'next-mj': + url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + break; + case 'lobechat': + url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}/v1"}}}`; + break; + default: + if (!chatLink) { + showError('管理员未设置聊天链接'); + return; + } + url = defaultUrl; + } + + window.open(url, '_blank'); + }; + + useEffect(() => { + loadTokens(0, orderBy) + .then() + .catch((reason) => { + showError(reason); + }); + }, [pageSize, orderBy]); + + const removeRecord = key => { + let newDataSource = [...tokens]; + if (key != null) { + let idx = newDataSource.findIndex(data => data.key === key); + + if (idx > -1) { + newDataSource.splice(idx, 1); + setTokensFormat(newDataSource); + } + } + }; + + const manageToken = async (id, action, record) => { + setLoading(true); + let data = { id }; + let res; + switch (action) { + case 'delete': + res = await API.delete(`/api/token/${id}/`); + break; + case 'enable': + data.status = 1; + res = await API.put('/api/token/?status_only=true', data); + break; + case 'disable': + data.status = 2; + res = await API.put('/api/token/?status_only=true', data); + break; + } + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let token = res.data.data; + let newTokens = [...tokens]; + // let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + if (action === 'delete') { + + } else { + record.status = token.status; + // newTokens[realIdx].status = token.status; + } + setTokensFormat(newTokens); + } else { + showError(message); + } + setLoading(false); + }; + + const searchTokens = async () => { + if (searchKeyword === '' && searchToken === '') { + // if keyword is blank, load files instead. + await loadTokens(0); + setActivePage(1); + setOrderBy(''); + return; + } + setSearching(true); + const res = await API.get(`/api/token/search?keyword=${searchKeyword}&token=${searchToken}`); + const { success, message, data } = res.data; + if (success) { + setTokensFormat(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (value) => { + setSearchKeyword(value.trim()); + }; + + const handleSearchTokenChange = async (value) => { + setSearchToken(value.trim()); + }; + + const sortToken = (key) => { + if (tokens.length === 0) return; + setLoading(true); + let sortedTokens = [...tokens]; + sortedTokens.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + if (sortedTokens[0].id === tokens[0].id) { + sortedTokens.reverse(); + } + setTokens(sortedTokens); + setLoading(false); + }; + + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(tokens.length / pageSize) + 1) { + // In this case we have to load more data and then append them. + loadTokens(page - 1).then(r => { + }); + } + }; + + const rowSelection = { + onSelect: (record, selected) => { + }, + onSelectAll: (selected, selectedRows) => { + }, + onChange: (selectedRowKeys, selectedRows) => { + setSelectedKeys(selectedRows); + } + }; + + const handleRow = (record, index) => { + if (record.status !== 1) { + return { + style: { + background: 'var(--semi-color-disabled-border)' + } + }; + } else { + return {}; + } + }; + + const handleOrderByChange = (e, { value }) => { + setOrderBy(value); + setActivePage(1); + setDropdownVisible(false); + }; + + const renderSelectedOption = (orderBy) => { + switch (orderBy) { + case 'remain_quota': + return '按剩余额度排序'; + case 'used_quota': + return '按已用额度排序'; + default: + return '默认排序'; + } + }; + + return ( + <> + +
+ + {/* */} + + + + `第 ${page.currentStart} - ${page.currentEnd} 条,共 ${tokens.length} 条`, + onPageSizeChange: (size) => { + setPageSize(size); + setActivePage(1); + }, + onPageChange: handlePageChange + }} loading={loading} rowSelection={rowSelection} onRow={handleRow}> +
+ + + setDropdownVisible(visible)} + render={ + + handleOrderByChange('', { value: '' })}>默认排序 + handleOrderByChange('', { value: 'remain_quota' })}>按剩余额度排序 + handleOrderByChange('', { value: 'used_quota' })}>按已用额度排序 + + } + > + + + + ); +}; + +export default TokensTable; diff --git a/web/air/src/components/UsersTable.js b/web/air/src/components/UsersTable.js new file mode 100644 index 0000000..4fc16ba --- /dev/null +++ b/web/air/src/components/UsersTable.js @@ -0,0 +1,376 @@ +import React, { useEffect, useState } from 'react'; +import { API, showError, showSuccess } from '../helpers'; +import { Button, Form, Popconfirm, Space, Table, Tag, Tooltip, Dropdown } from '@douyinfe/semi-ui'; +import { ITEMS_PER_PAGE } from '../constants'; +import { renderGroup, renderNumber, renderQuota } from '../helpers/render'; +import AddUser from '../pages/User/AddUser'; +import EditUser from '../pages/User/EditUser'; + +function renderRole(role) { + switch (role) { + case 1: + return 普通用户; + case 10: + return 管理员; + case 100: + return 超级管理员; + default: + return 未知身份; + } +} + +const UsersTable = () => { + const columns = [{ + title: 'ID', dataIndex: 'id' + }, { + title: '用户名', dataIndex: 'username' + }, { + title: '分组', dataIndex: 'group', render: (text, record, index) => { + return (
+ {renderGroup(text)} +
); + } + }, { + title: '统计信息', dataIndex: 'info', render: (text, record, index) => { + return (
+ + + {renderQuota(record.quota)} + + + {renderQuota(record.used_quota)} + + + {renderNumber(record.request_count)} + + +
); + } + }, + // { + // title: '邀请信息', dataIndex: 'invite', render: (text, record, index) => { + // return (
+ // + // + // {renderNumber(record.aff_count)} + // + // + // {renderQuota(record.aff_history_quota)} + // + // + // {record.inviter_id === 0 ? : + // {record.inviter_id}} + // + // + //
); + // } + // }, + { + title: '角色', dataIndex: 'role', render: (text, record, index) => { + return (
+ {renderRole(text)} +
); + } + }, + { + title: '状态', dataIndex: 'status', render: (text, record, index) => { + return (
+ {renderStatus(text)} +
); + } + }, + { + title: '', dataIndex: 'operate', render: (text, record, index) => (
+ <> + { + manageUser(record.username, 'promote', record); + }} + > + + + { + manageUser(record.username, 'demote', record); + }} + > + + + {record.status === 1 ? + : + } + + + { + manageUser(record.username, 'delete', record).then(() => { + removeRecord(record.id); + }); + }} + > + + +
) + }]; + + const [users, setUsers] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + const [userCount, setUserCount] = useState(ITEMS_PER_PAGE); + const [showAddUser, setShowAddUser] = useState(false); + const [showEditUser, setShowEditUser] = useState(false); + const [editingUser, setEditingUser] = useState({ + id: undefined + }); + const [orderBy, setOrderBy] = useState(''); + const [dropdownVisible, setDropdownVisible] = useState(false); + + const setCount = (data) => { + if (data.length >= (activePage) * ITEMS_PER_PAGE) { + setUserCount(data.length + 1); + } else { + setUserCount(data.length); + } + }; + + const removeRecord = key => { + console.log(key); + let newDataSource = [...users]; + if (key != null) { + let idx = newDataSource.findIndex(data => data.id === key); + + if (idx > -1) { + newDataSource.splice(idx, 1); + setUsers(newDataSource); + } + } + }; + + const loadUsers = async (startIdx) => { + const res = await API.get(`/api/user/?p=${startIdx}&order=${orderBy}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setUsers(data); + setCount(data); + } else { + let newUsers = users; + newUsers.push(...data); + setUsers(newUsers); + setCount(newUsers); + } + } else { + showError(message); + } + setLoading(false); + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(users.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + await loadUsers(activePage - 1, orderBy); + } + setActivePage(activePage); + })(); + }; + + useEffect(() => { + loadUsers(0, orderBy) + .then() + .catch((reason) => { + showError(reason); + }); + }, [orderBy]); + + const manageUser = async (username, action, record) => { + const res = await API.post('/api/user/manage', { + username, action + }); + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let user = res.data.data; + let newUsers = [...users]; + if (action === 'delete') { + + } else { + record.status = user.status; + record.role = user.role; + } + setUsers(newUsers); + } else { + showError(message); + } + }; + + const renderStatus = (status) => { + switch (status) { + case 1: + return 已激活; + case 2: + return ( + 已封禁 + ); + default: + return ( + 未知状态 + ); + } + }; + + const searchUsers = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadUsers(0); + setActivePage(1); + setOrderBy(''); + return; + } + setSearching(true); + const res = await API.get(`/api/user/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setUsers(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (value) => { + setSearchKeyword(value.trim()); + }; + + const sortUser = (key) => { + if (users.length === 0) return; + setLoading(true); + let sortedUsers = [...users]; + sortedUsers.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + if (sortedUsers[0].id === users[0].id) { + sortedUsers.reverse(); + } + setUsers(sortedUsers); + setLoading(false); + }; + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(users.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + loadUsers(page - 1).then(r => { + }); + } + }; + + const pageData = users.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE); + + const closeAddUser = () => { + setShowAddUser(false); + }; + + const closeEditUser = () => { + setShowEditUser(false); + setEditingUser({ + id: undefined + }); + }; + + const refresh = async () => { + if (searchKeyword === '') { + await loadUsers(activePage - 1); + } else { + await searchUsers(); + } + }; + + const handleOrderByChange = (e, { value }) => { + setOrderBy(value); + setActivePage(1); + setDropdownVisible(false); + }; + + const renderSelectedOption = (orderBy) => { + switch (orderBy) { + case 'quota': + return '按剩余额度排序'; + case 'used_quota': + return '按已用额度排序'; + case 'request_count': + return '按请求次数排序'; + default: + return '默认排序'; + } + }; + + return ( + <> + + +
+ handleKeywordChange(value)} + /> + + + + + setDropdownVisible(visible)} + render={ + + handleOrderByChange('', { value: '' })}>默认排序 + handleOrderByChange('', { value: 'quota' })}>按剩余额度排序 + handleOrderByChange('', { value: 'used_quota' })}>按已用额度排序 + handleOrderByChange('', { value: 'request_count' })}>按请求次数排序 + + } + > + + + + ); +}; + +export default UsersTable; diff --git a/web/air/src/components/WeChatIcon.js b/web/air/src/components/WeChatIcon.js new file mode 100644 index 0000000..22210d9 --- /dev/null +++ b/web/air/src/components/WeChatIcon.js @@ -0,0 +1,24 @@ +import React from 'react'; +import { Icon } from '@douyinfe/semi-ui'; + +const WeChatIcon = () => { + function CustomIcon() { + return + + + ; + } + + return ( +
+ } /> +
+ ); +}; + +export default WeChatIcon; diff --git a/web/air/src/components/utils.js b/web/air/src/components/utils.js new file mode 100644 index 0000000..5363ba5 --- /dev/null +++ b/web/air/src/components/utils.js @@ -0,0 +1,20 @@ +import { API, showError } from '../helpers'; + +export async function getOAuthState() { + const res = await API.get('/api/oauth/state'); + const { success, message, data } = res.data; + if (success) { + return data; + } else { + showError(message); + return ''; + } +} + +export async function onGitHubOAuthClicked(github_client_id) { + const state = await getOAuthState(); + if (!state) return; + window.open( + `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email` + ); +} \ No newline at end of file diff --git a/web/air/src/constants/channel.constants.js b/web/air/src/constants/channel.constants.js new file mode 100644 index 0000000..00a1d52 --- /dev/null +++ b/web/air/src/constants/channel.constants.js @@ -0,0 +1,52 @@ +export const CHANNEL_OPTIONS = [ + { key: 1, text: 'OpenAI', value: 1, color: 'green' }, + { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, + { key: 33, text: 'AWS', value: 33, color: 'black' }, + { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, + { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, + { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, + { key: 28, text: 'Mistral AI', value: 28, color: 'orange' }, + { key: 41, text: 'Novita', value: 41, color: 'purple' }, + {key: 40, text: '字节火山引擎', value: 40, color: 'blue'}, + { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, + { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, + { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, + { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, + { key: 19, text: '360 智脑', value: 19, color: 'blue' }, + { key: 25, text: 'Moonshot AI', value: 25, color: 'black' }, + { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, + { key: 26, text: '百川大模型', value: 26, color: 'orange' }, + { key: 27, text: 'MiniMax', value: 27, color: 'red' }, + { key: 29, text: 'Groq', value: 29, color: 'orange' }, + { key: 30, text: 'Ollama', value: 30, color: 'black' }, + { key: 31, text: '零一万物', value: 31, color: 'green' }, + { key: 32, text: '阶跃星辰', value: 32, color: 'blue' }, + { key: 34, text: 'Coze', value: 34, color: 'blue' }, + { key: 35, text: 'Cohere', value: 35, color: 'blue' }, + { key: 36, text: 'DeepSeek', value: 36, color: 'black' }, + { key: 37, text: 'Cloudflare', value: 37, color: 'orange' }, + { key: 38, text: 'DeepL', value: 38, color: 'black' }, + { key: 39, text: 'together.ai', value: 39, color: 'blue' }, + { key: 42, text: 'VertexAI', value: 42, color: 'blue' }, + { key: 43, text: 'Proxy', value: 43, color: 'blue' }, + { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, + { key: 45, text: 'xAI', value: 45, color: 'blue' }, + { key: 46, text: 'Replicate', value: 46, color: 'blue' }, + { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, + { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, + { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, + {key: 20, text: 'OpenRouter', value: 20, color: 'black'}, + { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, + { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, + { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, + { key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' }, + { key: 4, text: '代理:CloseAI', value: 4, color: 'teal' }, + { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' }, + { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, + { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, + { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } +]; + +for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { + CHANNEL_OPTIONS[i].label = CHANNEL_OPTIONS[i].text; +} diff --git a/web/air/src/constants/common.constant.js b/web/air/src/constants/common.constant.js new file mode 100644 index 0000000..1a37d5f --- /dev/null +++ b/web/air/src/constants/common.constant.js @@ -0,0 +1 @@ +export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend! diff --git a/web/air/src/constants/index.js b/web/air/src/constants/index.js new file mode 100644 index 0000000..e83152b --- /dev/null +++ b/web/air/src/constants/index.js @@ -0,0 +1,4 @@ +export * from './toast.constants'; +export * from './user.constants'; +export * from './common.constant'; +export * from './channel.constants'; \ No newline at end of file diff --git a/web/air/src/constants/toast.constants.js b/web/air/src/constants/toast.constants.js new file mode 100644 index 0000000..5068472 --- /dev/null +++ b/web/air/src/constants/toast.constants.js @@ -0,0 +1,7 @@ +export const toastConstants = { + SUCCESS_TIMEOUT: 1500, + INFO_TIMEOUT: 3000, + ERROR_TIMEOUT: 5000, + WARNING_TIMEOUT: 10000, + NOTICE_TIMEOUT: 20000 +}; diff --git a/web/air/src/constants/user.constants.js b/web/air/src/constants/user.constants.js new file mode 100644 index 0000000..2680d8e --- /dev/null +++ b/web/air/src/constants/user.constants.js @@ -0,0 +1,19 @@ +export const userConstants = { + REGISTER_REQUEST: 'USERS_REGISTER_REQUEST', + REGISTER_SUCCESS: 'USERS_REGISTER_SUCCESS', + REGISTER_FAILURE: 'USERS_REGISTER_FAILURE', + + LOGIN_REQUEST: 'USERS_LOGIN_REQUEST', + LOGIN_SUCCESS: 'USERS_LOGIN_SUCCESS', + LOGIN_FAILURE: 'USERS_LOGIN_FAILURE', + + LOGOUT: 'USERS_LOGOUT', + + GETALL_REQUEST: 'USERS_GETALL_REQUEST', + GETALL_SUCCESS: 'USERS_GETALL_SUCCESS', + GETALL_FAILURE: 'USERS_GETALL_FAILURE', + + DELETE_REQUEST: 'USERS_DELETE_REQUEST', + DELETE_SUCCESS: 'USERS_DELETE_SUCCESS', + DELETE_FAILURE: 'USERS_DELETE_FAILURE' +}; diff --git a/web/air/src/context/Status/index.js b/web/air/src/context/Status/index.js new file mode 100644 index 0000000..71f0682 --- /dev/null +++ b/web/air/src/context/Status/index.js @@ -0,0 +1,19 @@ +// contexts/User/index.jsx + +import React from 'react'; +import { initialState, reducer } from './reducer'; + +export const StatusContext = React.createContext({ + state: initialState, + dispatch: () => null, +}); + +export const StatusProvider = ({ children }) => { + const [state, dispatch] = React.useReducer(reducer, initialState); + + return ( + + {children} + + ); +}; \ No newline at end of file diff --git a/web/air/src/context/Status/reducer.js b/web/air/src/context/Status/reducer.js new file mode 100644 index 0000000..ec9ac6a --- /dev/null +++ b/web/air/src/context/Status/reducer.js @@ -0,0 +1,20 @@ +export const reducer = (state, action) => { + switch (action.type) { + case 'set': + return { + ...state, + status: action.payload, + }; + case 'unset': + return { + ...state, + status: undefined, + }; + default: + return state; + } +}; + +export const initialState = { + status: undefined, +}; diff --git a/web/air/src/context/User/index.js b/web/air/src/context/User/index.js new file mode 100644 index 0000000..c667159 --- /dev/null +++ b/web/air/src/context/User/index.js @@ -0,0 +1,19 @@ +// contexts/User/index.jsx + +import React from "react" +import { reducer, initialState } from "./reducer" + +export const UserContext = React.createContext({ + state: initialState, + dispatch: () => null +}) + +export const UserProvider = ({ children }) => { + const [state, dispatch] = React.useReducer(reducer, initialState) + + return ( + + { children } + + ) +} \ No newline at end of file diff --git a/web/air/src/context/User/reducer.js b/web/air/src/context/User/reducer.js new file mode 100644 index 0000000..9ed1d80 --- /dev/null +++ b/web/air/src/context/User/reducer.js @@ -0,0 +1,21 @@ +export const reducer = (state, action) => { + switch (action.type) { + case 'login': + return { + ...state, + user: action.payload + }; + case 'logout': + return { + ...state, + user: undefined + }; + + default: + return state; + } +}; + +export const initialState = { + user: undefined +}; \ No newline at end of file diff --git a/web/air/src/helpers/api.js b/web/air/src/helpers/api.js new file mode 100644 index 0000000..35fdb1e --- /dev/null +++ b/web/air/src/helpers/api.js @@ -0,0 +1,13 @@ +import { showError } from './utils'; +import axios from 'axios'; + +export const API = axios.create({ + baseURL: process.env.REACT_APP_SERVER ? process.env.REACT_APP_SERVER : '', +}); + +API.interceptors.response.use( + (response) => response, + (error) => { + showError(error); + } +); diff --git a/web/air/src/helpers/auth-header.js b/web/air/src/helpers/auth-header.js new file mode 100644 index 0000000..a8fe5f5 --- /dev/null +++ b/web/air/src/helpers/auth-header.js @@ -0,0 +1,10 @@ +export function authHeader() { + // return authorization header with jwt token + let user = JSON.parse(localStorage.getItem('user')); + + if (user && user.token) { + return { 'Authorization': 'Bearer ' + user.token }; + } else { + return {}; + } +} \ No newline at end of file diff --git a/web/air/src/helpers/history.js b/web/air/src/helpers/history.js new file mode 100644 index 0000000..629039e --- /dev/null +++ b/web/air/src/helpers/history.js @@ -0,0 +1,3 @@ +import { createBrowserHistory } from 'history'; + +export const history = createBrowserHistory(); \ No newline at end of file diff --git a/web/air/src/helpers/index.js b/web/air/src/helpers/index.js new file mode 100644 index 0000000..505a8cf --- /dev/null +++ b/web/air/src/helpers/index.js @@ -0,0 +1,4 @@ +export * from './history'; +export * from './auth-header'; +export * from './utils'; +export * from './api'; \ No newline at end of file diff --git a/web/air/src/helpers/render.js b/web/air/src/helpers/render.js new file mode 100644 index 0000000..62fb0dc --- /dev/null +++ b/web/air/src/helpers/render.js @@ -0,0 +1,170 @@ +import {Label} from 'semantic-ui-react'; +import {Tag} from "@douyinfe/semi-ui"; + +export function renderText(text, limit) { + if (text.length > limit) { + return text.slice(0, limit - 3) + '...'; + } + return text; +} + +export function renderGroup(group) { + if (group === '') { + return default; + } + let groups = group.split(','); + groups.sort(); + return <> + {groups.map((group) => { + if (group === 'vip' || group === 'pro') { + return {group}; + } else if (group === 'svip' || group === 'premium') { + return {group}; + } + if (group === 'default') { + return {group}; + } else { + return {group}; + } + })} + ; +} + +export function renderNumber(num) { + if (num >= 1000000000) { + return (num / 1000000000).toFixed(1) + 'B'; + } else if (num >= 1000000) { + return (num / 1000000).toFixed(1) + 'M'; + } else if (num >= 10000) { + return (num / 1000).toFixed(1) + 'k'; + } else { + return num; + } +} + +export function renderQuotaNumberWithDigit(num, digits = 2) { + let displayInCurrency = localStorage.getItem('display_in_currency'); + num = num.toFixed(digits); + if (displayInCurrency) { + return '$' + num; + } + return num; +} + +export function renderNumberWithPoint(num) { + num = num.toFixed(2); + if (num >= 100000) { + // Convert number to string to manipulate it + let numStr = num.toString(); + // Find the position of the decimal point + let decimalPointIndex = numStr.indexOf('.'); + + let wholePart = numStr; + let decimalPart = ''; + + // If there is a decimal point, split the number into whole and decimal parts + if (decimalPointIndex !== -1) { + wholePart = numStr.slice(0, decimalPointIndex); + decimalPart = numStr.slice(decimalPointIndex); + } + + // Take the first two and last two digits of the whole number part + let shortenedWholePart = wholePart.slice(0, 2) + '..' + wholePart.slice(-2); + + // Return the formatted number + return shortenedWholePart + decimalPart; + } + + // If the number is less than 100,000, return it unmodified + return num; +} + +export function getQuotaPerUnit() { + let quotaPerUnit = localStorage.getItem('quota_per_unit'); + quotaPerUnit = parseFloat(quotaPerUnit); + return quotaPerUnit; +} + +export function getQuotaWithUnit(quota, digits = 6) { + let quotaPerUnit = localStorage.getItem('quota_per_unit'); + quotaPerUnit = parseFloat(quotaPerUnit); + return (quota / quotaPerUnit).toFixed(digits); +} + +export function renderQuota(quota, digits = 2) { + let quotaPerUnit = localStorage.getItem('quota_per_unit'); + let displayInCurrency = localStorage.getItem('display_in_currency'); + quotaPerUnit = parseFloat(quotaPerUnit); + displayInCurrency = displayInCurrency === 'true'; + if (displayInCurrency) { + return '$' + (quota / quotaPerUnit).toFixed(digits); + } + return renderNumber(quota); +} + +export function renderQuotaWithPrompt(quota, digits) { + let displayInCurrency = localStorage.getItem('display_in_currency'); + displayInCurrency = displayInCurrency === 'true'; + if (displayInCurrency) { + return `(等价金额:${renderQuota(quota, digits)})`; + } + return ''; +} + +const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo', + 'light-blue', 'lime', 'orange', 'pink', + 'purple', 'red', 'teal', 'violet', 'yellow' +] + +export const modelColorMap = { + 'dall-e': 'rgb(147,112,219)', // 深紫色 + 'dall-e-2': 'rgb(147,112,219)', // 介于紫色和蓝色之间的色调 + 'dall-e-3': 'rgb(153,50,204)', // 介于紫罗兰和洋红之间的色调 + 'midjourney': 'rgb(136,43,180)', // 介于紫罗兰和洋红之间的色调 + 'gpt-3.5-turbo': 'rgb(184,227,167)', // 浅绿色 + 'gpt-3.5-turbo-0301': 'rgb(131,220,131)', // 亮绿色 + 'gpt-3.5-turbo-0613': 'rgb(60,179,113)', // 海洋绿 + 'gpt-3.5-turbo-1106': 'rgb(32,178,170)', // 浅海洋绿 + 'gpt-3.5-turbo-16k': 'rgb(252,200,149)', // 淡橙色 + 'gpt-3.5-turbo-16k-0613': 'rgb(255,181,119)', // 淡桃色 + 'gpt-3.5-turbo-instruct': 'rgb(175,238,238)', // 粉蓝色 + 'gpt-4': 'rgb(135,206,235)', // 天蓝色 + 'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色 + 'gpt-4-0613': 'rgb(100,149,237)', // 矢车菊蓝 + 'gpt-4-1106-preview': 'rgb(30,144,255)', // 道奇蓝 + 'gpt-4-0125-preview': 'rgb(2,177,236)', // 深天蓝 + 'gpt-4-turbo-preview': 'rgb(2,177,255)', // 深天蓝 + 'gpt-4-32k': 'rgb(104,111,238)', // 中紫色 + 'gpt-4-32k-0314': 'rgb(90,105,205)', // 暗灰蓝色 + 'gpt-4-32k-0613': 'rgb(61,71,139)', // 暗蓝灰色 + 'gpt-4-all': 'rgb(65,105,225)', // 皇家蓝 + 'gpt-4-gizmo-*': 'rgb(0,0,255)', // 纯蓝色 + 'gpt-4-vision-preview': 'rgb(25,25,112)', // 午夜蓝 + 'text-ada-001': 'rgb(255,192,203)', // 粉红色 + 'text-babbage-001': 'rgb(255,160,122)', // 浅珊瑚色 + 'text-curie-001': 'rgb(219,112,147)', // 苍紫罗兰色 + 'text-davinci-002': 'rgb(199,21,133)', // 中紫罗兰红色 + 'text-davinci-003': 'rgb(219,112,147)', // 苍紫罗兰色(与Curie相同,表示同一个系列) + 'text-davinci-edit-001': 'rgb(255,105,180)', // 热粉色 + 'text-embedding-ada-002': 'rgb(255,182,193)', // 浅粉红 + 'text-embedding-v1': 'rgb(255,174,185)', // 浅粉红色(略有区别) + 'text-moderation-latest': 'rgb(255,130,171)', // 强粉色 + 'text-moderation-stable': 'rgb(255,160,122)', // 浅珊瑚色(与Babbage相同,表示同一类功能) + 'tts-1': 'rgb(255,140,0)', // 深橙色 + 'tts-1-1106': 'rgb(255,165,0)', // 橙色 + 'tts-1-hd': 'rgb(255,215,0)', // 金色 + 'tts-1-hd-1106': 'rgb(255,223,0)', // 金黄色(略有区别) + 'whisper-1': 'rgb(245,245,220)' // 米色 +} + +export function stringToColor(str) { + let sum = 0; + // 对字符串中的每个字符进行操作 + for (let i = 0; i < str.length; i++) { + // 将字符的ASCII值加到sum中 + sum += str.charCodeAt(i); + } + // 使用模运算得到个位数 + let i = sum % colors.length; + return colors[i]; +} \ No newline at end of file diff --git a/web/air/src/helpers/utils.js b/web/air/src/helpers/utils.js new file mode 100644 index 0000000..580c77c --- /dev/null +++ b/web/air/src/helpers/utils.js @@ -0,0 +1,233 @@ +import { Toast } from '@douyinfe/semi-ui'; +import { toastConstants } from '../constants'; +import React from 'react'; +import {toast} from "react-toastify"; + +const HTMLToastContent = ({ htmlContent }) => { + return
; +}; +export default HTMLToastContent; +export function isAdmin() { + let user = localStorage.getItem('user'); + if (!user) return false; + user = JSON.parse(user); + return user.role >= 10; +} + +export function isRoot() { + let user = localStorage.getItem('user'); + if (!user) return false; + user = JSON.parse(user); + return user.role >= 100; +} + +export function getSystemName() { + let system_name = localStorage.getItem('system_name'); + if (!system_name) return 'One API'; + return system_name; +} + +export function getLogo() { + let logo = localStorage.getItem('logo'); + if (!logo) return '/logo.png'; + return logo +} + +export function getFooterHTML() { + return localStorage.getItem('footer_html'); +} + +export async function copy(text) { + let okay = true; + try { + await navigator.clipboard.writeText(text); + } catch (e) { + okay = false; + console.error(e); + } + return okay; +} + +export function isMobile() { + return window.innerWidth <= 600; +} + +let showErrorOptions = { autoClose: toastConstants.ERROR_TIMEOUT }; +let showWarningOptions = { autoClose: toastConstants.WARNING_TIMEOUT }; +let showSuccessOptions = { autoClose: toastConstants.SUCCESS_TIMEOUT }; +let showInfoOptions = { autoClose: toastConstants.INFO_TIMEOUT }; +let showNoticeOptions = { autoClose: false }; + +if (isMobile()) { + showErrorOptions.position = 'top-center'; + // showErrorOptions.transition = 'flip'; + + showSuccessOptions.position = 'top-center'; + // showSuccessOptions.transition = 'flip'; + + showInfoOptions.position = 'top-center'; + // showInfoOptions.transition = 'flip'; + + showNoticeOptions.position = 'top-center'; + // showNoticeOptions.transition = 'flip'; +} + +export function showError(error) { + console.error(error); + if (error.message) { + if (error.name === 'AxiosError') { + switch (error.response.status) { + case 401: + // toast.error('错误:未登录或登录已过期,请重新登录!', showErrorOptions); + window.location.href = '/login?expired=true'; + break; + case 429: + Toast.error('错误:请求次数过多,请稍后再试!'); + break; + case 500: + Toast.error('错误:服务器内部错误,请联系管理员!'); + break; + case 405: + Toast.info('本站仅作演示之用,无服务端!'); + break; + default: + Toast.error('错误:' + error.message); + } + return; + } + Toast.error('错误:' + error.message); + } else { + Toast.error('错误:' + error); + } +} + +export function showWarning(message) { + Toast.warning(message); +} + +export function showSuccess(message) { + Toast.success(message); +} + +export function showInfo(message) { + Toast.info(message); +} + +export function showNotice(message, isHTML = false) { + if (isHTML) { + toast(, showNoticeOptions); + } else { + Toast.info(message); + } +} + +export function openPage(url) { + window.open(url); +} + +export function removeTrailingSlash(url) { + if (url.endsWith('/')) { + return url.slice(0, -1); + } else { + return url; + } +} + +export function timestamp2string(timestamp) { + let date = new Date(timestamp * 1000); + let year = date.getFullYear().toString(); + let month = (date.getMonth() + 1).toString(); + let day = date.getDate().toString(); + let hour = date.getHours().toString(); + let minute = date.getMinutes().toString(); + let second = date.getSeconds().toString(); + if (month.length === 1) { + month = '0' + month; + } + if (day.length === 1) { + day = '0' + day; + } + if (hour.length === 1) { + hour = '0' + hour; + } + if (minute.length === 1) { + minute = '0' + minute; + } + if (second.length === 1) { + second = '0' + second; + } + return ( + year + + '-' + + month + + '-' + + day + + ' ' + + hour + + ':' + + minute + + ':' + + second + ); +} + +export function timestamp2string1(timestamp, dataExportDefaultTime = 'hour') { + let date = new Date(timestamp * 1000); + // let year = date.getFullYear().toString(); + let month = (date.getMonth() + 1).toString(); + let day = date.getDate().toString(); + let hour = date.getHours().toString(); + if (month.length === 1) { + month = '0' + month; + } + if (day.length === 1) { + day = '0' + day; + } + if (hour.length === 1) { + hour = '0' + hour; + } + let str = month + '-' + day + if (dataExportDefaultTime === 'hour') { + str += ' ' + hour + ":00" + } else if (dataExportDefaultTime === 'week') { + let nextWeek = new Date(timestamp * 1000 + 6 * 24 * 60 * 60 * 1000); + let nextMonth = (nextWeek.getMonth() + 1).toString(); + let nextDay = nextWeek.getDate().toString(); + if (nextMonth.length === 1) { + nextMonth = '0' + nextMonth; + } + if (nextDay.length === 1) { + nextDay = '0' + nextDay; + } + str += ' - ' + nextMonth + '-' + nextDay + } + return str; +} + +export function downloadTextAsFile(text, filename) { + let blob = new Blob([text], { type: 'text/plain;charset=utf-8' }); + let url = URL.createObjectURL(blob); + let a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); +} + +export const verifyJSON = (str) => { + try { + JSON.parse(str); + } catch (e) { + return false; + } + return true; +}; + +export function shouldShowPrompt(id) { + let prompt = localStorage.getItem(`prompt-${id}`); + return !prompt; + +} + +export function setPromptShown(id) { + localStorage.setItem(`prompt-${id}`, 'true'); +} \ No newline at end of file diff --git a/web/air/src/index.css b/web/air/src/index.css new file mode 100644 index 0000000..271f14e --- /dev/null +++ b/web/air/src/index.css @@ -0,0 +1,116 @@ +body { + margin: 0; + padding-top: 55px; + overflow-y: scroll; + font-family: Lato, 'Helvetica Neue', Arial, Helvetica, "Microsoft YaHei", sans-serif; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; + scrollbar-width: none; + color: var(--semi-color-text-0) !important; + background-color: var( --semi-color-bg-0) !important; + height: 100%; +} + +#root { + height: 100%; +} + +@media only screen and (max-width: 767px) { + .semi-table-tbody, .semi-table-row, .semi-table-row-cell { + display: block!important; + width: auto!important; + padding: 2px!important; + } + .semi-table-row-cell { + border-bottom: 0!important; + } + .semi-table-tbody>.semi-table-row { + border-bottom: 1px solid rgba(0,0,0,.1); + } + .semi-space { + /*display: block!important;*/ + display: flex; + flex-direction: row; + flex-wrap: wrap; + row-gap: 3px; + column-gap: 10px; + } +} + +.semi-table-tbody > .semi-table-row > .semi-table-row-cell { + padding: 16px 14px; +} + +.channel-table { + .semi-table-tbody > .semi-table-row > .semi-table-row-cell { + padding: 16px 8px; + } +} + +.semi-layout { + height: 100%; +} + +.tableShow { + display: revert; +} + +.tableHiddle { + display: none !important; +} + +body::-webkit-scrollbar { + display: none; +} + +code { + font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace; +} + +.semi-navigation-vertical { + /*display: flex;*/ + /*flex-direction: column;*/ +} + +.semi-navigation-item { + margin-bottom: 0; +} + +.semi-navigation-vertical { + /*flex: 0 0 auto;*/ + /*display: flex;*/ + /*flex-direction: column;*/ + /*width: 100%;*/ + height: 100%; + overflow: hidden; +} + +.main-content { + padding: 4px; + height: 100%; +} + +.small-icon .icon { + font-size: 1em !important; +} + +.custom-footer { + font-size: 1.1em; +} + +@media only screen and (max-width: 600px) { + .hide-on-mobile { + display: none !important; + } +} + + +/* 隐藏浏览器默认的滚动条 */ +body { + overflow: hidden; +} + +/* 自定义滚动条样式 */ +body::-webkit-scrollbar { + width: 0; /* 隐藏滚动条的宽度 */ +} \ No newline at end of file diff --git a/web/air/src/index.js b/web/air/src/index.js new file mode 100644 index 0000000..25b1d39 --- /dev/null +++ b/web/air/src/index.js @@ -0,0 +1,54 @@ +import { initVChartSemiTheme } from '@visactor/vchart-semi-theme'; +import React from 'react'; +import ReactDOM from 'react-dom/client'; +import {BrowserRouter} from 'react-router-dom'; +import App from './App'; +import HeaderBar from './components/HeaderBar'; +import Footer from './components/Footer'; +import 'semantic-ui-css/semantic.min.css'; +import './index.css'; +import {UserProvider} from './context/User'; +import {ToastContainer} from 'react-toastify'; +import 'react-toastify/dist/ReactToastify.css'; +import {StatusProvider} from './context/Status'; +import {Layout} from "@douyinfe/semi-ui"; +import SiderBar from "./components/SiderBar"; + +// initialization +initVChartSemiTheme({ + isWatchingThemeSwitch: true, +}); + +const root = ReactDOM.createRoot(document.getElementById('root')); +const {Sider, Content, Header} = Layout; +root.render( + + + + + + + + + +
+ +
+ + + + +
+
+
+ +
+
+
+
+
+); diff --git a/web/air/src/pages/About/index.js b/web/air/src/pages/About/index.js new file mode 100644 index 0000000..ec13f15 --- /dev/null +++ b/web/air/src/pages/About/index.js @@ -0,0 +1,58 @@ +import React, { useEffect, useState } from 'react'; +import { Header, Segment } from 'semantic-ui-react'; +import { API, showError } from '../../helpers'; +import { marked } from 'marked'; + +const About = () => { + const [about, setAbout] = useState(''); + const [aboutLoaded, setAboutLoaded] = useState(false); + + const displayAbout = async () => { + setAbout(localStorage.getItem('about') || ''); + const res = await API.get('/api/about'); + const { success, message, data } = res.data; + if (success) { + let aboutContent = data; + if (!data.startsWith('https://')) { + aboutContent = marked.parse(data); + } + setAbout(aboutContent); + localStorage.setItem('about', aboutContent); + } else { + showError(message); + setAbout('加载关于内容失败...'); + } + setAboutLoaded(true); + }; + + useEffect(() => { + displayAbout().then(); + }, []); + + return ( + <> + { + aboutLoaded && about === '' ? <> + +
关于
+

可在设置页面设置关于内容,支持 HTML & Markdown

+ 项目仓库地址: + + https://github.com/songquanpeng/one-api + +
+ : <> + { + about.startsWith('https://') ?