Merge branch 'main' into pr
This commit is contained in:
10
.github/workflows/docker-image-amd64.yml
vendored
10
.github/workflows/docker-image-amd64.yml
vendored
@@ -18,20 +18,20 @@ jobs:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Save version info
|
||||
run: |
|
||||
git describe --tags > VERSION
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -39,14 +39,14 @@ jobs:
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker images
|
||||
uses: docker/build-push-action@v3
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
|
||||
15
.github/workflows/docker-image-arm64.yml
vendored
15
.github/workflows/docker-image-arm64.yml
vendored
@@ -4,7 +4,6 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
@@ -19,26 +18,26 @@ jobs:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Save version info
|
||||
run: |
|
||||
git describe --tags > VERSION
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -46,14 +45,14 @@ jobs:
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker images
|
||||
uses: docker/build-push-action@v3
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -10,3 +10,4 @@ web/dist
|
||||
.env
|
||||
one-api
|
||||
.DS_Store
|
||||
tiktoken_cache
|
||||
222
README.md
222
README.md
@@ -7,7 +7,6 @@
|
||||
|
||||
# New API
|
||||
|
||||
|
||||
🍥新一代大模型网关与AI资产管理系统
|
||||
|
||||
<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
@@ -37,199 +36,154 @@
|
||||
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
|
||||
|
||||
> [!IMPORTANT]
|
||||
> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
||||
> - 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。
|
||||
> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
||||
> - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||
|
||||
## 📚 文档
|
||||
|
||||
详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
|
||||
|
||||
## ✨ 主要特性
|
||||
|
||||
1. 🎨 全新的UI界面(部分界面还待更新)
|
||||
2. 🌍 多语言支持(待完善)
|
||||
3. 🎨 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口支持,[对接文档](Midjourney.md)
|
||||
4. 💰 支持在线充值功能,可在系统设置中设置:
|
||||
- [x] 易支付
|
||||
5. 🔍 支持用key查询使用额度:
|
||||
- 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用
|
||||
6. 📑 分页支持选择每页显示数量
|
||||
7. 🔄 兼容原版One API的数据库,可直接使用原版数据库(one-api.db)
|
||||
8. 💵 支持模型按次数收费,可在 系统设置-运营设置 中设置
|
||||
9. ⚖️ 支持渠道**加权随机**
|
||||
10. 📈 数据看板(控制台)
|
||||
11. 🔒 可设置令牌能调用的模型
|
||||
12. 🤖 支持Telegram授权登录:
|
||||
1. 系统设置-配置登录注册-允许通过Telegram登录
|
||||
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
||||
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
||||
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
||||
13. 🎵 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口支持,[对接文档](Suno.md)
|
||||
14. 🔄 支持Rerank模型,目前兼容Cohere和Jina,可接入Dify,[对接文档](Rerank.md)
|
||||
15. ⚡ **[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API,支持Azure渠道
|
||||
16. 支持使用路由/chat2link 进入聊天界面
|
||||
17. 🧠 支持通过模型名称后缀设置 reasoning effort:
|
||||
New API提供了丰富的功能,详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction):
|
||||
|
||||
1. 🎨 全新的UI界面
|
||||
2. 🌍 多语言支持
|
||||
3. 💰 支持在线充值功能(易支付)
|
||||
4. 🔍 支持用key查询使用额度(配合[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool))
|
||||
5. 🔄 兼容原版One API的数据库
|
||||
6. 💵 支持模型按次数收费
|
||||
7. ⚖️ 支持渠道加权随机
|
||||
8. 📈 数据看板(控制台)
|
||||
9. 🔒 令牌分组、模型限制
|
||||
10. 🤖 支持更多授权登陆方式(LinuxDO,Telegram、OIDC)
|
||||
11. 🔄 支持Rerank模型(Cohere和Jina),[接口文档](https://docs.newapi.pro/api/jinaai-rerank)
|
||||
12. ⚡ 支持OpenAI Realtime API(包括Azure渠道),[接口文档](https://docs.newapi.pro/api/openai-realtime)
|
||||
13. ⚡ 支持Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
|
||||
14. 支持使用路由/chat2link进入聊天界面
|
||||
15. 🧠 支持通过模型名称后缀设置 reasoning effort:
|
||||
1. OpenAI o系列模型
|
||||
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
|
||||
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
|
||||
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
|
||||
2. Claude 思考模型
|
||||
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
|
||||
18. 🔄 思考转内容,支持在 `渠道-编辑-渠道额外设置` 中设置 `thinking_to_content` 选项,默认`false`,开启后会将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回。
|
||||
19. 🔄 模型限流,支持在 `系统设置-速率限制设置` 中设置模型限流,支持设置总请求数限制和成功请求数限制
|
||||
20. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
|
||||
16. 🔄 思考转内容功能
|
||||
17. 🔄 针对用户的模型限流功能
|
||||
18. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
|
||||
1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项
|
||||
2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费
|
||||
3. 支持的渠道:
|
||||
- [x] OpenAI
|
||||
- [x] Azure
|
||||
- [x] DeepSeek
|
||||
- [ ] Claude
|
||||
- [x] Claude
|
||||
|
||||
## 模型支持
|
||||
此版本额外支持以下模型:
|
||||
|
||||
此版本支持多种模型,详情请参考[接口文档-中继接口](https://docs.newapi.pro/api):
|
||||
|
||||
1. 第三方模型 **gpts** (gpt-4-gizmo-*)
|
||||
2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
||||
3. 自定义渠道,支持填入完整调用地址
|
||||
4. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
|
||||
5. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/),[对接文档](Rerank.md)
|
||||
6. Dify
|
||||
2. 第三方渠道[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[接口文档](https://docs.newapi.pro/api/midjourney-proxy-image)
|
||||
3. 第三方渠道[Suno API](https://github.com/Suno-API/Suno-API)接口,[接口文档](https://docs.newapi.pro/api/suno-music)
|
||||
4. 自定义渠道,支持填入完整调用地址
|
||||
5. Rerank模型([Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)),[接口文档](https://docs.newapi.pro/api/jinaai-rerank)
|
||||
6. Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
|
||||
7. Dify,当前仅支持chatflow
|
||||
|
||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||
## 环境变量配置
|
||||
|
||||
## 比原版One API多出的配置
|
||||
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`。
|
||||
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 60 秒。
|
||||
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`。
|
||||
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,请求上游返回流模式usage,默认为 `true`,建议开启,不影响客户端传入stream_options参数返回结果。
|
||||
- `GET_MEDIA_TOKEN`:是否统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
|
||||
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。
|
||||
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
|
||||
- `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE`。
|
||||
- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。
|
||||
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。
|
||||
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
|
||||
- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,如果渠道设置中未指定API版本,则使用此版本,默认为 `2024-12-01-preview`
|
||||
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`。
|
||||
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`。
|
||||
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
|
||||
|
||||
## 已废弃的环境变量
|
||||
- ~~`GEMINI_MODEL_MAP`(已废弃)~~:改为到`设置-模型相关设置`中设置
|
||||
- ~~`GEMINI_SAFETY_SETTING`(已废弃)~~:改为到`设置-模型相关设置`中设置
|
||||
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
|
||||
- `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒
|
||||
- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
|
||||
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
|
||||
- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
|
||||
- `GET_MEDIA_TOKEN_NOT_STREAM`:非流情况下是否统计图片token,默认 `true`
|
||||
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认 `true`
|
||||
- `COHERE_SAFETY_SETTING`:Cohere模型安全设置,可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认 `NONE`
|
||||
- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认 `16`
|
||||
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位MB,默认 `20`
|
||||
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容
|
||||
- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2024-12-01-preview`
|
||||
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟
|
||||
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`
|
||||
|
||||
## 部署
|
||||
|
||||
详细部署指南请参考[安装指南-部署方式](https://docs.newapi.pro/installation):
|
||||
|
||||
> [!TIP]
|
||||
> 最新版Docker镜像:`calciumion/new-api:latest`
|
||||
> 默认账号root 密码123456
|
||||
|
||||
### 多机部署
|
||||
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致。
|
||||
- 如果公用Redis,必须设置 `CRYPTO_SECRET`,否则会导致多机部署时Redis内容无法获取。
|
||||
### 多机部署注意事项
|
||||
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致
|
||||
- 如果公用Redis,必须设置 `CRYPTO_SECRET`,否则会导致多机部署时Redis内容无法获取
|
||||
|
||||
### 部署要求
|
||||
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
||||
- 远程数据库:MySQL 版本 >= 5.7.8,PgSQL 版本 >= 9.6
|
||||
- 本地数据库(默认):SQLite(Docker部署必须挂载`/data`目录)
|
||||
- 远程数据库:MySQL版本 >= 5.7.8,PgSQL版本 >= 9.6
|
||||
|
||||
### 使用宝塔面板Docker功能部署
|
||||
安装宝塔面板 (**9.2.0版本**及以上),前往 [宝塔面板](https://www.bt.cn/new/download.html) 官网,选择正式版的脚本下载安装
|
||||
安装后登录宝塔面板,在菜单栏中点击 Docker ,首次进入会提示安装 Docker 服务,点击立即安装,按提示完成安装
|
||||
安装完成后在应用商店中找到 **New-API** ,点击安装,配置基本选项 即可完成安装
|
||||
### 部署方式
|
||||
|
||||
#### 使用宝塔面板Docker功能部署
|
||||
安装宝塔面板(**9.2.0版本**及以上),在应用商店中找到**New-API**安装即可。
|
||||
[图文教程](BT.md)
|
||||
|
||||
### 基于 Docker 进行部署
|
||||
|
||||
> [!TIP]
|
||||
> 默认管理员账号root 密码123456
|
||||
|
||||
### 使用 Docker Compose 部署(推荐)
|
||||
#### 使用Docker Compose部署(推荐)
|
||||
```shell
|
||||
# 下载项目
|
||||
git clone https://github.com/Calcium-Ion/new-api.git
|
||||
cd new-api
|
||||
# 按需编辑 docker-compose.yml
|
||||
# nano docker-compose.yml
|
||||
# vim docker-compose.yml
|
||||
# 按需编辑docker-compose.yml
|
||||
# 启动
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
#### 更新版本
|
||||
#### 直接使用Docker镜像
|
||||
```shell
|
||||
docker-compose pull
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
### 直接使用 Docker 镜像
|
||||
```shell
|
||||
# 使用 SQLite 的部署命令:
|
||||
# 使用SQLite
|
||||
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
|
||||
|
||||
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。
|
||||
# 例如:
|
||||
# 使用MySQL
|
||||
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
|
||||
```
|
||||
|
||||
#### 更新版本
|
||||
```shell
|
||||
# 拉取最新镜像
|
||||
docker pull calciumion/new-api:latest
|
||||
# 停止并删除旧容器
|
||||
docker stop new-api
|
||||
docker rm new-api
|
||||
# 使用相同参数运行新容器
|
||||
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
|
||||
```
|
||||
|
||||
或者使用 Watchtower 自动更新(不推荐,可能会导致数据库不兼容):
|
||||
```shell
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
|
||||
```
|
||||
|
||||
## 渠道重试
|
||||
## 渠道重试与缓存
|
||||
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
|
||||
如果开启了重试功能,重试使用下一个优先级,以此类推。
|
||||
|
||||
### 缓存设置方法
|
||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
||||
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
||||
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
||||
+ 例子:`MEMORY_CACHE_ENABLED=true`
|
||||
### 为什么有的时候没有重试
|
||||
这些错误码不会重试:400,504,524
|
||||
### 我想让400也重试
|
||||
在`渠道->编辑`中,将`状态码复写`改为
|
||||
```json
|
||||
{
|
||||
"400": "500"
|
||||
}
|
||||
```
|
||||
可以实现400错误转为500错误,从而重试
|
||||
1. `REDIS_CONN_STRING`:设置Redis作为缓存
|
||||
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(设置了Redis则无需手动设置)
|
||||
|
||||
## Midjourney接口设置文档
|
||||
[对接文档](Midjourney.md)
|
||||
## 接口文档
|
||||
|
||||
## Suno接口设置文档
|
||||
[对接文档](Suno.md)
|
||||
详细接口文档请参考[接口文档](https://docs.newapi.pro/api):
|
||||
|
||||
## 界面截图
|
||||

|
||||
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
## 交流群
|
||||
<img src="https://github.com/user-attachments/assets/9ca0bc82-e057-4230-a28d-9f198fa022e3" width="200">
|
||||
- [聊天接口(Chat)](https://docs.newapi.pro/api/openai-chat)
|
||||
- [图像接口(Image)](https://docs.newapi.pro/api/openai-image)
|
||||
- [重排序接口(Rerank)](https://docs.newapi.pro/api/jinaai-rerank)
|
||||
- [实时对话接口(Realtime)](https://docs.newapi.pro/api/openai-realtime)
|
||||
- [Claude聊天接口(messages)](https://docs.newapi.pro/api/anthropic-chat)
|
||||
|
||||
## 相关项目
|
||||
- [One API](https://github.com/songquanpeng/one-api):原版项目
|
||||
- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy):Midjourney接口支持
|
||||
- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代 AI 一站式 B/C 端解决方案
|
||||
- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代AI一站式B/C端解决方案
|
||||
- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool):用key查询使用额度
|
||||
|
||||
其他基于New API的项目:
|
||||
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版,专注于高并发优化,并支持Claude格式
|
||||
- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本,闭源免费
|
||||
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
|
||||
- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本
|
||||
|
||||
## 帮助支持
|
||||
|
||||
如有问题,请参考[帮助支持](https://docs.newapi.pro/support):
|
||||
- [社区交流](https://docs.newapi.pro/support/community-interaction)
|
||||
- [反馈问题](https://docs.newapi.pro/support/feedback-issues)
|
||||
- [常见问题](https://docs.newapi.pro/support/faq)
|
||||
|
||||
## 🌟 Star History
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
//"os"
|
||||
//"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -63,8 +63,8 @@ var EmailDomainWhitelist = []string{
|
||||
"foxmail.com",
|
||||
}
|
||||
|
||||
var DebugEnabled = os.Getenv("DEBUG") == "true"
|
||||
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
|
||||
var DebugEnabled bool
|
||||
var MemoryCacheEnabled bool
|
||||
|
||||
var LogConsumeEnabled = true
|
||||
|
||||
@@ -77,7 +77,6 @@ var SMTPToken = ""
|
||||
|
||||
var GitHubClientId = ""
|
||||
var GitHubClientSecret = ""
|
||||
|
||||
var LinuxDOClientId = ""
|
||||
var LinuxDOClientSecret = ""
|
||||
|
||||
@@ -104,22 +103,22 @@ var RetryTimes = 0
|
||||
|
||||
//var RootUserEmail = ""
|
||||
|
||||
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||
var IsMasterNode bool
|
||||
|
||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||
var requestInterval int
|
||||
var RequestInterval time.Duration
|
||||
|
||||
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second
|
||||
var SyncFrequency int // unit is second
|
||||
|
||||
var BatchUpdateEnabled = false
|
||||
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
var BatchUpdateInterval int
|
||||
|
||||
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
||||
var RelayTimeout int // unit is second
|
||||
|
||||
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
var GeminiSafetySetting string
|
||||
|
||||
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
|
||||
var CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
|
||||
var CohereSafetySetting string
|
||||
|
||||
const (
|
||||
RequestIdKey = "X-Oneapi-Request-Id"
|
||||
@@ -146,13 +145,13 @@ var (
|
||||
// All duration's unit is seconds
|
||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||
var (
|
||||
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
|
||||
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
|
||||
GlobalApiRateLimitEnable bool
|
||||
GlobalApiRateLimitNum int
|
||||
GlobalApiRateLimitDuration int64
|
||||
|
||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||
GlobalWebRateLimitEnable bool
|
||||
GlobalWebRateLimitNum int
|
||||
GlobalWebRateLimitDuration int64
|
||||
|
||||
UploadRateLimitNum = 10
|
||||
UploadRateLimitDuration int64 = 60
|
||||
@@ -235,6 +234,8 @@ const (
|
||||
ChannelTypeMokaAI = 44
|
||||
ChannelTypeVolcEngine = 45
|
||||
ChannelTypeBaiduV2 = 46
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
@@ -287,4 +288,6 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.moka.ai", //44
|
||||
"https://ark.cn-beijing.volces.com", //45
|
||||
"https://qianfan.baidubce.com", //46
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ var fieldReplacer = strings.NewReplacer(
|
||||
"\r", "\\r")
|
||||
|
||||
var dataReplacer = strings.NewReplacer(
|
||||
"\n", "\ndata:",
|
||||
"\n", "\n",
|
||||
"\r", "\\r")
|
||||
|
||||
type CustomEvent struct {
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -66,4 +68,31 @@ func LoadEnv() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize variables from constants.go that were using environment variables
|
||||
DebugEnabled = os.Getenv("DEBUG") == "true"
|
||||
MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
|
||||
IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||
|
||||
// Parse requestInterval and set RequestInterval
|
||||
requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||
RequestInterval = time.Duration(requestInterval) * time.Second
|
||||
|
||||
// Initialize variables with GetEnvOrDefault
|
||||
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
|
||||
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
|
||||
|
||||
// Initialize string variables with GetEnvOrDefaultString
|
||||
GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
|
||||
|
||||
// Initialize rate limit variables
|
||||
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
|
||||
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
|
||||
|
||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||
}
|
||||
|
||||
18
common/json.go
Normal file
18
common/json.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
func DecodeJson(data []byte, v any) error {
|
||||
return json.NewDecoder(bytes.NewReader(data)).Decode(v)
|
||||
}
|
||||
|
||||
func DecodeJsonStr(data string, v any) error {
|
||||
return DecodeJson(StringToByteSlice(data), v)
|
||||
}
|
||||
|
||||
func EncodeJson(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
89
common/limiter/limiter.go
Normal file
89
common/limiter/limiter.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
)
|
||||
|
||||
//go:embed lua/rate_limit.lua
|
||||
var rateLimitScript string
|
||||
|
||||
type RedisLimiter struct {
|
||||
client *redis.Client
|
||||
limitScriptSHA string
|
||||
}
|
||||
|
||||
var (
|
||||
instance *RedisLimiter
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func New(ctx context.Context, r *redis.Client) *RedisLimiter {
|
||||
once.Do(func() {
|
||||
// 预加载脚本
|
||||
limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
|
||||
}
|
||||
instance = &RedisLimiter{
|
||||
client: r,
|
||||
limitScriptSHA: limitSHA,
|
||||
}
|
||||
})
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
|
||||
// 默认配置
|
||||
config := &Config{
|
||||
Capacity: 10,
|
||||
Rate: 1,
|
||||
Requested: 1,
|
||||
}
|
||||
|
||||
// 应用选项模式
|
||||
for _, opt := range opts {
|
||||
opt(config)
|
||||
}
|
||||
|
||||
// 执行限流
|
||||
result, err := rl.client.EvalSha(
|
||||
ctx,
|
||||
rl.limitScriptSHA,
|
||||
[]string{key},
|
||||
config.Requested,
|
||||
config.Rate,
|
||||
config.Capacity,
|
||||
).Int()
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("rate limit failed: %w", err)
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
// Config 配置选项模式
|
||||
type Config struct {
|
||||
Capacity int64
|
||||
Rate int64
|
||||
Requested int64
|
||||
}
|
||||
|
||||
type Option func(*Config)
|
||||
|
||||
func WithCapacity(c int64) Option {
|
||||
return func(cfg *Config) { cfg.Capacity = c }
|
||||
}
|
||||
|
||||
func WithRate(r int64) Option {
|
||||
return func(cfg *Config) { cfg.Rate = r }
|
||||
}
|
||||
|
||||
func WithRequested(n int64) Option {
|
||||
return func(cfg *Config) { cfg.Requested = n }
|
||||
}
|
||||
44
common/limiter/lua/rate_limit.lua
Normal file
44
common/limiter/lua/rate_limit.lua
Normal file
@@ -0,0 +1,44 @@
|
||||
-- 令牌桶限流器
|
||||
-- KEYS[1]: 限流器唯一标识
|
||||
-- ARGV[1]: 请求令牌数 (通常为1)
|
||||
-- ARGV[2]: 令牌生成速率 (每秒)
|
||||
-- ARGV[3]: 桶容量
|
||||
|
||||
local key = KEYS[1]
|
||||
local requested = tonumber(ARGV[1])
|
||||
local rate = tonumber(ARGV[2])
|
||||
local capacity = tonumber(ARGV[3])
|
||||
|
||||
-- 获取当前时间(Redis服务器时间)
|
||||
local now = redis.call('TIME')
|
||||
local nowInSeconds = tonumber(now[1])
|
||||
|
||||
-- 获取桶状态
|
||||
local bucket = redis.call('HMGET', key, 'tokens', 'last_time')
|
||||
local tokens = tonumber(bucket[1])
|
||||
local last_time = tonumber(bucket[2])
|
||||
|
||||
-- 初始化桶(首次请求或过期)
|
||||
if not tokens or not last_time then
|
||||
tokens = capacity
|
||||
last_time = nowInSeconds
|
||||
else
|
||||
-- 计算新增令牌
|
||||
local elapsed = nowInSeconds - last_time
|
||||
local add_tokens = elapsed * rate
|
||||
tokens = math.min(capacity, tokens + add_tokens)
|
||||
last_time = nowInSeconds
|
||||
end
|
||||
|
||||
-- 判断是否允许请求
|
||||
local allowed = false
|
||||
if tokens >= requested then
|
||||
tokens = tokens - requested
|
||||
allowed = true
|
||||
end
|
||||
|
||||
---- 更新桶状态并设置过期时间
|
||||
redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time)
|
||||
--redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间
|
||||
|
||||
return allowed and 1 or 0
|
||||
@@ -4,32 +4,39 @@ import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
|
||||
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
|
||||
var MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
|
||||
var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
|
||||
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
|
||||
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
|
||||
var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
|
||||
var StreamingTimeout int
|
||||
var DifyDebug bool
|
||||
var MaxFileDownloadMB int
|
||||
var ForceStreamOption bool
|
||||
var GetMediaToken bool
|
||||
var GetMediaTokenNotStream bool
|
||||
var UpdateTask bool
|
||||
var AzureDefaultAPIVersion string
|
||||
var GeminiVisionMaxImageNum int
|
||||
var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
|
||||
//var GeminiModelMap = map[string]string{
|
||||
// "gemini-1.0-pro": "v1",
|
||||
//}
|
||||
|
||||
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
|
||||
var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
|
||||
func InitEnv() {
|
||||
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
|
||||
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
|
||||
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
|
||||
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
||||
//if modelVersionMapStr == "" {
|
||||
// return
|
||||
@@ -43,6 +50,3 @@ func InitEnv() {
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
|
||||
3
constant/setup.go
Normal file
3
constant/setup.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package constant
|
||||
|
||||
var Setup = false
|
||||
@@ -1,11 +1,12 @@
|
||||
package constant
|
||||
|
||||
var (
|
||||
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
|
||||
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
|
||||
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
|
||||
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
|
||||
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
|
||||
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
|
||||
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
|
||||
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
|
||||
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
|
||||
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
|
||||
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -103,11 +103,19 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
}
|
||||
|
||||
request := buildTestRequest(testModel)
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info))
|
||||
// 创建一个用于日志的 info 副本,移除 ApiKey
|
||||
logInfo := *info
|
||||
logInfo.ApiKey = ""
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
|
||||
adaptor.Init(info)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, info, request)
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
@@ -125,7 +133,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(httpResp)
|
||||
err := service.RelayErrorHandler(httpResp, true)
|
||||
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
||||
}
|
||||
}
|
||||
@@ -143,10 +151,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
return err, nil
|
||||
}
|
||||
info.PromptTokens = usage.PromptTokens
|
||||
priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens))
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
|
||||
quota := 0
|
||||
if !priceData.UsePrice {
|
||||
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
|
||||
@@ -184,10 +189,14 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
return testRequest
|
||||
}
|
||||
// 并非Embedding 模型
|
||||
if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") {
|
||||
if strings.HasPrefix(model, "o") {
|
||||
testRequest.MaxCompletionTokens = 10
|
||||
} else if strings.Contains(model, "thinking") {
|
||||
testRequest.MaxTokens = 50
|
||||
if !strings.Contains(model, "claude") {
|
||||
testRequest.MaxTokens = 50
|
||||
}
|
||||
} else if strings.Contains(model, "gemini") {
|
||||
testRequest.MaxTokens = 300
|
||||
} else {
|
||||
testRequest.MaxTokens = 10
|
||||
}
|
||||
|
||||
@@ -119,6 +119,9 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
if channel.Type == common.ChannelTypeGemini {
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
||||
}
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -139,7 +142,11 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
|
||||
var ids []string
|
||||
for _, model := range result.Data {
|
||||
ids = append(ids, model.ID)
|
||||
id := model.ID
|
||||
if channel.Type == common.ChannelTypeGemini {
|
||||
id = strings.TrimPrefix(id, "models/")
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
9
controller/image.go
Normal file
9
controller/image.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetImage(c *gin.Context) {
|
||||
|
||||
}
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -34,40 +36,44 @@ func GetStatus(c *gin.Context) {
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||
"setup": constant.Setup,
|
||||
},
|
||||
})
|
||||
return
|
||||
|
||||
240
controller/oidc.go
Normal file
240
controller/oidc.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OidcResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type OidcUser struct {
|
||||
OpenID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
|
||||
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oidcResponse OidcResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if oidcResponse.AccessToken == "" {
|
||||
common.SysError("OIDC 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysError("OIDC 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var oidcUser OidcUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||
common.SysError("OIDC 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &oidcUser, nil
|
||||
}
|
||||
|
||||
func OidcAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
OidcBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
err := user.FillUserByOidcId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Email = oidcUser.Email
|
||||
if oidcUser.PreferredUsername != "" {
|
||||
user.Username = oidcUser.PreferredUsername
|
||||
} else {
|
||||
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if oidcUser.Name != "" {
|
||||
user.DisplayName = oidcUser.Name
|
||||
} else {
|
||||
user.DisplayName = "OIDC User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func OidcBind(c *gin.Context) {
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 OIDC 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.OidcId = oidcUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -51,6 +52,14 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "oidc.enabled":
|
||||
if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 OIDC 登录,请先填入 OIDC Client Id 以及 OIDC Client Secret!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "LinuxDOOAuthEnabled":
|
||||
if option.Value == "true" && common.LinuxDOClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -81,6 +90,15 @@ func UpdateOption(c *gin.Context) {
|
||||
"success": false,
|
||||
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
case "TelegramOAuthEnabled":
|
||||
if option.Value == "true" && common.TelegramBotToken == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 Telegram OAuth,请先填入 Telegram Bot Token!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
@@ -92,6 +110,7 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
err = model.UpdateOption(option.Key, option.Value)
|
||||
if err != nil {
|
||||
|
||||
@@ -148,6 +148,50 @@ func WssRelay(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func RelayClaude(c *gin.Context) {
|
||||
//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var claudeErr *dto.ClaudeErrorWithStatusCode
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
break
|
||||
}
|
||||
|
||||
claudeErr = claudeRequest(c, channel)
|
||||
|
||||
if claudeErr == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
openaiErr := service.ClaudeErrorToOpenAIError(claudeErr)
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if claudeErr != nil {
|
||||
claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId)
|
||||
c.JSON(claudeErr.StatusCode, gin.H{
|
||||
"type": "error",
|
||||
"error": claudeErr.Error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
@@ -162,6 +206,13 @@ func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *mode
|
||||
return relay.WssHelper(c, ws)
|
||||
}
|
||||
|
||||
func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relay.ClaudeHelper(c)
|
||||
}
|
||||
|
||||
func addUsedChannel(c *gin.Context, channelId int) {
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
|
||||
173
controller/setup.go
Normal file
173
controller/setup.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/setting/operation_setting"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Setup struct {
|
||||
Status bool `json:"status"`
|
||||
RootInit bool `json:"root_init"`
|
||||
DatabaseType string `json:"database_type"`
|
||||
}
|
||||
|
||||
type SetupRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
ConfirmPassword string `json:"confirmPassword"`
|
||||
SelfUseModeEnabled bool `json:"SelfUseModeEnabled"`
|
||||
DemoSiteEnabled bool `json:"DemoSiteEnabled"`
|
||||
}
|
||||
|
||||
func GetSetup(c *gin.Context) {
|
||||
setup := Setup{
|
||||
Status: constant.Setup,
|
||||
}
|
||||
if constant.Setup {
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": setup,
|
||||
})
|
||||
return
|
||||
}
|
||||
setup.RootInit = model.RootUserExists()
|
||||
if common.UsingMySQL {
|
||||
setup.DatabaseType = "mysql"
|
||||
}
|
||||
if common.UsingPostgreSQL {
|
||||
setup.DatabaseType = "postgres"
|
||||
}
|
||||
if common.UsingSQLite {
|
||||
setup.DatabaseType = "sqlite"
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": setup,
|
||||
})
|
||||
}
|
||||
|
||||
func PostSetup(c *gin.Context) {
|
||||
// Check if setup is already completed
|
||||
if constant.Setup {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "系统已经初始化完成",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if root user already exists
|
||||
rootExists := model.RootUserExists()
|
||||
|
||||
var req SetupRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "请求参数有误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// If root doesn't exist, validate and create admin account
|
||||
if !rootExists {
|
||||
// Validate password
|
||||
if req.Password != req.ConfirmPassword {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "两次输入的密码不一致",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Password) < 8 {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "密码长度至少为8个字符",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create root user
|
||||
hashedPassword, err := common.Password2Hash(req.Password)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "系统错误: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rootUser := model.User{
|
||||
Username: req.Username,
|
||||
Password: hashedPassword,
|
||||
Role: common.RoleRootUser,
|
||||
Status: common.UserStatusEnabled,
|
||||
DisplayName: "Root User",
|
||||
AccessToken: nil,
|
||||
Quota: 100000000,
|
||||
}
|
||||
err = model.DB.Create(&rootUser).Error
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "创建管理员账号失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set operation modes
|
||||
operation_setting.SelfUseModeEnabled = req.SelfUseModeEnabled
|
||||
operation_setting.DemoSiteEnabled = req.DemoSiteEnabled
|
||||
|
||||
// Save operation modes to database for persistence
|
||||
err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "保存自用模式设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "保存演示站点模式设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Update setup status
|
||||
constant.Setup = true
|
||||
|
||||
setup := model.Setup{
|
||||
Version: common.Version,
|
||||
InitializedAt: time.Now().Unix(),
|
||||
}
|
||||
err = model.DB.Create(&setup).Error
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "系统初始化失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "系统初始化成功",
|
||||
})
|
||||
}
|
||||
|
||||
func boolToString(b bool) string {
|
||||
if b {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
@@ -913,11 +913,12 @@ func TopUp(c *gin.Context) {
|
||||
}
|
||||
|
||||
type UpdateUserSettingRequest struct {
|
||||
QuotaWarningType string `json:"notify_type"`
|
||||
QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
|
||||
WebhookUrl string `json:"webhook_url,omitempty"`
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||
NotificationEmail string `json:"notification_email,omitempty"`
|
||||
QuotaWarningType string `json:"notify_type"`
|
||||
QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
|
||||
WebhookUrl string `json:"webhook_url,omitempty"`
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||
NotificationEmail string `json:"notification_email,omitempty"`
|
||||
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
||||
}
|
||||
|
||||
func UpdateUserSetting(c *gin.Context) {
|
||||
@@ -993,6 +994,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
settings := map[string]interface{}{
|
||||
constant.UserSettingNotifyType: req.QuotaWarningType,
|
||||
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
|
||||
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
|
||||
}
|
||||
|
||||
// 如果是webhook类型,添加webhook相关设置
|
||||
|
||||
@@ -15,6 +15,7 @@ services:
|
||||
- SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- TZ=Asia/Shanghai
|
||||
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache,请取消注释
|
||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
||||
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
- 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
|
||||
|
||||
3. thinking_to_content
|
||||
- 用于标识是否将思考内容`reasoning_conetnt`转换为`<think>`标签拼接到内容中返回
|
||||
- 用于标识是否将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回
|
||||
- 类型为布尔值,设置为 true 时启用思考内容转换
|
||||
|
||||
--------------------------------------------------------------
|
||||
|
||||
218
dto/claude.go
Normal file
218
dto/claude.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package dto
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ClaudeMediaMessage struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text *string `json:"text,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Source *ClaudeMessageSource `json:"source,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
PartialJson *string `json:"partial_json,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
// tool_calls
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
ToolUseId string `json:"tool_use_id,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) SetText(s string) {
|
||||
c.Text = &s
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) GetText() string {
|
||||
if c.Text == nil {
|
||||
return ""
|
||||
}
|
||||
return *c.Text
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) IsStringContent() bool {
|
||||
var content string
|
||||
return json.Unmarshal(c.Content, &content) == nil
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) GetStringContent() string {
|
||||
var content string
|
||||
if err := json.Unmarshal(c.Content, &content); err == nil {
|
||||
return content
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) GetJsonRowString() string {
|
||||
jsonContent, _ := json.Marshal(c)
|
||||
return string(jsonContent)
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) SetContent(content any) {
|
||||
jsonContent, _ := json.Marshal(content)
|
||||
c.Content = jsonContent
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
|
||||
var mediaContent []ClaudeMediaMessage
|
||||
if err := json.Unmarshal(c.Content, &mediaContent); err == nil {
|
||||
return mediaContent
|
||||
}
|
||||
return make([]ClaudeMediaMessage, 0)
|
||||
}
|
||||
|
||||
type ClaudeMessageSource struct {
|
||||
Type string `json:"type"`
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
Url string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) IsStringContent() bool {
|
||||
_, ok := c.Content.(string)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) GetStringContent() string {
|
||||
if c.IsStringContent() {
|
||||
return c.Content.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) SetStringContent(content string) {
|
||||
c.Content = content
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
|
||||
// map content to []ClaudeMediaMessage
|
||||
// parse to json
|
||||
jsonContent, _ := json.Marshal(c.Content)
|
||||
var contentList []ClaudeMediaMessage
|
||||
err := json.Unmarshal(jsonContent, &contentList)
|
||||
if err != nil {
|
||||
return make([]ClaudeMediaMessage, 0), err
|
||||
}
|
||||
return contentList, nil
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]interface{} `json:"input_schema"`
|
||||
}
|
||||
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties any `json:"properties,omitempty"`
|
||||
Required any `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
System any `json:"system,omitempty"`
|
||||
Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Thinking *Thinking `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
type Thinking struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens int `json:"budget_tokens"`
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) IsStringSystem() bool {
|
||||
_, ok := c.System.(string)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) GetStringSystem() string {
|
||||
if c.IsStringSystem() {
|
||||
return c.System.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) SetStringSystem(system string) {
|
||||
c.System = system
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
|
||||
// map content to []ClaudeMediaMessage
|
||||
// parse to json
|
||||
jsonContent, _ := json.Marshal(c.System)
|
||||
var contentList []ClaudeMediaMessage
|
||||
if err := json.Unmarshal(jsonContent, &contentList); err == nil {
|
||||
return contentList
|
||||
}
|
||||
return make([]ClaudeMediaMessage, 0)
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeErrorWithStatusCode struct {
|
||||
Error ClaudeError `json:"error"`
|
||||
StatusCode int `json:"status_code"`
|
||||
LocalError bool
|
||||
}
|
||||
|
||||
type ClaudeResponse struct {
|
||||
Id string `json:"id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []ClaudeMediaMessage `json:"content,omitempty"`
|
||||
Completion string `json:"completion,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Error *ClaudeError `json:"error,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
||||
Delta *ClaudeMediaMessage `json:"delta,omitempty"`
|
||||
Message *ClaudeMediaMessage `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// set index
|
||||
func (c *ClaudeResponse) SetIndex(i int) {
|
||||
c.Index = &i
|
||||
}
|
||||
|
||||
// get index
|
||||
func (c *ClaudeResponse) GetIndex() int {
|
||||
if c.Index == nil {
|
||||
return 0
|
||||
}
|
||||
return *c.Index
|
||||
}
|
||||
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
19
dto/dalle.go
19
dto/dalle.go
@@ -1,14 +1,17 @@
|
||||
package dto
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
|
||||
@@ -18,39 +18,40 @@ type FormatJsonSchema struct {
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Prefix any `json:"prefix,omitempty"`
|
||||
Suffix any `json:"suffix,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
EncodingFormat any `json:"encoding_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools []ToolCallRequest `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities any `json:"modalities,omitempty"`
|
||||
Audio any `json:"audio,omitempty"`
|
||||
ExtraBody any `json:"extra_body,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Prefix any `json:"prefix,omitempty"`
|
||||
Suffix any `json:"suffix,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
//Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
EncodingFormat any `json:"encoding_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools []ToolCallRequest `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities any `json:"modalities,omitempty"`
|
||||
Audio any `json:"audio,omitempty"`
|
||||
ExtraBody any `json:"extra_body,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCallRequest struct {
|
||||
@@ -111,11 +112,38 @@ type MediaContent struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageUrl any `json:"image_url,omitempty"`
|
||||
InputAudio any `json:"input_audio,omitempty"`
|
||||
File any `json:"file,omitempty"`
|
||||
}
|
||||
|
||||
func (m *MediaContent) GetImageMedia() *MessageImageUrl {
|
||||
if m.ImageUrl != nil {
|
||||
return m.ImageUrl.(*MessageImageUrl)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MediaContent) GetInputAudio() *MessageInputAudio {
|
||||
if m.InputAudio != nil {
|
||||
return m.InputAudio.(*MessageInputAudio)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MediaContent) GetFile() *MessageFile {
|
||||
if m.File != nil {
|
||||
return m.File.(*MessageFile)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type MessageImageUrl struct {
|
||||
Url string `json:"url"`
|
||||
Detail string `json:"detail"`
|
||||
Url string `json:"url"`
|
||||
Detail string `json:"detail"`
|
||||
MimeType string
|
||||
}
|
||||
|
||||
func (m *MessageImageUrl) IsRemoteImage() bool {
|
||||
return strings.HasPrefix(m.Url, "http")
|
||||
}
|
||||
|
||||
type MessageInputAudio struct {
|
||||
@@ -123,10 +151,17 @@ type MessageInputAudio struct {
|
||||
Format string `json:"format"`
|
||||
}
|
||||
|
||||
type MessageFile struct {
|
||||
FileName string `json:"filename,omitempty"`
|
||||
FileData string `json:"file_data,omitempty"`
|
||||
FileId string `json:"file_id,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
ContentTypeInputAudio = "input_audio"
|
||||
ContentTypeFile = "file"
|
||||
)
|
||||
|
||||
func (m *Message) GetPrefix() bool {
|
||||
@@ -180,6 +215,12 @@ func (m *Message) StringContent() string {
|
||||
return stringContent
|
||||
}
|
||||
|
||||
func (m *Message) SetNullContent() {
|
||||
m.Content = nil
|
||||
m.parsedStringContent = nil
|
||||
m.parsedContent = nil
|
||||
}
|
||||
|
||||
func (m *Message) SetStringContent(content string) {
|
||||
jsonContent, _ := json.Marshal(content)
|
||||
m.Content = jsonContent
|
||||
@@ -244,44 +285,64 @@ func (m *Message) ParseContent() []MediaContent {
|
||||
|
||||
case ContentTypeImageURL:
|
||||
imageUrl := contentItem["image_url"]
|
||||
temp := &MessageImageUrl{
|
||||
Detail: "high",
|
||||
}
|
||||
switch v := imageUrl.(type) {
|
||||
case string:
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: MessageImageUrl{
|
||||
Url: v,
|
||||
Detail: "high",
|
||||
},
|
||||
})
|
||||
temp.Url = v
|
||||
case map[string]interface{}:
|
||||
url, ok1 := v["url"].(string)
|
||||
detail, ok2 := v["detail"].(string)
|
||||
if !ok2 {
|
||||
detail = "high"
|
||||
if ok2 {
|
||||
temp.Detail = detail
|
||||
}
|
||||
if ok1 {
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: MessageImageUrl{
|
||||
Url: url,
|
||||
Detail: detail,
|
||||
},
|
||||
})
|
||||
temp.Url = url
|
||||
}
|
||||
}
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: temp,
|
||||
})
|
||||
|
||||
case ContentTypeInputAudio:
|
||||
if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
|
||||
data, ok1 := audioData["data"].(string)
|
||||
format, ok2 := audioData["format"].(string)
|
||||
if ok1 && ok2 {
|
||||
temp := &MessageInputAudio{
|
||||
Data: data,
|
||||
Format: format,
|
||||
}
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeInputAudio,
|
||||
InputAudio: MessageInputAudio{
|
||||
Data: data,
|
||||
Format: format,
|
||||
Type: ContentTypeInputAudio,
|
||||
InputAudio: temp,
|
||||
})
|
||||
}
|
||||
}
|
||||
case ContentTypeFile:
|
||||
if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
|
||||
fileId, ok3 := fileData["file_id"].(string)
|
||||
if ok3 {
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeFile,
|
||||
File: &MessageFile{
|
||||
FileId: fileId,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
fileName, ok1 := fileData["filename"].(string)
|
||||
fileDataStr, ok2 := fileData["file_data"].(string)
|
||||
if ok1 && ok2 {
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeFile,
|
||||
File: &MessageFile{
|
||||
FileName: fileName,
|
||||
FileData: fileDataStr,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,8 @@
|
||||
package dto
|
||||
|
||||
type TextResponseWithError struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Data []OpenAIEmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage `json:"usage"`
|
||||
Error OpenAIError `json:"error"`
|
||||
}
|
||||
|
||||
type SimpleResponse struct {
|
||||
Usage `json:"usage"`
|
||||
Error OpenAIError `json:"error"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Usage `json:"usage"`
|
||||
Error *OpenAIError `json:"error"`
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
@@ -38,6 +26,7 @@ type OpenAITextResponse struct {
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
@@ -125,6 +114,20 @@ type ChatCompletionsStreamResponse struct {
|
||||
Usage *Usage `json:"usage"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
|
||||
if len(c.Choices) == 0 {
|
||||
return false
|
||||
}
|
||||
return len(c.Choices[0].Delta.ToolCalls) > 0
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse {
|
||||
if c.IsToolCall() {
|
||||
return &c.Choices[0].Delta.ToolCalls[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
|
||||
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
|
||||
copy(choices, c.Choices)
|
||||
@@ -170,3 +173,17 @@ type Usage struct {
|
||||
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
|
||||
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
|
||||
}
|
||||
|
||||
type InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
CachedCreationTokens int `json:"-"`
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ImageTokens int `json:"image_tokens"`
|
||||
}
|
||||
|
||||
type OutputTokenDetails struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
}
|
||||
|
||||
@@ -43,18 +43,6 @@ type RealtimeUsage struct {
|
||||
OutputTokenDetails OutputTokenDetails `json:"output_token_details"`
|
||||
}
|
||||
|
||||
type InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ImageTokens int `json:"image_tokens"`
|
||||
}
|
||||
|
||||
type OutputTokenDetails struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
}
|
||||
|
||||
type RealtimeSession struct {
|
||||
Modalities []string `json:"modalities"`
|
||||
Instructions string `json:"instructions"`
|
||||
|
||||
@@ -5,18 +5,29 @@ type RerankRequest struct {
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model"`
|
||||
TopN int `json:"top_n"`
|
||||
ReturnDocuments bool `json:"return_documents,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
|
||||
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type RerankResponseDocument struct {
|
||||
func (r *RerankRequest) GetReturnDocuments() bool {
|
||||
if r.ReturnDocuments == nil {
|
||||
return false
|
||||
}
|
||||
return *r.ReturnDocuments
|
||||
}
|
||||
|
||||
type RerankResponseResult struct {
|
||||
Document any `json:"document,omitempty"`
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
}
|
||||
|
||||
type RerankResponse struct {
|
||||
Results []RerankResponseDocument `json:"results"`
|
||||
Usage Usage `json:"usage"`
|
||||
type RerankDocument struct {
|
||||
Text any `json:"text"`
|
||||
}
|
||||
|
||||
type RerankResponse struct {
|
||||
Results []RerankResponseResult `json:"results"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
12
go.mod
12
go.mod
@@ -11,6 +11,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
||||
github.com/bytedance/sonic v1.11.6
|
||||
github.com/gin-contrib/cors v1.7.2
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
@@ -28,9 +29,9 @@ require (
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
golang.org/x/crypto v0.27.0
|
||||
golang.org/x/crypto v0.35.0
|
||||
golang.org/x/image v0.23.0
|
||||
golang.org/x/net v0.28.0
|
||||
golang.org/x/net v0.35.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/gorm v1.25.2
|
||||
@@ -42,7 +43,6 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
||||
github.com/aws/smithy-go v1.20.2 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
@@ -84,9 +84,9 @@ require (
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.12.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||
golang.org/x/sync v0.10.0 // indirect
|
||||
golang.org/x/sys v0.27.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.22.5 // indirect
|
||||
|
||||
20
go.sum
20
go.sum
@@ -217,18 +217,18 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
|
||||
golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg=
|
||||
golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
|
||||
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
|
||||
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
|
||||
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
|
||||
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
||||
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
|
||||
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -239,14 +239,14 @@ golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
|
||||
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
|
||||
9
main.go
9
main.go
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/router"
|
||||
"one-api/service"
|
||||
"one-api/setting/operation_setting"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
@@ -33,7 +34,7 @@ var indexPage []byte
|
||||
func main() {
|
||||
err := godotenv.Load(".env")
|
||||
if err != nil {
|
||||
common.SysLog("Support for .env file is disabled")
|
||||
common.SysLog("Support for .env file is disabled: " + err.Error())
|
||||
}
|
||||
|
||||
common.LoadEnv()
|
||||
@@ -51,6 +52,9 @@ func main() {
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize database: " + err.Error())
|
||||
}
|
||||
|
||||
model.CheckSetup()
|
||||
|
||||
// Initialize SQL Database
|
||||
err = model.InitLogDB()
|
||||
if err != nil {
|
||||
@@ -69,10 +73,13 @@ func main() {
|
||||
common.FatalLog("failed to initialize Redis: " + err.Error())
|
||||
}
|
||||
|
||||
// Initialize model settings
|
||||
operation_setting.InitModelSettings()
|
||||
// Initialize constants
|
||||
constant.InitEnv()
|
||||
// Initialize options
|
||||
model.InitOptionMap()
|
||||
|
||||
if common.RedisEnabled {
|
||||
// for compatibility with old versions
|
||||
common.MemoryCacheEnabled = true
|
||||
|
||||
@@ -174,6 +174,14 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
// 检查path包含/v1/messages
|
||||
if strings.Contains(c.Request.URL.Path, "/v1/messages") {
|
||||
// 从x-api-key中获取key
|
||||
key := c.Request.Header.Get("x-api-key")
|
||||
if key != "" {
|
||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
}
|
||||
key := c.Request.Header.Get("Authorization")
|
||||
parts := make([]string, 0)
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
|
||||
@@ -212,6 +212,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
c.Set("channel_name", channel.Name)
|
||||
c.Set("channel_type", channel.Type)
|
||||
c.Set("channel_setting", channel.GetSetting())
|
||||
c.Set("param_override", channel.GetParamOverride())
|
||||
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/limiter"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -78,21 +79,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
|
||||
ctx := context.Background()
|
||||
rdb := common.RDB
|
||||
|
||||
// 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过)
|
||||
totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
|
||||
allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
|
||||
if err != nil {
|
||||
fmt.Println("检查总请求数限制失败:", err.Error())
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
|
||||
return
|
||||
}
|
||||
if !allowed {
|
||||
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
|
||||
}
|
||||
|
||||
// 2. 检查成功请求数限制
|
||||
// 1. 检查成功请求数限制
|
||||
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
|
||||
allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
|
||||
allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
|
||||
if err != nil {
|
||||
fmt.Println("检查成功请求数限制失败:", err.Error())
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
|
||||
@@ -103,8 +92,27 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 记录总请求(当totalMaxCount为0时会自动跳过)
|
||||
recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
|
||||
//2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
|
||||
totalKey := fmt.Sprintf("rateLimit:%s", userId)
|
||||
// 初始化
|
||||
tb := limiter.New(ctx, rdb)
|
||||
allowed, err = tb.Allow(
|
||||
ctx,
|
||||
totalKey,
|
||||
limiter.WithCapacity(int64(totalMaxCount)*duration),
|
||||
limiter.WithRate(int64(totalMaxCount)),
|
||||
limiter.WithRequested(duration),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("检查总请求数限制失败:", err.Error())
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
|
||||
}
|
||||
|
||||
// 4. 处理请求
|
||||
c.Next()
|
||||
|
||||
@@ -35,7 +35,8 @@ type Channel struct {
|
||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||
OtherInfo string `json:"other_info"`
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
Setting string `json:"setting" gorm:"type:text"`
|
||||
Setting *string `json:"setting" gorm:"type:text"`
|
||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||
}
|
||||
|
||||
func (channel *Channel) GetModels() []string {
|
||||
@@ -493,8 +494,8 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
|
||||
func (channel *Channel) GetSetting() map[string]interface{} {
|
||||
setting := make(map[string]interface{})
|
||||
if channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(channel.Setting), &setting)
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
}
|
||||
@@ -508,7 +509,18 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
return
|
||||
}
|
||||
channel.Setting = string(settingBytes)
|
||||
channel.Setting = common.GetPointer[string](string(settingBytes))
|
||||
}
|
||||
|
||||
func (channel *Channel) GetParamOverride() map[string]interface{} {
|
||||
paramOverride := make(map[string]interface{})
|
||||
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
|
||||
err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal param override: " + err.Error())
|
||||
}
|
||||
}
|
||||
return paramOverride
|
||||
}
|
||||
|
||||
func GetChannelsByIds(ids []int) ([]*Channel, error) {
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var groupCol string
|
||||
@@ -54,13 +56,40 @@ func createRootAccountIfNeed() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func CheckSetup() {
|
||||
setup := GetSetup()
|
||||
if setup == nil {
|
||||
// No setup record exists, check if we have a root user
|
||||
if RootUserExists() {
|
||||
common.SysLog("system is not initialized, but root user exists")
|
||||
// Create setup record
|
||||
newSetup := Setup{
|
||||
Version: common.Version,
|
||||
InitializedAt: time.Now().Unix(),
|
||||
}
|
||||
err := DB.Create(&newSetup).Error
|
||||
if err != nil {
|
||||
common.SysLog("failed to create setup record: " + err.Error())
|
||||
}
|
||||
constant.Setup = true
|
||||
} else {
|
||||
common.SysLog("system is not initialized and no root user exists")
|
||||
constant.Setup = false
|
||||
}
|
||||
} else {
|
||||
// Setup record exists, system is initialized
|
||||
common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
|
||||
constant.Setup = true
|
||||
}
|
||||
}
|
||||
|
||||
func chooseDB(envName string) (*gorm.DB, error) {
|
||||
defer func() {
|
||||
initCol()
|
||||
}()
|
||||
dsn := os.Getenv(envName)
|
||||
if dsn != "" {
|
||||
if strings.HasPrefix(dsn, "postgres://") {
|
||||
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
||||
// Use PostgreSQL
|
||||
common.SysLog("using PostgreSQL as database")
|
||||
common.UsingPostgreSQL = true
|
||||
@@ -213,8 +242,9 @@ func migrateDB() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&Setup{})
|
||||
common.SysLog("database migrated")
|
||||
err = createRootAccountIfNeed()
|
||||
//err = createRootAccountIfNeed()
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
16
model/setup.go
Normal file
16
model/setup.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package model
|
||||
|
||||
type Setup struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
Version string `json:"version" gorm:"type:varchar(50);not null"`
|
||||
InitializedAt int64 `json:"initialized_at" gorm:"type:bigint;not null"`
|
||||
}
|
||||
|
||||
func GetSetup() *Setup {
|
||||
var setup Setup
|
||||
err := DB.First(&setup).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &setup
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -24,6 +23,7 @@ type User struct {
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
@@ -108,7 +108,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
||||
|
||||
func GetMaxUserId() int {
|
||||
var user User
|
||||
DB.Last(&user)
|
||||
DB.Unscoped().Last(&user)
|
||||
return user.Id
|
||||
}
|
||||
|
||||
@@ -442,6 +442,14 @@ func (user *User) FillUserByGitHubId() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByOidcId() error {
|
||||
if user.OidcId == "" {
|
||||
return errors.New("oidc id 为空!")
|
||||
}
|
||||
DB.Where(User{OidcId: user.OidcId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByWeChatId() error {
|
||||
if user.WeChatId == "" {
|
||||
return errors.New("WeChat id 为空!")
|
||||
@@ -473,6 +481,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
|
||||
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsOidcIdAlreadyTaken(oidcId string) bool {
|
||||
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsTelegramIdAlreadyTaken(telegramId string) bool {
|
||||
return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
@@ -796,3 +808,12 @@ func (user *User) FillUserByLinuxDOId() error {
|
||||
err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func RootUserExists() bool {
|
||||
var user User
|
||||
err := DB.Where("role = ?", common.RoleRootUser).First(&user).Error
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ type Adaptor interface {
|
||||
Init(info *relaycommon.RelayInfo)
|
||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
|
||||
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
|
||||
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
||||
@@ -22,6 +22,7 @@ type Adaptor interface {
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
||||
}
|
||||
|
||||
type TaskAdaptor interface {
|
||||
|
||||
@@ -16,6 +16,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
@@ -44,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -87,7 +93,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -26,8 +26,8 @@ func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
|
||||
return &imageRequest
|
||||
}
|
||||
|
||||
func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) {
|
||||
url := fmt.Sprintf("/api/v1/tasks/%s", taskID)
|
||||
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
|
||||
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID)
|
||||
|
||||
var aliResponse AliResponse
|
||||
|
||||
@@ -36,7 +36,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliRes
|
||||
return &aliResponse, err, nil
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
@@ -58,7 +58,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliRes
|
||||
return &response, nil, responseBody
|
||||
}
|
||||
|
||||
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, []byte, error) {
|
||||
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
|
||||
waitSeconds := 3
|
||||
step := 0
|
||||
maxStep := 20
|
||||
@@ -68,7 +68,7 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*Ali
|
||||
|
||||
for {
|
||||
step++
|
||||
rsp, err, body := updateTask(info, taskID, key)
|
||||
rsp, err, body := updateTask(info, taskID)
|
||||
responseBody = body
|
||||
if err != nil {
|
||||
return &taskResponse, responseBody, err
|
||||
@@ -125,8 +125,6 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
|
||||
}
|
||||
|
||||
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
responseFormat := c.GetString("response_format")
|
||||
|
||||
var aliTaskResponse AliResponse
|
||||
@@ -148,7 +146,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId, apiKey)
|
||||
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"net/http"
|
||||
common2 "one-api/common"
|
||||
"one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/service"
|
||||
@@ -31,6 +32,9 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||
}
|
||||
if common2.DebugEnabled {
|
||||
println("fullRequestURL:", fullRequestURL)
|
||||
}
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
|
||||
@@ -20,6 +20,12 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", request)
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -43,12 +49,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
var claudeReq *claude.ClaudeRequest
|
||||
var claudeReq *dto.ClaudeRequest
|
||||
var err error
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
|
||||
if err != nil {
|
||||
|
||||
@@ -13,4 +13,41 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
}
|
||||
|
||||
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"ap": true,
|
||||
},
|
||||
"anthropic.claude-3-opus-20240229-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"ap": true,
|
||||
},
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"ap": true,
|
||||
},
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0": {
|
||||
"us": true,
|
||||
"ap": true,
|
||||
},
|
||||
"anthropic.claude-3-5-haiku-20241022-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
}
|
||||
|
||||
var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
"us": "us",
|
||||
"eu": "eu",
|
||||
"ap": "apac",
|
||||
}
|
||||
|
||||
var ChannelName = "aws"
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/dto"
|
||||
)
|
||||
|
||||
type AwsClaudeRequest struct {
|
||||
// AnthropicVersion should be "bedrock-2023-05-31"
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []claude.ClaudeMessage `json:"messages"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Thinking *claude.Thinking `json:"thinking,omitempty"`
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
System any `json:"system,omitempty"`
|
||||
Messages []dto.ClaudeMessage `json:"messages"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Thinking *dto.Thinking `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
|
||||
func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
|
||||
return &AwsClaudeRequest{
|
||||
AnthropicVersion: "bedrock-2023-05-31",
|
||||
System: req.System,
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
relaymodel "one-api/dto"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
@@ -39,15 +34,37 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode {
|
||||
return &relaymodel.OpenAIErrorWithStatusCode{
|
||||
func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: relaymodel.OpenAIError{
|
||||
Error: dto.OpenAIError{
|
||||
Message: fmt.Sprintf("%s", err.Error()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func awsRegionPrefix(awsRegionId string) string {
|
||||
parts := strings.Split(awsRegionId, "-")
|
||||
regionPrefix := ""
|
||||
if len(parts) > 0 {
|
||||
regionPrefix = parts[0]
|
||||
}
|
||||
return regionPrefix
|
||||
}
|
||||
|
||||
func awsModelCanCrossRegion(awsModelId, awsRegionPrefix string) bool {
|
||||
regionSet, exists := awsModelCanCrossRegionMap[awsModelId]
|
||||
return exists && regionSet[awsRegionPrefix]
|
||||
}
|
||||
|
||||
func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
|
||||
modelPrefix, find := awsRegionCrossModelPrefixMap[awsRegionPrefix]
|
||||
if !find {
|
||||
return awsModelId
|
||||
}
|
||||
return modelPrefix + "." + awsModelId
|
||||
}
|
||||
|
||||
func awsModelID(requestModel string) (string, error) {
|
||||
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
||||
return awsModelID, nil
|
||||
@@ -56,7 +73,7 @@ func awsModelID(requestModel string) (string, error) {
|
||||
return requestModel, nil
|
||||
}
|
||||
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
@@ -67,6 +84,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
@@ -77,7 +100,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*claude.ClaudeRequest)
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
@@ -89,25 +112,19 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
|
||||
}
|
||||
|
||||
claudeResponse := new(claude.ClaudeResponse)
|
||||
err = json.Unmarshal(awsResp.Body, claudeResponse)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
|
||||
claudeInfo := &claude.ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
|
||||
openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
|
||||
usage := relaymodel.Usage{
|
||||
PromptTokens: claudeResponse.Usage.InputTokens,
|
||||
CompletionTokens: claudeResponse.Usage.OutputTokens,
|
||||
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
|
||||
}
|
||||
openaiResp.Usage = usage
|
||||
|
||||
c.JSON(http.StatusOK, openaiResp)
|
||||
return nil, &usage
|
||||
claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
@@ -118,6 +135,12 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
@@ -128,7 +151,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*claude.ClaudeRequest)
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
@@ -143,79 +166,31 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
stream := awsResp.GetStream()
|
||||
defer stream.Close()
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
var usage relaymodel.Usage
|
||||
var id string
|
||||
var model string
|
||||
isFirst := true
|
||||
createdTime := common.GetTimestamp()
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
event, ok := <-stream.Events()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
claudeInfo := &claude.ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
|
||||
for event := range stream.Events() {
|
||||
switch v := event.(type) {
|
||||
case *types.ResponseStreamMemberChunk:
|
||||
if isFirst {
|
||||
isFirst = false
|
||||
info.FirstResponseTime = time.Now()
|
||||
info.SetFirstResponseTime()
|
||||
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
|
||||
if respErr != nil {
|
||||
return respErr, nil
|
||||
}
|
||||
claudeResp := new(claude.ClaudeResponse)
|
||||
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
|
||||
if claudeUsage != nil {
|
||||
usage.PromptTokens += claudeUsage.InputTokens
|
||||
usage.CompletionTokens += claudeUsage.OutputTokens
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if response.Id != "" {
|
||||
id = response.Id
|
||||
}
|
||||
if response.Model != "" {
|
||||
model = response.Model
|
||||
}
|
||||
response.Created = createdTime
|
||||
response.Id = id
|
||||
response.Model = model
|
||||
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case *types.UnknownUnionMember:
|
||||
fmt.Println("unknown tag:", v.Tag)
|
||||
return false
|
||||
return wrapErr(errors.New("unknown response type")), nil
|
||||
default:
|
||||
fmt.Println("union is nil or unknown type")
|
||||
return false
|
||||
}
|
||||
})
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
return wrapErr(errors.New("nil or unknown response type")), nil
|
||||
}
|
||||
}
|
||||
helper.Done(c)
|
||||
if resp != nil {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
}
|
||||
return nil, &usage
|
||||
|
||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
@@ -16,6 +16,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -104,7 +110,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -62,7 +68,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -22,6 +22,10 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -60,7 +64,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -1,94 +1,95 @@
|
||||
package claude
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ClaudeMediaMessage struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Source *ClaudeMessageSource `json:"source,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
PartialJson string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
// tool_calls
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolUseId string `json:"tool_use_id,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeMessageSource struct {
|
||||
Type string `json:"type"`
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type ClaudeMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]interface{} `json:"input_schema"`
|
||||
}
|
||||
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties any `json:"properties,omitempty"`
|
||||
Required any `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Thinking *Thinking `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
type Thinking struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens int `json:"budget_tokens"`
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ClaudeResponse struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Content []ClaudeMediaMessage `json:"content"`
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Model string `json:"model"`
|
||||
Error ClaudeError `json:"error"`
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
Index int `json:"index"` // stream only
|
||||
ContentBlock *ClaudeMediaMessage `json:"content_block"`
|
||||
Delta *ClaudeMediaMessage `json:"delta"` // stream only
|
||||
Message *ClaudeResponse `json:"message"` // stream only: message_start
|
||||
}
|
||||
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
//
|
||||
//type ClaudeMetadata struct {
|
||||
// UserId string `json:"user_id"`
|
||||
//}
|
||||
//
|
||||
//type ClaudeMediaMessage struct {
|
||||
// Type string `json:"type"`
|
||||
// Text string `json:"text,omitempty"`
|
||||
// Source *ClaudeMessageSource `json:"source,omitempty"`
|
||||
// Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
// StopReason *string `json:"stop_reason,omitempty"`
|
||||
// PartialJson string `json:"partial_json,omitempty"`
|
||||
// Thinking string `json:"thinking,omitempty"`
|
||||
// Signature string `json:"signature,omitempty"`
|
||||
// Delta string `json:"delta,omitempty"`
|
||||
// // tool_calls
|
||||
// Id string `json:"id,omitempty"`
|
||||
// Name string `json:"name,omitempty"`
|
||||
// Input any `json:"input,omitempty"`
|
||||
// Content string `json:"content,omitempty"`
|
||||
// ToolUseId string `json:"tool_use_id,omitempty"`
|
||||
//}
|
||||
//
|
||||
//type ClaudeMessageSource struct {
|
||||
// Type string `json:"type"`
|
||||
// MediaType string `json:"media_type"`
|
||||
// Data string `json:"data"`
|
||||
//}
|
||||
//
|
||||
//type ClaudeMessage struct {
|
||||
// Role string `json:"role"`
|
||||
// Content any `json:"content"`
|
||||
//}
|
||||
//
|
||||
//type Tool struct {
|
||||
// Name string `json:"name"`
|
||||
// Description string `json:"description,omitempty"`
|
||||
// InputSchema map[string]interface{} `json:"input_schema"`
|
||||
//}
|
||||
//
|
||||
//type InputSchema struct {
|
||||
// Type string `json:"type"`
|
||||
// Properties any `json:"properties,omitempty"`
|
||||
// Required any `json:"required,omitempty"`
|
||||
//}
|
||||
//
|
||||
//type ClaudeRequest struct {
|
||||
// Model string `json:"model"`
|
||||
// Prompt string `json:"prompt,omitempty"`
|
||||
// System string `json:"system,omitempty"`
|
||||
// Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||
// MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
// MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
||||
// StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
// Temperature *float64 `json:"temperature,omitempty"`
|
||||
// TopP float64 `json:"top_p,omitempty"`
|
||||
// TopK int `json:"top_k,omitempty"`
|
||||
// //ClaudeMetadata `json:"metadata,omitempty"`
|
||||
// Stream bool `json:"stream,omitempty"`
|
||||
// Tools any `json:"tools,omitempty"`
|
||||
// ToolChoice any `json:"tool_choice,omitempty"`
|
||||
// Thinking *Thinking `json:"thinking,omitempty"`
|
||||
//}
|
||||
//
|
||||
//type Thinking struct {
|
||||
// Type string `json:"type"`
|
||||
// BudgetTokens int `json:"budget_tokens"`
|
||||
//}
|
||||
//
|
||||
//type ClaudeError struct {
|
||||
// Type string `json:"type"`
|
||||
// Message string `json:"message"`
|
||||
//}
|
||||
//
|
||||
//type ClaudeResponse struct {
|
||||
// Id string `json:"id"`
|
||||
// Type string `json:"type"`
|
||||
// Content []ClaudeMediaMessage `json:"content"`
|
||||
// Completion string `json:"completion"`
|
||||
// StopReason string `json:"stop_reason"`
|
||||
// Model string `json:"model"`
|
||||
// Error ClaudeError `json:"error"`
|
||||
// Usage ClaudeUsage `json:"usage"`
|
||||
// Index int `json:"index"` // stream only
|
||||
// ContentBlock *ClaudeMediaMessage `json:"content_block"`
|
||||
// Delta *ClaudeMediaMessage `json:"delta"` // stream only
|
||||
// Message *ClaudeResponse `json:"message"` // stream only: message_start
|
||||
//}
|
||||
//
|
||||
//type ClaudeUsage struct {
|
||||
// InputTokens int `json:"input_tokens"`
|
||||
// OutputTokens int `json:"output_tokens"`
|
||||
//}
|
||||
|
||||
@@ -24,14 +24,16 @@ func stopReasonClaude2OpenAI(reason string) string {
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "max_tokens"
|
||||
case "tool_use":
|
||||
return "tool_calls"
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
}
|
||||
|
||||
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
||||
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
|
||||
|
||||
claudeRequest := ClaudeRequest{
|
||||
claudeRequest := dto.ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
Prompt: "",
|
||||
StopSequences: nil,
|
||||
@@ -60,17 +62,19 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
|
||||
claudeTools := make([]Tool, 0, len(textRequest.Tools))
|
||||
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
|
||||
claudeTools := make([]dto.Tool, 0, len(textRequest.Tools))
|
||||
|
||||
for _, tool := range textRequest.Tools {
|
||||
if params, ok := tool.Function.Parameters.(map[string]any); ok {
|
||||
claudeTool := Tool{
|
||||
claudeTool := dto.Tool{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
}
|
||||
claudeTool.InputSchema = make(map[string]interface{})
|
||||
claudeTool.InputSchema["type"] = params["type"].(string)
|
||||
if params["type"] != nil {
|
||||
claudeTool.InputSchema["type"] = params["type"].(string)
|
||||
}
|
||||
claudeTool.InputSchema["properties"] = params["properties"]
|
||||
claudeTool.InputSchema["required"] = params["required"]
|
||||
for s, a := range params {
|
||||
@@ -83,7 +87,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
}
|
||||
}
|
||||
|
||||
claudeRequest := ClaudeRequest{
|
||||
claudeRequest := dto.ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
StopSequences: nil,
|
||||
@@ -107,7 +111,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
}
|
||||
|
||||
// BudgetTokens 为 max_tokens 的 80%
|
||||
claudeRequest.Thinking = &Thinking{
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
|
||||
}
|
||||
@@ -165,7 +169,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
lastMessage = fmtMessage
|
||||
}
|
||||
|
||||
claudeMessages := make([]ClaudeMessage, 0)
|
||||
claudeMessages := make([]dto.ClaudeMessage, 0)
|
||||
isFirstMessage := true
|
||||
for _, message := range formatMessages {
|
||||
if message.Role == "system" {
|
||||
@@ -186,63 +190,63 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
isFirstMessage = false
|
||||
if message.Role != "user" {
|
||||
// fix: first message is assistant, add user message
|
||||
claudeMessage := ClaudeMessage{
|
||||
claudeMessage := dto.ClaudeMessage{
|
||||
Role: "user",
|
||||
Content: []ClaudeMediaMessage{
|
||||
Content: []dto.ClaudeMediaMessage{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "...",
|
||||
Text: common.GetPointer[string]("..."),
|
||||
},
|
||||
},
|
||||
}
|
||||
claudeMessages = append(claudeMessages, claudeMessage)
|
||||
}
|
||||
}
|
||||
claudeMessage := ClaudeMessage{
|
||||
claudeMessage := dto.ClaudeMessage{
|
||||
Role: message.Role,
|
||||
}
|
||||
if message.Role == "tool" {
|
||||
if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
|
||||
lastMessage := claudeMessages[len(claudeMessages)-1]
|
||||
if content, ok := lastMessage.Content.(string); ok {
|
||||
lastMessage.Content = []ClaudeMediaMessage{
|
||||
lastMessage.Content = []dto.ClaudeMediaMessage{
|
||||
{
|
||||
Type: "text",
|
||||
Text: content,
|
||||
Text: common.GetPointer[string](content),
|
||||
},
|
||||
}
|
||||
}
|
||||
lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
|
||||
lastMessage.Content = append(lastMessage.Content.([]dto.ClaudeMediaMessage), dto.ClaudeMediaMessage{
|
||||
Type: "tool_result",
|
||||
ToolUseId: message.ToolCallId,
|
||||
Content: message.StringContent(),
|
||||
Content: message.Content,
|
||||
})
|
||||
claudeMessages[len(claudeMessages)-1] = lastMessage
|
||||
continue
|
||||
} else {
|
||||
claudeMessage.Role = "user"
|
||||
claudeMessage.Content = []ClaudeMediaMessage{
|
||||
claudeMessage.Content = []dto.ClaudeMediaMessage{
|
||||
{
|
||||
Type: "tool_result",
|
||||
ToolUseId: message.ToolCallId,
|
||||
Content: message.StringContent(),
|
||||
Content: message.Content,
|
||||
},
|
||||
}
|
||||
}
|
||||
} else if message.IsStringContent() && message.ToolCalls == nil {
|
||||
claudeMessage.Content = message.StringContent()
|
||||
} else {
|
||||
claudeMediaMessages := make([]ClaudeMediaMessage, 0)
|
||||
claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
|
||||
for _, mediaMessage := range message.ParseContent() {
|
||||
claudeMediaMessage := ClaudeMediaMessage{
|
||||
claudeMediaMessage := dto.ClaudeMediaMessage{
|
||||
Type: mediaMessage.Type,
|
||||
}
|
||||
if mediaMessage.Type == "text" {
|
||||
claudeMediaMessage.Text = mediaMessage.Text
|
||||
claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
|
||||
} else {
|
||||
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
|
||||
imageUrl := mediaMessage.GetImageMedia()
|
||||
claudeMediaMessage.Type = "image"
|
||||
claudeMediaMessage.Source = &ClaudeMessageSource{
|
||||
claudeMediaMessage.Source = &dto.ClaudeMessageSource{
|
||||
Type: "base64",
|
||||
}
|
||||
// 判断是否是url
|
||||
@@ -272,7 +276,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
|
||||
continue
|
||||
}
|
||||
claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
|
||||
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
|
||||
Type: "tool_use",
|
||||
Id: toolCall.ID,
|
||||
Name: toolCall.Function.Name,
|
||||
@@ -290,13 +294,19 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
return &claudeRequest, nil
|
||||
}
|
||||
|
||||
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
|
||||
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
var claudeUsage *ClaudeUsage
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = claudeResponse.Model
|
||||
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
|
||||
tools := make([]dto.ToolCallResponse, 0)
|
||||
fcIdx := 0
|
||||
if claudeResponse.Index != nil {
|
||||
fcIdx = *claudeResponse.Index - 1
|
||||
if fcIdx < 0 {
|
||||
fcIdx = 0
|
||||
}
|
||||
}
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if reqMode == RequestModeCompletion {
|
||||
choice.Delta.SetContentString(claudeResponse.Completion)
|
||||
@@ -308,7 +318,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
||||
if claudeResponse.Type == "message_start" {
|
||||
response.Id = claudeResponse.Message.Id
|
||||
response.Model = claudeResponse.Message.Model
|
||||
claudeUsage = &claudeResponse.Message.Usage
|
||||
//claudeUsage = &claudeResponse.Message.Usage
|
||||
choice.Delta.SetContentString("")
|
||||
choice.Delta.Role = "assistant"
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
@@ -316,8 +326,9 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
||||
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
|
||||
if claudeResponse.ContentBlock.Type == "tool_use" {
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
ID: claudeResponse.ContentBlock.Id,
|
||||
Type: "function",
|
||||
Index: common.GetPointer(fcIdx),
|
||||
ID: claudeResponse.ContentBlock.Id,
|
||||
Type: "function",
|
||||
Function: dto.FunctionResponse{
|
||||
Name: claudeResponse.ContentBlock.Name,
|
||||
Arguments: "",
|
||||
@@ -325,17 +336,18 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
||||
})
|
||||
}
|
||||
} else {
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
if claudeResponse.Delta != nil {
|
||||
choice.Index = claudeResponse.Index
|
||||
choice.Delta.SetContentString(claudeResponse.Delta.Text)
|
||||
choice.Delta.Content = claudeResponse.Delta.Text
|
||||
switch claudeResponse.Delta.Type {
|
||||
case "input_json_delta":
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
Type: "function",
|
||||
Index: common.GetPointer(fcIdx),
|
||||
Function: dto.FunctionResponse{
|
||||
Arguments: claudeResponse.Delta.PartialJson,
|
||||
Arguments: *claudeResponse.Delta.PartialJson,
|
||||
},
|
||||
})
|
||||
case "signature_delta":
|
||||
@@ -352,26 +364,23 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
claudeUsage = &claudeResponse.Usage
|
||||
//claudeUsage = &claudeResponse.Usage
|
||||
} else if claudeResponse.Type == "message_stop" {
|
||||
return nil, nil
|
||||
return nil
|
||||
} else {
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if claudeUsage == nil {
|
||||
claudeUsage = &ClaudeUsage{}
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
|
||||
choice.Delta.ToolCalls = tools
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
|
||||
return &response, claudeUsage
|
||||
return &response
|
||||
}
|
||||
|
||||
func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
|
||||
func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
|
||||
choices := make([]dto.OpenAITextResponseChoice, 0)
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
@@ -379,8 +388,10 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
||||
Created: common.GetTimestamp(),
|
||||
}
|
||||
var responseText string
|
||||
var responseThinking string
|
||||
if len(claudeResponse.Content) > 0 {
|
||||
responseText = claudeResponse.Content[0].Text
|
||||
responseText = claudeResponse.Content[0].GetText()
|
||||
responseThinking = claudeResponse.Content[0].Thinking
|
||||
}
|
||||
tools := make([]dto.ToolCallResponse, 0)
|
||||
thinkingContent := ""
|
||||
@@ -415,7 +426,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
||||
// 加密的不管, 只输出明文的推理过程
|
||||
thinkingContent = message.Thinking
|
||||
case "text":
|
||||
responseText = message.Text
|
||||
responseText = message.GetText()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -427,6 +438,9 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
choice.SetStringContent(responseText)
|
||||
if len(responseThinking) > 0 {
|
||||
choice.ReasoningContent = responseThinking
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
choice.Message.SetToolCalls(tools)
|
||||
}
|
||||
@@ -437,126 +451,228 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
var usage *dto.Usage
|
||||
usage = &dto.Usage{}
|
||||
responseText := ""
|
||||
createdTime := common.GetTimestamp()
|
||||
type ClaudeResponseInfo struct {
|
||||
ResponseId string
|
||||
Created int64
|
||||
Model string
|
||||
ResponseText strings.Builder
|
||||
Usage *dto.Usage
|
||||
}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var claudeResponse ClaudeResponse
|
||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
claudeInfo.ResponseId = claudeResponse.Message.Id
|
||||
claudeInfo.Model = claudeResponse.Message.Model
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
if claudeResponse.Delta.Text != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if oaiResponse != nil {
|
||||
oaiResponse.Id = claudeInfo.ResponseId
|
||||
oaiResponse.Created = claudeInfo.Created
|
||||
oaiResponse.Model = claudeInfo.Model
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
if response == nil {
|
||||
return true
|
||||
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.DecodeJsonStr(data, &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Code: "stream_response_error",
|
||||
Type: claudeResponse.Error.Type,
|
||||
Message: claudeResponse.Error.Message,
|
||||
},
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
if requestMode == RequestModeCompletion {
|
||||
responseText += claudeResponse.Completion
|
||||
responseId = response.Id
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
responseId = claudeResponse.Message.Id
|
||||
info.UpstreamModelName = claudeResponse.Message.Model
|
||||
usage.PromptTokens = claudeUsage.InputTokens
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
responseText += claudeResponse.Delta.Text
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
usage.CompletionTokens = claudeUsage.OutputTokens
|
||||
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
return true
|
||||
} else {
|
||||
return true
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
//response.Id = responseId
|
||||
response.Id = responseId
|
||||
response.Created = createdTime
|
||||
response.Model = info.UpstreamModelName
|
||||
helper.ClaudeChunkData(c, claudeResponse, data)
|
||||
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
|
||||
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.LogError(c, "send_stream_response_failed: "+err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if usage.PromptTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
}
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
helper.Done(c)
|
||||
//resp.Body.Close()
|
||||
return nil, usage
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 说明流模式建立失败,可能为官方出错
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//usage.PromptTokens = info.PromptTokens
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
helper.Done(c)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
var claudeResponse ClaudeResponse
|
||||
err = json.Unmarshal(responseBody, &claudeResponse)
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return err, nil
|
||||
}
|
||||
if claudeResponse.Error.Type != "" {
|
||||
|
||||
HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.DecodeJson(data, &claudeResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Param: "",
|
||||
Code: claudeResponse.Error.Type,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
if requestMode == RequestModeCompletion {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens = completionTokens
|
||||
usage.TotalTokens = info.PromptTokens + completionTokens
|
||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
|
||||
}
|
||||
claudeInfo.Usage.PromptTokens = info.PromptTokens
|
||||
claudeInfo.Usage.CompletionTokens = completionTokens
|
||||
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
|
||||
} else {
|
||||
usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
var responseData []byte
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
openaiResponse.Usage = *claudeInfo.Usage
|
||||
responseData, err = json.Marshal(openaiResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
case relaycommon.RelayFormatClaude:
|
||||
responseData = data
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, err = c.Writer.Write(responseData)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
resp.Body.Close()
|
||||
if common.DebugEnabled {
|
||||
println("responseBody: ", string(responseBody))
|
||||
}
|
||||
handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
|
||||
if handleErr != nil {
|
||||
return handleErr, nil
|
||||
}
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
@@ -17,6 +17,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
@@ -37,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -42,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return requestOpenAI2Cohere(*request), nil
|
||||
}
|
||||
|
||||
@@ -59,7 +65,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = cohereRerankHandler(c, resp, info)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package cohere
|
||||
|
||||
var ModelList = []string{
|
||||
"command-a-03-2025",
|
||||
"command-r", "command-r-plus",
|
||||
"command-r-08-2024", "command-r-plus-08-2024",
|
||||
"c4ai-aya-23-35b", "c4ai-aya-23-8b",
|
||||
|
||||
@@ -40,8 +40,8 @@ type CohereRerankRequest struct {
|
||||
}
|
||||
|
||||
type CohereRerankResponseResult struct {
|
||||
Results []dto.RerankResponseDocument `json:"results"`
|
||||
Meta CohereMeta `json:"meta"`
|
||||
Results []dto.RerankResponseResult `json:"results"`
|
||||
Meta CohereMeta `json:"meta"`
|
||||
}
|
||||
|
||||
type CohereMeta struct {
|
||||
|
||||
@@ -11,11 +11,18 @@ import (
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -30,9 +37,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
fimBaseUrl := info.BaseUrl
|
||||
if !strings.HasSuffix(info.BaseUrl, "/beta") {
|
||||
fimBaseUrl += "/beta"
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
@@ -44,7 +55,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -68,7 +79,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -23,6 +22,12 @@ type Adaptor struct {
|
||||
BotType int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -34,15 +39,16 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "agent") {
|
||||
a.BotType = BotTypeAgent
|
||||
} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
|
||||
a.BotType = BotTypeWorkFlow
|
||||
} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
|
||||
a.BotType = BotTypeCompletion
|
||||
} else {
|
||||
a.BotType = BotTypeChatFlow
|
||||
}
|
||||
//if strings.HasPrefix(info.UpstreamModelName, "agent") {
|
||||
// a.BotType = BotTypeAgent
|
||||
//} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
|
||||
// a.BotType = BotTypeWorkFlow
|
||||
//} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
|
||||
// a.BotType = BotTypeCompletion
|
||||
//} else {
|
||||
//}
|
||||
a.BotType = BotTypeChatFlow
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
@@ -64,11 +70,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return requestOpenAI2Dify(*request), nil
|
||||
return requestOpenAI2Dify(c, info, *request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
|
||||
@@ -8,6 +8,14 @@ type DifyChatRequest struct {
|
||||
ResponseMode string `json:"response_mode"`
|
||||
User string `json:"user"`
|
||||
AutoGenerateName bool `json:"auto_generate_name"`
|
||||
Files []DifyFile `json:"files"`
|
||||
}
|
||||
|
||||
type DifyFile struct {
|
||||
Type string `json:"type"`
|
||||
TransferMode string `json:"transfer_mode"`
|
||||
URL string `json:"url,omitempty"`
|
||||
UploadFileId string `json:"upload_file_id,omitempty"`
|
||||
}
|
||||
|
||||
type DifyMetaData struct {
|
||||
@@ -17,6 +25,8 @@ type DifyMetaData struct {
|
||||
type DifyData struct {
|
||||
WorkflowId string `json:"workflow_id"`
|
||||
NodeId string `json:"node_id"`
|
||||
NodeType string `json:"node_type"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type DifyChatCompletionResponse struct {
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package dify
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -12,35 +14,163 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest {
|
||||
content := ""
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
content += "SYSTEM: \n" + message.StringContent() + "\n"
|
||||
} else if message.Role == "assistant" {
|
||||
content += "ASSISTANT: \n" + message.StringContent() + "\n"
|
||||
} else {
|
||||
content += "USER: \n" + message.StringContent() + "\n"
|
||||
func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
|
||||
uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl)
|
||||
switch media.Type {
|
||||
case dto.ContentTypeImageURL:
|
||||
// Decode base64 data
|
||||
imageMedia := media.GetImageMedia()
|
||||
base64Data := imageMedia.Url
|
||||
// Remove base64 prefix if exists (e.g., "data:image/jpeg;base64,")
|
||||
if idx := strings.Index(base64Data, ","); idx != -1 {
|
||||
base64Data = base64Data[idx+1:]
|
||||
}
|
||||
|
||||
// Decode base64 string
|
||||
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
|
||||
if err != nil {
|
||||
common.SysError("failed to decode base64: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create temporary file
|
||||
tempFile, err := os.CreateTemp("", "dify-upload-*")
|
||||
if err != nil {
|
||||
common.SysError("failed to create temp file: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
defer tempFile.Close()
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
// Write decoded data to temp file
|
||||
if _, err := tempFile.Write(decodedData); err != nil {
|
||||
common.SysError("failed to write to temp file: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create multipart form
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
// Add user field
|
||||
if err := writer.WriteField("user", user); err != nil {
|
||||
common.SysError("failed to add user field: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create form file with proper mime type
|
||||
mimeType := imageMedia.MimeType
|
||||
if mimeType == "" {
|
||||
mimeType = "image/jpeg" // default mime type
|
||||
}
|
||||
|
||||
// Create form file
|
||||
part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
|
||||
if err != nil {
|
||||
common.SysError("failed to create form file: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy file content to form
|
||||
if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
|
||||
common.SysError("failed to copy file content: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequest("POST", uploadUrl, body)
|
||||
if err != nil {
|
||||
common.SysError("failed to create request: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
|
||||
// Send request
|
||||
client := service.GetImpatientHttpClient()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysError("failed to send request: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Parse response
|
||||
var result struct {
|
||||
Id string `json:"id"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
common.SysError("failed to decode response: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
return &DifyFile{
|
||||
UploadFileId: result.Id,
|
||||
Type: "image",
|
||||
TransferMode: "local_file",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) *DifyChatRequest {
|
||||
difyReq := DifyChatRequest{
|
||||
Inputs: make(map[string]interface{}),
|
||||
AutoGenerateName: false,
|
||||
}
|
||||
|
||||
user := request.User
|
||||
if user == "" {
|
||||
user = helper.GetResponseID(c)
|
||||
}
|
||||
difyReq.User = user
|
||||
|
||||
files := make([]DifyFile, 0)
|
||||
var content strings.Builder
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
content.WriteString("SYSTEM: \n" + message.StringContent() + "\n")
|
||||
} else if message.Role == "assistant" {
|
||||
content.WriteString("ASSISTANT: \n" + message.StringContent() + "\n")
|
||||
} else {
|
||||
parseContent := message.ParseContent()
|
||||
for _, mediaContent := range parseContent {
|
||||
switch mediaContent.Type {
|
||||
case dto.ContentTypeText:
|
||||
content.WriteString("USER: \n" + mediaContent.Text + "\n")
|
||||
case dto.ContentTypeImageURL:
|
||||
media := mediaContent.GetImageMedia()
|
||||
var file *DifyFile
|
||||
if media.IsRemoteImage() {
|
||||
file.Type = media.MimeType
|
||||
file.TransferMode = "remote_url"
|
||||
file.URL = media.Url
|
||||
} else {
|
||||
file = uploadDifyFile(c, info, difyReq.User, mediaContent)
|
||||
}
|
||||
if file != nil {
|
||||
files = append(files, *file)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
difyReq.Query = content.String()
|
||||
difyReq.Files = files
|
||||
mode := "blocking"
|
||||
if request.Stream {
|
||||
mode = "streaming"
|
||||
}
|
||||
user := request.User
|
||||
if user == "" {
|
||||
user = "api-user"
|
||||
}
|
||||
return &DifyChatRequest{
|
||||
Inputs: make(map[string]interface{}),
|
||||
Query: content,
|
||||
ResponseMode: mode,
|
||||
User: user,
|
||||
AutoGenerateName: false,
|
||||
}
|
||||
difyReq.ResponseMode = mode
|
||||
return &difyReq
|
||||
}
|
||||
|
||||
func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
|
||||
@@ -50,11 +180,29 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
|
||||
Model: "dify",
|
||||
}
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if constant.DifyDebug && difyResponse.Event == "workflow_started" {
|
||||
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
|
||||
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
|
||||
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
|
||||
if strings.HasPrefix(difyResponse.Event, "workflow_") {
|
||||
if constant.DifyDebug {
|
||||
text := "Workflow: " + difyResponse.Data.WorkflowId
|
||||
if difyResponse.Event == "workflow_finished" {
|
||||
text += " " + difyResponse.Data.Status
|
||||
}
|
||||
choice.Delta.SetReasoningContent(text + "\n")
|
||||
}
|
||||
} else if strings.HasPrefix(difyResponse.Event, "node_") {
|
||||
if constant.DifyDebug {
|
||||
text := "Node: " + difyResponse.Data.NodeType
|
||||
if difyResponse.Event == "node_finished" {
|
||||
text += " " + difyResponse.Data.Status
|
||||
}
|
||||
choice.Delta.SetReasoningContent(text + "\n")
|
||||
}
|
||||
} else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
|
||||
if difyResponse.Answer == "<details style=\"color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;\" open> <summary> Thinking... </summary>\n" {
|
||||
difyResponse.Answer = "<think>"
|
||||
} else if difyResponse.Answer == "</details>" {
|
||||
difyResponse.Answer = "</think>"
|
||||
}
|
||||
|
||||
choice.Delta.SetContentString(difyResponse.Answer)
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
@@ -64,47 +212,40 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
|
||||
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var responseText string
|
||||
usage := &dto.Usage{}
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
var nodeToken int
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data:")
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var difyResponse DifyChunkChatCompletionResponse
|
||||
err := json.Unmarshal([]byte(data), &difyResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
return true
|
||||
}
|
||||
var openaiResponse dto.ChatCompletionsStreamResponse
|
||||
if difyResponse.Event == "message_end" {
|
||||
usage = &difyResponse.MetaData.Usage
|
||||
break
|
||||
return false
|
||||
} else if difyResponse.Event == "error" {
|
||||
break
|
||||
return false
|
||||
} else {
|
||||
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
|
||||
if len(openaiResponse.Choices) != 0 {
|
||||
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
||||
if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
|
||||
nodeToken += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
err = helper.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
helper.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.SysError("close_response_body_failed: " + err.Error())
|
||||
}
|
||||
if usage.TotalTokens == 0 {
|
||||
@@ -112,6 +253,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
usage.CompletionTokens += nodeToken
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -21,6 +20,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -64,12 +69,28 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// suffix -thinking and -nothinking
|
||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
}
|
||||
}
|
||||
|
||||
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
||||
return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
|
||||
}
|
||||
|
||||
action := "generateContent"
|
||||
if info.IsStream {
|
||||
action = "streamGenerateContent?alt=sse"
|
||||
@@ -83,15 +104,17 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
ai, err := CovertGemini2OpenAI(*request)
|
||||
|
||||
geminiRequest, err := CovertGemini2OpenAI(*request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ai, nil
|
||||
|
||||
return geminiRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -99,8 +122,37 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
if request.Input == nil {
|
||||
return nil, errors.New("input is required")
|
||||
}
|
||||
|
||||
inputs := request.ParseInput()
|
||||
if len(inputs) == 0 {
|
||||
return nil, errors.New("input is empty")
|
||||
}
|
||||
|
||||
// only process the first input
|
||||
geminiRequest := GeminiEmbeddingRequest{
|
||||
Content: GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: inputs[0],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// set specific parameters for different models
|
||||
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
|
||||
switch info.UpstreamModelName {
|
||||
case "text-embedding-004":
|
||||
// except embedding-001 supports setting `OutputDimensionality`
|
||||
if request.Dimensions > 0 {
|
||||
geminiRequest.OutputDimensionality = request.Dimensions
|
||||
}
|
||||
}
|
||||
|
||||
return geminiRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
@@ -112,11 +164,30 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return GeminiImageHandler(c, resp, info)
|
||||
}
|
||||
|
||||
// check if the model is an embedding model
|
||||
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
||||
return GeminiEmbeddingHandler(c, resp, info)
|
||||
}
|
||||
|
||||
if info.IsStream {
|
||||
err, usage = GeminiChatStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = GeminiChatHandler(c, resp, info)
|
||||
}
|
||||
|
||||
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
|
||||
// // 没有请求-thinking的情况下,产生思考token,则按照思考模型计费
|
||||
// if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
|
||||
// !strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
// thinkingModelName := info.OriginModelName + "-thinking"
|
||||
// if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
|
||||
// info.OriginModelName = thinkingModelName
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -16,8 +16,14 @@ var ModelList = []string{
|
||||
"gemini-2.0-pro-exp",
|
||||
// thinking exp
|
||||
"gemini-2.0-flash-thinking-exp",
|
||||
"gemini-2.5-pro-exp-03-25",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
// imagen models
|
||||
"imagen-3.0-generate-002",
|
||||
// embedding models
|
||||
"gemini-embedding-exp-03-07",
|
||||
"text-embedding-004",
|
||||
"embedding-001",
|
||||
}
|
||||
|
||||
var SafetySettingList = []string{
|
||||
|
||||
@@ -8,6 +8,15 @@ type GeminiChatRequest struct {
|
||||
SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiThinkingConfig struct {
|
||||
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
}
|
||||
|
||||
func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
|
||||
c.ThinkingBudget = &budget
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
@@ -71,15 +80,17 @@ type GeminiChatTool struct {
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
@@ -108,6 +119,7 @@ type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||
}
|
||||
|
||||
// Imagen related structs
|
||||
@@ -136,3 +148,19 @@ type GeminiImagePrediction struct {
|
||||
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
|
||||
SafetyAttributes any `json:"safetyAttributes,omitempty"`
|
||||
}
|
||||
|
||||
// Embedding related structs
|
||||
type GeminiEmbeddingRequest struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
TaskType string `json:"taskType,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiEmbeddingResponse struct {
|
||||
Embedding ContentEmbedding `json:"embedding"`
|
||||
}
|
||||
|
||||
type ContentEmbedding struct {
|
||||
Values []float64 `json:"values"`
|
||||
}
|
||||
|
||||
@@ -19,11 +19,10 @@ import (
|
||||
)
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
||||
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
//SafetySettings: []GeminiChatSafetySettings{},
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
@@ -32,6 +31,30 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
},
|
||||
}
|
||||
|
||||
if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
|
||||
geminiRequest.GenerationConfig.ResponseModalities = []string{
|
||||
"TEXT",
|
||||
"IMAGE",
|
||||
}
|
||||
}
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
||||
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
|
||||
if budgetTokens == 0 || budgetTokens > 24576 {
|
||||
budgetTokens = 24576
|
||||
}
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(int(budgetTokens)),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
for _, category := range SafetySettingList {
|
||||
safetySettings = append(safetySettings, GeminiChatSafetySettings{
|
||||
@@ -56,6 +79,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
continue
|
||||
}
|
||||
if tool.Function.Parameters != nil {
|
||||
|
||||
params, ok := tool.Function.Parameters.(map[string]interface{})
|
||||
if ok {
|
||||
if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
|
||||
@@ -65,6 +89,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
}
|
||||
}
|
||||
}
|
||||
// Clean the parameters before appending
|
||||
cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
|
||||
tool.Function.Parameters = cleanedParams
|
||||
functions = append(functions, tool.Function)
|
||||
}
|
||||
if codeExecution {
|
||||
@@ -86,11 +113,11 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
// json_data, _ := json.Marshal(geminiRequest.Tools)
|
||||
// common.SysLog("tools_json: " + string(json_data))
|
||||
} else if textRequest.Functions != nil {
|
||||
geminiRequest.Tools = []GeminiChatTool{
|
||||
{
|
||||
FunctionDeclarations: textRequest.Functions,
|
||||
},
|
||||
}
|
||||
//geminiRequest.Tools = []GeminiChatTool{
|
||||
// {
|
||||
// FunctionDeclarations: textRequest.Functions,
|
||||
// },
|
||||
//}
|
||||
}
|
||||
|
||||
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
|
||||
@@ -180,9 +207,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
|
||||
}
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
||||
if strings.HasPrefix(part.GetImageMedia().Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
fileData, err := service.GetFileBase64FromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
||||
}
|
||||
@@ -193,7 +220,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
},
|
||||
})
|
||||
} else {
|
||||
format, base64String, err := service.DecodeBase64FileData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||
}
|
||||
@@ -204,6 +231,34 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
},
|
||||
})
|
||||
}
|
||||
} else if part.Type == dto.ContentTypeFile {
|
||||
if part.GetFile().FileId != "" {
|
||||
return nil, fmt.Errorf("only base64 file is supported in gemini")
|
||||
}
|
||||
format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeInputAudio {
|
||||
if part.GetInputAudio().Data == "" {
|
||||
return nil, fmt.Errorf("only base64 audio is supported in gemini")
|
||||
}
|
||||
format, base64String, err := service.DecodeBase64FileData(part.GetInputAudio().Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,6 +284,102 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
return &geminiRequest, nil
|
||||
}
|
||||
|
||||
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
|
||||
func cleanFunctionParameters(params interface{}) interface{} {
|
||||
if params == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
paramMap, ok := params.(map[string]interface{})
|
||||
if !ok {
|
||||
// Not a map, return as is (e.g., could be an array or primitive)
|
||||
return params
|
||||
}
|
||||
|
||||
// Create a copy to avoid modifying the original
|
||||
cleanedMap := make(map[string]interface{})
|
||||
for k, v := range paramMap {
|
||||
cleanedMap[k] = v
|
||||
}
|
||||
|
||||
// Remove unsupported root-level fields
|
||||
delete(cleanedMap, "default")
|
||||
delete(cleanedMap, "exclusiveMaximum")
|
||||
delete(cleanedMap, "exclusiveMinimum")
|
||||
delete(cleanedMap, "$schema")
|
||||
delete(cleanedMap, "additionalProperties")
|
||||
|
||||
// Clean properties
|
||||
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
|
||||
cleanedProps := make(map[string]interface{})
|
||||
for propName, propValue := range props {
|
||||
propMap, ok := propValue.(map[string]interface{})
|
||||
if !ok {
|
||||
cleanedProps[propName] = propValue // Keep non-map properties
|
||||
continue
|
||||
}
|
||||
|
||||
// Create a copy of the property map
|
||||
cleanedPropMap := make(map[string]interface{})
|
||||
for k, v := range propMap {
|
||||
cleanedPropMap[k] = v
|
||||
}
|
||||
|
||||
// Remove unsupported fields
|
||||
delete(cleanedPropMap, "default")
|
||||
delete(cleanedPropMap, "exclusiveMaximum")
|
||||
delete(cleanedPropMap, "exclusiveMinimum")
|
||||
delete(cleanedPropMap, "$schema")
|
||||
delete(cleanedPropMap, "additionalProperties")
|
||||
|
||||
// Check and clean 'format' for string types
|
||||
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && propType == "string" {
|
||||
if formatValue, formatExists := cleanedPropMap["format"].(string); formatExists {
|
||||
if formatValue != "enum" && formatValue != "date-time" {
|
||||
delete(cleanedPropMap, "format")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively clean nested properties within this property if it's an object/array
|
||||
// Check the type before recursing
|
||||
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && (propType == "object" || propType == "array") {
|
||||
cleanedProps[propName] = cleanFunctionParameters(cleanedPropMap)
|
||||
} else {
|
||||
cleanedProps[propName] = cleanedPropMap // Assign the cleaned map back if not recursing
|
||||
}
|
||||
|
||||
}
|
||||
cleanedMap["properties"] = cleanedProps
|
||||
}
|
||||
|
||||
// Recursively clean items in arrays if needed (e.g., type: array, items: { ... })
|
||||
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
|
||||
cleanedMap["items"] = cleanFunctionParameters(items)
|
||||
}
|
||||
// Also handle items if it's an array of schemas
|
||||
if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
|
||||
cleanedItemsArray := make([]interface{}, len(itemsArray))
|
||||
for i, item := range itemsArray {
|
||||
cleanedItemsArray[i] = cleanFunctionParameters(item)
|
||||
}
|
||||
cleanedMap["items"] = cleanedItemsArray
|
||||
}
|
||||
|
||||
// Recursively clean other schema composition keywords if necessary
|
||||
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
|
||||
if nested, ok := cleanedMap[field].([]interface{}); ok {
|
||||
cleanedNested := make([]interface{}, len(nested))
|
||||
for i, item := range nested {
|
||||
cleanedNested[i] = cleanFunctionParameters(item)
|
||||
}
|
||||
cleanedMap[field] = cleanedNested
|
||||
}
|
||||
}
|
||||
|
||||
return cleanedMap
|
||||
}
|
||||
|
||||
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
|
||||
if depth >= 5 {
|
||||
return schema
|
||||
@@ -427,9 +578,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
|
||||
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
||||
isStop := false
|
||||
hasImage := false
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
|
||||
isStop = true
|
||||
@@ -455,7 +607,13 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
}
|
||||
}
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
if part.InlineData != nil {
|
||||
if strings.HasPrefix(part.InlineData.MimeType, "image") {
|
||||
imgText := ""
|
||||
texts = append(texts, imgText)
|
||||
hasImage = true
|
||||
}
|
||||
} else if part.FunctionCall != nil {
|
||||
isTools = true
|
||||
if call := getResponseToolCall(&part); call != nil {
|
||||
call.SetIndex(len(choice.Delta.ToolCalls))
|
||||
@@ -483,7 +641,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Choices = choices
|
||||
return &response, isStop
|
||||
return &response, isStop, hasImage
|
||||
}
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
@@ -491,23 +649,27 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createAt := common.GetTimestamp()
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
err := json.Unmarshal([]byte(data), &geminiResponse)
|
||||
err := common.DecodeJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
if hasImage {
|
||||
imageCount++
|
||||
}
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
// responseText += response.Choices[0].Delta.GetContentString()
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
}
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
@@ -522,9 +684,15 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
|
||||
var response *dto.ChatCompletionsStreamResponse
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 258
|
||||
}
|
||||
}
|
||||
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
|
||||
//usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
@@ -570,6 +738,9 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
@@ -580,3 +751,52 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var geminiResponse GeminiEmbeddingResponse
|
||||
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// convert to openai format response
|
||||
openAIResponse := dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: []dto.OpenAIEmbeddingResponseItem{
|
||||
{
|
||||
Object: "embedding",
|
||||
Embedding: geminiResponse.Embedding.Values,
|
||||
Index: 0,
|
||||
},
|
||||
},
|
||||
Model: info.UpstreamModelName,
|
||||
}
|
||||
|
||||
// calculate usage
|
||||
// https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
|
||||
// Google has not yet clarified how embedding models will be billed
|
||||
// refer to openai billing method to use input tokens billing
|
||||
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
openAIResponse.Usage = *usage.(*dto.Usage)
|
||||
|
||||
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -8,13 +8,21 @@ import (
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/common_handler"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -43,7 +51,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -61,9 +69,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = JinaRerankHandler(c, resp)
|
||||
err, usage = common_handler.RerankHandler(c, info, resp)
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
err, usage = jinaEmbeddingHandler(c, resp)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,60 +1 @@
|
||||
package jina
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var jinaResp dto.RerankResponse
|
||||
err = json.Unmarshal(responseBody, &jinaResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(jinaResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &jinaResp.Usage
|
||||
}
|
||||
|
||||
func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var jinaResp dto.OpenAIEmbeddingResponse
|
||||
err = json.Unmarshal(responseBody, &jinaResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(jinaResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &jinaResp.Usage
|
||||
}
|
||||
|
||||
@@ -14,6 +14,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -37,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -61,7 +67,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI
|
||||
mediaMessages := message.ParseContent()
|
||||
for j, mediaMessage := range mediaMessages {
|
||||
if mediaMessage.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
|
||||
imageUrl := mediaMessage.GetImageMedia()
|
||||
mediaMessage.ImageUrl = imageUrl.Url
|
||||
mediaMessages[j] = mediaMessage
|
||||
}
|
||||
|
||||
@@ -16,6 +16,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -51,7 +57,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -43,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -69,7 +75,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -19,7 +19,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err
|
||||
mediaMessages := message.ParseContent()
|
||||
for j, mediaMessage := range mediaMessages {
|
||||
if mediaMessage.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
|
||||
imageUrl := mediaMessage.GetImageMedia()
|
||||
// check if not base64
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
@@ -14,13 +13,18 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ai360"
|
||||
"one-api/relay/channel/jina"
|
||||
"one-api/relay/channel/lingyiwanwu"
|
||||
"one-api/relay/channel/minimax"
|
||||
"one-api/relay/channel/moonshot"
|
||||
"one-api/relay/channel/openrouter"
|
||||
"one-api/relay/channel/xinference"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/common_handler"
|
||||
"one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
@@ -28,11 +32,39 @@ type Adaptor struct {
|
||||
ResponseFormat string
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
if !strings.Contains(request.Model, "claude") {
|
||||
return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
||||
}
|
||||
aiRequest, err := service.ClaudeToOpenAIRequest(*request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if info.SupportStreamOptions {
|
||||
aiRequest.StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
}
|
||||
return a.ConvertOpenAIRequest(c, info, aiRequest)
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
|
||||
// initialize ThinkingContentInfo when thinking_to_content is enabled
|
||||
if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content {
|
||||
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
|
||||
IsFirstThinkingContent: true,
|
||||
SendLastThinkingContent: false,
|
||||
HasSentThinkingContent: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
if strings.HasPrefix(info.BaseUrl, "https://") {
|
||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
||||
@@ -101,28 +133,26 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
|
||||
} else {
|
||||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
}
|
||||
//if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
// req.Header.Set("X-Title", "One API")
|
||||
//}
|
||||
if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
|
||||
header.Set("X-Title", "New API")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "o1") || strings.HasPrefix(request.Model, "o3") {
|
||||
if strings.HasPrefix(request.Model, "o") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = 0
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
|
||||
request.Temperature = nil
|
||||
}
|
||||
request.Temperature = nil
|
||||
if strings.HasSuffix(request.Model, "-high") {
|
||||
request.ReasoningEffort = "high"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-high")
|
||||
@@ -135,11 +165,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
}
|
||||
info.ReasoningEffort = request.ReasoningEffort
|
||||
info.UpstreamModelName = request.Model
|
||||
}
|
||||
if request.Model == "o1" || request.Model == "o1-2024-12-17" || strings.HasPrefix(request.Model, "o3") {
|
||||
//修改第一个Message的内容,将system改为developer
|
||||
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
|
||||
request.Messages[0].Role = "developer"
|
||||
|
||||
// o系列模型developer适配(o1-mini除外)
|
||||
if !strings.HasPrefix(request.Model, "o1-mini") {
|
||||
//修改第一个Message的内容,将system改为developer
|
||||
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
|
||||
request.Messages[0].Role = "developer"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,12 +262,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = jina.JinaRerankHandler(c, resp)
|
||||
err, usage = common_handler.RerankHandler(c, info, resp)
|
||||
default:
|
||||
if info.IsStream {
|
||||
err, usage = OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = OpenaiHandler(c, resp, info)
|
||||
}
|
||||
}
|
||||
return
|
||||
@@ -251,6 +283,10 @@ func (a *Adaptor) GetModelList() []string {
|
||||
return lingyiwanwu.ModelList
|
||||
case common.ChannelTypeMiniMax:
|
||||
return minimax.ModelList
|
||||
case common.ChannelTypeXinference:
|
||||
return xinference.ModelList
|
||||
case common.ChannelTypeOpenRouter:
|
||||
return openrouter.ModelList
|
||||
default:
|
||||
return ModelList
|
||||
}
|
||||
@@ -266,6 +302,10 @@ func (a *Adaptor) GetChannelName() string {
|
||||
return lingyiwanwu.ChannelName
|
||||
case common.ChannelTypeMiniMax:
|
||||
return minimax.ChannelName
|
||||
case common.ChannelTypeXinference:
|
||||
return xinference.ChannelName
|
||||
case common.ChannelTypeOpenRouter:
|
||||
return openrouter.ChannelName
|
||||
default:
|
||||
return ChannelName
|
||||
}
|
||||
|
||||
189
relay/channel/openai/helper.go
Normal file
189
relay/channel/openai/helper.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 辅助函数
|
||||
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
info.SendResponseCount++
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||
case relaycommon.RelayFormatClaude:
|
||||
return handleClaudeFormat(c, data, info)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if streamResponse.Usage != nil {
|
||||
info.ClaudeConvertInfo.Usage = streamResponse.Usage
|
||||
}
|
||||
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||||
for _, resp := range claudeResponses {
|
||||
helper.ClaudeData(c, *resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > *toolCount {
|
||||
*toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
return processCompletions(streamResp, streamItems, responseTextBuilder)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
|
||||
common.SysError("error processing stream response: " + err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 批量处理所有响应
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > *toolCount {
|
||||
*toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 批量处理所有响应
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
|
||||
systemFingerprint *string, model *string, usage **dto.Usage,
|
||||
containStreamUsage *bool, info *relaycommon.RelayInfo,
|
||||
shouldSendLastResp *bool) error {
|
||||
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*responseId = lastStreamResponse.Id
|
||||
*createAt = lastStreamResponse.Created
|
||||
*systemFingerprint = lastStreamResponse.GetSystemFingerprint()
|
||||
*model = lastStreamResponse.Model
|
||||
|
||||
if service.ValidUsage(lastStreamResponse.Usage) {
|
||||
*containStreamUsage = true
|
||||
*usage = lastStreamResponse.Usage
|
||||
if !info.ShouldIncludeUsage {
|
||||
*shouldSendLastResp = false
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
|
||||
responseId string, createAt int64, model string, systemFingerprint string,
|
||||
usage *dto.Usage, containStreamUsage bool) {
|
||||
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
if info.ShouldIncludeUsage && !containStreamUsage {
|
||||
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||||
response.SetSystemFingerprint(systemFingerprint)
|
||||
helper.ObjectData(c, response)
|
||||
}
|
||||
helper.Done(c)
|
||||
|
||||
case relaycommon.RelayFormatClaude:
|
||||
info.ClaudeConvertInfo.Done = true
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
info.ClaudeConvertInfo.Usage = usage
|
||||
|
||||
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||||
for _, resp := range claudeResponses {
|
||||
helper.ClaudeData(c, *resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"os"
|
||||
@@ -34,7 +33,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
}
|
||||
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
|
||||
if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -66,6 +65,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
response.Choices[i].Delta.Reasoning = nil
|
||||
}
|
||||
info.ThinkingContentInfo.IsFirstThinkingContent = false
|
||||
info.ThinkingContentInfo.HasSentThinkingContent = true
|
||||
return helper.ObjectData(c, response)
|
||||
}
|
||||
}
|
||||
@@ -77,7 +77,8 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
// Process each choice
|
||||
for i, choice := range lastStreamResponse.Choices {
|
||||
// Handle transition from thinking to content
|
||||
if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent {
|
||||
// only send `</think>` tag when previous thinking content has been sent
|
||||
if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
|
||||
response := lastStreamResponse.Copy()
|
||||
for j := range response.Choices {
|
||||
response.Choices[j].Delta.SetContentString("\n</think>\n")
|
||||
@@ -88,7 +89,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
helper.ObjectData(c, response)
|
||||
}
|
||||
|
||||
// Convert reasoning content to regular content
|
||||
// Convert reasoning content to regular content if any
|
||||
if len(choice.Delta.GetReasoningContent()) > 0 {
|
||||
lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
|
||||
lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
|
||||
@@ -116,6 +117,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
model := info.UpstreamModelName
|
||||
|
||||
var responseTextBuilder strings.Builder
|
||||
var toolCount int
|
||||
var usage = &dto.Usage{}
|
||||
var streamItems []string // store stream items
|
||||
var forceFormat bool
|
||||
@@ -129,17 +131,15 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
thinkToContent = think2Content
|
||||
}
|
||||
|
||||
toolCount := 0
|
||||
|
||||
var (
|
||||
lastStreamData string
|
||||
)
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
if lastStreamData != "" {
|
||||
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
if err != nil {
|
||||
common.LogError(c, "streaming error: "+err.Error())
|
||||
common.SysError("error handling stream format: " + err.Error())
|
||||
}
|
||||
}
|
||||
lastStreamData = data
|
||||
@@ -149,7 +149,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
|
||||
shouldSendLastResp := true
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
|
||||
err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse)
|
||||
if err == nil {
|
||||
responseId = lastStreamResponse.Id
|
||||
createAt = lastStreamResponse.Created
|
||||
@@ -168,87 +168,15 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if shouldSendLastResp {
|
||||
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
//err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
}
|
||||
|
||||
// 计算token
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
//if service.ValidUsage(streamResponse.Usage) {
|
||||
// usage = streamResponse.Usage
|
||||
//}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
|
||||
// handle both reasoning_content and reasoning
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
//if service.ValidUsage(streamResponse.Usage) {
|
||||
// usage = streamResponse.Usage
|
||||
// containStreamUsage = true
|
||||
//}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) // This will handle both reasoning_content and reasoning
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case relayconstant.RelayModeCompletions:
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 处理token计算
|
||||
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
|
||||
common.SysError("error processing tokens: " + err.Error())
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
@@ -262,20 +190,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
}
|
||||
|
||||
if info.ShouldIncludeUsage && !containStreamUsage {
|
||||
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||||
response.SetSystemFingerprint(systemFingerprint)
|
||||
helper.ObjectData(c, response)
|
||||
}
|
||||
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||
|
||||
helper.Done(c)
|
||||
|
||||
//resp.Body.Close()
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var simpleResponse dto.SimpleResponse
|
||||
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var simpleResponse dto.OpenAITextResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -284,16 +205,29 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &simpleResponse)
|
||||
err = common.DecodeJson(responseBody, &simpleResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if simpleResponse.Error.Type != "" {
|
||||
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: simpleResponse.Error,
|
||||
Error: *simpleResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
break
|
||||
case relaycommon.RelayFormatClaude:
|
||||
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
|
||||
claudeRespStr, err := json.Marshal(claudeResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
responseBody = claudeRespStr
|
||||
}
|
||||
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
@@ -306,19 +240,20 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
//return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.SysError("error copying response body: " + err.Error())
|
||||
}
|
||||
resp.Body.Close()
|
||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||
completionTokens := 0
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, model)
|
||||
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||
completionTokens += ctkm
|
||||
}
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
TotalTokens: info.PromptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
return nil, &simpleResponse.Usage
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
package openrouter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
req.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
|
||||
req.Set("X-Title", "New API")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -54,7 +60,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -57,7 +63,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
@@ -66,7 +71,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -16,6 +16,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -48,7 +54,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -68,20 +74,16 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = siliconflowRerankHandler(c, resp)
|
||||
case constant.RelayModeCompletions:
|
||||
fallthrough
|
||||
case constant.RelayModeChatCompletions:
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
case constant.RelayModeCompletions:
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -12,6 +12,6 @@ type SFMeta struct {
|
||||
}
|
||||
|
||||
type SFRerankResponse struct {
|
||||
Results []dto.RerankResponseDocument `json:"results"`
|
||||
Meta SFMeta `json:"meta"`
|
||||
Results []dto.RerankResponseResult `json:"results"`
|
||||
Meta SFMeta `json:"meta"`
|
||||
}
|
||||
|
||||
@@ -23,6 +23,12 @@ type Adaptor struct {
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -52,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -78,7 +84,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -38,6 +38,16 @@ type Adaptor struct {
|
||||
AccountCredentials Credentials
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
c.Set("request_model", v)
|
||||
} else {
|
||||
c.Set("request_model", request.Model)
|
||||
}
|
||||
vertexClaudeReq := copyRequest(request, anthropicVersion)
|
||||
return vertexClaudeReq, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -119,7 +129,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -133,7 +143,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
info.UpstreamModelName = claudeReq.Model
|
||||
return vertexClaudeReq, nil
|
||||
} else if a.RequestMode == RequestModeGemini {
|
||||
geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
|
||||
geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -175,7 +185,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
case RequestModeGemini:
|
||||
err, usage = gemini.GeminiChatHandler(c, resp, info)
|
||||
case RequestModeLlama:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/dto"
|
||||
)
|
||||
|
||||
type VertexAIClaudeRequest struct {
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
Messages []claude.ClaudeMessage `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Thinking *claude.Thinking `json:"thinking,omitempty"`
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
Messages []dto.ClaudeMessage `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Thinking *dto.Thinking `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
func copyRequest(req *claude.ClaudeRequest, version string) *VertexAIClaudeRequest {
|
||||
func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest {
|
||||
return &VertexAIClaudeRequest{
|
||||
AnthropicVersion: version,
|
||||
System: req.System,
|
||||
|
||||
@@ -17,6 +17,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -50,7 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -75,10 +81,10 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
104
relay/channel/xai/adaptor.go
Normal file
104
relay/channel/xai/adaptor.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
//panic("implement me")
|
||||
return nil, errors.New("not available")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//not available
|
||||
return nil, errors.New("not available")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
request.Size = ""
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "grok-3-mini") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = 0
|
||||
}
|
||||
if strings.HasSuffix(request.Model, "-high") {
|
||||
request.ReasoningEffort = "high"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-high")
|
||||
} else if strings.HasSuffix(request.Model, "-low") {
|
||||
request.ReasoningEffort = "low"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-low")
|
||||
} else if strings.HasSuffix(request.Model, "-medium") {
|
||||
request.ReasoningEffort = "medium"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-medium")
|
||||
}
|
||||
info.ReasoningEffort = request.ReasoningEffort
|
||||
info.UpstreamModelName = request.Model
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
//not available
|
||||
return nil, errors.New("not available")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = xAIStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = xAIHandler(c, resp, info)
|
||||
}
|
||||
//if _, ok := usage.(*dto.Usage); ok && usage != nil {
|
||||
// usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens
|
||||
//}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
18
relay/channel/xai/constants.go
Normal file
18
relay/channel/xai/constants.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package xai
|
||||
|
||||
var ModelList = []string{
|
||||
// grok-3
|
||||
"grok-3-beta", "grok-3-mini-beta",
|
||||
// grok-3 mini
|
||||
"grok-3-fast-beta", "grok-3-mini-fast-beta",
|
||||
// extend grok-3-mini reasoning
|
||||
"grok-3-mini-beta-high", "grok-3-mini-beta-low", "grok-3-mini-beta-medium",
|
||||
"grok-3-mini-fast-beta-high", "grok-3-mini-fast-beta-low", "grok-3-mini-fast-beta-medium",
|
||||
// image model
|
||||
"grok-2-image",
|
||||
// legacy models
|
||||
"grok-2", "grok-2-vision",
|
||||
"grok-beta", "grok-vision-beta",
|
||||
}
|
||||
|
||||
var ChannelName = "xai"
|
||||
14
relay/channel/xai/dto.go
Normal file
14
relay/channel/xai/dto.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package xai
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
// ChatCompletionResponse represents the response from XAI chat completion API
|
||||
type ChatCompletionResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []dto.ChatCompletionsStreamResponseChoice
|
||||
Usage *dto.Usage `json:"usage"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
}
|
||||
119
relay/channel/xai/text.go
Normal file
119
relay/channel/xai/text.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
|
||||
if xAIResp == nil {
|
||||
return nil
|
||||
}
|
||||
if xAIResp.Usage != nil {
|
||||
xAIResp.Usage.CompletionTokens = usage.CompletionTokens
|
||||
}
|
||||
openAIResp := &dto.ChatCompletionsStreamResponse{
|
||||
Id: xAIResp.Id,
|
||||
Object: xAIResp.Object,
|
||||
Created: xAIResp.Created,
|
||||
Model: xAIResp.Model,
|
||||
Choices: xAIResp.Choices,
|
||||
Usage: xAIResp.Usage,
|
||||
}
|
||||
|
||||
return openAIResp
|
||||
}
|
||||
|
||||
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
usage := &dto.Usage{}
|
||||
var responseTextBuilder strings.Builder
|
||||
var toolCount int
|
||||
var containStreamUsage bool
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var xAIResp *dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &xAIResp)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
// 把 xAI 的usage转换为 OpenAI 的usage
|
||||
if xAIResp.Usage != nil {
|
||||
containStreamUsage = true
|
||||
usage.PromptTokens = xAIResp.Usage.PromptTokens
|
||||
usage.TotalTokens = xAIResp.Usage.TotalTokens
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
}
|
||||
|
||||
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
|
||||
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
|
||||
err = helper.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !containStreamUsage {
|
||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
helper.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.SysError("close_response_body_failed: " + err.Error())
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
var response *dto.TextResponse
|
||||
err = common.DecodeJson(responseBody, &response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return nil, nil
|
||||
}
|
||||
response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
|
||||
response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
|
||||
|
||||
// new body
|
||||
encodeJson, err := common.EncodeJson(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// set new body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
return nil, &response.Usage
|
||||
}
|
||||
8
relay/channel/xinference/constant.go
Normal file
8
relay/channel/xinference/constant.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package xinference
|
||||
|
||||
var ModelList = []string{
|
||||
"bge-reranker-v2-m3",
|
||||
"jina-reranker-v2",
|
||||
}
|
||||
|
||||
var ChannelName = "xinference"
|
||||
11
relay/channel/xinference/dto.go
Normal file
11
relay/channel/xinference/dto.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package xinference
|
||||
|
||||
type XinRerankResponseDocument struct {
|
||||
Document string `json:"document,omitempty"`
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
}
|
||||
|
||||
type XinRerankResponse struct {
|
||||
Results []XinRerankResponseDocument `json:"results"`
|
||||
}
|
||||
@@ -16,6 +16,12 @@ type Adaptor struct {
|
||||
request *dto.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -55,7 +61,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
// xunfei's request is not http request, so we don't need to do anything here
|
||||
dummyResp := &http.Response{}
|
||||
|
||||
@@ -14,6 +14,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -42,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -61,7 +67,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -10,11 +10,18 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -29,7 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
|
||||
baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl)
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/embeddings", baseUrl), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/chat/completions", baseUrl), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
@@ -39,7 +52,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -54,11 +67,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
@@ -67,7 +78,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
package zhipu_4v
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -79,7 +71,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
|
||||
mediaMessages := message.ParseContent()
|
||||
for j, mediaMessage := range mediaMessages {
|
||||
if mediaMessage.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
|
||||
imageUrl := mediaMessage.GetImageMedia()
|
||||
// check if base64
|
||||
if strings.HasPrefix(imageUrl.Url, "data:image/") {
|
||||
// 去除base64数据的URL前缀(如果有)
|
||||
@@ -119,163 +111,3 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
|
||||
ToolChoice: request.ToolChoice,
|
||||
}
|
||||
}
|
||||
|
||||
//func responseZhipu2OpenAI(response *dto.OpenAITextResponse) *dto.OpenAITextResponse {
|
||||
// fullTextResponse := dto.OpenAITextResponse{
|
||||
// Id: response.Id,
|
||||
// Object: "chat.completion",
|
||||
// Created: common.GetTimestamp(),
|
||||
// Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.TextResponseChoices)),
|
||||
// Usage: response.Usage,
|
||||
// }
|
||||
// for i, choice := range response.TextResponseChoices {
|
||||
// content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
|
||||
// openaiChoice := dto.OpenAITextResponseChoice{
|
||||
// Index: i,
|
||||
// Message: dto.Message{
|
||||
// Role: choice.Role,
|
||||
// Content: content,
|
||||
// },
|
||||
// FinishReason: "",
|
||||
// }
|
||||
// if i == len(response.TextResponseChoices)-1 {
|
||||
// openaiChoice.FinishReason = "stop"
|
||||
// }
|
||||
// fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
|
||||
// }
|
||||
// return &fullTextResponse
|
||||
//}
|
||||
|
||||
func streamResponseZhipu2OpenAI(zhipuResponse *ZhipuV4StreamResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = zhipuResponse.Choices[0].Delta.Content
|
||||
choice.Delta.Role = zhipuResponse.Choices[0].Delta.Role
|
||||
choice.Delta.ToolCalls = zhipuResponse.Choices[0].Delta.ToolCalls
|
||||
choice.Index = zhipuResponse.Choices[0].Index
|
||||
choice.FinishReason = zhipuResponse.Choices[0].FinishReason
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: zhipuResponse.Id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: zhipuResponse.Created,
|
||||
Model: "glm-4v",
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func lastStreamResponseZhipuV42OpenAI(zhipuResponse *ZhipuV4StreamResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
|
||||
response := streamResponseZhipu2OpenAI(zhipuResponse)
|
||||
return response, &zhipuResponse.Usage
|
||||
}
|
||||
|
||||
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var usage *dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
helper.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
if strings.HasPrefix(data, "data: [DONE]") {
|
||||
data = data[:12]
|
||||
}
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
|
||||
var streamResponse ZhipuV4StreamResponse
|
||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
}
|
||||
var response *dto.ChatCompletionsStreamResponse
|
||||
if strings.Contains(data, "prompt_tokens") {
|
||||
response, usage = lastStreamResponseZhipuV42OpenAI(&streamResponse)
|
||||
} else {
|
||||
response = streamResponseZhipu2OpenAI(&streamResponse)
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var textResponse ZhipuV4Response
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if textResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the HTTPClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
return nil, &textResponse.Usage
|
||||
}
|
||||
|
||||
163
relay/claude_handler.go
Normal file
163
relay/claude_handler.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
|
||||
textRequest = &dto.ClaudeRequest{}
|
||||
err = c.ShouldBindJSON(textRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
||||
return nil, errors.New("field messages is required")
|
||||
}
|
||||
if textRequest.Model == "" {
|
||||
return nil, errors.New("field model is required")
|
||||
}
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoClaude(c)
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateClaudeRequest(c)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "invalid_claude_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if textRequest.Stream {
|
||||
relayInfo.IsStream = true
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
textRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
|
||||
if openaiErr != nil {
|
||||
return service.OpenAIErrorToClaudeError(openaiErr)
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.ClaudeErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
|
||||
if textRequest.MaxTokens == 0 {
|
||||
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
|
||||
}
|
||||
|
||||
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
|
||||
strings.HasSuffix(textRequest.Model, "-thinking") {
|
||||
if textRequest.Thinking == nil {
|
||||
// 因为BudgetTokens 必须大于1024
|
||||
if textRequest.MaxTokens < 1280 {
|
||||
textRequest.MaxTokens = 1280
|
||||
}
|
||||
|
||||
// BudgetTokens 为 max_tokens 的 80%
|
||||
textRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
textRequest.TopP = 0
|
||||
textRequest.Temperature = common.GetPointer[float64](1.0)
|
||||
}
|
||||
textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
}
|
||||
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if common.DebugEnabled {
|
||||
println("requestBody: ", string(jsonData))
|
||||
}
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return service.OpenAIErrorToClaudeError(openaiErr)
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
//log.Printf("usage: %v", usage)
|
||||
if openaiErr != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return service.OpenAIErrorToClaudeError(openaiErr)
|
||||
}
|
||||
service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||
var promptTokens int
|
||||
var err error
|
||||
switch info.RelayMode {
|
||||
default:
|
||||
promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName)
|
||||
}
|
||||
info.PromptTokens = promptTokens
|
||||
return promptTokens, err
|
||||
}
|
||||
@@ -15,6 +15,32 @@ import (
|
||||
type ThinkingContentInfo struct {
|
||||
IsFirstThinkingContent bool
|
||||
SendLastThinkingContent bool
|
||||
HasSentThinkingContent bool
|
||||
}
|
||||
|
||||
const (
|
||||
LastMessageTypeNone = "none"
|
||||
LastMessageTypeText = "text"
|
||||
LastMessageTypeTools = "tools"
|
||||
LastMessageTypeThinking = "thinking"
|
||||
)
|
||||
|
||||
type ClaudeConvertInfo struct {
|
||||
LastMessagesType string
|
||||
Index int
|
||||
Usage *dto.Usage
|
||||
FinishReason string
|
||||
Done bool
|
||||
}
|
||||
|
||||
const (
|
||||
RelayFormatOpenAI = "openai"
|
||||
RelayFormatClaude = "claude"
|
||||
)
|
||||
|
||||
type RerankerInfo struct {
|
||||
Documents []any
|
||||
ReturnDocuments bool
|
||||
}
|
||||
|
||||
type RelayInfo struct {
|
||||
@@ -55,10 +81,15 @@ type RelayInfo struct {
|
||||
AudioUsage bool
|
||||
ReasoningEffort string
|
||||
ChannelSetting map[string]interface{}
|
||||
ParamOverride map[string]interface{}
|
||||
UserSetting map[string]interface{}
|
||||
UserEmail string
|
||||
UserQuota int
|
||||
RelayFormat string
|
||||
SendResponseCount int
|
||||
ThinkingContentInfo
|
||||
*ClaudeConvertInfo
|
||||
*RerankerInfo
|
||||
}
|
||||
|
||||
// 定义支持流式选项的通道类型
|
||||
@@ -71,6 +102,7 @@ var streamSupportedChannels = map[int]bool{
|
||||
common.ChannelTypeAzure: true,
|
||||
common.ChannelTypeVolcEngine: true,
|
||||
common.ChannelTypeOllama: true,
|
||||
common.ChannelTypeXai: true,
|
||||
}
|
||||
|
||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||
@@ -82,10 +114,31 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatClaude
|
||||
info.ShouldIncludeUsage = false
|
||||
info.ClaudeConvertInfo = &ClaudeConvertInfo{
|
||||
LastMessagesType: LastMessageTypeNone,
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayMode = relayconstant.RelayModeRerank
|
||||
info.RerankerInfo = &RerankerInfo{
|
||||
Documents: req.Documents,
|
||||
ReturnDocuments: req.GetReturnDocuments(),
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
channelType := c.GetInt("channel_type")
|
||||
channelId := c.GetInt("channel_id")
|
||||
channelSetting := c.GetStringMap("channel_setting")
|
||||
paramOverride := c.GetStringMap("param_override")
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
tokenKey := c.GetString("token_key")
|
||||
@@ -123,6 +176,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
ChannelSetting: channelSetting,
|
||||
ParamOverride: paramOverride,
|
||||
RelayFormat: RelayFormatOpenAI,
|
||||
ThinkingContentInfo: ThinkingContentInfo{
|
||||
IsFirstThinkingContent: true,
|
||||
SendLastThinkingContent: false,
|
||||
@@ -163,6 +218,10 @@ func (info *RelayInfo) SetFirstResponseTime() {
|
||||
}
|
||||
}
|
||||
|
||||
func (info *RelayInfo) HasSendResponse() bool {
|
||||
return info.FirstResponseTime.After(info.StartTime)
|
||||
}
|
||||
|
||||
type TaskRelayInfo struct {
|
||||
*RelayInfo
|
||||
Action string
|
||||
|
||||
68
relay/common_handler/rerank.go
Normal file
68
relay/common_handler/rerank.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package common_handler
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/xinference"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("reranker response body: ", string(responseBody))
|
||||
}
|
||||
var jinaResp dto.RerankResponse
|
||||
if info.ChannelType == common.ChannelTypeXinference {
|
||||
var xinRerankResponse xinference.XinRerankResponse
|
||||
err = common.DecodeJson(responseBody, &xinRerankResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
|
||||
for i, result := range xinRerankResponse.Results {
|
||||
respResult := dto.RerankResponseResult{
|
||||
Index: result.Index,
|
||||
RelevanceScore: result.RelevanceScore,
|
||||
}
|
||||
if info.ReturnDocuments {
|
||||
var document any
|
||||
if result.Document == "" {
|
||||
document = info.Documents[result.Index]
|
||||
} else {
|
||||
document = result.Document
|
||||
}
|
||||
respResult.Document = document
|
||||
}
|
||||
jinaRespResults[i] = respResult
|
||||
}
|
||||
jinaResp = dto.RerankResponse{
|
||||
Results: jinaRespResults,
|
||||
Usage: dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
TotalTokens: info.PromptTokens,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
err = common.DecodeJson(responseBody, &jinaResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.JSON(http.StatusOK, jinaResp)
|
||||
return nil, &jinaResp.Usage
|
||||
}
|
||||
@@ -31,6 +31,8 @@ const (
|
||||
APITypeVolcEngine
|
||||
APITypeBaiduV2
|
||||
APITypeOpenRouter
|
||||
APITypeXinference
|
||||
APITypeXai
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
@@ -89,6 +91,10 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = APITypeBaiduV2
|
||||
case common.ChannelTypeOpenRouter:
|
||||
apiType = APITypeOpenRouter
|
||||
case common.ChannelTypeXinference:
|
||||
apiType = APITypeXinference
|
||||
case common.ChannelTypeXai:
|
||||
apiType = APITypeXai
|
||||
}
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
|
||||
@@ -19,6 +19,30 @@ func SetEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
|
||||
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
|
||||
jsonData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
} else {
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
|
||||
}
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func StringData(c *gin.Context, str string) error {
|
||||
//str = strings.TrimPrefix(str, "data: ")
|
||||
//str = strings.TrimSuffix(str, "\r")
|
||||
@@ -31,7 +55,20 @@ func StringData(c *gin.Context, str string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func PingData(c *gin.Context) error {
|
||||
c.Writer.Write([]byte(": PING\n\n"))
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ObjectData(c *gin.Context, object interface{}) error {
|
||||
if object == nil {
|
||||
return errors.New("object is nil")
|
||||
}
|
||||
jsonData, err := json.Marshal(object)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling object: %w", err)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
@@ -16,9 +17,14 @@ type PriceData struct {
|
||||
CacheRatio float64
|
||||
GroupRatio float64
|
||||
UsePrice bool
|
||||
CacheCreationRatio float64
|
||||
ShouldPreConsumedQuota int
|
||||
}
|
||||
|
||||
func (p PriceData) ToSetting() string {
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota)
|
||||
}
|
||||
|
||||
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
|
||||
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
|
||||
groupRatio := setting.GetGroupRatio(info.Group)
|
||||
@@ -26,6 +32,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
var modelRatio float64
|
||||
var completionRatio float64
|
||||
var cacheRatio float64
|
||||
var cacheCreationRatio float64
|
||||
if !usePrice {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if maxTokens != 0 {
|
||||
@@ -34,26 +41,52 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
var success bool
|
||||
modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
|
||||
if !success {
|
||||
if info.UserId == 1 {
|
||||
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
|
||||
} else {
|
||||
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
|
||||
acceptUnsetRatio := false
|
||||
if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
|
||||
b, ok := accept.(bool)
|
||||
if ok {
|
||||
acceptUnsetRatio = b
|
||||
}
|
||||
}
|
||||
if !acceptUnsetRatio {
|
||||
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
|
||||
}
|
||||
}
|
||||
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
|
||||
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
|
||||
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
return PriceData{
|
||||
|
||||
priceData := PriceData{
|
||||
ModelPrice: modelPrice,
|
||||
ModelRatio: modelRatio,
|
||||
CompletionRatio: completionRatio,
|
||||
GroupRatio: groupRatio,
|
||||
UsePrice: usePrice,
|
||||
CacheRatio: cacheRatio,
|
||||
CacheCreationRatio: cacheCreationRatio,
|
||||
ShouldPreConsumedQuota: preConsumedQuota,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting()))
|
||||
}
|
||||
|
||||
return priceData, nil
|
||||
}
|
||||
|
||||
func ContainPriceOrRatio(modelName string) bool {
|
||||
_, ok := operation_setting.GetModelPrice(modelName, false)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
_, ok = operation_setting.GetModelRatio(modelName)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user