Merge branch 'main' into pr

This commit is contained in:
IcedTangerine
2025-04-25 18:27:11 +08:00
committed by GitHub
201 changed files with 16703 additions and 9350 deletions

View File

@@ -18,20 +18,20 @@ jobs:
contents: read contents: read
steps: steps:
- name: Check out the repo - name: Check out the repo
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Save version info - name: Save version info
run: | run: |
git describe --tags > VERSION git describe --tags > VERSION
- name: Log in to Docker Hub - name: Log in to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v3
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to the Container registry - name: Log in to the Container registry
uses: docker/login-action@v2 uses: docker/login-action@v3
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.actor }} username: ${{ github.actor }}
@@ -39,14 +39,14 @@ jobs:
- name: Extract metadata (tags, labels) for Docker - name: Extract metadata (tags, labels) for Docker
id: meta id: meta
uses: docker/metadata-action@v4 uses: docker/metadata-action@v5
with: with:
images: | images: |
calciumion/new-api calciumion/new-api
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
- name: Build and push Docker images - name: Build and push Docker images
uses: docker/build-push-action@v3 uses: docker/build-push-action@v5
with: with:
context: . context: .
push: true push: true

View File

@@ -4,7 +4,6 @@ on:
push: push:
tags: tags:
- '*' - '*'
- '!*-alpha*'
workflow_dispatch: workflow_dispatch:
inputs: inputs:
name: name:
@@ -19,26 +18,26 @@ jobs:
contents: read contents: read
steps: steps:
- name: Check out the repo - name: Check out the repo
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Save version info - name: Save version info
run: | run: |
git describe --tags > VERSION git describe --tags > VERSION
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v2 uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2 uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub - name: Log in to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v3
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to the Container registry - name: Log in to the Container registry
uses: docker/login-action@v2 uses: docker/login-action@v3
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.actor }} username: ${{ github.actor }}
@@ -46,14 +45,14 @@ jobs:
- name: Extract metadata (tags, labels) for Docker - name: Extract metadata (tags, labels) for Docker
id: meta id: meta
uses: docker/metadata-action@v4 uses: docker/metadata-action@v5
with: with:
images: | images: |
calciumion/new-api calciumion/new-api
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
- name: Build and push Docker images - name: Build and push Docker images
uses: docker/build-push-action@v3 uses: docker/build-push-action@v5
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64

1
.gitignore vendored
View File

@@ -10,3 +10,4 @@ web/dist
.env .env
one-api one-api
.DS_Store .DS_Store
tiktoken_cache

216
README.md
View File

@@ -7,7 +7,6 @@
# New API # New API
🍥新一代大模型网关与AI资产管理系统 🍥新一代大模型网关与AI资产管理系统
<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> <a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
@@ -37,189 +36,137 @@
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发 > 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
> [!IMPORTANT] > [!IMPORTANT]
> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> - 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。 > - 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。
> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 > - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
## 📚 文档
详细文档请访问我们的官方Wiki[https://docs.newapi.pro/](https://docs.newapi.pro/)
## ✨ 主要特性 ## ✨ 主要特性
1. 🎨 全新的UI界面部分界面还待更新 New API提供了丰富的功能详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction)
2. 🌍 多语言支持(待完善)
3. 🎨 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口支持,[对接文档](Midjourney.md) 1. 🎨 全新的UI界面
4. 💰 支持在线充值功能,可在系统设置中设置: 2. 🌍 多语言支持
- [x] 易支付 3. 💰 支持在线充值功能(易支付
5. 🔍 支持用key查询使用额度 4. 🔍 支持用key查询使用额度(配合[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)
- 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用 5. 🔄 兼容原版One API的数据库
6. 📑 分页支持选择每页显示数量 6. 💵 支持模型按次数收费
7. 🔄 兼容原版One API的数据库可直接使用原版数据库one-api.db 7. ⚖️ 支持渠道加权随机
8. 💵 支持模型按次数收费,可在 系统设置-运营设置 中设置 8. 📈 数据看板(控制台)
9. ⚖️ 支持渠道**加权随机** 9. 🔒 令牌分组、模型限制
10. 📈 数据看板(控制台 10. 🤖 支持更多授权登陆方式LinuxDO,Telegram、OIDC
11. 🔒 可设置令牌能调用的模型 11. 🔄 支持Rerank模型Cohere和Jina[接口文档](https://docs.newapi.pro/api/jinaai-rerank)
12. 🤖 支持Telegram授权登录 12. 支持OpenAI Realtime API包括Azure渠道[接口文档](https://docs.newapi.pro/api/openai-realtime)
1. 系统设置-配置登录注册-允许通过Telegram登录 13. ⚡ 支持Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain 14. 支持使用路由/chat2link进入聊天界面
3. 选择你的bot然后输入http(s)://你的网站地址/login 15. 🧠 支持通过模型名称后缀设置 reasoning effort
4. Telegram Bot 名称是bot username 去掉@后的字符串
13. 🎵 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口支持,[对接文档](Suno.md)
14. 🔄 支持Rerank模型目前兼容Cohere和Jina可接入Dify[对接文档](Rerank.md)
15.**[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API支持Azure渠道
16. 支持使用路由/chat2link 进入聊天界面
17. 🧠 支持通过模型名称后缀设置 reasoning effort
1. OpenAI o系列模型 1. OpenAI o系列模型
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`) - 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`) - 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`) - 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
2. Claude 思考模型 2. Claude 思考模型
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`) - 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
18. 🔄 思考转内容,支持在 `渠道-编辑-渠道额外设置` 中设置 `thinking_to_content` 选项,默认`false`,开启后会将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回。 16. 🔄 思考转内容功能
19. 🔄 模型限流,支持在 `系统设置-速率限制设置` 中设置模型限流,支持设置总请求数限制和成功请求数限制 17. 🔄 针对用户的模型限流功能
20. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费: 18. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
1.`系统设置-运营设置` 中设置 `提示缓存倍率` 选项 1.`系统设置-运营设置` 中设置 `提示缓存倍率` 选项
2. 在渠道中设置 `提示缓存倍率`,范围 0-1例如设置为 0.5 表示缓存命中时按照 50% 计费 2. 在渠道中设置 `提示缓存倍率`,范围 0-1例如设置为 0.5 表示缓存命中时按照 50% 计费
3. 支持的渠道: 3. 支持的渠道:
- [x] OpenAI - [x] OpenAI
- [x] Azure - [x] Azure
- [x] DeepSeek - [x] DeepSeek
- [ ] Claude - [x] Claude
## 模型支持 ## 模型支持
此版本额外支持以下模型:
此版本支持多种模型,详情请参考[接口文档-中继接口](https://docs.newapi.pro/api)
1. 第三方模型 **gpts** gpt-4-gizmo-* 1. 第三方模型 **gpts** gpt-4-gizmo-*
2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[接文档](Midjourney.md) 2. 第三方渠道[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[文档](https://docs.newapi.pro/api/midjourney-proxy-image)
3. 自定义渠道,支持填入完整调用地址 3. 第三方渠道[Suno API](https://github.com/Suno-API/Suno-API)接口,[接口文档](https://docs.newapi.pro/api/suno-music)
4. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md) 4. 自定义渠道,支持填入完整调用地址
5. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[接文档](Rerank.md) 5. Rerank模型[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[文档](https://docs.newapi.pro/api/jinaai-rerank)
6. Dify 6. Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
7. Dify当前仅支持chatflow
您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。 ## 环境变量配置
## 比原版One API多出的配置 详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables)
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 60 秒。
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。
- `GET_MEDIA_TOKEN`是否统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`Gemini模型最大图片数量默认为 `16`,设置为 `-1` 则不限制。
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB默认为 `20`
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
- `AZURE_DEFAULT_API_VERSION`Azure渠道默认API版本如果渠道设置中未指定API版本则使用此版本默认为 `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`
## 已废弃的环境变量 - `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- ~~`GEMINI_MODEL_MAP`(已废弃)~~:改为到`设置-模型相关设置`中设置 - `STREAMING_TIMEOUT`流式回复超时时间默认60秒
- ~~`GEMINI_SAFETY_SETTING`(已废弃)~~:改为到`设置-模型相关设置`中设置 - `DIFY_DEBUG`Dify渠道是否输出工作流和节点信息默认 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数默认 `true`
- `GET_MEDIA_TOKEN`是否统计图片token默认 `true`
- `GET_MEDIA_TOKEN_NOT_STREAM`非流情况下是否统计图片token默认 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认 `true`
- `COHERE_SAFETY_SETTING`Cohere模型安全设置可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认 `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`Gemini模型最大图片数量默认 `16`
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小单位MB默认 `20`
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容
- `AZURE_DEFAULT_API_VERSION`Azure渠道默认API版本默认 `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`
## 部署 ## 部署
详细部署指南请参考[安装指南-部署方式](https://docs.newapi.pro/installation)
> [!TIP] > [!TIP]
> 最新版Docker镜像`calciumion/new-api:latest` > 最新版Docker镜像`calciumion/new-api:latest`
> 默认账号root 密码123456
### 多机部署 ### 多机部署注意事项
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致 - 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致
- 如果公用Redis必须设置 `CRYPTO_SECRET`否则会导致多机部署时Redis内容无法获取 - 如果公用Redis必须设置 `CRYPTO_SECRET`否则会导致多机部署时Redis内容无法获取
### 部署要求 ### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机 - 本地数据库默认SQLiteDocker部署必须挂载`/data`目录)
- 远程数据库MySQL版本 >= 5.7.8PgSQL版本 >= 9.6 - 远程数据库MySQL版本 >= 5.7.8PgSQL版本 >= 9.6
### 使用宝塔面板Docker功能部署 ### 部署方式
安装宝塔面板 (**9.2.0版本**及以上),前往 [宝塔面板](https://www.bt.cn/new/download.html) 官网,选择正式版的脚本下载安装
安装后登录宝塔面板,在菜单栏中点击 Docker ,首次进入会提示安装 Docker 服务,点击立即安装,按提示完成安装 #### 使用宝塔面板Docker功能部署
安装完成后在应用商店中找到 **New-API** ,点击安装,配置基本选项 即可完成安装 安装宝塔面板(**9.2.0版本**及以上),在应用商店中找到**New-API**安装即可。
[图文教程](BT.md) [图文教程](BT.md)
### 基于 Docker 进行部署 #### 使用Docker Compose部署推荐
> [!TIP]
> 默认管理员账号root 密码123456
### 使用 Docker Compose 部署(推荐)
```shell ```shell
# 下载项目 # 下载项目
git clone https://github.com/Calcium-Ion/new-api.git git clone https://github.com/Calcium-Ion/new-api.git
cd new-api cd new-api
# 按需编辑docker-compose.yml # 按需编辑docker-compose.yml
# nano docker-compose.yml
# vim docker-compose.yml
# 启动 # 启动
docker-compose up -d docker-compose up -d
``` ```
#### 更新版本 #### 直接使用Docker镜像
```shell ```shell
docker-compose pull # 使用SQLite
docker-compose up -d
```
### 直接使用 Docker 镜像
```shell
# 使用 SQLite 的部署命令:
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。 # 使用MySQL
# 例如:
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
``` ```
#### 更新版本 ## 渠道重试与缓存
```shell
# 拉取最新镜像
docker pull calciumion/new-api:latest
# 停止并删除旧容器
docker stop new-api
docker rm new-api
# 使用相同参数运行新容器
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
或者使用 Watchtower 自动更新(不推荐,可能会导致数据库不兼容):
```shell
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
```
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。 渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
如果开启了重试功能,重试使用下一个优先级,以此类推。
### 缓存设置方法 ### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 1. `REDIS_CONN_STRING`设置Redis作为缓存
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` 2. `MEMORY_CACHE_ENABLED`启用内存缓存设置了Redis则无需手动设置
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`MEMORY_CACHE_ENABLED=true`
### 为什么有的时候没有重试
这些错误码不会重试400504524
### 我想让400也重试
`渠道->编辑`中,将`状态码复写`改为
```json
{
"400": "500"
}
```
可以实现400错误转为500错误从而重试
## Midjourney接口设置文档 ## 接口文档
[对接文档](Midjourney.md)
## Suno接口设置文档 详细接口文档请参考[接口文档](https://docs.newapi.pro/api)
[对接文档](Suno.md)
## 界面截图 - [聊天接口Chat](https://docs.newapi.pro/api/openai-chat)
![image](https://github.com/user-attachments/assets/a0dcd349-5df8-4dc8-9acf-ca272b239919) - [图像接口Image](https://docs.newapi.pro/api/openai-image)
- [重排序接口Rerank](https://docs.newapi.pro/api/jinaai-rerank)
- [实时对话接口Realtime](https://docs.newapi.pro/api/openai-realtime)
![image](https://github.com/user-attachments/assets/c7d0f7e1-729c-43e2-ac7c-2cb73b0afc8e) - [Claude聊天接口messages](https://docs.newapi.pro/api/anthropic-chat)
![image](https://github.com/user-attachments/assets/29f81de5-33fc-4fc5-a5ff-f9b54b653c7c)
![image](https://github.com/user-attachments/assets/4fa53e18-d2c5-477a-9b26-b86e44c71e35)
## 交流群
<img src="https://github.com/user-attachments/assets/9ca0bc82-e057-4230-a28d-9f198fa022e3" width="200">
## 相关项目 ## 相关项目
- [One API](https://github.com/songquanpeng/one-api):原版项目 - [One API](https://github.com/songquanpeng/one-api):原版项目
@@ -228,8 +175,15 @@ docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtow
- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)用key查询使用额度 - [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)用key查询使用额度
其他基于New API的项目 其他基于New API的项目
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon)New API高性能优化版专注于高并发优化并支持Claude格式 - [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon)New API高性能优化版
- [VoAPI](https://github.com/VoAPI/VoAPI)基于New API的前端美化版本,闭源免费 - [VoAPI](https://github.com/VoAPI/VoAPI)基于New API的前端美化版本
## 帮助支持
如有问题,请参考[帮助支持](https://docs.newapi.pro/support)
- [社区交流](https://docs.newapi.pro/support/community-interaction)
- [反馈问题](https://docs.newapi.pro/support/feedback-issues)
- [常见问题](https://docs.newapi.pro/support/faq)
## 🌟 Star History ## 🌟 Star History

View File

@@ -1,8 +1,8 @@
package common package common
import ( import (
"os" //"os"
"strconv" //"strconv"
"sync" "sync"
"time" "time"
@@ -63,8 +63,8 @@ var EmailDomainWhitelist = []string{
"foxmail.com", "foxmail.com",
} }
var DebugEnabled = os.Getenv("DEBUG") == "true" var DebugEnabled bool
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" var MemoryCacheEnabled bool
var LogConsumeEnabled = true var LogConsumeEnabled = true
@@ -77,7 +77,6 @@ var SMTPToken = ""
var GitHubClientId = "" var GitHubClientId = ""
var GitHubClientSecret = "" var GitHubClientSecret = ""
var LinuxDOClientId = "" var LinuxDOClientId = ""
var LinuxDOClientSecret = "" var LinuxDOClientSecret = ""
@@ -104,22 +103,22 @@ var RetryTimes = 0
//var RootUserEmail = "" //var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" var IsMasterNode bool
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var requestInterval int
var RequestInterval = time.Duration(requestInterval) * time.Second var RequestInterval time.Duration
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second var SyncFrequency int // unit is second
var BatchUpdateEnabled = false var BatchUpdateEnabled = false
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5) var BatchUpdateInterval int
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second var RelayTimeout int // unit is second
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") var GeminiSafetySetting string
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT // https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
var CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE") var CohereSafetySetting string
const ( const (
RequestIdKey = "X-Oneapi-Request-Id" RequestIdKey = "X-Oneapi-Request-Id"
@@ -146,13 +145,13 @@ var (
// All duration's unit is seconds // All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration // Shouldn't larger then RateLimitKeyExpirationDuration
var ( var (
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true) GlobalApiRateLimitEnable bool
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitNum int
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180)) GlobalApiRateLimitDuration int64
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true) GlobalWebRateLimitEnable bool
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitNum int
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180)) GlobalWebRateLimitDuration int64
UploadRateLimitNum = 10 UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60 UploadRateLimitDuration int64 = 60
@@ -235,6 +234,8 @@ const (
ChannelTypeMokaAI = 44 ChannelTypeMokaAI = 44
ChannelTypeVolcEngine = 45 ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46 ChannelTypeBaiduV2 = 46
ChannelTypeXinference = 47
ChannelTypeXai = 48
ChannelTypeDummy // this one is only for count, do not add any channel after this ChannelTypeDummy // this one is only for count, do not add any channel after this
) )
@@ -287,4 +288,6 @@ var ChannelBaseURLs = []string{
"https://api.moka.ai", //44 "https://api.moka.ai", //44
"https://ark.cn-beijing.volces.com", //45 "https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46 "https://qianfan.baidubce.com", //46
"", //47
"https://api.x.ai", //48
} }

View File

@@ -44,7 +44,7 @@ var fieldReplacer = strings.NewReplacer(
"\r", "\\r") "\r", "\\r")
var dataReplacer = strings.NewReplacer( var dataReplacer = strings.NewReplacer(
"\n", "\ndata:", "\n", "\n",
"\r", "\\r") "\r", "\\r")
type CustomEvent struct { type CustomEvent struct {

View File

@@ -6,6 +6,8 @@ import (
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"time"
) )
var ( var (
@@ -66,4 +68,31 @@ func LoadEnv() {
} }
} }
} }
// Initialize variables from constants.go that were using environment variables
DebugEnabled = os.Getenv("DEBUG") == "true"
MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
// Parse requestInterval and set RequestInterval
requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
RequestInterval = time.Duration(requestInterval) * time.Second
// Initialize variables with GetEnvOrDefault
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
// Initialize string variables with GetEnvOrDefaultString
GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
// Initialize rate limit variables
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
} }

18
common/json.go Normal file
View File

@@ -0,0 +1,18 @@
package common
import (
"bytes"
"encoding/json"
)
func DecodeJson(data []byte, v any) error {
return json.NewDecoder(bytes.NewReader(data)).Decode(v)
}
func DecodeJsonStr(data string, v any) error {
return DecodeJson(StringToByteSlice(data), v)
}
func EncodeJson(v any) ([]byte, error) {
return json.Marshal(v)
}

89
common/limiter/limiter.go Normal file
View File

@@ -0,0 +1,89 @@
package limiter
import (
"context"
_ "embed"
"fmt"
"github.com/go-redis/redis/v8"
"one-api/common"
"sync"
)
//go:embed lua/rate_limit.lua
var rateLimitScript string
type RedisLimiter struct {
client *redis.Client
limitScriptSHA string
}
var (
instance *RedisLimiter
once sync.Once
)
func New(ctx context.Context, r *redis.Client) *RedisLimiter {
once.Do(func() {
// 预加载脚本
limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
if err != nil {
common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
}
instance = &RedisLimiter{
client: r,
limitScriptSHA: limitSHA,
}
})
return instance
}
func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
// 默认配置
config := &Config{
Capacity: 10,
Rate: 1,
Requested: 1,
}
// 应用选项模式
for _, opt := range opts {
opt(config)
}
// 执行限流
result, err := rl.client.EvalSha(
ctx,
rl.limitScriptSHA,
[]string{key},
config.Requested,
config.Rate,
config.Capacity,
).Int()
if err != nil {
return false, fmt.Errorf("rate limit failed: %w", err)
}
return result == 1, nil
}
// Config 配置选项模式
type Config struct {
Capacity int64
Rate int64
Requested int64
}
type Option func(*Config)
func WithCapacity(c int64) Option {
return func(cfg *Config) { cfg.Capacity = c }
}
func WithRate(r int64) Option {
return func(cfg *Config) { cfg.Rate = r }
}
func WithRequested(n int64) Option {
return func(cfg *Config) { cfg.Requested = n }
}

View File

@@ -0,0 +1,44 @@
-- 令牌桶限流器
-- KEYS[1]: 限流器唯一标识
-- ARGV[1]: 请求令牌数 (通常为1)
-- ARGV[2]: 令牌生成速率 (每秒)
-- ARGV[3]: 桶容量
local key = KEYS[1]
local requested = tonumber(ARGV[1])
local rate = tonumber(ARGV[2])
local capacity = tonumber(ARGV[3])
-- 获取当前时间Redis服务器时间
local now = redis.call('TIME')
local nowInSeconds = tonumber(now[1])
-- 获取桶状态
local bucket = redis.call('HMGET', key, 'tokens', 'last_time')
local tokens = tonumber(bucket[1])
local last_time = tonumber(bucket[2])
-- 初始化桶(首次请求或过期)
if not tokens or not last_time then
tokens = capacity
last_time = nowInSeconds
else
-- 计算新增令牌
local elapsed = nowInSeconds - last_time
local add_tokens = elapsed * rate
tokens = math.min(capacity, tokens + add_tokens)
last_time = nowInSeconds
end
-- 判断是否允许请求
local allowed = false
if tokens >= requested then
tokens = tokens - requested
allowed = true
end
---- 更新桶状态并设置过期时间
redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time)
--redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间
return allowed and 1 or 0

View File

@@ -4,32 +4,39 @@ import (
"one-api/common" "one-api/common"
) )
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60) var StreamingTimeout int
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true) var DifyDebug bool
var MaxFileDownloadMB int
var MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20) var ForceStreamOption bool
var GetMediaToken bool
// ForceStreamOption 覆盖请求参数强制返回usage信息 var GetMediaTokenNotStream bool
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) var UpdateTask bool
var AzureDefaultAPIVersion string
var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) var GeminiVisionMaxImageNum int
var NotifyLimitCount int
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) var NotificationLimitDurationMinute int
var GenerateDefaultToken bool
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
//var GeminiModelMap = map[string]string{ //var GeminiModelMap = map[string]string{
// "gemini-1.0-pro": "v1", // "gemini-1.0-pro": "v1",
//} //}
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
func InitEnv() { func InitEnv() {
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) //modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
//if modelVersionMapStr == "" { //if modelVersionMapStr == "" {
// return // return
@@ -43,6 +50,3 @@ func InitEnv() {
// } // }
//} //}
} }
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)

3
constant/setup.go Normal file
View File

@@ -0,0 +1,3 @@
package constant
var Setup = false

View File

@@ -6,6 +6,7 @@ var (
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址 UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥 UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址 UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
) )
var ( var (

View File

@@ -103,11 +103,19 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
} }
request := buildTestRequest(testModel) request := buildTestRequest(testModel)
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info)) // 创建一个用于日志的 info 副本,移除 ApiKey
logInfo := *info
logInfo.ApiKey = ""
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
if err != nil {
return err, nil
}
adaptor.Init(info) adaptor.Init(info)
convertedRequest, err := adaptor.ConvertRequest(c, info, request) convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
if err != nil { if err != nil {
return err, nil return err, nil
} }
@@ -125,7 +133,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if resp != nil { if resp != nil {
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(httpResp) err := service.RelayErrorHandler(httpResp, true)
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
} }
} }
@@ -143,10 +151,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return err, nil return err, nil
} }
info.PromptTokens = usage.PromptTokens info.PromptTokens = usage.PromptTokens
priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens))
if err != nil {
return err, nil
}
quota := 0 quota := 0
if !priceData.UsePrice { if !priceData.UsePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio)) quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
@@ -184,10 +189,14 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
return testRequest return testRequest
} }
// 并非Embedding 模型 // 并非Embedding 模型
if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") { if strings.HasPrefix(model, "o") {
testRequest.MaxCompletionTokens = 10 testRequest.MaxCompletionTokens = 10
} else if strings.Contains(model, "thinking") { } else if strings.Contains(model, "thinking") {
if !strings.Contains(model, "claude") {
testRequest.MaxTokens = 50 testRequest.MaxTokens = 50
}
} else if strings.Contains(model, "gemini") {
testRequest.MaxTokens = 300
} else { } else {
testRequest.MaxTokens = 10 testRequest.MaxTokens = 10
} }

View File

@@ -119,6 +119,9 @@ func FetchUpstreamModels(c *gin.Context) {
baseURL = channel.GetBaseURL() baseURL = channel.GetBaseURL()
} }
url := fmt.Sprintf("%s/v1/models", baseURL) url := fmt.Sprintf("%s/v1/models", baseURL)
if channel.Type == common.ChannelTypeGemini {
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
}
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -139,7 +142,11 @@ func FetchUpstreamModels(c *gin.Context) {
var ids []string var ids []string
for _, model := range result.Data { for _, model := range result.Data {
ids = append(ids, model.ID) id := model.ID
if channel.Type == common.ChannelTypeGemini {
id = strings.TrimPrefix(id, "models/")
}
ids = append(ids, id)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

9
controller/image.go Normal file
View File

@@ -0,0 +1,9 @@
package controller
import (
"github.com/gin-gonic/gin"
)
func GetImage(c *gin.Context) {
}

View File

@@ -5,9 +5,11 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/model" "one-api/model"
"one-api/setting" "one-api/setting"
"one-api/setting/operation_setting" "one-api/setting/operation_setting"
"one-api/setting/system_setting"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -68,6 +70,10 @@ func GetStatus(c *gin.Context) {
"chats": setting.Chats, "chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled, "demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled, "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
"setup": constant.Setup,
}, },
}) })
return return

240
controller/oidc.go Normal file
View File

@@ -0,0 +1,240 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"one-api/setting"
"one-api/setting/system_setting"
"strconv"
"strings"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type OidcResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type OidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := url.Values{}
values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
formData := values.Encode()
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res.Body.Close()
var oidcResponse OidcResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
return nil, err
}
if oidcResponse.AccessToken == "" {
common.SysError("OIDC 获取 Token 失败,请检查设置!")
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
}
req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusOK {
common.SysError("OIDC 获取用户信息失败!请检查设置!")
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
}
var oidcUser OidcUser
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
if err != nil {
return nil, err
}
if oidcUser.OpenID == "" || oidcUser.Email == "" {
common.SysError("OIDC 获取用户信息为空!请检查设置!")
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
}
return &oidcUser, nil
}
func OidcAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
OidcBind(c)
return
}
if !system_setting.GetOIDCSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
err := user.FillUserByOidcId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
user.Email = oidcUser.Email
if oidcUser.PreferredUsername != "" {
user.Username = oidcUser.PreferredUsername
} else {
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if oidcUser.Name != "" {
user.DisplayName = oidcUser.Name
} else {
user.DisplayName = "OIDC User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func OidcBind(c *gin.Context) {
if !system_setting.GetOIDCSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 OIDC 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.OidcId = oidcUser.OpenID
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -6,6 +6,7 @@ import (
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"one-api/setting" "one-api/setting"
"one-api/setting/system_setting"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -51,6 +52,14 @@ func UpdateOption(c *gin.Context) {
}) })
return return
} }
case "oidc.enabled":
if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 OIDC 登录,请先填入 OIDC Client Id 以及 OIDC Client Secret",
})
return
}
case "LinuxDOOAuthEnabled": case "LinuxDOOAuthEnabled":
if option.Value == "true" && common.LinuxDOClientId == "" { if option.Value == "true" && common.LinuxDOClientId == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -81,6 +90,15 @@ func UpdateOption(c *gin.Context) {
"success": false, "success": false,
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
}) })
return
}
case "TelegramOAuthEnabled":
if option.Value == "true" && common.TelegramBotToken == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 Telegram OAuth请先填入 Telegram Bot Token",
})
return return
} }
case "GroupRatio": case "GroupRatio":
@@ -92,6 +110,7 @@ func UpdateOption(c *gin.Context) {
}) })
return return
} }
} }
err = model.UpdateOption(option.Key, option.Value) err = model.UpdateOption(option.Key, option.Value)
if err != nil { if err != nil {

View File

@@ -148,6 +148,50 @@ func WssRelay(c *gin.Context) {
} }
} }
func RelayClaude(c *gin.Context) {
//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
var claudeErr *dto.ClaudeErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
claudeErr = claudeRequest(c, channel)
if claudeErr == nil {
return // 成功处理请求,直接返回
}
openaiErr := service.ClaudeErrorToOpenAIError(claudeErr)
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if claudeErr != nil {
claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId)
c.JSON(claudeErr.StatusCode, gin.H{
"type": "error",
"error": claudeErr.Error,
})
}
}
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id) addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c) requestBody, _ := common.GetRequestBody(c)
@@ -162,6 +206,13 @@ func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *mode
return relay.WssHelper(c, ws) return relay.WssHelper(c, ws)
} }
func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relay.ClaudeHelper(c)
}
func addUsedChannel(c *gin.Context, channelId int) { func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) useChannel = append(useChannel, fmt.Sprintf("%d", channelId))

173
controller/setup.go Normal file
View File

@@ -0,0 +1,173 @@
package controller
import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/setting/operation_setting"
"time"
)
type Setup struct {
Status bool `json:"status"`
RootInit bool `json:"root_init"`
DatabaseType string `json:"database_type"`
}
type SetupRequest struct {
Username string `json:"username"`
Password string `json:"password"`
ConfirmPassword string `json:"confirmPassword"`
SelfUseModeEnabled bool `json:"SelfUseModeEnabled"`
DemoSiteEnabled bool `json:"DemoSiteEnabled"`
}
func GetSetup(c *gin.Context) {
setup := Setup{
Status: constant.Setup,
}
if constant.Setup {
c.JSON(200, gin.H{
"success": true,
"data": setup,
})
return
}
setup.RootInit = model.RootUserExists()
if common.UsingMySQL {
setup.DatabaseType = "mysql"
}
if common.UsingPostgreSQL {
setup.DatabaseType = "postgres"
}
if common.UsingSQLite {
setup.DatabaseType = "sqlite"
}
c.JSON(200, gin.H{
"success": true,
"data": setup,
})
}
func PostSetup(c *gin.Context) {
// Check if setup is already completed
if constant.Setup {
c.JSON(400, gin.H{
"success": false,
"message": "系统已经初始化完成",
})
return
}
// Check if root user already exists
rootExists := model.RootUserExists()
var req SetupRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(400, gin.H{
"success": false,
"message": "请求参数有误",
})
return
}
// If root doesn't exist, validate and create admin account
if !rootExists {
// Validate password
if req.Password != req.ConfirmPassword {
c.JSON(400, gin.H{
"success": false,
"message": "两次输入的密码不一致",
})
return
}
if len(req.Password) < 8 {
c.JSON(400, gin.H{
"success": false,
"message": "密码长度至少为8个字符",
})
return
}
// Create root user
hashedPassword, err := common.Password2Hash(req.Password)
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "系统错误: " + err.Error(),
})
return
}
rootUser := model.User{
Username: req.Username,
Password: hashedPassword,
Role: common.RoleRootUser,
Status: common.UserStatusEnabled,
DisplayName: "Root User",
AccessToken: nil,
Quota: 100000000,
}
err = model.DB.Create(&rootUser).Error
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "创建管理员账号失败: " + err.Error(),
})
return
}
}
// Set operation modes
operation_setting.SelfUseModeEnabled = req.SelfUseModeEnabled
operation_setting.DemoSiteEnabled = req.DemoSiteEnabled
// Save operation modes to database for persistence
err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "保存自用模式设置失败: " + err.Error(),
})
return
}
err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "保存演示站点模式设置失败: " + err.Error(),
})
return
}
// Update setup status
constant.Setup = true
setup := model.Setup{
Version: common.Version,
InitializedAt: time.Now().Unix(),
}
err = model.DB.Create(&setup).Error
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "系统初始化失败: " + err.Error(),
})
return
}
c.JSON(200, gin.H{
"success": true,
"message": "系统初始化成功",
})
}
func boolToString(b bool) string {
if b {
return "true"
}
return "false"
}

View File

@@ -918,6 +918,7 @@ type UpdateUserSettingRequest struct {
WebhookUrl string `json:"webhook_url,omitempty"` WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"` WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"` NotificationEmail string `json:"notification_email,omitempty"`
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
} }
func UpdateUserSetting(c *gin.Context) { func UpdateUserSetting(c *gin.Context) {
@@ -993,6 +994,7 @@ func UpdateUserSetting(c *gin.Context) {
settings := map[string]interface{}{ settings := map[string]interface{}{
constant.UserSettingNotifyType: req.QuotaWarningType, constant.UserSettingNotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold, constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
} }
// 如果是webhook类型,添加webhook相关设置 // 如果是webhook类型,添加webhook相关设置

View File

@@ -15,6 +15,7 @@ services:
- SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service - SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
- REDIS_CONN_STRING=redis://redis - REDIS_CONN_STRING=redis://redis
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache请取消注释
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!! # - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment # - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed # - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed

View File

@@ -11,7 +11,7 @@
- 类型为字符串,填写代理地址(例如 socks5 协议的代理地址) - 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
3. thinking_to_content 3. thinking_to_content
- 用于标识是否将思考内容`reasoning_conetnt`转换为`<think>`标签拼接到内容中返回 - 用于标识是否将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回
- 类型为布尔值,设置为 true 时启用思考内容转换 - 类型为布尔值,设置为 true 时启用思考内容转换
-------------------------------------------------------------- --------------------------------------------------------------

218
dto/claude.go Normal file
View File

@@ -0,0 +1,218 @@
package dto
import "encoding/json"
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeMediaMessage struct {
Type string `json:"type,omitempty"`
Text *string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
PartialJson *string `json:"partial_json,omitempty"`
Role string `json:"role,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
Delta string `json:"delta,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"`
}
func (c *ClaudeMediaMessage) SetText(s string) {
c.Text = &s
}
func (c *ClaudeMediaMessage) GetText() string {
if c.Text == nil {
return ""
}
return *c.Text
}
func (c *ClaudeMediaMessage) IsStringContent() bool {
var content string
return json.Unmarshal(c.Content, &content) == nil
}
func (c *ClaudeMediaMessage) GetStringContent() string {
var content string
if err := json.Unmarshal(c.Content, &content); err == nil {
return content
}
return ""
}
func (c *ClaudeMediaMessage) GetJsonRowString() string {
jsonContent, _ := json.Marshal(c)
return string(jsonContent)
}
func (c *ClaudeMediaMessage) SetContent(content any) {
jsonContent, _ := json.Marshal(content)
c.Content = jsonContent
}
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
var mediaContent []ClaudeMediaMessage
if err := json.Unmarshal(c.Content, &mediaContent); err == nil {
return mediaContent
}
return make([]ClaudeMediaMessage, 0)
}
type ClaudeMessageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type,omitempty"`
Data any `json:"data,omitempty"`
Url string `json:"url,omitempty"`
}
type ClaudeMessage struct {
Role string `json:"role"`
Content any `json:"content"`
}
func (c *ClaudeMessage) IsStringContent() bool {
_, ok := c.Content.(string)
return ok
}
func (c *ClaudeMessage) GetStringContent() string {
if c.IsStringContent() {
return c.Content.(string)
}
return ""
}
func (c *ClaudeMessage) SetStringContent(content string) {
c.Content = content
}
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
// map content to []ClaudeMediaMessage
// parse to json
jsonContent, _ := json.Marshal(c.Content)
var contentList []ClaudeMediaMessage
err := json.Unmarshal(jsonContent, &contentList)
if err != nil {
return make([]ClaudeMediaMessage, 0), err
}
return contentList, nil
}
type Tool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema map[string]interface{} `json:"input_schema"`
}
type InputSchema struct {
Type string `json:"type"`
Properties any `json:"properties,omitempty"`
Required any `json:"required,omitempty"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
System any `json:"system,omitempty"`
Messages []ClaudeMessage `json:"messages,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"`
}
type Thinking struct {
Type string `json:"type"`
BudgetTokens int `json:"budget_tokens"`
}
func (c *ClaudeRequest) IsStringSystem() bool {
_, ok := c.System.(string)
return ok
}
func (c *ClaudeRequest) GetStringSystem() string {
if c.IsStringSystem() {
return c.System.(string)
}
return ""
}
func (c *ClaudeRequest) SetStringSystem(system string) {
c.System = system
}
func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
// map content to []ClaudeMediaMessage
// parse to json
jsonContent, _ := json.Marshal(c.System)
var contentList []ClaudeMediaMessage
if err := json.Unmarshal(jsonContent, &contentList); err == nil {
return contentList
}
return make([]ClaudeMediaMessage, 0)
}
type ClaudeError struct {
Type string `json:"type,omitempty"`
Message string `json:"message,omitempty"`
}
type ClaudeErrorWithStatusCode struct {
Error ClaudeError `json:"error"`
StatusCode int `json:"status_code"`
LocalError bool
}
type ClaudeResponse struct {
Id string `json:"id,omitempty"`
Type string `json:"type"`
Role string `json:"role,omitempty"`
Content []ClaudeMediaMessage `json:"content,omitempty"`
Completion string `json:"completion,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
Model string `json:"model,omitempty"`
Error *ClaudeError `json:"error,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
Index *int `json:"index,omitempty"`
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
Delta *ClaudeMediaMessage `json:"delta,omitempty"`
Message *ClaudeMediaMessage `json:"message,omitempty"`
}
// set index
func (c *ClaudeResponse) SetIndex(i int) {
c.Index = &i
}
// get index
func (c *ClaudeResponse) GetIndex() int {
if c.Index == nil {
return 0
}
return *c.Index
}
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
OutputTokens int `json:"output_tokens"`
}

View File

@@ -1,5 +1,7 @@
package dto package dto
import "encoding/json"
type ImageRequest struct { type ImageRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"` Prompt string `json:"prompt" binding:"required"`
@@ -9,6 +11,7 @@ type ImageRequest struct {
ResponseFormat string `json:"response_format,omitempty"` ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"` Style string `json:"style,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
} }
type ImageResponse struct { type ImageResponse struct {

View File

@@ -28,6 +28,7 @@ type GeneralOpenAIRequest struct {
MaxTokens uint `json:"max_tokens,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"`
//Reasoning json.RawMessage `json:"reasoning,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
@@ -111,11 +112,38 @@ type MediaContent struct {
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
ImageUrl any `json:"image_url,omitempty"` ImageUrl any `json:"image_url,omitempty"`
InputAudio any `json:"input_audio,omitempty"` InputAudio any `json:"input_audio,omitempty"`
File any `json:"file,omitempty"`
}
func (m *MediaContent) GetImageMedia() *MessageImageUrl {
if m.ImageUrl != nil {
return m.ImageUrl.(*MessageImageUrl)
}
return nil
}
func (m *MediaContent) GetInputAudio() *MessageInputAudio {
if m.InputAudio != nil {
return m.InputAudio.(*MessageInputAudio)
}
return nil
}
func (m *MediaContent) GetFile() *MessageFile {
if m.File != nil {
return m.File.(*MessageFile)
}
return nil
} }
type MessageImageUrl struct { type MessageImageUrl struct {
Url string `json:"url"` Url string `json:"url"`
Detail string `json:"detail"` Detail string `json:"detail"`
MimeType string
}
func (m *MessageImageUrl) IsRemoteImage() bool {
return strings.HasPrefix(m.Url, "http")
} }
type MessageInputAudio struct { type MessageInputAudio struct {
@@ -123,10 +151,17 @@ type MessageInputAudio struct {
Format string `json:"format"` Format string `json:"format"`
} }
type MessageFile struct {
FileName string `json:"filename,omitempty"`
FileData string `json:"file_data,omitempty"`
FileId string `json:"file_id,omitempty"`
}
const ( const (
ContentTypeText = "text" ContentTypeText = "text"
ContentTypeImageURL = "image_url" ContentTypeImageURL = "image_url"
ContentTypeInputAudio = "input_audio" ContentTypeInputAudio = "input_audio"
ContentTypeFile = "file"
) )
func (m *Message) GetPrefix() bool { func (m *Message) GetPrefix() bool {
@@ -180,6 +215,12 @@ func (m *Message) StringContent() string {
return stringContent return stringContent
} }
func (m *Message) SetNullContent() {
m.Content = nil
m.parsedStringContent = nil
m.parsedContent = nil
}
func (m *Message) SetStringContent(content string) { func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content) jsonContent, _ := json.Marshal(content)
m.Content = jsonContent m.Content = jsonContent
@@ -244,44 +285,64 @@ func (m *Message) ParseContent() []MediaContent {
case ContentTypeImageURL: case ContentTypeImageURL:
imageUrl := contentItem["image_url"] imageUrl := contentItem["image_url"]
temp := &MessageImageUrl{
Detail: "high",
}
switch v := imageUrl.(type) { switch v := imageUrl.(type) {
case string: case string:
contentList = append(contentList, MediaContent{ temp.Url = v
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: v,
Detail: "high",
},
})
case map[string]interface{}: case map[string]interface{}:
url, ok1 := v["url"].(string) url, ok1 := v["url"].(string)
detail, ok2 := v["detail"].(string) detail, ok2 := v["detail"].(string)
if !ok2 { if ok2 {
detail = "high" temp.Detail = detail
} }
if ok1 { if ok1 {
temp.Url = url
}
}
contentList = append(contentList, MediaContent{ contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL, Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{ ImageUrl: temp,
Url: url,
Detail: detail,
},
}) })
}
}
case ContentTypeInputAudio: case ContentTypeInputAudio:
if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok { if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
data, ok1 := audioData["data"].(string) data, ok1 := audioData["data"].(string)
format, ok2 := audioData["format"].(string) format, ok2 := audioData["format"].(string)
if ok1 && ok2 { if ok1 && ok2 {
contentList = append(contentList, MediaContent{ temp := &MessageInputAudio{
Type: ContentTypeInputAudio,
InputAudio: MessageInputAudio{
Data: data, Data: data,
Format: format, Format: format,
}
contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio,
InputAudio: temp,
})
}
}
case ContentTypeFile:
if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
fileId, ok3 := fileData["file_id"].(string)
if ok3 {
contentList = append(contentList, MediaContent{
Type: ContentTypeFile,
File: &MessageFile{
FileId: fileId,
}, },
}) })
} else {
fileName, ok1 := fileData["filename"].(string)
fileDataStr, ok2 := fileData["file_data"].(string)
if ok1 && ok2 {
contentList = append(contentList, MediaContent{
Type: ContentTypeFile,
File: &MessageFile{
FileName: fileName,
FileData: fileDataStr,
},
})
}
} }
} }
} }

View File

@@ -1,20 +1,8 @@
package dto package dto
type TextResponseWithError struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type SimpleResponse struct { type SimpleResponse struct {
Usage `json:"usage"` Usage `json:"usage"`
Error OpenAIError `json:"error"` Error *OpenAIError `json:"error"`
Choices []OpenAITextResponseChoice `json:"choices"`
} }
type TextResponse struct { type TextResponse struct {
@@ -38,6 +26,7 @@ type OpenAITextResponse struct {
Object string `json:"object"` Object string `json:"object"`
Created int64 `json:"created"` Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"` Choices []OpenAITextResponseChoice `json:"choices"`
Error *OpenAIError `json:"error,omitempty"`
Usage `json:"usage"` Usage `json:"usage"`
} }
@@ -125,6 +114,20 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"` Usage *Usage `json:"usage"`
} }
func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
if len(c.Choices) == 0 {
return false
}
return len(c.Choices[0].Delta.ToolCalls) > 0
}
func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse {
if c.IsToolCall() {
return &c.Choices[0].Delta.ToolCalls[0]
}
return nil
}
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse { func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices)) choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
copy(choices, c.Choices) copy(choices, c.Choices)
@@ -170,3 +173,17 @@ type Usage struct {
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"` PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"` CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
} }
type InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
CachedCreationTokens int `json:"-"`
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ImageTokens int `json:"image_tokens"`
}
type OutputTokenDetails struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ReasoningTokens int `json:"reasoning_tokens"`
}

View File

@@ -43,18 +43,6 @@ type RealtimeUsage struct {
OutputTokenDetails OutputTokenDetails `json:"output_token_details"` OutputTokenDetails OutputTokenDetails `json:"output_token_details"`
} }
type InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ImageTokens int `json:"image_tokens"`
}
type OutputTokenDetails struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
}
type RealtimeSession struct { type RealtimeSession struct {
Modalities []string `json:"modalities"` Modalities []string `json:"modalities"`
Instructions string `json:"instructions"` Instructions string `json:"instructions"`

View File

@@ -5,18 +5,29 @@ type RerankRequest struct {
Query string `json:"query"` Query string `json:"query"`
Model string `json:"model"` Model string `json:"model"`
TopN int `json:"top_n"` TopN int `json:"top_n"`
ReturnDocuments bool `json:"return_documents,omitempty"` ReturnDocuments *bool `json:"return_documents,omitempty"`
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"` MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
OverLapTokens int `json:"overlap_tokens,omitempty"` OverLapTokens int `json:"overlap_tokens,omitempty"`
} }
type RerankResponseDocument struct { func (r *RerankRequest) GetReturnDocuments() bool {
if r.ReturnDocuments == nil {
return false
}
return *r.ReturnDocuments
}
type RerankResponseResult struct {
Document any `json:"document,omitempty"` Document any `json:"document,omitempty"`
Index int `json:"index"` Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"` RelevanceScore float64 `json:"relevance_score"`
} }
type RerankDocument struct {
Text any `json:"text"`
}
type RerankResponse struct { type RerankResponse struct {
Results []RerankResponseDocument `json:"results"` Results []RerankResponseResult `json:"results"`
Usage Usage `json:"usage"` Usage Usage `json:"usage"`
} }

12
go.mod
View File

@@ -11,6 +11,7 @@ require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
github.com/bytedance/sonic v1.11.6
github.com/gin-contrib/cors v1.7.2 github.com/gin-contrib/cors v1.7.2
github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5 github.com/gin-contrib/sessions v0.0.5
@@ -28,9 +29,9 @@ require (
github.com/samber/lo v1.39.0 github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.4.0 github.com/shopspring/decimal v1.4.0
golang.org/x/crypto v0.27.0 golang.org/x/crypto v0.35.0
golang.org/x/image v0.23.0 golang.org/x/image v0.23.0
golang.org/x/net v0.28.0 golang.org/x/net v0.35.0
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2 gorm.io/driver/postgres v1.5.2
gorm.io/gorm v1.25.2 gorm.io/gorm v1.25.2
@@ -42,7 +43,6 @@ require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/base64x v0.1.4 // indirect
@@ -84,9 +84,9 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/sync v0.10.0 // indirect golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.27.0 // indirect golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/text v0.22.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.22.5 // indirect modernc.org/libc v1.22.5 // indirect

20
go.sum
View File

@@ -217,18 +217,18 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg= golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg=
golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68= golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY= golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -239,14 +239,14 @@ golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=

View File

@@ -12,6 +12,7 @@ import (
"one-api/model" "one-api/model"
"one-api/router" "one-api/router"
"one-api/service" "one-api/service"
"one-api/setting/operation_setting"
"os" "os"
"strconv" "strconv"
@@ -33,7 +34,7 @@ var indexPage []byte
func main() { func main() {
err := godotenv.Load(".env") err := godotenv.Load(".env")
if err != nil { if err != nil {
common.SysLog("Support for .env file is disabled") common.SysLog("Support for .env file is disabled: " + err.Error())
} }
common.LoadEnv() common.LoadEnv()
@@ -51,6 +52,9 @@ func main() {
if err != nil { if err != nil {
common.FatalLog("failed to initialize database: " + err.Error()) common.FatalLog("failed to initialize database: " + err.Error())
} }
model.CheckSetup()
// Initialize SQL Database // Initialize SQL Database
err = model.InitLogDB() err = model.InitLogDB()
if err != nil { if err != nil {
@@ -69,10 +73,13 @@ func main() {
common.FatalLog("failed to initialize Redis: " + err.Error()) common.FatalLog("failed to initialize Redis: " + err.Error())
} }
// Initialize model settings
operation_setting.InitModelSettings()
// Initialize constants // Initialize constants
constant.InitEnv() constant.InitEnv()
// Initialize options // Initialize options
model.InitOptionMap() model.InitOptionMap()
if common.RedisEnabled { if common.RedisEnabled {
// for compatibility with old versions // for compatibility with old versions
common.MemoryCacheEnabled = true common.MemoryCacheEnabled = true

View File

@@ -174,6 +174,14 @@ func TokenAuth() func(c *gin.Context) {
} }
c.Request.Header.Set("Authorization", "Bearer "+key) c.Request.Header.Set("Authorization", "Bearer "+key)
} }
// 检查path包含/v1/messages
if strings.Contains(c.Request.URL.Path, "/v1/messages") {
// 从x-api-key中获取key
key := c.Request.Header.Get("x-api-key")
if key != "" {
c.Request.Header.Set("Authorization", "Bearer "+key)
}
}
key := c.Request.Header.Get("Authorization") key := c.Request.Header.Get("Authorization")
parts := make([]string, 0) parts := make([]string, 0)
key = strings.TrimPrefix(key, "Bearer ") key = strings.TrimPrefix(key, "Bearer ")

View File

@@ -212,6 +212,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("channel_name", channel.Name) c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type) c.Set("channel_type", channel.Type)
c.Set("channel_setting", channel.GetSetting()) c.Set("channel_setting", channel.GetSetting())
c.Set("param_override", channel.GetParamOverride())
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization) c.Set("channel_organization", *channel.OpenAIOrganization)
} }

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/limiter"
"one-api/setting" "one-api/setting"
"strconv" "strconv"
"time" "time"
@@ -78,21 +79,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
ctx := context.Background() ctx := context.Background()
rdb := common.RDB rdb := common.RDB
// 1. 检查请求数限制当totalMaxCount为0时会自动跳过 // 1. 检查成功请求数限制
totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
// 2. 检查成功请求数限制
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
if err != nil { if err != nil {
fmt.Println("检查成功请求数限制失败:", err.Error()) fmt.Println("检查成功请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
@@ -103,8 +92,27 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
return return
} }
// 3. 记录总请求当totalMaxCount为0时会自动跳过 //2.检查总请求数限制并记录总请求当totalMaxCount为0时会自动跳过,使用令牌桶限流器
recordRedisRequest(ctx, rdb, totalKey, totalMaxCount) totalKey := fmt.Sprintf("rateLimit:%s", userId)
// 初始化
tb := limiter.New(ctx, rdb)
allowed, err = tb.Allow(
ctx,
totalKey,
limiter.WithCapacity(int64(totalMaxCount)*duration),
limiter.WithRate(int64(totalMaxCount)),
limiter.WithRequested(duration),
)
if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
// 4. 处理请求 // 4. 处理请求
c.Next() c.Next()

View File

@@ -35,7 +35,8 @@ type Channel struct {
AutoBan *int `json:"auto_ban" gorm:"default:1"` AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"` OtherInfo string `json:"other_info"`
Tag *string `json:"tag" gorm:"index"` Tag *string `json:"tag" gorm:"index"`
Setting string `json:"setting" gorm:"type:text"` Setting *string `json:"setting" gorm:"type:text"`
ParamOverride *string `json:"param_override" gorm:"type:text"`
} }
func (channel *Channel) GetModels() []string { func (channel *Channel) GetModels() []string {
@@ -493,8 +494,8 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
func (channel *Channel) GetSetting() map[string]interface{} { func (channel *Channel) GetSetting() map[string]interface{} {
setting := make(map[string]interface{}) setting := make(map[string]interface{})
if channel.Setting != "" { if channel.Setting != nil && *channel.Setting != "" {
err := json.Unmarshal([]byte(channel.Setting), &setting) err := json.Unmarshal([]byte(*channel.Setting), &setting)
if err != nil { if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error()) common.SysError("failed to unmarshal setting: " + err.Error())
} }
@@ -508,7 +509,18 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) {
common.SysError("failed to marshal setting: " + err.Error()) common.SysError("failed to marshal setting: " + err.Error())
return return
} }
channel.Setting = string(settingBytes) channel.Setting = common.GetPointer[string](string(settingBytes))
}
func (channel *Channel) GetParamOverride() map[string]interface{} {
paramOverride := make(map[string]interface{})
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
err := json.Unmarshal([]byte(*channel.ParamOverride), &paramOverride)
if err != nil {
common.SysError("failed to unmarshal param override: " + err.Error())
}
}
return paramOverride
} }
func GetChannelsByIds(ids []int) ([]*Channel, error) { func GetChannelsByIds(ids []int) ([]*Channel, error) {

View File

@@ -1,16 +1,18 @@
package model package model
import ( import (
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"log" "log"
"one-api/common" "one-api/common"
"one-api/constant"
"os" "os"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
) )
var groupCol string var groupCol string
@@ -54,13 +56,40 @@ func createRootAccountIfNeed() error {
return nil return nil
} }
func CheckSetup() {
setup := GetSetup()
if setup == nil {
// No setup record exists, check if we have a root user
if RootUserExists() {
common.SysLog("system is not initialized, but root user exists")
// Create setup record
newSetup := Setup{
Version: common.Version,
InitializedAt: time.Now().Unix(),
}
err := DB.Create(&newSetup).Error
if err != nil {
common.SysLog("failed to create setup record: " + err.Error())
}
constant.Setup = true
} else {
common.SysLog("system is not initialized and no root user exists")
constant.Setup = false
}
} else {
// Setup record exists, system is initialized
common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
constant.Setup = true
}
}
func chooseDB(envName string) (*gorm.DB, error) { func chooseDB(envName string) (*gorm.DB, error) {
defer func() { defer func() {
initCol() initCol()
}() }()
dsn := os.Getenv(envName) dsn := os.Getenv(envName)
if dsn != "" { if dsn != "" {
if strings.HasPrefix(dsn, "postgres://") { if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
// Use PostgreSQL // Use PostgreSQL
common.SysLog("using PostgreSQL as database") common.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true common.UsingPostgreSQL = true
@@ -213,8 +242,9 @@ func migrateDB() error {
if err != nil { if err != nil {
return err return err
} }
err = DB.AutoMigrate(&Setup{})
common.SysLog("database migrated") common.SysLog("database migrated")
err = createRootAccountIfNeed() //err = createRootAccountIfNeed()
return err return err
} }

16
model/setup.go Normal file
View File

@@ -0,0 +1,16 @@
package model
type Setup struct {
ID uint `json:"id" gorm:"primaryKey"`
Version string `json:"version" gorm:"type:varchar(50);not null"`
InitializedAt int64 `json:"initialized_at" gorm:"type:bigint;not null"`
}
func GetSetup() *Setup {
var setup Setup
err := DB.First(&setup).Error
if err != nil {
return nil
}
return &setup
}

View File

@@ -9,7 +9,6 @@ import (
"strings" "strings"
"github.com/bytedance/gopkg/util/gopool" "github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -24,6 +23,7 @@ type User struct {
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"` Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"` GitHubId string `json:"github_id" gorm:"column:github_id;index"`
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"` TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
@@ -108,7 +108,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) {
func GetMaxUserId() int { func GetMaxUserId() int {
var user User var user User
DB.Last(&user) DB.Unscoped().Last(&user)
return user.Id return user.Id
} }
@@ -442,6 +442,14 @@ func (user *User) FillUserByGitHubId() error {
return nil return nil
} }
func (user *User) FillUserByOidcId() error {
if user.OidcId == "" {
return errors.New("oidc id 为空!")
}
DB.Where(User{OidcId: user.OidcId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error { func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" { if user.WeChatId == "" {
return errors.New("WeChat id 为空!") return errors.New("WeChat id 为空!")
@@ -473,6 +481,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
} }
func IsOidcIdAlreadyTaken(oidcId string) bool {
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
}
func IsTelegramIdAlreadyTaken(telegramId string) bool { func IsTelegramIdAlreadyTaken(telegramId string) bool {
return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1 return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
} }
@@ -796,3 +808,12 @@ func (user *User) FillUserByLinuxDOId() error {
err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
return err return err
} }
func RootUserExists() bool {
var user User
err := DB.Where("role = ?", common.RoleRootUser).First(&user).Error
if err != nil {
return false
}
return true
}

View File

@@ -13,7 +13,7 @@ type Adaptor interface {
Init(info *relaycommon.RelayInfo) Init(info *relaycommon.RelayInfo)
GetRequestURL(info *relaycommon.RelayInfo) (string, error) GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
@@ -22,6 +22,7 @@ type Adaptor interface {
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string GetModelList() []string
GetChannelName() string GetChannelName() string
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
} }
type TaskAdaptor interface { type TaskAdaptor interface {

View File

@@ -16,6 +16,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
@@ -44,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -87,7 +93,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info)
} }
} }
return return

View File

@@ -26,8 +26,8 @@ func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
return &imageRequest return &imageRequest
} }
func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) { func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
url := fmt.Sprintf("/api/v1/tasks/%s", taskID) url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID)
var aliResponse AliResponse var aliResponse AliResponse
@@ -36,7 +36,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliRes
return &aliResponse, err, nil return &aliResponse, err, nil
} }
req.Header.Set("Authorization", "Bearer "+key) req.Header.Set("Authorization", "Bearer "+info.ApiKey)
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req)
@@ -58,7 +58,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliRes
return &response, nil, responseBody return &response, nil, responseBody
} }
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, []byte, error) { func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
waitSeconds := 3 waitSeconds := 3
step := 0 step := 0
maxStep := 20 maxStep := 20
@@ -68,7 +68,7 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*Ali
for { for {
step++ step++
rsp, err, body := updateTask(info, taskID, key) rsp, err, body := updateTask(info, taskID)
responseBody = body responseBody = body
if err != nil { if err != nil {
return &taskResponse, responseBody, err return &taskResponse, responseBody, err
@@ -125,8 +125,6 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
} }
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
responseFormat := c.GetString("response_format") responseFormat := c.GetString("response_format")
var aliTaskResponse AliResponse var aliTaskResponse AliResponse
@@ -148,7 +146,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
} }
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId, apiKey) aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
} }

View File

@@ -7,6 +7,7 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"io" "io"
"net/http" "net/http"
common2 "one-api/common"
"one-api/relay/common" "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
"one-api/service" "one-api/service"
@@ -31,6 +32,9 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if err != nil { if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err) return nil, fmt.Errorf("get request url failed: %w", err)
} }
if common2.DebugEnabled {
println("fullRequestURL:", fullRequestURL)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil { if err != nil {
return nil, fmt.Errorf("new request failed: %w", err) return nil, fmt.Errorf("new request failed: %w", err)

View File

@@ -20,6 +20,12 @@ type Adaptor struct {
RequestMode int RequestMode int
} }
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
c.Set("request_model", request.Model)
c.Set("converted_request", request)
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -43,12 +49,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
var claudeReq *claude.ClaudeRequest var claudeReq *dto.ClaudeRequest
var err error var err error
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request) claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
if err != nil { if err != nil {

View File

@@ -13,4 +13,41 @@ var awsModelIDMap = map[string]string{
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0", "claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
} }
var awsModelCanCrossRegionMap = map[string]map[string]bool{
"anthropic.claude-3-sonnet-20240229-v1:0": {
"us": true,
"eu": true,
"ap": true,
},
"anthropic.claude-3-opus-20240229-v1:0": {
"us": true,
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"us": true,
"eu": true,
"ap": true,
},
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
"us": true,
"eu": true,
"ap": true,
},
"anthropic.claude-3-5-sonnet-20241022-v2:0": {
"us": true,
"ap": true,
},
"anthropic.claude-3-5-haiku-20241022-v1:0": {
"us": true,
},
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
"us": true,
},
}
var awsRegionCrossModelPrefixMap = map[string]string{
"us": "us",
"eu": "eu",
"ap": "apac",
}
var ChannelName = "aws" var ChannelName = "aws"

View File

@@ -1,14 +1,14 @@
package aws package aws
import ( import (
"one-api/relay/channel/claude" "one-api/dto"
) )
type AwsClaudeRequest struct { type AwsClaudeRequest struct {
// AnthropicVersion should be "bedrock-2023-05-31" // AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"` AnthropicVersion string `json:"anthropic_version"`
System string `json:"system,omitempty"` System any `json:"system,omitempty"`
Messages []claude.ClaudeMessage `json:"messages"` Messages []dto.ClaudeMessage `json:"messages"`
MaxTokens uint `json:"max_tokens,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
@@ -16,10 +16,10 @@ type AwsClaudeRequest struct {
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Tools any `json:"tools,omitempty"` Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
Thinking *claude.Thinking `json:"thinking,omitempty"` Thinking *dto.Thinking `json:"thinking,omitempty"`
} }
func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest { func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
return &AwsClaudeRequest{ return &AwsClaudeRequest{
AnthropicVersion: "bedrock-2023-05-31", AnthropicVersion: "bedrock-2023-05-31",
System: req.System, System: req.System,

View File

@@ -1,21 +1,16 @@
package aws package aws
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors" "github.com/pkg/errors"
"io"
"net/http" "net/http"
"one-api/common" "one-api/common"
relaymodel "one-api/dto" "one-api/dto"
"one-api/relay/channel/claude" "one-api/relay/channel/claude"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings" "strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials"
@@ -39,15 +34,37 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.
return client, nil return client, nil
} }
func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode { func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
return &relaymodel.OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Error: relaymodel.OpenAIError{ Error: dto.OpenAIError{
Message: fmt.Sprintf("%s", err.Error()), Message: fmt.Sprintf("%s", err.Error()),
}, },
} }
} }
func awsRegionPrefix(awsRegionId string) string {
parts := strings.Split(awsRegionId, "-")
regionPrefix := ""
if len(parts) > 0 {
regionPrefix = parts[0]
}
return regionPrefix
}
func awsModelCanCrossRegion(awsModelId, awsRegionPrefix string) bool {
regionSet, exists := awsModelCanCrossRegionMap[awsModelId]
return exists && regionSet[awsRegionPrefix]
}
func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
modelPrefix, find := awsRegionCrossModelPrefixMap[awsRegionPrefix]
if !find {
return awsModelId
}
return modelPrefix + "." + awsModelId
}
func awsModelID(requestModel string) (string, error) { func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := awsModelIDMap[requestModel]; ok { if awsModelID, ok := awsModelIDMap[requestModel]; ok {
return awsModelID, nil return awsModelID, nil
@@ -56,7 +73,7 @@ func awsModelID(requestModel string) (string, error) {
return requestModel, nil return requestModel, nil
} }
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
awsCli, err := newAwsClient(c, info) awsCli, err := newAwsClient(c, info)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil return wrapErr(errors.Wrap(err, "newAwsClient")), nil
@@ -67,6 +84,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
return wrapErr(errors.Wrap(err, "awsModelID")), nil return wrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
awsReq := &bedrockruntime.InvokeModelInput{ awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId), ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"), Accept: aws.String("application/json"),
@@ -77,7 +100,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
if !ok { if !ok {
return wrapErr(errors.New("request not found")), nil return wrapErr(errors.New("request not found")), nil
} }
claudeReq := claudeReq_.(*claude.ClaudeRequest) claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq) awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = json.Marshal(awsClaudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil { if err != nil {
@@ -89,25 +112,19 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
return wrapErr(errors.Wrap(err, "InvokeModel")), nil return wrapErr(errors.Wrap(err, "InvokeModel")), nil
} }
claudeResponse := new(claude.ClaudeResponse) claudeInfo := &claude.ClaudeResponseInfo{
err = json.Unmarshal(awsResp.Body, claudeResponse) ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
if err != nil { Created: common.GetTimestamp(),
return wrapErr(errors.Wrap(err, "unmarshal response")), nil Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
} }
openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse) claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
usage := relaymodel.Usage{ return nil, claudeInfo.Usage
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
} }
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
awsCli, err := newAwsClient(c, info) awsCli, err := newAwsClient(c, info)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil return wrapErr(errors.Wrap(err, "newAwsClient")), nil
@@ -118,6 +135,12 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return wrapErr(errors.Wrap(err, "awsModelID")), nil return wrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId), ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"), Accept: aws.String("application/json"),
@@ -128,7 +151,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if !ok { if !ok {
return wrapErr(errors.New("request not found")), nil return wrapErr(errors.New("request not found")), nil
} }
claudeReq := claudeReq_.(*claude.ClaudeRequest) claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq) awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = json.Marshal(awsClaudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq)
@@ -143,79 +166,31 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
stream := awsResp.GetStream() stream := awsResp.GetStream()
defer stream.Close() defer stream.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream") claudeInfo := &claude.ClaudeResponseInfo{
var usage relaymodel.Usage ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
var id string Created: common.GetTimestamp(),
var model string Model: info.UpstreamModelName,
isFirst := true ResponseText: strings.Builder{},
createdTime := common.GetTimestamp() Usage: &dto.Usage{},
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
return false
} }
for event := range stream.Events() {
switch v := event.(type) { switch v := event.(type) {
case *types.ResponseStreamMemberChunk: case *types.ResponseStreamMemberChunk:
if isFirst { info.SetFirstResponseTime()
isFirst = false respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
info.FirstResponseTime = time.Now() if respErr != nil {
return respErr, nil
} }
claudeResp := new(claude.ClaudeResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return false
}
response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
if claudeUsage != nil {
usage.PromptTokens += claudeUsage.InputTokens
usage.CompletionTokens += claudeUsage.OutputTokens
}
if response == nil {
return true
}
if response.Id != "" {
id = response.Id
}
if response.Model != "" {
model = response.Model
}
response.Created = createdTime
response.Id = id
response.Model = model
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case *types.UnknownUnionMember: case *types.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag) fmt.Println("unknown tag:", v.Tag)
return false return wrapErr(errors.New("unknown response type")), nil
default: default:
fmt.Println("union is nil or unknown type") fmt.Println("union is nil or unknown type")
return false return wrapErr(errors.New("nil or unknown response type")), nil
}
})
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
} }
} }
helper.Done(c)
if resp != nil { claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
err = resp.Body.Close() return nil, claudeInfo.Usage
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
}
return nil, &usage
} }

View File

@@ -16,6 +16,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -104,7 +110,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -15,6 +15,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -62,7 +68,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info)
} }
return return
} }

View File

@@ -22,6 +22,10 @@ type Adaptor struct {
RequestMode int RequestMode int
} }
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -60,7 +64,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -1,94 +1,95 @@
package claude package claude
type ClaudeMetadata struct { //
UserId string `json:"user_id"` //type ClaudeMetadata struct {
} // UserId string `json:"user_id"`
//}
type ClaudeMediaMessage struct { //
Type string `json:"type"` //type ClaudeMediaMessage struct {
Text string `json:"text,omitempty"` // Type string `json:"type"`
Source *ClaudeMessageSource `json:"source,omitempty"` // Text string `json:"text,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"` // Source *ClaudeMessageSource `json:"source,omitempty"`
StopReason *string `json:"stop_reason,omitempty"` // Usage *ClaudeUsage `json:"usage,omitempty"`
PartialJson string `json:"partial_json,omitempty"` // StopReason *string `json:"stop_reason,omitempty"`
Thinking string `json:"thinking,omitempty"` // PartialJson string `json:"partial_json,omitempty"`
Signature string `json:"signature,omitempty"` // Thinking string `json:"thinking,omitempty"`
Delta string `json:"delta,omitempty"` // Signature string `json:"signature,omitempty"`
// tool_calls // Delta string `json:"delta,omitempty"`
Id string `json:"id,omitempty"` // // tool_calls
Name string `json:"name,omitempty"` // Id string `json:"id,omitempty"`
Input any `json:"input,omitempty"` // Name string `json:"name,omitempty"`
Content string `json:"content,omitempty"` // Input any `json:"input,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"` // Content string `json:"content,omitempty"`
} // ToolUseId string `json:"tool_use_id,omitempty"`
//}
type ClaudeMessageSource struct { //
Type string `json:"type"` //type ClaudeMessageSource struct {
MediaType string `json:"media_type"` // Type string `json:"type"`
Data string `json:"data"` // MediaType string `json:"media_type"`
} // Data string `json:"data"`
//}
type ClaudeMessage struct { //
Role string `json:"role"` //type ClaudeMessage struct {
Content any `json:"content"` // Role string `json:"role"`
} // Content any `json:"content"`
//}
type Tool struct { //
Name string `json:"name"` //type Tool struct {
Description string `json:"description,omitempty"` // Name string `json:"name"`
InputSchema map[string]interface{} `json:"input_schema"` // Description string `json:"description,omitempty"`
} // InputSchema map[string]interface{} `json:"input_schema"`
//}
type InputSchema struct { //
Type string `json:"type"` //type InputSchema struct {
Properties any `json:"properties,omitempty"` // Type string `json:"type"`
Required any `json:"required,omitempty"` // Properties any `json:"properties,omitempty"`
} // Required any `json:"required,omitempty"`
//}
type ClaudeRequest struct { //
Model string `json:"model"` //type ClaudeRequest struct {
Prompt string `json:"prompt,omitempty"` // Model string `json:"model"`
System string `json:"system,omitempty"` // Prompt string `json:"prompt,omitempty"`
Messages []ClaudeMessage `json:"messages,omitempty"` // System string `json:"system,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"` // Messages []ClaudeMessage `json:"messages,omitempty"`
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"` // MaxTokens uint `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` // MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` // StopSequences []string `json:"stop_sequences,omitempty"`
TopP float64 `json:"top_p,omitempty"` // Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"` // TopP float64 `json:"top_p,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"` // TopK int `json:"top_k,omitempty"`
Stream bool `json:"stream,omitempty"` // //ClaudeMetadata `json:"metadata,omitempty"`
Tools any `json:"tools,omitempty"` // Stream bool `json:"stream,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` // Tools any `json:"tools,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"` // ToolChoice any `json:"tool_choice,omitempty"`
} // Thinking *Thinking `json:"thinking,omitempty"`
//}
type Thinking struct { //
Type string `json:"type"` //type Thinking struct {
BudgetTokens int `json:"budget_tokens"` // Type string `json:"type"`
} // BudgetTokens int `json:"budget_tokens"`
//}
type ClaudeError struct { //
Type string `json:"type"` //type ClaudeError struct {
Message string `json:"message"` // Type string `json:"type"`
} // Message string `json:"message"`
//}
type ClaudeResponse struct { //
Id string `json:"id"` //type ClaudeResponse struct {
Type string `json:"type"` // Id string `json:"id"`
Content []ClaudeMediaMessage `json:"content"` // Type string `json:"type"`
Completion string `json:"completion"` // Content []ClaudeMediaMessage `json:"content"`
StopReason string `json:"stop_reason"` // Completion string `json:"completion"`
Model string `json:"model"` // StopReason string `json:"stop_reason"`
Error ClaudeError `json:"error"` // Model string `json:"model"`
Usage ClaudeUsage `json:"usage"` // Error ClaudeError `json:"error"`
Index int `json:"index"` // stream only // Usage ClaudeUsage `json:"usage"`
ContentBlock *ClaudeMediaMessage `json:"content_block"` // Index int `json:"index"` // stream only
Delta *ClaudeMediaMessage `json:"delta"` // stream only // ContentBlock *ClaudeMediaMessage `json:"content_block"`
Message *ClaudeResponse `json:"message"` // stream only: message_start // Delta *ClaudeMediaMessage `json:"delta"` // stream only
} // Message *ClaudeResponse `json:"message"` // stream only: message_start
//}
type ClaudeUsage struct { //
InputTokens int `json:"input_tokens"` //type ClaudeUsage struct {
OutputTokens int `json:"output_tokens"` // InputTokens int `json:"input_tokens"`
} // OutputTokens int `json:"output_tokens"`
//}

View File

@@ -24,14 +24,16 @@ func stopReasonClaude2OpenAI(reason string) string {
return "stop" return "stop"
case "max_tokens": case "max_tokens":
return "max_tokens" return "max_tokens"
case "tool_use":
return "tool_calls"
default: default:
return reason return reason
} }
} }
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
claudeRequest := ClaudeRequest{ claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model, Model: textRequest.Model,
Prompt: "", Prompt: "",
StopSequences: nil, StopSequences: nil,
@@ -60,17 +62,19 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
return &claudeRequest return &claudeRequest
} }
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) { func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
claudeTools := make([]Tool, 0, len(textRequest.Tools)) claudeTools := make([]dto.Tool, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools { for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok { if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTool := Tool{ claudeTool := dto.Tool{
Name: tool.Function.Name, Name: tool.Function.Name,
Description: tool.Function.Description, Description: tool.Function.Description,
} }
claudeTool.InputSchema = make(map[string]interface{}) claudeTool.InputSchema = make(map[string]interface{})
if params["type"] != nil {
claudeTool.InputSchema["type"] = params["type"].(string) claudeTool.InputSchema["type"] = params["type"].(string)
}
claudeTool.InputSchema["properties"] = params["properties"] claudeTool.InputSchema["properties"] = params["properties"]
claudeTool.InputSchema["required"] = params["required"] claudeTool.InputSchema["required"] = params["required"]
for s, a := range params { for s, a := range params {
@@ -83,7 +87,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
} }
} }
claudeRequest := ClaudeRequest{ claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model, Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens, MaxTokens: textRequest.MaxTokens,
StopSequences: nil, StopSequences: nil,
@@ -107,7 +111,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
} }
// BudgetTokens 为 max_tokens 的 80% // BudgetTokens 为 max_tokens 的 80%
claudeRequest.Thinking = &Thinking{ claudeRequest.Thinking = &dto.Thinking{
Type: "enabled", Type: "enabled",
BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage), BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
} }
@@ -165,7 +169,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
lastMessage = fmtMessage lastMessage = fmtMessage
} }
claudeMessages := make([]ClaudeMessage, 0) claudeMessages := make([]dto.ClaudeMessage, 0)
isFirstMessage := true isFirstMessage := true
for _, message := range formatMessages { for _, message := range formatMessages {
if message.Role == "system" { if message.Role == "system" {
@@ -186,63 +190,63 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
isFirstMessage = false isFirstMessage = false
if message.Role != "user" { if message.Role != "user" {
// fix: first message is assistant, add user message // fix: first message is assistant, add user message
claudeMessage := ClaudeMessage{ claudeMessage := dto.ClaudeMessage{
Role: "user", Role: "user",
Content: []ClaudeMediaMessage{ Content: []dto.ClaudeMediaMessage{
{ {
Type: "text", Type: "text",
Text: "...", Text: common.GetPointer[string]("..."),
}, },
}, },
} }
claudeMessages = append(claudeMessages, claudeMessage) claudeMessages = append(claudeMessages, claudeMessage)
} }
} }
claudeMessage := ClaudeMessage{ claudeMessage := dto.ClaudeMessage{
Role: message.Role, Role: message.Role,
} }
if message.Role == "tool" { if message.Role == "tool" {
if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" { if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
lastMessage := claudeMessages[len(claudeMessages)-1] lastMessage := claudeMessages[len(claudeMessages)-1]
if content, ok := lastMessage.Content.(string); ok { if content, ok := lastMessage.Content.(string); ok {
lastMessage.Content = []ClaudeMediaMessage{ lastMessage.Content = []dto.ClaudeMediaMessage{
{ {
Type: "text", Type: "text",
Text: content, Text: common.GetPointer[string](content),
}, },
} }
} }
lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{ lastMessage.Content = append(lastMessage.Content.([]dto.ClaudeMediaMessage), dto.ClaudeMediaMessage{
Type: "tool_result", Type: "tool_result",
ToolUseId: message.ToolCallId, ToolUseId: message.ToolCallId,
Content: message.StringContent(), Content: message.Content,
}) })
claudeMessages[len(claudeMessages)-1] = lastMessage claudeMessages[len(claudeMessages)-1] = lastMessage
continue continue
} else { } else {
claudeMessage.Role = "user" claudeMessage.Role = "user"
claudeMessage.Content = []ClaudeMediaMessage{ claudeMessage.Content = []dto.ClaudeMediaMessage{
{ {
Type: "tool_result", Type: "tool_result",
ToolUseId: message.ToolCallId, ToolUseId: message.ToolCallId,
Content: message.StringContent(), Content: message.Content,
}, },
} }
} }
} else if message.IsStringContent() && message.ToolCalls == nil { } else if message.IsStringContent() && message.ToolCalls == nil {
claudeMessage.Content = message.StringContent() claudeMessage.Content = message.StringContent()
} else { } else {
claudeMediaMessages := make([]ClaudeMediaMessage, 0) claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
for _, mediaMessage := range message.ParseContent() { for _, mediaMessage := range message.ParseContent() {
claudeMediaMessage := ClaudeMediaMessage{ claudeMediaMessage := dto.ClaudeMediaMessage{
Type: mediaMessage.Type, Type: mediaMessage.Type,
} }
if mediaMessage.Type == "text" { if mediaMessage.Type == "text" {
claudeMediaMessage.Text = mediaMessage.Text claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
} else { } else {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) imageUrl := mediaMessage.GetImageMedia()
claudeMediaMessage.Type = "image" claudeMediaMessage.Type = "image"
claudeMediaMessage.Source = &ClaudeMessageSource{ claudeMediaMessage.Source = &dto.ClaudeMessageSource{
Type: "base64", Type: "base64",
} }
// 判断是否是url // 判断是否是url
@@ -272,7 +276,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
continue continue
} }
claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{ claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
Type: "tool_use", Type: "tool_use",
Id: toolCall.ID, Id: toolCall.ID,
Name: toolCall.Function.Name, Name: toolCall.Function.Name,
@@ -290,13 +294,19 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
return &claudeRequest, nil return &claudeRequest, nil
} }
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) { func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
var response dto.ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
var claudeUsage *ClaudeUsage
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
tools := make([]dto.ToolCallResponse, 0) tools := make([]dto.ToolCallResponse, 0)
fcIdx := 0
if claudeResponse.Index != nil {
fcIdx = *claudeResponse.Index - 1
if fcIdx < 0 {
fcIdx = 0
}
}
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion { if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion) choice.Delta.SetContentString(claudeResponse.Completion)
@@ -308,7 +318,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.Type == "message_start" { if claudeResponse.Type == "message_start" {
response.Id = claudeResponse.Message.Id response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model response.Model = claudeResponse.Message.Model
claudeUsage = &claudeResponse.Message.Usage //claudeUsage = &claudeResponse.Message.Usage
choice.Delta.SetContentString("") choice.Delta.SetContentString("")
choice.Delta.Role = "assistant" choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" { } else if claudeResponse.Type == "content_block_start" {
@@ -316,6 +326,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text) //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
if claudeResponse.ContentBlock.Type == "tool_use" { if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCallResponse{ tools = append(tools, dto.ToolCallResponse{
Index: common.GetPointer(fcIdx),
ID: claudeResponse.ContentBlock.Id, ID: claudeResponse.ContentBlock.Id,
Type: "function", Type: "function",
Function: dto.FunctionResponse{ Function: dto.FunctionResponse{
@@ -325,17 +336,18 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
}) })
} }
} else { } else {
return nil, nil return nil
} }
} else if claudeResponse.Type == "content_block_delta" { } else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil { if claudeResponse.Delta != nil {
choice.Index = claudeResponse.Index choice.Delta.Content = claudeResponse.Delta.Text
choice.Delta.SetContentString(claudeResponse.Delta.Text)
switch claudeResponse.Delta.Type { switch claudeResponse.Delta.Type {
case "input_json_delta": case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{ tools = append(tools, dto.ToolCallResponse{
Type: "function",
Index: common.GetPointer(fcIdx),
Function: dto.FunctionResponse{ Function: dto.FunctionResponse{
Arguments: claudeResponse.Delta.PartialJson, Arguments: *claudeResponse.Delta.PartialJson,
}, },
}) })
case "signature_delta": case "signature_delta":
@@ -352,26 +364,23 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if finishReason != "null" { if finishReason != "null" {
choice.FinishReason = &finishReason choice.FinishReason = &finishReason
} }
claudeUsage = &claudeResponse.Usage //claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" { } else if claudeResponse.Type == "message_stop" {
return nil, nil return nil
} else { } else {
return nil, nil return nil
} }
} }
if claudeUsage == nil {
claudeUsage = &ClaudeUsage{}
}
if len(tools) > 0 { if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools choice.Delta.ToolCalls = tools
} }
response.Choices = append(response.Choices, choice) response.Choices = append(response.Choices, choice)
return &response, claudeUsage return &response
} }
func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse { func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
choices := make([]dto.OpenAITextResponseChoice, 0) choices := make([]dto.OpenAITextResponseChoice, 0)
fullTextResponse := dto.OpenAITextResponse{ fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -379,8 +388,10 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
} }
var responseText string var responseText string
var responseThinking string
if len(claudeResponse.Content) > 0 { if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text responseText = claudeResponse.Content[0].GetText()
responseThinking = claudeResponse.Content[0].Thinking
} }
tools := make([]dto.ToolCallResponse, 0) tools := make([]dto.ToolCallResponse, 0)
thinkingContent := "" thinkingContent := ""
@@ -415,7 +426,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
// 加密的不管, 只输出明文的推理过程 // 加密的不管, 只输出明文的推理过程
thinkingContent = message.Thinking thinkingContent = message.Thinking
case "text": case "text":
responseText = message.Text responseText = message.GetText()
} }
} }
} }
@@ -427,6 +438,9 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
} }
choice.SetStringContent(responseText) choice.SetStringContent(responseText)
if len(responseThinking) > 0 {
choice.ReasoningContent = responseThinking
}
if len(tools) > 0 { if len(tools) > 0 {
choice.Message.SetToolCalls(tools) choice.Message.SetToolCalls(tools)
} }
@@ -437,126 +451,228 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
return &fullTextResponse return &fullTextResponse
} }
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { type ClaudeResponseInfo struct {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) ResponseId string
var usage *dto.Usage Created int64
usage = &dto.Usage{} Model string
responseText := "" ResponseText strings.Builder
createdTime := common.GetTimestamp() Usage *dto.Usage
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
} }
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
if response == nil {
return true
}
if requestMode == RequestModeCompletion { if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
responseId = response.Id
} else { } else {
if claudeResponse.Type == "message_start" { if claudeResponse.Type == "message_start" {
// message_start, 获取usage // message_start, 获取usage
responseId = claudeResponse.Message.Id claudeInfo.ResponseId = claudeResponse.Message.Id
info.UpstreamModelName = claudeResponse.Message.Model claudeInfo.Model = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
} else if claudeResponse.Type == "content_block_delta" { } else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text if claudeResponse.Delta.Text != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
}
} else if claudeResponse.Type == "message_delta" { } else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens if claudeResponse.Usage.InputTokens > 0 {
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_start" { } else if claudeResponse.Type == "content_block_start" {
return true
} else { } else {
return false
}
}
if oaiResponse != nil {
oaiResponse.Id = claudeInfo.ResponseId
oaiResponse.Created = claudeInfo.Created
oaiResponse.Model = claudeInfo.Model
}
return true return true
} }
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
var claudeResponse dto.ClaudeResponse
err := common.DecodeJsonStr(data, &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
}
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Code: "stream_response_error",
Type: claudeResponse.Error.Type,
Message: claudeResponse.Error.Message,
},
StatusCode: http.StatusInternalServerError,
}
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
if requestMode == RequestModeCompletion {
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
info.UpstreamModelName = claudeResponse.Message.Model
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
} else if claudeResponse.Type == "message_delta" {
if claudeResponse.Usage.InputTokens > 0 {
// 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
}
}
helper.ClaudeChunkData(c, claudeResponse, data)
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
return nil
} }
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = info.UpstreamModelName
err = helper.ObjectData(c, response) err = helper.ObjectData(c, response)
if err != nil { if err != nil {
common.LogError(c, "send_stream_response_failed: "+err.Error()) common.LogError(c, "send_stream_response_failed: "+err.Error())
} }
return true
})
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
if usage.PromptTokens == 0 {
usage.PromptTokens = info.PromptTokens
} }
if usage.CompletionTokens == 0 { return nil
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens) }
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
if requestMode == RequestModeCompletion {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
// 说明流模式建立失败,可能为官方出错
if claudeInfo.Usage.PromptTokens == 0 {
//usage.PromptTokens = info.PromptTokens
}
if claudeInfo.Usage.CompletionTokens == 0 {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
}
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
if requestMode == RequestModeCompletion {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if claudeInfo.Usage.CompletionTokens == 0 {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
} }
} }
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage) response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response) err := helper.ObjectData(c, response)
if err != nil { if err != nil {
common.SysError("send final response failed: " + err.Error()) common.SysError("send final response failed: " + err.Error())
} }
} }
helper.Done(c) helper.Done(c)
//resp.Body.Close() }
return nil, usage
} }
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body) claudeInfo := &ClaudeResponseInfo{
if err != nil { ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
} }
err = resp.Body.Close() var err *dto.OpenAIErrorWithStatusCode
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return false
} }
var claudeResponse ClaudeResponse return true
err = json.Unmarshal(responseBody, &claudeResponse) })
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return err, nil
} }
if claudeResponse.Error.Type != "" {
HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
return nil, claudeInfo.Usage
}
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
var claudeResponse dto.ClaudeResponse
err := common.DecodeJson(data, &claudeResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
}
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{ Error: dto.OpenAIError{
Message: claudeResponse.Error.Message, Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type, Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type, Code: claudeResponse.Error.Type,
}, },
StatusCode: resp.StatusCode, StatusCode: http.StatusInternalServerError,
}, nil
} }
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) }
if requestMode == RequestModeCompletion {
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
} }
usage := dto.Usage{} claudeInfo.Usage.PromptTokens = info.PromptTokens
if requestMode == RequestModeCompletion { claudeInfo.Usage.CompletionTokens = completionTokens
usage.PromptTokens = info.PromptTokens claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
usage.CompletionTokens = completionTokens
usage.TotalTokens = info.PromptTokens + completionTokens
} else { } else {
usage.PromptTokens = claudeResponse.Usage.InputTokens claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
usage.CompletionTokens = claudeResponse.Usage.OutputTokens claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
} }
fullTextResponse.Usage = usage var responseData []byte
jsonResponse, err := json.Marshal(fullTextResponse) switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
openaiResponse.Usage = *claudeInfo.Usage
responseData, err = json.Marshal(openaiResponse)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
case relaycommon.RelayFormatClaude:
responseData = data
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(http.StatusOK)
_, err = c.Writer.Write(jsonResponse) _, err = c.Writer.Write(responseData)
return nil, &usage return nil
}
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
claudeInfo := &ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body.Close()
if common.DebugEnabled {
println("responseBody: ", string(responseBody))
}
handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
if handleErr != nil {
return handleErr, nil
}
return nil, claudeInfo.Usage
} }

View File

@@ -17,6 +17,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
@@ -37,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -15,6 +15,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -42,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return requestOpenAI2Cohere(*request), nil return requestOpenAI2Cohere(*request), nil
} }
@@ -59,7 +65,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info) err, usage = cohereRerankHandler(c, resp, info)

View File

@@ -1,6 +1,7 @@
package cohere package cohere
var ModelList = []string{ var ModelList = []string{
"command-a-03-2025",
"command-r", "command-r-plus", "command-r", "command-r-plus",
"command-r-08-2024", "command-r-plus-08-2024", "command-r-08-2024", "command-r-plus-08-2024",
"c4ai-aya-23-35b", "c4ai-aya-23-8b", "c4ai-aya-23-35b", "c4ai-aya-23-8b",

View File

@@ -40,7 +40,7 @@ type CohereRerankRequest struct {
} }
type CohereRerankResponseResult struct { type CohereRerankResponseResult struct {
Results []dto.RerankResponseDocument `json:"results"` Results []dto.RerankResponseResult `json:"results"`
Meta CohereMeta `json:"meta"` Meta CohereMeta `json:"meta"`
} }

View File

@@ -11,11 +11,18 @@ import (
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
"strings"
) )
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -30,9 +37,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fimBaseUrl := info.BaseUrl
if !strings.HasSuffix(info.BaseUrl, "/beta") {
fimBaseUrl += "/beta"
}
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeCompletions: case constant.RelayModeCompletions:
return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil return fmt.Sprintf("%s/completions", fimBaseUrl), nil
default: default:
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
} }
@@ -44,7 +55,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -68,7 +79,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info)
} }
return return
} }

View File

@@ -9,7 +9,6 @@ import (
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"strings"
) )
const ( const (
@@ -23,6 +22,12 @@ type Adaptor struct {
BotType int BotType int
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -34,15 +39,16 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "agent") { //if strings.HasPrefix(info.UpstreamModelName, "agent") {
a.BotType = BotTypeAgent // a.BotType = BotTypeAgent
} else if strings.HasPrefix(info.UpstreamModelName, "workflow") { //} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
a.BotType = BotTypeWorkFlow // a.BotType = BotTypeWorkFlow
} else if strings.HasPrefix(info.UpstreamModelName, "chat") { //} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
a.BotType = BotTypeCompletion // a.BotType = BotTypeCompletion
} else { //} else {
//}
a.BotType = BotTypeChatFlow a.BotType = BotTypeChatFlow
}
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -64,11 +70,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
return requestOpenAI2Dify(*request), nil return requestOpenAI2Dify(c, info, *request), nil
} }
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {

View File

@@ -8,6 +8,14 @@ type DifyChatRequest struct {
ResponseMode string `json:"response_mode"` ResponseMode string `json:"response_mode"`
User string `json:"user"` User string `json:"user"`
AutoGenerateName bool `json:"auto_generate_name"` AutoGenerateName bool `json:"auto_generate_name"`
Files []DifyFile `json:"files"`
}
type DifyFile struct {
Type string `json:"type"`
TransferMode string `json:"transfer_mode"`
URL string `json:"url,omitempty"`
UploadFileId string `json:"upload_file_id,omitempty"`
} }
type DifyMetaData struct { type DifyMetaData struct {
@@ -17,6 +25,8 @@ type DifyMetaData struct {
type DifyData struct { type DifyData struct {
WorkflowId string `json:"workflow_id"` WorkflowId string `json:"workflow_id"`
NodeId string `json:"node_id"` NodeId string `json:"node_id"`
NodeType string `json:"node_type"`
Status string `json:"status"`
} }
type DifyChatCompletionResponse struct { type DifyChatCompletionResponse struct {

View File

@@ -1,10 +1,12 @@
package dify package dify
import ( import (
"bufio" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin" "fmt"
"io" "io"
"mime/multipart"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
@@ -12,35 +14,163 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"os"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest { func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
content := "" uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl)
switch media.Type {
case dto.ContentTypeImageURL:
// Decode base64 data
imageMedia := media.GetImageMedia()
base64Data := imageMedia.Url
// Remove base64 prefix if exists (e.g., "data:image/jpeg;base64,")
if idx := strings.Index(base64Data, ","); idx != -1 {
base64Data = base64Data[idx+1:]
}
// Decode base64 string
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
common.SysError("failed to decode base64: " + err.Error())
return nil
}
// Create temporary file
tempFile, err := os.CreateTemp("", "dify-upload-*")
if err != nil {
common.SysError("failed to create temp file: " + err.Error())
return nil
}
defer tempFile.Close()
defer os.Remove(tempFile.Name())
// Write decoded data to temp file
if _, err := tempFile.Write(decodedData); err != nil {
common.SysError("failed to write to temp file: " + err.Error())
return nil
}
// Create multipart form
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// Add user field
if err := writer.WriteField("user", user); err != nil {
common.SysError("failed to add user field: " + err.Error())
return nil
}
// Create form file with proper mime type
mimeType := imageMedia.MimeType
if mimeType == "" {
mimeType = "image/jpeg" // default mime type
}
// Create form file
part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
if err != nil {
common.SysError("failed to create form file: " + err.Error())
return nil
}
// Copy file content to form
if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
common.SysError("failed to copy file content: " + err.Error())
return nil
}
writer.Close()
// Create HTTP request
req, err := http.NewRequest("POST", uploadUrl, body)
if err != nil {
common.SysError("failed to create request: " + err.Error())
return nil
}
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
// Send request
client := service.GetImpatientHttpClient()
resp, err := client.Do(req)
if err != nil {
common.SysError("failed to send request: " + err.Error())
return nil
}
defer resp.Body.Close()
// Parse response
var result struct {
Id string `json:"id"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
common.SysError("failed to decode response: " + err.Error())
return nil
}
return &DifyFile{
UploadFileId: result.Id,
Type: "image",
TransferMode: "local_file",
}
}
return nil
}
func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) *DifyChatRequest {
difyReq := DifyChatRequest{
Inputs: make(map[string]interface{}),
AutoGenerateName: false,
}
user := request.User
if user == "" {
user = helper.GetResponseID(c)
}
difyReq.User = user
files := make([]DifyFile, 0)
var content strings.Builder
for _, message := range request.Messages { for _, message := range request.Messages {
if message.Role == "system" { if message.Role == "system" {
content += "SYSTEM: \n" + message.StringContent() + "\n" content.WriteString("SYSTEM: \n" + message.StringContent() + "\n")
} else if message.Role == "assistant" { } else if message.Role == "assistant" {
content += "ASSISTANT: \n" + message.StringContent() + "\n" content.WriteString("ASSISTANT: \n" + message.StringContent() + "\n")
} else { } else {
content += "USER: \n" + message.StringContent() + "\n" parseContent := message.ParseContent()
for _, mediaContent := range parseContent {
switch mediaContent.Type {
case dto.ContentTypeText:
content.WriteString("USER: \n" + mediaContent.Text + "\n")
case dto.ContentTypeImageURL:
media := mediaContent.GetImageMedia()
var file *DifyFile
if media.IsRemoteImage() {
file.Type = media.MimeType
file.TransferMode = "remote_url"
file.URL = media.Url
} else {
file = uploadDifyFile(c, info, difyReq.User, mediaContent)
}
if file != nil {
files = append(files, *file)
} }
} }
}
}
}
difyReq.Query = content.String()
difyReq.Files = files
mode := "blocking" mode := "blocking"
if request.Stream { if request.Stream {
mode = "streaming" mode = "streaming"
} }
user := request.User difyReq.ResponseMode = mode
if user == "" { return &difyReq
user = "api-user"
}
return &DifyChatRequest{
Inputs: make(map[string]interface{}),
Query: content,
ResponseMode: mode,
User: user,
AutoGenerateName: false,
}
} }
func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse { func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
@@ -50,11 +180,29 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
Model: "dify", Model: "dify",
} }
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
if constant.DifyDebug && difyResponse.Event == "workflow_started" { if strings.HasPrefix(difyResponse.Event, "workflow_") {
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n") if constant.DifyDebug {
} else if constant.DifyDebug && difyResponse.Event == "node_started" { text := "Workflow: " + difyResponse.Data.WorkflowId
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n") if difyResponse.Event == "workflow_finished" {
text += " " + difyResponse.Data.Status
}
choice.Delta.SetReasoningContent(text + "\n")
}
} else if strings.HasPrefix(difyResponse.Event, "node_") {
if constant.DifyDebug {
text := "Node: " + difyResponse.Data.NodeType
if difyResponse.Event == "node_finished" {
text += " " + difyResponse.Data.Status
}
choice.Delta.SetReasoningContent(text + "\n")
}
} else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" { } else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
if difyResponse.Answer == "<details style=\"color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;\" open> <summary> Thinking... </summary>\n" {
difyResponse.Answer = "<think>"
} else if difyResponse.Answer == "</details>" {
difyResponse.Answer = "</think>"
}
choice.Delta.SetContentString(difyResponse.Answer) choice.Delta.SetContentString(difyResponse.Answer)
} }
response.Choices = append(response.Choices, choice) response.Choices = append(response.Choices, choice)
@@ -64,43 +212,36 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var responseText string var responseText string
usage := &dto.Usage{} usage := &dto.Usage{}
scanner := bufio.NewScanner(resp.Body) var nodeToken int
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
var difyResponse DifyChunkChatCompletionResponse var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse) err := json.Unmarshal([]byte(data), &difyResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
continue return true
} }
var openaiResponse dto.ChatCompletionsStreamResponse var openaiResponse dto.ChatCompletionsStreamResponse
if difyResponse.Event == "message_end" { if difyResponse.Event == "message_end" {
usage = &difyResponse.MetaData.Usage usage = &difyResponse.MetaData.Usage
break return false
} else if difyResponse.Event == "error" { } else if difyResponse.Event == "error" {
break return false
} else { } else {
openaiResponse = *streamResponseDify2OpenAI(difyResponse) openaiResponse = *streamResponseDify2OpenAI(difyResponse)
if len(openaiResponse.Choices) != 0 { if len(openaiResponse.Choices) != 0 {
responseText += openaiResponse.Choices[0].Delta.GetContentString() responseText += openaiResponse.Choices[0].Delta.GetContentString()
if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
nodeToken += 1
}
} }
} }
err = helper.ObjectData(c, openaiResponse) err = helper.ObjectData(c, openaiResponse)
if err != nil { if err != nil {
common.SysError(err.Error()) common.SysError(err.Error())
} }
} return true
if err := scanner.Err(); err != nil { })
common.SysError("error reading stream: " + err.Error())
}
helper.Done(c) helper.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
@@ -112,6 +253,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
} }
usage.CompletionTokens += nodeToken
return nil, usage return nil, usage
} }

View File

@@ -12,7 +12,6 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"one-api/setting/model_setting" "one-api/setting/model_setting"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -21,6 +20,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -64,12 +69,28 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// suffix -thinking and -nothinking
if strings.HasSuffix(info.OriginModelName, "-thinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName) version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
if strings.HasPrefix(info.UpstreamModelName, "imagen") { if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
} }
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
}
action := "generateContent" action := "generateContent"
if info.IsStream { if info.IsStream {
action = "streamGenerateContent?alt=sse" action = "streamGenerateContent?alt=sse"
@@ -83,15 +104,17 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
ai, err := CovertGemini2OpenAI(*request)
geminiRequest, err := CovertGemini2OpenAI(*request, info)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ai, nil
return geminiRequest, nil
} }
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -99,8 +122,37 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me if request.Input == nil {
return nil, errors.New("not implemented") return nil, errors.New("input is required")
}
inputs := request.ParseInput()
if len(inputs) == 0 {
return nil, errors.New("input is empty")
}
// only process the first input
geminiRequest := GeminiEmbeddingRequest{
Content: GeminiChatContent{
Parts: []GeminiPart{
{
Text: inputs[0],
},
},
},
}
// set specific parameters for different models
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
switch info.UpstreamModelName {
case "text-embedding-004":
// except embedding-001 supports setting `OutputDimensionality`
if request.Dimensions > 0 {
geminiRequest.OutputDimensionality = request.Dimensions
}
}
return geminiRequest, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
@@ -112,11 +164,30 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return GeminiImageHandler(c, resp, info) return GeminiImageHandler(c, resp, info)
} }
// check if the model is an embedding model
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
return GeminiEmbeddingHandler(c, resp, info)
}
if info.IsStream { if info.IsStream {
err, usage = GeminiChatStreamHandler(c, resp, info) err, usage = GeminiChatStreamHandler(c, resp, info)
} else { } else {
err, usage = GeminiChatHandler(c, resp, info) err, usage = GeminiChatHandler(c, resp, info)
} }
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
// // 没有请求-thinking的情况下产生思考token则按照思考模型计费
// if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
// !strings.HasSuffix(info.OriginModelName, "-nothinking") {
// thinkingModelName := info.OriginModelName + "-thinking"
// if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
// info.OriginModelName = thinkingModelName
// }
// }
//}
return return
} }

View File

@@ -16,8 +16,14 @@ var ModelList = []string{
"gemini-2.0-pro-exp", "gemini-2.0-pro-exp",
// thinking exp // thinking exp
"gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp",
"gemini-2.5-pro-exp-03-25",
"gemini-2.5-pro-preview-03-25",
// imagen models // imagen models
"imagen-3.0-generate-002", "imagen-3.0-generate-002",
// embedding models
"gemini-embedding-exp-03-07",
"text-embedding-004",
"embedding-001",
} }
var SafetySettingList = []string{ var SafetySettingList = []string{

View File

@@ -8,6 +8,15 @@ type GeminiChatRequest struct {
SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"` SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"`
} }
type GeminiThinkingConfig struct {
IncludeThoughts bool `json:"includeThoughts,omitempty"`
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
}
func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
c.ThinkingBudget = &budget
}
type GeminiInlineData struct { type GeminiInlineData struct {
MimeType string `json:"mimeType"` MimeType string `json:"mimeType"`
Data string `json:"data"` Data string `json:"data"`
@@ -80,6 +89,8 @@ type GeminiChatGenerationConfig struct {
ResponseMimeType string `json:"responseMimeType,omitempty"` ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"` ResponseSchema any `json:"responseSchema,omitempty"`
Seed int64 `json:"seed,omitempty"` Seed int64 `json:"seed,omitempty"`
ResponseModalities []string `json:"responseModalities,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
} }
type GeminiChatCandidate struct { type GeminiChatCandidate struct {
@@ -108,6 +119,7 @@ type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"` PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"` CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"` TotalTokenCount int `json:"totalTokenCount"`
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
} }
// Imagen related structs // Imagen related structs
@@ -136,3 +148,19 @@ type GeminiImagePrediction struct {
RaiFilteredReason string `json:"raiFilteredReason,omitempty"` RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
SafetyAttributes any `json:"safetyAttributes,omitempty"` SafetyAttributes any `json:"safetyAttributes,omitempty"`
} }
// Embedding related structs
type GeminiEmbeddingRequest struct {
Content GeminiChatContent `json:"content"`
TaskType string `json:"taskType,omitempty"`
Title string `json:"title,omitempty"`
OutputDimensionality int `json:"outputDimensionality,omitempty"`
}
type GeminiEmbeddingResponse struct {
Embedding ContentEmbedding `json:"embedding"`
}
type ContentEmbedding struct {
Values []float64 `json:"values"`
}

View File

@@ -19,11 +19,10 @@ import (
) )
// Setting safety to the lowest possible values since Gemini is already powerless enough // Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) { func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
geminiRequest := GeminiChatRequest{ geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
//SafetySettings: []GeminiChatSafetySettings{},
GenerationConfig: GeminiChatGenerationConfig{ GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature, Temperature: textRequest.Temperature,
TopP: textRequest.TopP, TopP: textRequest.TopP,
@@ -32,6 +31,30 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}, },
} }
if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
geminiRequest.GenerationConfig.ResponseModalities = []string{
"TEXT",
"IMAGE",
}
}
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
if strings.HasSuffix(info.OriginModelName, "-thinking") {
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
if budgetTokens == 0 || budgetTokens > 24576 {
budgetTokens = 24576
}
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(int(budgetTokens)),
IncludeThoughts: true,
}
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(0),
}
}
}
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList)) safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList { for _, category := range SafetySettingList {
safetySettings = append(safetySettings, GeminiChatSafetySettings{ safetySettings = append(safetySettings, GeminiChatSafetySettings{
@@ -56,6 +79,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
continue continue
} }
if tool.Function.Parameters != nil { if tool.Function.Parameters != nil {
params, ok := tool.Function.Parameters.(map[string]interface{}) params, ok := tool.Function.Parameters.(map[string]interface{})
if ok { if ok {
if props, hasProps := params["properties"].(map[string]interface{}); hasProps { if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
@@ -65,6 +89,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
} }
} }
} }
// Clean the parameters before appending
cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
tool.Function.Parameters = cleanedParams
functions = append(functions, tool.Function) functions = append(functions, tool.Function)
} }
if codeExecution { if codeExecution {
@@ -86,11 +113,11 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
// json_data, _ := json.Marshal(geminiRequest.Tools) // json_data, _ := json.Marshal(geminiRequest.Tools)
// common.SysLog("tools_json: " + string(json_data)) // common.SysLog("tools_json: " + string(json_data))
} else if textRequest.Functions != nil { } else if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTool{ //geminiRequest.Tools = []GeminiChatTool{
{ // {
FunctionDeclarations: textRequest.Functions, // FunctionDeclarations: textRequest.Functions,
}, // },
} //}
} }
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") { if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
@@ -180,9 +207,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum) return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
} }
// 判断是否是url // 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { if strings.HasPrefix(part.GetImageMedia().Url, "http") {
// 是url获取图片的类型和base64编码的数据 // 是url获取图片的类型和base64编码的数据
fileData, err := service.GetFileBase64FromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
if err != nil { if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error()) return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
} }
@@ -193,7 +220,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}, },
}) })
} else { } else {
format, base64String, err := service.DecodeBase64FileData(part.ImageUrl.(dto.MessageImageUrl).Url) format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
if err != nil { if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
} }
@@ -204,6 +231,34 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}, },
}) })
} }
} else if part.Type == dto.ContentTypeFile {
if part.GetFile().FileId != "" {
return nil, fmt.Errorf("only base64 file is supported in gemini")
}
format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
if err != nil {
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: format,
Data: base64String,
},
})
} else if part.Type == dto.ContentTypeInputAudio {
if part.GetInputAudio().Data == "" {
return nil, fmt.Errorf("only base64 audio is supported in gemini")
}
format, base64String, err := service.DecodeBase64FileData(part.GetInputAudio().Data)
if err != nil {
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: format,
Data: base64String,
},
})
} }
} }
@@ -229,6 +284,102 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
return &geminiRequest, nil return &geminiRequest, nil
} }
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
func cleanFunctionParameters(params interface{}) interface{} {
if params == nil {
return nil
}
paramMap, ok := params.(map[string]interface{})
if !ok {
// Not a map, return as is (e.g., could be an array or primitive)
return params
}
// Create a copy to avoid modifying the original
cleanedMap := make(map[string]interface{})
for k, v := range paramMap {
cleanedMap[k] = v
}
// Remove unsupported root-level fields
delete(cleanedMap, "default")
delete(cleanedMap, "exclusiveMaximum")
delete(cleanedMap, "exclusiveMinimum")
delete(cleanedMap, "$schema")
delete(cleanedMap, "additionalProperties")
// Clean properties
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
cleanedProps := make(map[string]interface{})
for propName, propValue := range props {
propMap, ok := propValue.(map[string]interface{})
if !ok {
cleanedProps[propName] = propValue // Keep non-map properties
continue
}
// Create a copy of the property map
cleanedPropMap := make(map[string]interface{})
for k, v := range propMap {
cleanedPropMap[k] = v
}
// Remove unsupported fields
delete(cleanedPropMap, "default")
delete(cleanedPropMap, "exclusiveMaximum")
delete(cleanedPropMap, "exclusiveMinimum")
delete(cleanedPropMap, "$schema")
delete(cleanedPropMap, "additionalProperties")
// Check and clean 'format' for string types
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && propType == "string" {
if formatValue, formatExists := cleanedPropMap["format"].(string); formatExists {
if formatValue != "enum" && formatValue != "date-time" {
delete(cleanedPropMap, "format")
}
}
}
// Recursively clean nested properties within this property if it's an object/array
// Check the type before recursing
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && (propType == "object" || propType == "array") {
cleanedProps[propName] = cleanFunctionParameters(cleanedPropMap)
} else {
cleanedProps[propName] = cleanedPropMap // Assign the cleaned map back if not recursing
}
}
cleanedMap["properties"] = cleanedProps
}
// Recursively clean items in arrays if needed (e.g., type: array, items: { ... })
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
cleanedMap["items"] = cleanFunctionParameters(items)
}
// Also handle items if it's an array of schemas
if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
cleanedItemsArray := make([]interface{}, len(itemsArray))
for i, item := range itemsArray {
cleanedItemsArray[i] = cleanFunctionParameters(item)
}
cleanedMap["items"] = cleanedItemsArray
}
// Recursively clean other schema composition keywords if necessary
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
if nested, ok := cleanedMap[field].([]interface{}); ok {
cleanedNested := make([]interface{}, len(nested))
for i, item := range nested {
cleanedNested[i] = cleanFunctionParameters(item)
}
cleanedMap[field] = cleanedNested
}
}
return cleanedMap
}
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} { func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
if depth >= 5 { if depth >= 5 {
return schema return schema
@@ -427,9 +578,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
return &fullTextResponse return &fullTextResponse
} }
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) { func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates)) choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
isStop := false isStop := false
hasImage := false
for _, candidate := range geminiResponse.Candidates { for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" { if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
isStop = true isStop = true
@@ -455,7 +607,13 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
} }
} }
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil { if part.InlineData != nil {
if strings.HasPrefix(part.InlineData.MimeType, "image") {
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText)
hasImage = true
}
} else if part.FunctionCall != nil {
isTools = true isTools = true
if call := getResponseToolCall(&part); call != nil { if call := getResponseToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls)) call.SetIndex(len(choice.Delta.ToolCalls))
@@ -483,7 +641,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
var response dto.ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Choices = choices response.Choices = choices
return &response, isStop return &response, isStop, hasImage
} }
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -491,23 +649,27 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp() createAt := common.GetTimestamp()
var usage = &dto.Usage{} var usage = &dto.Usage{}
var imageCount int
helper.StreamScannerHandler(c, resp, info, func(data string) bool { helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse var geminiResponse GeminiChatResponse
err := json.Unmarshal([]byte(data), &geminiResponse) err := common.DecodeJsonStr(data, &geminiResponse)
if err != nil { if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error()) common.LogError(c, "error unmarshalling stream response: "+err.Error())
return false return false
} }
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse) response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
if hasImage {
imageCount++
}
response.Id = id response.Id = id
response.Created = createAt response.Created = createAt
response.Model = info.UpstreamModelName response.Model = info.UpstreamModelName
// responseText += response.Choices[0].Delta.GetContentString()
if geminiResponse.UsageMetadata.TotalTokenCount != 0 { if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
} }
err = helper.ObjectData(c, response) err = helper.ObjectData(c, response)
if err != nil { if err != nil {
@@ -522,9 +684,15 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
var response *dto.ChatCompletionsStreamResponse var response *dto.ChatCompletionsStreamResponse
if imageCount != 0 {
if usage.CompletionTokens == 0 {
usage.CompletionTokens = imageCount * 258
}
}
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
usage.PromptTokensDetails.TextTokens = usage.PromptTokens usage.PromptTokensDetails.TextTokens = usage.PromptTokens
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens //usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
@@ -570,6 +738,9 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount, CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
} }
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
fullTextResponse.Usage = usage fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
@@ -580,3 +751,52 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
_, err = c.Writer.Write(jsonResponse) _, err = c.Writer.Write(jsonResponse)
return nil, &usage return nil, &usage
} }
func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
}
_ = resp.Body.Close()
var geminiResponse GeminiEmbeddingResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
// convert to openai format response
openAIResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: []dto.OpenAIEmbeddingResponseItem{
{
Object: "embedding",
Embedding: geminiResponse.Embedding.Values,
Index: 0,
},
},
Model: info.UpstreamModelName,
}
// calculate usage
// https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
// Google has not yet clarified how embedding models will be billed
// refer to openai billing method to use input tokens billing
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
usage = &dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: 0,
TotalTokens: info.PromptTokens,
}
openAIResponse.Usage = *usage.(*dto.Usage)
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return usage, nil
}

View File

@@ -8,13 +8,21 @@ import (
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/common_handler"
"one-api/relay/constant" "one-api/relay/constant"
) )
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -43,7 +51,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil return request, nil
} }
@@ -61,9 +69,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
err, usage = JinaRerankHandler(c, resp) err, usage = common_handler.RerankHandler(c, info, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings { } else if info.RelayMode == constant.RelayModeEmbeddings {
err, usage = jinaEmbeddingHandler(c, resp) err, usage = openai.OpenaiHandler(c, resp, info)
} }
return return
} }

View File

@@ -1,60 +1 @@
package jina package jina
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var jinaResp dto.RerankResponse
err = json.Unmarshal(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jsonResponse, err := json.Marshal(jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage
}
func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var jinaResp dto.OpenAIEmbeddingResponse
err = json.Unmarshal(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jsonResponse, err := json.Marshal(jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage
}

View File

@@ -14,6 +14,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -37,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -61,7 +67,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info)
} }
return return
} }

View File

@@ -10,7 +10,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI
mediaMessages := message.ParseContent() mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages { for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL { if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) imageUrl := mediaMessage.GetImageMedia()
mediaMessage.ImageUrl = imageUrl.Url mediaMessage.ImageUrl = imageUrl.Url
mediaMessages[j] = mediaMessage mediaMessages[j] = mediaMessage
} }

View File

@@ -16,6 +16,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -51,7 +57,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -15,6 +15,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -43,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -69,7 +75,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.RelayMode == relayconstant.RelayModeEmbeddings { if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info)
} }
} }
return return

View File

@@ -19,7 +19,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err
mediaMessages := message.ParseContent() mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages { for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL { if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) imageUrl := mediaMessage.GetImageMedia()
// check if not base64 // check if not base64
if strings.HasPrefix(imageUrl.Url, "http") { if strings.HasPrefix(imageUrl.Url, "http") {
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url) fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)

View File

@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
@@ -14,13 +13,18 @@ import (
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/ai360" "one-api/relay/channel/ai360"
"one-api/relay/channel/jina"
"one-api/relay/channel/lingyiwanwu" "one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/minimax" "one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot" "one-api/relay/channel/moonshot"
"one-api/relay/channel/openrouter"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/common_handler"
"one-api/relay/constant" "one-api/relay/constant"
"one-api/service"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
type Adaptor struct { type Adaptor struct {
@@ -28,11 +32,39 @@ type Adaptor struct {
ResponseFormat string ResponseFormat string
} }
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
if !strings.Contains(request.Model, "claude") {
return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
}
aiRequest, err := service.ClaudeToOpenAIRequest(*request, info)
if err != nil {
return nil, err
}
if info.SupportStreamOptions {
aiRequest.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
return a.ConvertOpenAIRequest(c, info, aiRequest)
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType a.ChannelType = info.ChannelType
// initialize ThinkingContentInfo when thinking_to_content is enabled
if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content {
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
HasSentThinkingContent: false,
}
}
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
if info.RelayMode == constant.RelayModeRealtime { if info.RelayMode == constant.RelayModeRealtime {
if strings.HasPrefix(info.BaseUrl, "https://") { if strings.HasPrefix(info.BaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
@@ -101,28 +133,26 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
} else { } else {
header.Set("Authorization", "Bearer "+info.ApiKey) header.Set("Authorization", "Bearer "+info.ApiKey)
} }
//if info.ChannelType == common.ChannelTypeOpenRouter { if info.ChannelType == common.ChannelTypeOpenRouter {
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
// req.Header.Set("X-Title", "One API") header.Set("X-Title", "New API")
//} }
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure { if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
request.StreamOptions = nil request.StreamOptions = nil
} }
if strings.HasPrefix(request.Model, "o1") || strings.HasPrefix(request.Model, "o3") { if strings.HasPrefix(request.Model, "o") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 { if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0 request.MaxTokens = 0
} }
if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
request.Temperature = nil request.Temperature = nil
}
if strings.HasSuffix(request.Model, "-high") { if strings.HasSuffix(request.Model, "-high") {
request.ReasoningEffort = "high" request.ReasoningEffort = "high"
request.Model = strings.TrimSuffix(request.Model, "-high") request.Model = strings.TrimSuffix(request.Model, "-high")
@@ -135,13 +165,15 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
} }
info.ReasoningEffort = request.ReasoningEffort info.ReasoningEffort = request.ReasoningEffort
info.UpstreamModelName = request.Model info.UpstreamModelName = request.Model
}
if request.Model == "o1" || request.Model == "o1-2024-12-17" || strings.HasPrefix(request.Model, "o3") { // o系列模型developer适配o1-mini除外
if !strings.HasPrefix(request.Model, "o1-mini") {
//修改第一个Message的内容将system改为developer //修改第一个Message的内容将system改为developer
if len(request.Messages) > 0 && request.Messages[0].Role == "system" { if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
request.Messages[0].Role = "developer" request.Messages[0].Role = "developer"
} }
} }
}
return request, nil return request, nil
} }
@@ -230,12 +262,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info) err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeRerank: case constant.RelayModeRerank:
err, usage = jina.JinaRerankHandler(c, resp) err, usage = common_handler.RerankHandler(c, info, resp)
default: default:
if info.IsStream { if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info) err, usage = OaiStreamHandler(c, resp, info)
} else { } else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = OpenaiHandler(c, resp, info)
} }
} }
return return
@@ -251,6 +283,10 @@ func (a *Adaptor) GetModelList() []string {
return lingyiwanwu.ModelList return lingyiwanwu.ModelList
case common.ChannelTypeMiniMax: case common.ChannelTypeMiniMax:
return minimax.ModelList return minimax.ModelList
case common.ChannelTypeXinference:
return xinference.ModelList
case common.ChannelTypeOpenRouter:
return openrouter.ModelList
default: default:
return ModelList return ModelList
} }
@@ -266,6 +302,10 @@ func (a *Adaptor) GetChannelName() string {
return lingyiwanwu.ChannelName return lingyiwanwu.ChannelName
case common.ChannelTypeMiniMax: case common.ChannelTypeMiniMax:
return minimax.ChannelName return minimax.ChannelName
case common.ChannelTypeXinference:
return xinference.ChannelName
case common.ChannelTypeOpenRouter:
return openrouter.ChannelName
default: default:
return ChannelName return ChannelName
} }

View File

@@ -0,0 +1,189 @@
package openai
import (
"encoding/json"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
// 辅助函数
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
info.SendResponseCount++
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
return sendStreamData(c, info, data, forceFormat, thinkToContent)
case relaycommon.RelayFormatClaude:
return handleClaudeFormat(c, data, info)
}
return nil
}
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
return err
}
if streamResponse.Usage != nil {
info.ClaudeConvertInfo.Usage = streamResponse.Usage
}
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
}
return nil
}
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
return nil
}
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
case relayconstant.RelayModeChatCompletions:
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
case relayconstant.RelayModeCompletions:
return processCompletions(streamResp, streamItems, responseTextBuilder)
}
return nil
}
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
common.SysError("error processing stream response: " + err.Error())
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
return nil
}
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
continue
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
systemFingerprint *string, model *string, usage **dto.Usage,
containStreamUsage *bool, info *relaycommon.RelayInfo,
shouldSendLastResp *bool) error {
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
return err
}
*responseId = lastStreamResponse.Id
*createAt = lastStreamResponse.Created
*systemFingerprint = lastStreamResponse.GetSystemFingerprint()
*model = lastStreamResponse.Model
if service.ValidUsage(lastStreamResponse.Usage) {
*containStreamUsage = true
*usage = lastStreamResponse.Usage
if !info.ShouldIncludeUsage {
*shouldSendLastResp = false
}
}
return nil
}
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
responseId string, createAt int64, model string, systemFingerprint string,
usage *dto.Usage, containStreamUsage bool) {
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
if info.ShouldIncludeUsage && !containStreamUsage {
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
helper.ObjectData(c, response)
}
helper.Done(c)
case relaycommon.RelayFormatClaude:
info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
info.ClaudeConvertInfo.Usage = usage
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
}
}
}

View File

@@ -12,7 +12,6 @@ import (
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"os" "os"
@@ -34,7 +33,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
} }
var lastStreamResponse dto.ChatCompletionsStreamResponse var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil { if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil {
return err return err
} }
@@ -66,6 +65,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
response.Choices[i].Delta.Reasoning = nil response.Choices[i].Delta.Reasoning = nil
} }
info.ThinkingContentInfo.IsFirstThinkingContent = false info.ThinkingContentInfo.IsFirstThinkingContent = false
info.ThinkingContentInfo.HasSentThinkingContent = true
return helper.ObjectData(c, response) return helper.ObjectData(c, response)
} }
} }
@@ -77,7 +77,8 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
// Process each choice // Process each choice
for i, choice := range lastStreamResponse.Choices { for i, choice := range lastStreamResponse.Choices {
// Handle transition from thinking to content // Handle transition from thinking to content
if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent { // only send `</think>` tag when previous thinking content has been sent
if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
response := lastStreamResponse.Copy() response := lastStreamResponse.Copy()
for j := range response.Choices { for j := range response.Choices {
response.Choices[j].Delta.SetContentString("\n</think>\n") response.Choices[j].Delta.SetContentString("\n</think>\n")
@@ -88,7 +89,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
helper.ObjectData(c, response) helper.ObjectData(c, response)
} }
// Convert reasoning content to regular content // Convert reasoning content to regular content if any
if len(choice.Delta.GetReasoningContent()) > 0 { if len(choice.Delta.GetReasoningContent()) > 0 {
lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent()) lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
lastStreamResponse.Choices[i].Delta.ReasoningContent = nil lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
@@ -116,6 +117,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
model := info.UpstreamModelName model := info.UpstreamModelName
var responseTextBuilder strings.Builder var responseTextBuilder strings.Builder
var toolCount int
var usage = &dto.Usage{} var usage = &dto.Usage{}
var streamItems []string // store stream items var streamItems []string // store stream items
var forceFormat bool var forceFormat bool
@@ -129,17 +131,15 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
thinkToContent = think2Content thinkToContent = think2Content
} }
toolCount := 0
var ( var (
lastStreamData string lastStreamData string
) )
helper.StreamScannerHandler(c, resp, info, func(data string) bool { helper.StreamScannerHandler(c, resp, info, func(data string) bool {
if lastStreamData != "" { if lastStreamData != "" {
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
if err != nil { if err != nil {
common.LogError(c, "streaming error: "+err.Error()) common.SysError("error handling stream format: " + err.Error())
} }
} }
lastStreamData = data lastStreamData = data
@@ -149,7 +149,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
shouldSendLastResp := true shouldSendLastResp := true
var lastStreamResponse dto.ChatCompletionsStreamResponse var lastStreamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse) err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse)
if err == nil { if err == nil {
responseId = lastStreamResponse.Id responseId = lastStreamResponse.Id
createAt = lastStreamResponse.Created createAt = lastStreamResponse.Created
@@ -168,87 +168,15 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
} }
} }
if shouldSendLastResp { if shouldSendLastResp {
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
//err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
} }
// 计算token // 处理token计算
streamResp := "[" + strings.Join(streamItems, ",") + "]" if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
switch info.RelayMode { common.SysError("error processing tokens: " + err.Error())
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
// handle both reasoning_content and reasoning
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
}
} else {
for _, streamResponse := range streamResponses {
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
// containStreamUsage = true
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) // This will handle both reasoning_content and reasoning
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
}
case relayconstant.RelayModeCompletions:
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
} else {
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
} }
if !containStreamUsage { if !containStreamUsage {
@@ -262,20 +190,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
} }
if info.ShouldIncludeUsage && !containStreamUsage { handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
helper.ObjectData(c, response)
}
helper.Done(c)
//resp.Body.Close()
return nil, usage return nil, usage
} }
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var simpleResponse dto.SimpleResponse var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -284,16 +205,29 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
err = json.Unmarshal(responseBody, &simpleResponse) err = common.DecodeJson(responseBody, &simpleResponse)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if simpleResponse.Error.Type != "" { if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
Error: simpleResponse.Error, Error: *simpleResponse.Error,
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
}, nil }, nil
} }
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
break
case relaycommon.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
claudeRespStr, err := json.Marshal(claudeResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
responseBody = claudeRespStr
}
// Reset response body // Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail. // We shouldn't set the header before we parse the response body, because the parse part may fail.
@@ -306,19 +240,20 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body) _, err = io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil //return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
common.SysError("error copying response body: " + err.Error())
} }
resp.Body.Close() resp.Body.Close()
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0 completionTokens := 0
for _, choice := range simpleResponse.Choices { for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, model) ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
completionTokens += ctkm completionTokens += ctkm
} }
simpleResponse.Usage = dto.Usage{ simpleResponse.Usage = dto.Usage{
PromptTokens: promptTokens, PromptTokens: info.PromptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens, TotalTokens: info.PromptTokens + completionTokens,
} }
} }
return nil, &simpleResponse.Usage return nil, &simpleResponse.Usage

View File

@@ -1,74 +0,0 @@
package openrouter
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
req.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
req.Set("X-Title", "New API")
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -15,6 +15,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -54,7 +60,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -15,6 +15,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -57,7 +63,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
@@ -66,7 +71,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info)
} }
return return
} }

View File

@@ -16,6 +16,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -48,7 +54,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil return request, nil
} }
@@ -68,20 +74,16 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeRerank: case constant.RelayModeRerank:
err, usage = siliconflowRerankHandler(c, resp) err, usage = siliconflowRerankHandler(c, resp)
case constant.RelayModeCompletions:
fallthrough
case constant.RelayModeChatCompletions: case constant.RelayModeChatCompletions:
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info)
}
case constant.RelayModeCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info)
} }
return return
} }

View File

@@ -12,6 +12,6 @@ type SFMeta struct {
} }
type SFRerankResponse struct { type SFRerankResponse struct {
Results []dto.RerankResponseDocument `json:"results"` Results []dto.RerankResponseResult `json:"results"`
Meta SFMeta `json:"meta"` Meta SFMeta `json:"meta"`
} }

View File

@@ -23,6 +23,12 @@ type Adaptor struct {
Timestamp int64 Timestamp int64
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -52,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -78,7 +84,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -38,6 +38,16 @@ type Adaptor struct {
AccountCredentials Credentials AccountCredentials Credentials
} }
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
c.Set("request_model", v)
} else {
c.Set("request_model", request.Model)
}
vertexClaudeReq := copyRequest(request, anthropicVersion)
return vertexClaudeReq, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -119,7 +129,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -133,7 +143,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
info.UpstreamModelName = claudeReq.Model info.UpstreamModelName = claudeReq.Model
return vertexClaudeReq, nil return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini { } else if a.RequestMode == RequestModeGemini {
geminiRequest, err := gemini.CovertGemini2OpenAI(*request) geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -175,7 +185,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
case RequestModeGemini: case RequestModeGemini:
err, usage = gemini.GeminiChatHandler(c, resp, info) err, usage = gemini.GeminiChatHandler(c, resp, info)
case RequestModeLlama: case RequestModeLlama:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName) err, usage = openai.OpenaiHandler(c, resp, info)
} }
} }
return return

View File

@@ -1,12 +1,12 @@
package vertex package vertex
import ( import (
"one-api/relay/channel/claude" "one-api/dto"
) )
type VertexAIClaudeRequest struct { type VertexAIClaudeRequest struct {
AnthropicVersion string `json:"anthropic_version"` AnthropicVersion string `json:"anthropic_version"`
Messages []claude.ClaudeMessage `json:"messages"` Messages []dto.ClaudeMessage `json:"messages"`
System any `json:"system,omitempty"` System any `json:"system,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
@@ -16,10 +16,10 @@ type VertexAIClaudeRequest struct {
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Tools any `json:"tools,omitempty"` Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
Thinking *claude.Thinking `json:"thinking,omitempty"` Thinking *dto.Thinking `json:"thinking,omitempty"`
} }
func copyRequest(req *claude.ClaudeRequest, version string) *VertexAIClaudeRequest { func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest {
return &VertexAIClaudeRequest{ return &VertexAIClaudeRequest{
AnthropicVersion: version, AnthropicVersion: version,
System: req.System, System: req.System,

View File

@@ -17,6 +17,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -50,7 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -75,10 +81,10 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info)
} }
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info)
} }
return return
} }

View File

@@ -0,0 +1,104 @@
package xai
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"strings"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
//panic("implement me")
return nil, errors.New("not available")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//not available
return nil, errors.New("not available")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
request.Size = ""
return request, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if strings.HasPrefix(request.Model, "grok-3-mini") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
if strings.HasSuffix(request.Model, "-high") {
request.ReasoningEffort = "high"
request.Model = strings.TrimSuffix(request.Model, "-high")
} else if strings.HasSuffix(request.Model, "-low") {
request.ReasoningEffort = "low"
request.Model = strings.TrimSuffix(request.Model, "-low")
} else if strings.HasSuffix(request.Model, "-medium") {
request.ReasoningEffort = "medium"
request.Model = strings.TrimSuffix(request.Model, "-medium")
}
info.ReasoningEffort = request.ReasoningEffort
info.UpstreamModelName = request.Model
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//not available
return nil, errors.New("not available")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = xAIStreamHandler(c, resp, info)
} else {
err, usage = xAIHandler(c, resp, info)
}
//if _, ok := usage.(*dto.Usage); ok && usage != nil {
// usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens
//}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,18 @@
package xai
var ModelList = []string{
// grok-3
"grok-3-beta", "grok-3-mini-beta",
// grok-3 mini
"grok-3-fast-beta", "grok-3-mini-fast-beta",
// extend grok-3-mini reasoning
"grok-3-mini-beta-high", "grok-3-mini-beta-low", "grok-3-mini-beta-medium",
"grok-3-mini-fast-beta-high", "grok-3-mini-fast-beta-low", "grok-3-mini-fast-beta-medium",
// image model
"grok-2-image",
// legacy models
"grok-2", "grok-2-vision",
"grok-beta", "grok-vision-beta",
}
var ChannelName = "xai"

14
relay/channel/xai/dto.go Normal file
View File

@@ -0,0 +1,14 @@
package xai
import "one-api/dto"
// ChatCompletionResponse represents the response from XAI chat completion API
type ChatCompletionResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []dto.ChatCompletionsStreamResponseChoice
Usage *dto.Usage `json:"usage"`
SystemFingerprint string `json:"system_fingerprint"`
}

119
relay/channel/xai/text.go Normal file
View File

@@ -0,0 +1,119 @@
package xai
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
)
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
if xAIResp == nil {
return nil
}
if xAIResp.Usage != nil {
xAIResp.Usage.CompletionTokens = usage.CompletionTokens
}
openAIResp := &dto.ChatCompletionsStreamResponse{
Id: xAIResp.Id,
Object: xAIResp.Object,
Created: xAIResp.Created,
Model: xAIResp.Model,
Choices: xAIResp.Choices,
Usage: xAIResp.Usage,
}
return openAIResp
}
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
usage := &dto.Usage{}
var responseTextBuilder strings.Builder
var toolCount int
var containStreamUsage bool
helper.SetEventStreamHeaders(c)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var xAIResp *dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &xAIResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
// 把 xAI 的usage转换为 OpenAI 的usage
if xAIResp.Usage != nil {
containStreamUsage = true
usage.PromptTokens = xAIResp.Usage.PromptTokens
usage.TotalTokens = xAIResp.Usage.TotalTokens
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
}
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
err = helper.ObjectData(c, openaiResponse)
if err != nil {
common.SysError(err.Error())
}
return true
})
if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
helper.Done(c)
err := resp.Body.Close()
if err != nil {
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
common.SysError("close_response_body_failed: " + err.Error())
}
return nil, usage
}
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
var response *dto.TextResponse
err = common.DecodeJson(responseBody, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return nil, nil
}
response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
// new body
encodeJson, err := common.EncodeJson(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return nil, nil
}
// set new body
resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &response.Usage
}

View File

@@ -0,0 +1,8 @@
package xinference
var ModelList = []string{
"bge-reranker-v2-m3",
"jina-reranker-v2",
}
var ChannelName = "xinference"

View File

@@ -0,0 +1,11 @@
package xinference
type XinRerankResponseDocument struct {
Document string `json:"document,omitempty"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
type XinRerankResponse struct {
Results []XinRerankResponseDocument `json:"results"`
}

View File

@@ -16,6 +16,12 @@ type Adaptor struct {
request *dto.GeneralOpenAIRequest request *dto.GeneralOpenAIRequest
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -55,7 +61,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
// xunfei's request is not http request, so we don't need to do anything here // xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{} dummyResp := &http.Response{}

View File

@@ -14,6 +14,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -42,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -61,7 +67,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -10,11 +10,18 @@ import (
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
) )
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
@@ -29,7 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl)
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return fmt.Sprintf("%s/embeddings", baseUrl), nil
default:
return fmt.Sprintf("%s/chat/completions", baseUrl), nil
}
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -39,7 +52,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -54,11 +67,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me return request, nil
return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
@@ -67,7 +78,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info)
} }
return return
} }

View File

@@ -1,17 +1,9 @@
package zhipu_4v package zhipu_4v
import ( import (
"bufio"
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"io"
"net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -79,7 +71,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
mediaMessages := message.ParseContent() mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages { for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL { if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) imageUrl := mediaMessage.GetImageMedia()
// check if base64 // check if base64
if strings.HasPrefix(imageUrl.Url, "data:image/") { if strings.HasPrefix(imageUrl.Url, "data:image/") {
// 去除base64数据的URL前缀如果有 // 去除base64数据的URL前缀如果有
@@ -119,163 +111,3 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
ToolChoice: request.ToolChoice, ToolChoice: request.ToolChoice,
} }
} }
//func responseZhipu2OpenAI(response *dto.OpenAITextResponse) *dto.OpenAITextResponse {
// fullTextResponse := dto.OpenAITextResponse{
// Id: response.Id,
// Object: "chat.completion",
// Created: common.GetTimestamp(),
// Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.TextResponseChoices)),
// Usage: response.Usage,
// }
// for i, choice := range response.TextResponseChoices {
// content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
// openaiChoice := dto.OpenAITextResponseChoice{
// Index: i,
// Message: dto.Message{
// Role: choice.Role,
// Content: content,
// },
// FinishReason: "",
// }
// if i == len(response.TextResponseChoices)-1 {
// openaiChoice.FinishReason = "stop"
// }
// fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
// }
// return &fullTextResponse
//}
func streamResponseZhipu2OpenAI(zhipuResponse *ZhipuV4StreamResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = zhipuResponse.Choices[0].Delta.Content
choice.Delta.Role = zhipuResponse.Choices[0].Delta.Role
choice.Delta.ToolCalls = zhipuResponse.Choices[0].Delta.ToolCalls
choice.Index = zhipuResponse.Choices[0].Index
choice.FinishReason = zhipuResponse.Choices[0].FinishReason
response := dto.ChatCompletionsStreamResponse{
Id: zhipuResponse.Id,
Object: "chat.completion.chunk",
Created: zhipuResponse.Created,
Model: "glm-4v",
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func lastStreamResponseZhipuV42OpenAI(zhipuResponse *ZhipuV4StreamResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
response := streamResponseZhipu2OpenAI(zhipuResponse)
return response, &zhipuResponse.Usage
}
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage *dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data
}
stopChan <- true
}()
helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var streamResponse ZhipuV4StreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
}
var response *dto.ChatCompletionsStreamResponse
if strings.Contains(data, "prompt_tokens") {
response, usage = lastStreamResponseZhipuV42OpenAI(&streamResponse)
} else {
response = streamResponseZhipu2OpenAI(&streamResponse)
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage
}
func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var textResponse ZhipuV4Response
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the HTTPClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &textResponse.Usage
}

163
relay/claude_handler.go Normal file
View File

@@ -0,0 +1,163 @@
package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
"strings"
)
func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
textRequest = &dto.ClaudeRequest{}
err = c.ShouldBindJSON(textRequest)
if err != nil {
return nil, err
}
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return nil, errors.New("field messages is required")
}
if textRequest.Model == "" {
return nil, errors.New("field model is required")
}
return textRequest, nil
}
func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfoClaude(c)
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateClaudeRequest(c)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "invalid_claude_request", http.StatusBadRequest)
}
if textRequest.Stream {
relayInfo.IsStream = true
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
textRequest.Model = relayInfo.UpstreamModelName
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "count_token_messages_failed", http.StatusInternalServerError)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return service.OpenAIErrorToClaudeError(openaiErr)
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.ClaudeErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(relayInfo)
var requestBody io.Reader
if textRequest.MaxTokens == 0 {
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
}
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
strings.HasSuffix(textRequest.Model, "-thinking") {
if textRequest.Thinking == nil {
// 因为BudgetTokens 必须大于1024
if textRequest.MaxTokens < 1280 {
textRequest.MaxTokens = 1280
}
// BudgetTokens 为 max_tokens 的 80%
textRequest.Thinking = &dto.Thinking{
Type: "enabled",
BudgetTokens: int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
textRequest.TopP = 0
textRequest.Temperature = common.GetPointer[float64](1.0)
}
textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
relayInfo.UpstreamModelName = textRequest.Model
}
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if common.DebugEnabled {
println("requestBody: ", string(jsonData))
}
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
}
if resp != nil {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return service.OpenAIErrorToClaudeError(openaiErr)
}
}
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
//log.Printf("usage: %v", usage)
if openaiErr != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return service.OpenAIErrorToClaudeError(openaiErr)
}
service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}
func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) {
var promptTokens int
var err error
switch info.RelayMode {
default:
promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName)
}
info.PromptTokens = promptTokens
return promptTokens, err
}

View File

@@ -15,6 +15,32 @@ import (
type ThinkingContentInfo struct { type ThinkingContentInfo struct {
IsFirstThinkingContent bool IsFirstThinkingContent bool
SendLastThinkingContent bool SendLastThinkingContent bool
HasSentThinkingContent bool
}
const (
LastMessageTypeNone = "none"
LastMessageTypeText = "text"
LastMessageTypeTools = "tools"
LastMessageTypeThinking = "thinking"
)
type ClaudeConvertInfo struct {
LastMessagesType string
Index int
Usage *dto.Usage
FinishReason string
Done bool
}
const (
RelayFormatOpenAI = "openai"
RelayFormatClaude = "claude"
)
type RerankerInfo struct {
Documents []any
ReturnDocuments bool
} }
type RelayInfo struct { type RelayInfo struct {
@@ -55,10 +81,15 @@ type RelayInfo struct {
AudioUsage bool AudioUsage bool
ReasoningEffort string ReasoningEffort string
ChannelSetting map[string]interface{} ChannelSetting map[string]interface{}
ParamOverride map[string]interface{}
UserSetting map[string]interface{} UserSetting map[string]interface{}
UserEmail string UserEmail string
UserQuota int UserQuota int
RelayFormat string
SendResponseCount int
ThinkingContentInfo ThinkingContentInfo
*ClaudeConvertInfo
*RerankerInfo
} }
// 定义支持流式选项的通道类型 // 定义支持流式选项的通道类型
@@ -71,6 +102,7 @@ var streamSupportedChannels = map[int]bool{
common.ChannelTypeAzure: true, common.ChannelTypeAzure: true,
common.ChannelTypeVolcEngine: true, common.ChannelTypeVolcEngine: true,
common.ChannelTypeOllama: true, common.ChannelTypeOllama: true,
common.ChannelTypeXai: true,
} }
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
@@ -82,10 +114,31 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
return info return info
} }
func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatClaude
info.ShouldIncludeUsage = false
info.ClaudeConvertInfo = &ClaudeConvertInfo{
LastMessagesType: LastMessageTypeNone,
}
return info
}
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeRerank
info.RerankerInfo = &RerankerInfo{
Documents: req.Documents,
ReturnDocuments: req.GetReturnDocuments(),
}
return info
}
func GenRelayInfo(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel_type") channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
channelSetting := c.GetStringMap("channel_setting") channelSetting := c.GetStringMap("channel_setting")
paramOverride := c.GetStringMap("param_override")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
tokenKey := c.GetString("token_key") tokenKey := c.GetString("token_key")
@@ -123,6 +176,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"), Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting, ChannelSetting: channelSetting,
ParamOverride: paramOverride,
RelayFormat: RelayFormatOpenAI,
ThinkingContentInfo: ThinkingContentInfo{ ThinkingContentInfo: ThinkingContentInfo{
IsFirstThinkingContent: true, IsFirstThinkingContent: true,
SendLastThinkingContent: false, SendLastThinkingContent: false,
@@ -163,6 +218,10 @@ func (info *RelayInfo) SetFirstResponseTime() {
} }
} }
func (info *RelayInfo) HasSendResponse() bool {
return info.FirstResponseTime.After(info.StartTime)
}
type TaskRelayInfo struct { type TaskRelayInfo struct {
*RelayInfo *RelayInfo
Action string Action string

View File

@@ -0,0 +1,68 @@
package common_handler
import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/service"
)
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if common.DebugEnabled {
println("reranker response body: ", string(responseBody))
}
var jinaResp dto.RerankResponse
if info.ChannelType == common.ChannelTypeXinference {
var xinRerankResponse xinference.XinRerankResponse
err = common.DecodeJson(responseBody, &xinRerankResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
for i, result := range xinRerankResponse.Results {
respResult := dto.RerankResponseResult{
Index: result.Index,
RelevanceScore: result.RelevanceScore,
}
if info.ReturnDocuments {
var document any
if result.Document == "" {
document = info.Documents[result.Index]
} else {
document = result.Document
}
respResult.Document = document
}
jinaRespResults[i] = respResult
}
jinaResp = dto.RerankResponse{
Results: jinaRespResults,
Usage: dto.Usage{
PromptTokens: info.PromptTokens,
TotalTokens: info.PromptTokens,
},
}
} else {
err = common.DecodeJson(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
}
c.Writer.Header().Set("Content-Type", "application/json")
c.JSON(http.StatusOK, jinaResp)
return nil, &jinaResp.Usage
}

View File

@@ -31,6 +31,8 @@ const (
APITypeVolcEngine APITypeVolcEngine
APITypeBaiduV2 APITypeBaiduV2
APITypeOpenRouter APITypeOpenRouter
APITypeXinference
APITypeXai
APITypeDummy // this one is only for count, do not add any channel after this APITypeDummy // this one is only for count, do not add any channel after this
) )
@@ -89,6 +91,10 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeBaiduV2 apiType = APITypeBaiduV2
case common.ChannelTypeOpenRouter: case common.ChannelTypeOpenRouter:
apiType = APITypeOpenRouter apiType = APITypeOpenRouter
case common.ChannelTypeXinference:
apiType = APITypeXinference
case common.ChannelTypeXai:
apiType = APITypeXai
} }
if apiType == -1 { if apiType == -1 {
return APITypeOpenAI, false return APITypeOpenAI, false

View File

@@ -19,6 +19,30 @@ func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("X-Accel-Buffering", "no") c.Writer.Header().Set("X-Accel-Buffering", "no")
} }
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
jsonData, err := json.Marshal(resp)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
} else {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
}
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
func StringData(c *gin.Context, str string) error { func StringData(c *gin.Context, str string) error {
//str = strings.TrimPrefix(str, "data: ") //str = strings.TrimPrefix(str, "data: ")
//str = strings.TrimSuffix(str, "\r") //str = strings.TrimSuffix(str, "\r")
@@ -31,7 +55,20 @@ func StringData(c *gin.Context, str string) error {
return nil return nil
} }
func PingData(c *gin.Context) error {
c.Writer.Write([]byte(": PING\n\n"))
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ObjectData(c *gin.Context, object interface{}) error { func ObjectData(c *gin.Context, object interface{}) error {
if object == nil {
return errors.New("object is nil")
}
jsonData, err := json.Marshal(object) jsonData, err := json.Marshal(object)
if err != nil { if err != nil {
return fmt.Errorf("error marshalling object: %w", err) return fmt.Errorf("error marshalling object: %w", err)

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
constant2 "one-api/constant"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/setting" "one-api/setting"
"one-api/setting/operation_setting" "one-api/setting/operation_setting"
@@ -16,9 +17,14 @@ type PriceData struct {
CacheRatio float64 CacheRatio float64
GroupRatio float64 GroupRatio float64
UsePrice bool UsePrice bool
CacheCreationRatio float64
ShouldPreConsumedQuota int ShouldPreConsumedQuota int
} }
func (p PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota)
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false) modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
groupRatio := setting.GetGroupRatio(info.Group) groupRatio := setting.GetGroupRatio(info.Group)
@@ -26,6 +32,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var modelRatio float64 var modelRatio float64
var completionRatio float64 var completionRatio float64
var cacheRatio float64 var cacheRatio float64
var cacheCreationRatio float64
if !usePrice { if !usePrice {
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
if maxTokens != 0 { if maxTokens != 0 {
@@ -34,26 +41,52 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var success bool var success bool
modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName) modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
if !success { if !success {
if info.UserId == 1 { acceptUnsetRatio := false
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置请设置或开始自用模式Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName) if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
} else { b, ok := accept.(bool)
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName) if ok {
acceptUnsetRatio = b
}
}
if !acceptUnsetRatio {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置请联系管理员设置或开始自用模式Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
} }
} }
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName) completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName) cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio) preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else { } else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
} }
return PriceData{
priceData := PriceData{
ModelPrice: modelPrice, ModelPrice: modelPrice,
ModelRatio: modelRatio, ModelRatio: modelRatio,
CompletionRatio: completionRatio, CompletionRatio: completionRatio,
GroupRatio: groupRatio, GroupRatio: groupRatio,
UsePrice: usePrice, UsePrice: usePrice,
CacheRatio: cacheRatio, CacheRatio: cacheRatio,
CacheCreationRatio: cacheCreationRatio,
ShouldPreConsumedQuota: preConsumedQuota, ShouldPreConsumedQuota: preConsumedQuota,
}, nil }
if common.DebugEnabled {
println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting()))
}
return priceData, nil
}
func ContainPriceOrRatio(modelName string) bool {
_, ok := operation_setting.GetModelPrice(modelName, false)
if ok {
return true
}
_, ok = operation_setting.GetModelRatio(modelName)
if ok {
return true
}
return false
} }

Some files were not shown because too many files have changed in this diff Show More