Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
282dcf05f0 | ||
|
|
f2b1fc0ace | ||
|
|
13b95049c3 | ||
|
|
d274c8cb14 | ||
|
|
e27c1acf79 | ||
|
|
f7a5cee262 | ||
|
|
106d8d8e57 | ||
|
|
b57334b82c | ||
|
|
e4db851b31 |
420
GIT_GUIDE.md
Normal file
420
GIT_GUIDE.md
Normal file
@@ -0,0 +1,420 @@
|
||||
# Sub2API 双 Remote Git 配置指南
|
||||
|
||||
## 📋 Git 仓库配置
|
||||
|
||||
### Remote 配置结构
|
||||
|
||||
```
|
||||
upstream (官方仓库)
|
||||
└── https://github.com/Wei-Shaw/sub2api.git
|
||||
用途: 拉取官方更新
|
||||
|
||||
origin (你的仓库)
|
||||
└── https://git.586vip.cn/oadmin/sub2api.git
|
||||
用途: 保存你的二次开发代码
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🚀 快速部署(一键完成)
|
||||
|
||||
### 方式 1:使用自动化脚本(推荐)
|
||||
|
||||
**在服务器上执行以下命令:**
|
||||
|
||||
```bash
|
||||
# 下载部署脚本
|
||||
curl -o /tmp/deploy-sub2api.sh https://你的脚本地址/deploy-complete.sh
|
||||
|
||||
# 或者直接创建脚本
|
||||
cat > /tmp/deploy-sub2api.sh << 'SCRIPT_END'
|
||||
# [这里粘贴 deploy-complete.sh 的全部内容]
|
||||
SCRIPT_END
|
||||
|
||||
# 赋予执行权限
|
||||
chmod +x /tmp/deploy-sub2api.sh
|
||||
|
||||
# 运行脚本
|
||||
bash /tmp/deploy-sub2api.sh
|
||||
```
|
||||
|
||||
脚本会自动完成:
|
||||
- ✅ 克隆官方仓库
|
||||
- ✅ 配置双 remote
|
||||
- ✅ 推送到你的仓库
|
||||
- ✅ 创建部署配置
|
||||
- ✅ 启动 Docker 服务
|
||||
|
||||
---
|
||||
|
||||
### 方式 2:手动分步执行
|
||||
|
||||
如果自动脚本有问题,可以手动执行:
|
||||
|
||||
```bash
|
||||
# 1. 克隆官方仓库
|
||||
cd /opt
|
||||
git clone https://github.com/Wei-Shaw/sub2api.git sub2api-dev
|
||||
cd sub2api-dev
|
||||
|
||||
# 2. 配置 Git Remote
|
||||
git remote rename origin upstream
|
||||
git remote add origin https://git.586vip.cn/oadmin/sub2api.git
|
||||
|
||||
# 3. 查看配置
|
||||
git remote -v
|
||||
# 应该看到:
|
||||
# origin https://git.586vip.cn/oadmin/sub2api.git (fetch)
|
||||
# origin https://git.586vip.cn/oadmin/sub2api.git (push)
|
||||
# upstream https://github.com/Wei-Shaw/sub2api.git (fetch)
|
||||
# upstream https://github.com/Wei-Shaw/sub2api.git (push)
|
||||
|
||||
# 4. 创建 main 分支
|
||||
git checkout -b main
|
||||
|
||||
# 5. 推送到你的仓库
|
||||
git push -u origin main
|
||||
|
||||
# 6. 配置部署文件
|
||||
cd deploy
|
||||
```
|
||||
|
||||
创建 `docker-compose.prod.yml`:
|
||||
|
||||
```bash
|
||||
cat > docker-compose.prod.yml << 'EOF'
|
||||
services:
|
||||
sub2api:
|
||||
image: weishaw/sub2api:latest
|
||||
container_name: sub2api
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "2080:8080"
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
environment:
|
||||
- AUTO_SETUP=true
|
||||
- SERVER_HOST=0.0.0.0
|
||||
- SERVER_PORT=8080
|
||||
- SERVER_MODE=release
|
||||
- DATABASE_HOST=postgres
|
||||
- DATABASE_PORT=5432
|
||||
- DATABASE_USER=sub2api
|
||||
- DATABASE_PASSWORD=${POSTGRES_PASSWORD}
|
||||
- DATABASE_DBNAME=sub2api
|
||||
- DATABASE_SSLMODE=disable
|
||||
- REDIS_HOST=host.docker.internal
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=redis_bJFKDk
|
||||
- REDIS_DB=1
|
||||
- ADMIN_EMAIL=${ADMIN_EMAIL}
|
||||
- ADMIN_PASSWORD=${ADMIN_PASSWORD}
|
||||
- JWT_SECRET=${JWT_SECRET}
|
||||
- TZ=Asia/Shanghai
|
||||
depends_on:
|
||||
- postgres
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
|
||||
postgres:
|
||||
image: postgres:18-alpine
|
||||
container_name: sub2api-postgres
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
environment:
|
||||
- POSTGRES_USER=sub2api
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||
- POSTGRES_DB=sub2api
|
||||
- TZ=Asia/Shanghai
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U sub2api"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
EOF
|
||||
```
|
||||
|
||||
创建 `.env` 配置:
|
||||
|
||||
```bash
|
||||
cat > .env << EOF
|
||||
POSTGRES_USER=sub2api
|
||||
POSTGRES_PASSWORD=$(openssl rand -base64 24)
|
||||
POSTGRES_DB=sub2api
|
||||
ADMIN_EMAIL=admin@example.com
|
||||
ADMIN_PASSWORD=admin123
|
||||
JWT_SECRET=$(openssl rand -hex 32)
|
||||
JWT_EXPIRE_HOUR=24
|
||||
TZ=Asia/Shanghai
|
||||
SERVER_MODE=release
|
||||
RUN_MODE=standard
|
||||
EOF
|
||||
|
||||
# 显示生成的密码
|
||||
echo "===== 配置信息 ====="
|
||||
cat .env
|
||||
echo "===================="
|
||||
```
|
||||
|
||||
启动服务:
|
||||
|
||||
```bash
|
||||
# 拉取镜像
|
||||
docker-compose -f docker-compose.prod.yml pull
|
||||
|
||||
# 启动服务
|
||||
docker-compose -f docker-compose.prod.yml up -d
|
||||
|
||||
# 查看日志
|
||||
docker-compose -f docker-compose.prod.yml logs -f
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔄 日常工作流程
|
||||
|
||||
### 开发流程
|
||||
|
||||
```bash
|
||||
cd /opt/sub2api-dev
|
||||
|
||||
# 1. 创建功能分支
|
||||
git checkout -b feature/new-feature
|
||||
|
||||
# 2. 开发并提交
|
||||
git add .
|
||||
git commit -m "feat: 添加新功能"
|
||||
|
||||
# 3. 切回主分支
|
||||
git checkout main
|
||||
|
||||
# 4. 合并功能分支
|
||||
git merge feature/new-feature
|
||||
|
||||
# 5. 推送到你的仓库
|
||||
git push origin main
|
||||
|
||||
# 6. 删除功能分支(可选)
|
||||
git branch -d feature/new-feature
|
||||
```
|
||||
|
||||
### 同步官方更新
|
||||
|
||||
```bash
|
||||
cd /opt/sub2api-dev
|
||||
|
||||
# 1. 查看官方更新
|
||||
git fetch upstream
|
||||
git log HEAD..upstream/main --oneline
|
||||
|
||||
# 2. 查看详细差异
|
||||
git diff HEAD..upstream/main
|
||||
|
||||
# 3. 合并官方更新
|
||||
git merge upstream/main
|
||||
|
||||
# 4. 如果有冲突,解决后提交
|
||||
git add .
|
||||
git commit -m "merge: 合并官方更新 v1.x.x"
|
||||
|
||||
# 5. 推送到你的仓库
|
||||
git push origin main
|
||||
|
||||
# 6. 重新部署(如果需要)
|
||||
cd deploy
|
||||
docker-compose -f docker-compose.prod.yml pull
|
||||
docker-compose -f docker-compose.prod.yml up -d
|
||||
```
|
||||
|
||||
### 查看分支和远程信息
|
||||
|
||||
```bash
|
||||
# 查看所有分支
|
||||
git branch -a
|
||||
|
||||
# 查看远程仓库
|
||||
git remote -v
|
||||
|
||||
# 查看当前状态
|
||||
git status
|
||||
|
||||
# 查看提交历史
|
||||
git log --oneline --graph --all -10
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ 常见操作
|
||||
|
||||
### 回滚到某个版本
|
||||
|
||||
```bash
|
||||
# 查看提交历史
|
||||
git log --oneline
|
||||
|
||||
# 回滚到指定提交(软回滚,保留修改)
|
||||
git reset --soft <commit-hash>
|
||||
|
||||
# 回滚到指定提交(硬回滚,丢弃修改)
|
||||
git reset --hard <commit-hash>
|
||||
|
||||
# 推送到远程(需要强制推送)
|
||||
git push -f origin main
|
||||
```
|
||||
|
||||
### 对比官方版本
|
||||
|
||||
```bash
|
||||
# 对比特定文件
|
||||
git diff upstream/main -- backend/internal/service/gateway_service.go
|
||||
|
||||
# 对比整个目录
|
||||
git diff upstream/main -- backend/internal/service/
|
||||
|
||||
# 生成 patch 文件
|
||||
git diff upstream/main > my-changes.patch
|
||||
|
||||
# 查看改动的文件列表
|
||||
git diff --name-only upstream/main
|
||||
```
|
||||
|
||||
### 从官方仓库拉取特定分支/标签
|
||||
|
||||
```bash
|
||||
# 拉取官方的所有标签
|
||||
git fetch upstream --tags
|
||||
|
||||
# 查看所有标签
|
||||
git tag -l
|
||||
|
||||
# 基于某个标签创建分支
|
||||
git checkout -b v1.0.0 tags/v1.0.0
|
||||
|
||||
# 切回主分支
|
||||
git checkout main
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📦 Docker 管理
|
||||
|
||||
### 服务管理
|
||||
|
||||
```bash
|
||||
cd /opt/sub2api-dev/deploy
|
||||
|
||||
# 查看状态
|
||||
docker-compose -f docker-compose.prod.yml ps
|
||||
|
||||
# 查看日志
|
||||
docker-compose -f docker-compose.prod.yml logs -f
|
||||
|
||||
# 只看 sub2api 日志
|
||||
docker-compose -f docker-compose.prod.yml logs -f sub2api
|
||||
|
||||
# 重启服务
|
||||
docker-compose -f docker-compose.prod.yml restart
|
||||
|
||||
# 停止服务
|
||||
docker-compose -f docker-compose.prod.yml down
|
||||
|
||||
# 完全清理(包括数据)
|
||||
docker-compose -f docker-compose.prod.yml down -v
|
||||
```
|
||||
|
||||
### 更新镜像
|
||||
|
||||
```bash
|
||||
# 拉取最新镜像
|
||||
docker-compose -f docker-compose.prod.yml pull
|
||||
|
||||
# 重新创建容器
|
||||
docker-compose -f docker-compose.prod.yml up -d
|
||||
|
||||
# 查看镜像信息
|
||||
docker images | grep sub2api
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ 注意事项
|
||||
|
||||
1. **定期备份**
|
||||
- 备份 `.env` 文件
|
||||
- 导出 PostgreSQL 数据库
|
||||
- 提交代码到远程仓库
|
||||
|
||||
2. **合并冲突处理**
|
||||
- 遇到冲突时,优先保留官方核心逻辑
|
||||
- 调整你的二开代码以适配官方更新
|
||||
|
||||
3. **测试环境**
|
||||
- 建议先在测试环境验证更新
|
||||
- 确认无误后再应用到生产环境
|
||||
|
||||
4. **版本管理**
|
||||
- 重要功能单独开分支
|
||||
- 使用有意义的提交信息
|
||||
- 定期推送到远程仓库
|
||||
|
||||
---
|
||||
|
||||
## 🔍 故障排查
|
||||
|
||||
### Git 问题
|
||||
|
||||
```bash
|
||||
# 如果 push 被拒绝
|
||||
git pull origin main --rebase
|
||||
git push origin main
|
||||
|
||||
# 如果需要强制推送(危险!)
|
||||
git push -f origin main
|
||||
|
||||
# 查看 Git 配置
|
||||
git config --list
|
||||
|
||||
# 重置 remote
|
||||
git remote remove origin
|
||||
git remote add origin https://git.586vip.cn/oadmin/sub2api.git
|
||||
```
|
||||
|
||||
### Docker 问题
|
||||
|
||||
```bash
|
||||
# Redis 连接失败
|
||||
# 修改 docker-compose.prod.yml 中的 REDIS_HOST
|
||||
# 从 host.docker.internal 改为 172.17.0.1
|
||||
|
||||
# 查看容器详细信息
|
||||
docker inspect sub2api
|
||||
|
||||
# 进入容器调试
|
||||
docker exec -it sub2api sh
|
||||
|
||||
# 查看网络
|
||||
docker network ls
|
||||
docker network inspect bridge
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📞 快速命令速查
|
||||
|
||||
```bash
|
||||
# Git
|
||||
git status # 状态
|
||||
git fetch upstream # 拉取官方更新
|
||||
git merge upstream/main # 合并更新
|
||||
git push origin main # 推送到你的仓库
|
||||
|
||||
# Docker
|
||||
docker-compose -f docker-compose.prod.yml ps # 状态
|
||||
docker-compose -f docker-compose.prod.yml logs -f # 日志
|
||||
docker-compose -f docker-compose.prod.yml restart # 重启
|
||||
docker-compose -f docker-compose.prod.yml down # 停止
|
||||
```
|
||||
276
GIT_WORKFLOW.md
Normal file
276
GIT_WORKFLOW.md
Normal file
@@ -0,0 +1,276 @@
|
||||
# Sub2API 二次开发指南
|
||||
|
||||
## 🔧 Git 工作流程
|
||||
|
||||
### 初始设置
|
||||
|
||||
```bash
|
||||
# 1. 克隆官方仓库
|
||||
git clone https://github.com/Wei-Shaw/sub2api.git sub2api-dev
|
||||
cd sub2api-dev
|
||||
|
||||
# 2. 配置远程仓库
|
||||
git remote rename origin upstream # 官方仓库改名为 upstream
|
||||
git remote add origin https://your-git.com/your-repo.git # 添加你的仓库
|
||||
|
||||
# 3. 创建开发分支
|
||||
git checkout -b dev
|
||||
|
||||
# 4. 推送到你的仓库
|
||||
git push -u origin dev
|
||||
```
|
||||
|
||||
### 远程仓库配置结果
|
||||
|
||||
```bash
|
||||
git remote -v
|
||||
# upstream https://github.com/Wei-Shaw/sub2api.git (fetch)
|
||||
# upstream https://github.com/Wei-Shaw/sub2api.git (push)
|
||||
# origin https://your-git.com/your-repo.git (fetch)
|
||||
# origin https://your-git.com/your-repo.git (push)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔄 同步官方更新
|
||||
|
||||
### 方式 1:查看官方更新
|
||||
|
||||
```bash
|
||||
# 拉取官方最新代码
|
||||
git fetch upstream
|
||||
|
||||
# 查看官方更新内容
|
||||
git log HEAD..upstream/main --oneline
|
||||
|
||||
# 查看详细差异
|
||||
git diff HEAD..upstream/main
|
||||
|
||||
# 查看某个文件的差异
|
||||
git diff HEAD..upstream/main -- backend/internal/service/gateway_service.go
|
||||
```
|
||||
|
||||
### 方式 2:合并官方更新
|
||||
|
||||
```bash
|
||||
# 确保工作区干净
|
||||
git status
|
||||
|
||||
# 提交你的修改
|
||||
git add .
|
||||
git commit -m "feat: 我的二开功能"
|
||||
|
||||
# 拉取并合并官方更新
|
||||
git fetch upstream
|
||||
git merge upstream/main
|
||||
|
||||
# 如果有冲突,解决后:
|
||||
git add .
|
||||
git commit -m "merge: 合并官方更新"
|
||||
|
||||
# 推送到你的仓库
|
||||
git push origin dev
|
||||
```
|
||||
|
||||
### 方式 3:使用 Rebase(更清晰的历史)
|
||||
|
||||
```bash
|
||||
# 拉取官方更新
|
||||
git fetch upstream
|
||||
|
||||
# 将你的提交重放到官方最新代码之上
|
||||
git rebase upstream/main
|
||||
|
||||
# 如果有冲突,解决后:
|
||||
git add .
|
||||
git rebase --continue
|
||||
|
||||
# 强制推送到你的仓库(注意:rebase 会改变历史)
|
||||
git push -f origin dev
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📦 部署流程
|
||||
|
||||
### 生产环境部署
|
||||
|
||||
```bash
|
||||
cd /opt/sub2api-dev/deploy
|
||||
|
||||
# 运行自动部署脚本
|
||||
chmod +x deploy-dev.sh
|
||||
bash deploy-dev.sh
|
||||
```
|
||||
|
||||
### 手动部署
|
||||
|
||||
```bash
|
||||
# 1. 进入部署目录
|
||||
cd /opt/sub2api-dev/deploy
|
||||
|
||||
# 2. 创建配置(如果没有)
|
||||
cp docker-compose.yml docker-compose.prod.yml
|
||||
cp .env.example .env
|
||||
|
||||
# 3. 修改配置
|
||||
nano docker-compose.prod.yml # 改端口为 2080,配置 Redis
|
||||
nano .env # 设置密码等
|
||||
|
||||
# 4. 启动服务
|
||||
docker-compose -f docker-compose.prod.yml pull
|
||||
docker-compose -f docker-compose.prod.yml up -d
|
||||
|
||||
# 5. 查看日志
|
||||
docker-compose -f docker-compose.prod.yml logs -f
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ 开发建议
|
||||
|
||||
### 目录结构
|
||||
|
||||
```
|
||||
sub2api-dev/
|
||||
├── .git/ # Git 仓库
|
||||
├── backend/ # 后端代码(Go)
|
||||
│ ├── cmd/server/ # 入口
|
||||
│ ├── internal/ # 核心逻辑
|
||||
│ └── ent/ # 数据模型
|
||||
├── frontend/ # 前端代码(Vue)
|
||||
│ └── src/
|
||||
└── deploy/ # 部署配置
|
||||
├── docker-compose.prod.yml
|
||||
├── .env
|
||||
└── deploy-dev.sh
|
||||
```
|
||||
|
||||
### 常见二开场景
|
||||
|
||||
#### 1. 修改后端逻辑
|
||||
|
||||
```bash
|
||||
# 修改代码
|
||||
vim backend/internal/service/gateway_service.go
|
||||
|
||||
# 提交
|
||||
git add backend/
|
||||
git commit -m "feat: 修改网关逻辑"
|
||||
|
||||
# 重新构建镜像(如果需要)
|
||||
cd backend
|
||||
docker build -t sub2api:custom .
|
||||
|
||||
# 修改 docker-compose.prod.yml 使用自定义镜像
|
||||
# image: sub2api:custom
|
||||
|
||||
# 重启服务
|
||||
cd ../deploy
|
||||
docker-compose -f docker-compose.prod.yml up -d
|
||||
```
|
||||
|
||||
#### 2. 修改前端界面
|
||||
|
||||
```bash
|
||||
# 修改代码
|
||||
vim frontend/src/views/HomeView.vue
|
||||
|
||||
# 本地测试
|
||||
cd frontend
|
||||
npm install
|
||||
npm run dev
|
||||
|
||||
# 构建
|
||||
npm run build
|
||||
|
||||
# 提交
|
||||
git add frontend/
|
||||
git commit -m "feat: 修改首页界面"
|
||||
|
||||
# 后端重新构建(前端会被嵌入)
|
||||
cd ../backend
|
||||
go build -tags embed -o sub2api ./cmd/server
|
||||
```
|
||||
|
||||
#### 3. 添加新功能
|
||||
|
||||
```bash
|
||||
# 创建功能分支
|
||||
git checkout -b feature/new-feature
|
||||
|
||||
# 开发...
|
||||
# 提交
|
||||
git add .
|
||||
git commit -m "feat: 添加新功能"
|
||||
|
||||
# 合并到开发分支
|
||||
git checkout dev
|
||||
git merge feature/new-feature
|
||||
|
||||
# 推送
|
||||
git push origin dev
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔍 对比工具
|
||||
|
||||
### 使用 VS Code 对比
|
||||
|
||||
```bash
|
||||
# 安装 VS Code Remote SSH 插件
|
||||
# 然后在 VS Code 中:
|
||||
# 1. 连接到服务器
|
||||
# 2. 打开 /opt/sub2api-dev
|
||||
# 3. 使用 Git 功能对比差异
|
||||
```
|
||||
|
||||
### 命令行对比
|
||||
|
||||
```bash
|
||||
# 对比某个文件
|
||||
git diff upstream/main -- backend/internal/service/gateway_service.go
|
||||
|
||||
# 对比某个目录
|
||||
git diff upstream/main -- backend/internal/service/
|
||||
|
||||
# 生成对比报告
|
||||
git diff upstream/main > changes.patch
|
||||
|
||||
# 查看修改的文件列表
|
||||
git diff --name-only upstream/main
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ 注意事项
|
||||
|
||||
1. **定期同步**: 建议每周拉取一次官方更新
|
||||
2. **冲突处理**: 遇到冲突时,优先保留官方的核心逻辑,调整你的二开代码
|
||||
3. **分支管理**: 重要功能单独开分支,测试通过后再合并
|
||||
4. **备份数据**: 更新前备份数据库和配置文件
|
||||
5. **测试环境**: 建议先在测试环境验证更新,再应用到生产
|
||||
|
||||
---
|
||||
|
||||
## 📋 常用命令速查
|
||||
|
||||
```bash
|
||||
# Git 管理
|
||||
git status # 查看状态
|
||||
git log --oneline -10 # 查看提交历史
|
||||
git fetch upstream # 拉取官方更新
|
||||
git diff upstream/main # 对比差异
|
||||
git merge upstream/main # 合并更新
|
||||
|
||||
# Docker 管理
|
||||
docker-compose -f docker-compose.prod.yml ps # 查看状态
|
||||
docker-compose -f docker-compose.prod.yml logs -f # 查看日志
|
||||
docker-compose -f docker-compose.prod.yml restart # 重启
|
||||
docker-compose -f docker-compose.prod.yml down # 停止
|
||||
|
||||
# 服务管理
|
||||
docker-compose -f docker-compose.prod.yml pull # 更新镜像
|
||||
docker-compose -f docker-compose.prod.yml up -d # 启动
|
||||
```
|
||||
@@ -1 +1 @@
|
||||
0.1.1
|
||||
0.1.1
|
||||
|
||||
@@ -1,154 +1,154 @@
|
||||
package main
|
||||
|
||||
//go:generate go run github.com/google/wire/cmd/wire
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"flag"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/setup"
|
||||
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
//go:embed VERSION
|
||||
var embeddedVersion string
|
||||
|
||||
// Build-time variables (can be set by ldflags)
|
||||
var (
|
||||
Version = ""
|
||||
Commit = "unknown"
|
||||
Date = "unknown"
|
||||
BuildType = "source" // "source" for manual builds, "release" for CI builds (set by ldflags)
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Read version from embedded VERSION file
|
||||
Version = strings.TrimSpace(embeddedVersion)
|
||||
if Version == "" {
|
||||
Version = "0.0.0-dev"
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Parse command line flags
|
||||
setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode")
|
||||
showVersion := flag.Bool("version", false, "Show version information")
|
||||
flag.Parse()
|
||||
|
||||
if *showVersion {
|
||||
log.Printf("Sub2API %s (commit: %s, built: %s)\n", Version, Commit, Date)
|
||||
return
|
||||
}
|
||||
|
||||
// CLI setup mode
|
||||
if *setupMode {
|
||||
if err := setup.RunCLI(); err != nil {
|
||||
log.Fatalf("Setup failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check if setup is needed
|
||||
if setup.NeedsSetup() {
|
||||
// Check if auto-setup is enabled (for Docker deployment)
|
||||
if setup.AutoSetupEnabled() {
|
||||
log.Println("Auto setup mode enabled...")
|
||||
if err := setup.AutoSetupFromEnv(); err != nil {
|
||||
log.Fatalf("Auto setup failed: %v", err)
|
||||
}
|
||||
// Continue to main server after auto-setup
|
||||
} else {
|
||||
log.Println("First run detected, starting setup wizard...")
|
||||
runSetupServer()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Normal server mode
|
||||
runMainServer()
|
||||
}
|
||||
|
||||
func runSetupServer() {
|
||||
r := gin.New()
|
||||
r.Use(middleware.Recovery())
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
// Register setup routes
|
||||
setup.RegisterRoutes(r)
|
||||
|
||||
// Serve embedded frontend if available
|
||||
if web.HasEmbeddedFrontend() {
|
||||
r.Use(web.ServeEmbeddedFrontend())
|
||||
}
|
||||
|
||||
// Get server address from config.yaml or environment variables (SERVER_HOST, SERVER_PORT)
|
||||
// This allows users to run setup on a different address if needed
|
||||
addr := config.GetServerAddress()
|
||||
log.Printf("Setup wizard available at http://%s", addr)
|
||||
log.Println("Complete the setup wizard to configure Sub2API")
|
||||
|
||||
if err := r.Run(addr); err != nil {
|
||||
log.Fatalf("Failed to start setup server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func runMainServer() {
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
log.Println("⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED")
|
||||
}
|
||||
|
||||
buildInfo := handler.BuildInfo{
|
||||
Version: Version,
|
||||
BuildType: BuildType,
|
||||
}
|
||||
|
||||
app, err := initializeApplication(buildInfo)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize application: %v", err)
|
||||
}
|
||||
defer app.Cleanup()
|
||||
|
||||
// 启动服务器
|
||||
go func() {
|
||||
if err := app.Server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("Server started on %s", app.Server.Addr)
|
||||
|
||||
// 等待中断信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
log.Println("Shutting down server...")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := app.Server.Shutdown(ctx); err != nil {
|
||||
log.Fatalf("Server forced to shutdown: %v", err)
|
||||
}
|
||||
|
||||
log.Println("Server exited")
|
||||
}
|
||||
package main
|
||||
|
||||
//go:generate go run github.com/google/wire/cmd/wire
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"flag"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/setup"
|
||||
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
//go:embed VERSION
|
||||
var embeddedVersion string
|
||||
|
||||
// Build-time variables (can be set by ldflags)
|
||||
var (
|
||||
Version = ""
|
||||
Commit = "unknown"
|
||||
Date = "unknown"
|
||||
BuildType = "source" // "source" for manual builds, "release" for CI builds (set by ldflags)
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Read version from embedded VERSION file
|
||||
Version = strings.TrimSpace(embeddedVersion)
|
||||
if Version == "" {
|
||||
Version = "0.0.0-dev"
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Parse command line flags
|
||||
setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode")
|
||||
showVersion := flag.Bool("version", false, "Show version information")
|
||||
flag.Parse()
|
||||
|
||||
if *showVersion {
|
||||
log.Printf("Sub2API %s (commit: %s, built: %s)\n", Version, Commit, Date)
|
||||
return
|
||||
}
|
||||
|
||||
// CLI setup mode
|
||||
if *setupMode {
|
||||
if err := setup.RunCLI(); err != nil {
|
||||
log.Fatalf("Setup failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check if setup is needed
|
||||
if setup.NeedsSetup() {
|
||||
// Check if auto-setup is enabled (for Docker deployment)
|
||||
if setup.AutoSetupEnabled() {
|
||||
log.Println("Auto setup mode enabled...")
|
||||
if err := setup.AutoSetupFromEnv(); err != nil {
|
||||
log.Fatalf("Auto setup failed: %v", err)
|
||||
}
|
||||
// Continue to main server after auto-setup
|
||||
} else {
|
||||
log.Println("First run detected, starting setup wizard...")
|
||||
runSetupServer()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Normal server mode
|
||||
runMainServer()
|
||||
}
|
||||
|
||||
func runSetupServer() {
|
||||
r := gin.New()
|
||||
r.Use(middleware.Recovery())
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
// Register setup routes
|
||||
setup.RegisterRoutes(r)
|
||||
|
||||
// Serve embedded frontend if available
|
||||
if web.HasEmbeddedFrontend() {
|
||||
r.Use(web.ServeEmbeddedFrontend())
|
||||
}
|
||||
|
||||
// Get server address from config.yaml or environment variables (SERVER_HOST, SERVER_PORT)
|
||||
// This allows users to run setup on a different address if needed
|
||||
addr := config.GetServerAddress()
|
||||
log.Printf("Setup wizard available at http://%s", addr)
|
||||
log.Println("Complete the setup wizard to configure Sub2API")
|
||||
|
||||
if err := r.Run(addr); err != nil {
|
||||
log.Fatalf("Failed to start setup server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func runMainServer() {
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
log.Println("⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED")
|
||||
}
|
||||
|
||||
buildInfo := handler.BuildInfo{
|
||||
Version: Version,
|
||||
BuildType: BuildType,
|
||||
}
|
||||
|
||||
app, err := initializeApplication(buildInfo)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize application: %v", err)
|
||||
}
|
||||
defer app.Cleanup()
|
||||
|
||||
// 启动服务器
|
||||
go func() {
|
||||
if err := app.Server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("Server started on %s", app.Server.Addr)
|
||||
|
||||
// 等待中断信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
log.Println("Shutting down server...")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := app.Server.Shutdown(ctx); err != nil {
|
||||
log.Fatalf("Server forced to shutdown: %v", err)
|
||||
}
|
||||
|
||||
log.Println("Server exited")
|
||||
}
|
||||
|
||||
@@ -1,140 +1,140 @@
|
||||
//go:build wireinject
|
||||
// +build wireinject
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type Application struct {
|
||||
Server *http.Server
|
||||
Cleanup func()
|
||||
}
|
||||
|
||||
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
wire.Build(
|
||||
// Infrastructure layer ProviderSets
|
||||
config.ProviderSet,
|
||||
|
||||
// Business layer ProviderSets
|
||||
repository.ProviderSet,
|
||||
service.ProviderSet,
|
||||
middleware.ProviderSet,
|
||||
handler.ProviderSet,
|
||||
|
||||
// Server layer ProviderSet
|
||||
server.ProviderSet,
|
||||
|
||||
// BuildInfo provider
|
||||
provideServiceBuildInfo,
|
||||
|
||||
// Cleanup function provider
|
||||
provideCleanup,
|
||||
|
||||
// Application struct
|
||||
wire.Struct(new(Application), "Server", "Cleanup"),
|
||||
)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
BuildType: buildInfo.BuildType,
|
||||
}
|
||||
}
|
||||
|
||||
func provideCleanup(
|
||||
entClient *ent.Client,
|
||||
rdb *redis.Client,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
billingCache *service.BillingCacheService,
|
||||
oauth *service.OAuthService,
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Cleanup steps in reverse dependency order
|
||||
cleanupSteps := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{"TokenRefreshService", func() error {
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
pricing.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"EmailQueueService", func() error {
|
||||
emailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"BillingCacheService", func() error {
|
||||
billingCache.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OAuthService", func() error {
|
||||
oauth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIOAuthService", func() error {
|
||||
openaiOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"GeminiOAuthService", func() error {
|
||||
geminiOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"AntigravityOAuthService", func() error {
|
||||
antigravityOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
return rdb.Close()
|
||||
}},
|
||||
{"Ent", func() error {
|
||||
return entClient.Close()
|
||||
}},
|
||||
}
|
||||
|
||||
for _, step := range cleanupSteps {
|
||||
if err := step.fn(); err != nil {
|
||||
log.Printf("[Cleanup] %s failed: %v", step.name, err)
|
||||
// Continue with remaining cleanup steps even if one fails
|
||||
} else {
|
||||
log.Printf("[Cleanup] %s succeeded", step.name)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if context timed out
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
|
||||
default:
|
||||
log.Printf("[Cleanup] All cleanup steps completed")
|
||||
}
|
||||
}
|
||||
}
|
||||
//go:build wireinject
|
||||
// +build wireinject
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type Application struct {
|
||||
Server *http.Server
|
||||
Cleanup func()
|
||||
}
|
||||
|
||||
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
wire.Build(
|
||||
// Infrastructure layer ProviderSets
|
||||
config.ProviderSet,
|
||||
|
||||
// Business layer ProviderSets
|
||||
repository.ProviderSet,
|
||||
service.ProviderSet,
|
||||
middleware.ProviderSet,
|
||||
handler.ProviderSet,
|
||||
|
||||
// Server layer ProviderSet
|
||||
server.ProviderSet,
|
||||
|
||||
// BuildInfo provider
|
||||
provideServiceBuildInfo,
|
||||
|
||||
// Cleanup function provider
|
||||
provideCleanup,
|
||||
|
||||
// Application struct
|
||||
wire.Struct(new(Application), "Server", "Cleanup"),
|
||||
)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
BuildType: buildInfo.BuildType,
|
||||
}
|
||||
}
|
||||
|
||||
func provideCleanup(
|
||||
entClient *ent.Client,
|
||||
rdb *redis.Client,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
billingCache *service.BillingCacheService,
|
||||
oauth *service.OAuthService,
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Cleanup steps in reverse dependency order
|
||||
cleanupSteps := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{"TokenRefreshService", func() error {
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
pricing.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"EmailQueueService", func() error {
|
||||
emailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"BillingCacheService", func() error {
|
||||
billingCache.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OAuthService", func() error {
|
||||
oauth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIOAuthService", func() error {
|
||||
openaiOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"GeminiOAuthService", func() error {
|
||||
geminiOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"AntigravityOAuthService", func() error {
|
||||
antigravityOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
return rdb.Close()
|
||||
}},
|
||||
{"Ent", func() error {
|
||||
return entClient.Close()
|
||||
}},
|
||||
}
|
||||
|
||||
for _, step := range cleanupSteps {
|
||||
if err := step.fn(); err != nil {
|
||||
log.Printf("[Cleanup] %s failed: %v", step.name, err)
|
||||
// Continue with remaining cleanup steps even if one fails
|
||||
} else {
|
||||
log.Printf("[Cleanup] %s succeeded", step.name)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if context timed out
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
|
||||
default:
|
||||
log.Printf("[Cleanup] All cleanup steps completed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,70 +1,70 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func TestNormalizeRunMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"simple", "simple"},
|
||||
{"SIMPLE", "simple"},
|
||||
{"standard", "standard"},
|
||||
{"invalid", "standard"},
|
||||
{"", "standard"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := NormalizeRunMode(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("NormalizeRunMode(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
|
||||
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
|
||||
}
|
||||
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second {
|
||||
t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
|
||||
}
|
||||
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
|
||||
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
|
||||
}
|
||||
if cfg.Gateway.Scheduling.FallbackMaxWaiting != 100 {
|
||||
t.Fatalf("FallbackMaxWaiting = %d, want 100", cfg.Gateway.Scheduling.FallbackMaxWaiting)
|
||||
}
|
||||
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
|
||||
t.Fatalf("LoadBatchEnabled = false, want true")
|
||||
}
|
||||
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
|
||||
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
|
||||
viper.Reset()
|
||||
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 5 {
|
||||
t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
|
||||
}
|
||||
}
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func TestNormalizeRunMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"simple", "simple"},
|
||||
{"SIMPLE", "simple"},
|
||||
{"standard", "standard"},
|
||||
{"invalid", "standard"},
|
||||
{"", "standard"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := NormalizeRunMode(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("NormalizeRunMode(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
|
||||
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
|
||||
}
|
||||
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second {
|
||||
t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
|
||||
}
|
||||
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
|
||||
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
|
||||
}
|
||||
if cfg.Gateway.Scheduling.FallbackMaxWaiting != 100 {
|
||||
t.Fatalf("FallbackMaxWaiting = %d, want 100", cfg.Gateway.Scheduling.FallbackMaxWaiting)
|
||||
}
|
||||
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
|
||||
t.Fatalf("LoadBatchEnabled = false, want true")
|
||||
}
|
||||
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
|
||||
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
|
||||
viper.Reset()
|
||||
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 5 {
|
||||
t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package config
|
||||
|
||||
import "github.com/google/wire"
|
||||
|
||||
// ProviderSet 提供配置层的依赖
|
||||
var ProviderSet = wire.NewSet(
|
||||
ProvideConfig,
|
||||
)
|
||||
|
||||
// ProvideConfig 提供应用配置
|
||||
func ProvideConfig() (*Config, error) {
|
||||
return Load()
|
||||
}
|
||||
package config
|
||||
|
||||
import "github.com/google/wire"
|
||||
|
||||
// ProviderSet 提供配置层的依赖
|
||||
var ProviderSet = wire.NewSet(
|
||||
ProvideConfig,
|
||||
)
|
||||
|
||||
// ProvideConfig 提供应用配置
|
||||
func ProvideConfig() (*Config, error) {
|
||||
return Load()
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,67 +1,67 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AntigravityOAuthHandler struct {
|
||||
antigravityOAuthService *service.AntigravityOAuthService
|
||||
}
|
||||
|
||||
func NewAntigravityOAuthHandler(antigravityOAuthService *service.AntigravityOAuthService) *AntigravityOAuthHandler {
|
||||
return &AntigravityOAuthHandler{antigravityOAuthService: antigravityOAuthService}
|
||||
}
|
||||
|
||||
type AntigravityGenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates Google OAuth authorization URL
|
||||
// POST /api/v1/admin/antigravity/oauth/auth-url
|
||||
func (h *AntigravityOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
var req AntigravityGenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.antigravityOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "生成授权链接失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
type AntigravityExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
// POST /api/v1/admin/antigravity/oauth/exchange-code
|
||||
func (h *AntigravityOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
var req AntigravityExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.antigravityOAuthService.ExchangeCode(c.Request.Context(), &service.AntigravityExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
State: req.State,
|
||||
Code: req.Code,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Token 交换失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AntigravityOAuthHandler struct {
|
||||
antigravityOAuthService *service.AntigravityOAuthService
|
||||
}
|
||||
|
||||
func NewAntigravityOAuthHandler(antigravityOAuthService *service.AntigravityOAuthService) *AntigravityOAuthHandler {
|
||||
return &AntigravityOAuthHandler{antigravityOAuthService: antigravityOAuthService}
|
||||
}
|
||||
|
||||
type AntigravityGenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates Google OAuth authorization URL
|
||||
// POST /api/v1/admin/antigravity/oauth/auth-url
|
||||
func (h *AntigravityOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
var req AntigravityGenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.antigravityOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "生成授权链接失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
type AntigravityExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
// POST /api/v1/admin/antigravity/oauth/exchange-code
|
||||
func (h *AntigravityOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
var req AntigravityExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.antigravityOAuthService.ExchangeCode(c.Request.Context(), &service.AntigravityExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
State: req.State,
|
||||
Code: req.Code,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Token 交换失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
@@ -1,302 +1,302 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// DashboardHandler handles admin dashboard statistics
|
||||
type DashboardHandler struct {
|
||||
dashboardService *service.DashboardService
|
||||
startTime time.Time // Server start time for uptime calculation
|
||||
}
|
||||
|
||||
// NewDashboardHandler creates a new admin dashboard handler
|
||||
func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler {
|
||||
return &DashboardHandler{
|
||||
dashboardService: dashboardService,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// parseTimeRange parses start_date, end_date query parameters
|
||||
func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
now := timezone.Now()
|
||||
startDate := c.Query("start_date")
|
||||
endDate := c.Query("end_date")
|
||||
|
||||
var startTime, endTime time.Time
|
||||
|
||||
if startDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
|
||||
startTime = t
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
}
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
}
|
||||
|
||||
if endDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
|
||||
endTime = t.Add(24 * time.Hour) // Include the end date
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
}
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
}
|
||||
|
||||
return startTime, endTime
|
||||
}
|
||||
|
||||
// GetStats handles getting dashboard statistics
|
||||
// GET /api/v1/admin/dashboard/stats
|
||||
func (h *DashboardHandler) GetStats(c *gin.Context) {
|
||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate uptime in seconds
|
||||
uptime := int64(time.Since(h.startTime).Seconds())
|
||||
|
||||
response.Success(c, gin.H{
|
||||
// 用户统计
|
||||
"total_users": stats.TotalUsers,
|
||||
"today_new_users": stats.TodayNewUsers,
|
||||
"active_users": stats.ActiveUsers,
|
||||
|
||||
// API Key 统计
|
||||
"total_api_keys": stats.TotalApiKeys,
|
||||
"active_api_keys": stats.ActiveApiKeys,
|
||||
|
||||
// 账户统计
|
||||
"total_accounts": stats.TotalAccounts,
|
||||
"normal_accounts": stats.NormalAccounts,
|
||||
"error_accounts": stats.ErrorAccounts,
|
||||
"ratelimit_accounts": stats.RateLimitAccounts,
|
||||
"overload_accounts": stats.OverloadAccounts,
|
||||
|
||||
// 累计 Token 使用统计
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
"total_output_tokens": stats.TotalOutputTokens,
|
||||
"total_cache_creation_tokens": stats.TotalCacheCreationTokens,
|
||||
"total_cache_read_tokens": stats.TotalCacheReadTokens,
|
||||
"total_tokens": stats.TotalTokens,
|
||||
"total_cost": stats.TotalCost, // 标准计费
|
||||
"total_actual_cost": stats.TotalActualCost, // 实际扣除
|
||||
|
||||
// 今日 Token 使用统计
|
||||
"today_requests": stats.TodayRequests,
|
||||
"today_input_tokens": stats.TodayInputTokens,
|
||||
"today_output_tokens": stats.TodayOutputTokens,
|
||||
"today_cache_creation_tokens": stats.TodayCacheCreationTokens,
|
||||
"today_cache_read_tokens": stats.TodayCacheReadTokens,
|
||||
"today_tokens": stats.TodayTokens,
|
||||
"today_cost": stats.TodayCost, // 今日标准计费
|
||||
"today_actual_cost": stats.TodayActualCost, // 今日实际扣除
|
||||
|
||||
// 系统运行统计
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
"uptime": uptime,
|
||||
|
||||
// 性能指标
|
||||
"rpm": stats.Rpm,
|
||||
"tpm": stats.Tpm,
|
||||
})
|
||||
}
|
||||
|
||||
// GetRealtimeMetrics handles getting real-time system metrics
|
||||
// GET /api/v1/admin/dashboard/realtime
|
||||
func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
||||
// Return mock data for now
|
||||
response.Success(c, gin.H{
|
||||
"active_requests": 0,
|
||||
"requests_per_minute": 0,
|
||||
"average_response_time": 0,
|
||||
"error_rate": 0.0,
|
||||
})
|
||||
}
|
||||
|
||||
// GetUsageTrend handles getting usage trend data
|
||||
// GET /api/v1/admin/dashboard/trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
|
||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = id
|
||||
}
|
||||
}
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
|
||||
apiKeyID = id
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
"granularity": granularity,
|
||||
})
|
||||
}
|
||||
|
||||
// GetModelStats handles getting model usage statistics
|
||||
// GET /api/v1/admin/dashboard/models
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
|
||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = id
|
||||
}
|
||||
}
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
|
||||
apiKeyID = id
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"models": stats,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
})
|
||||
}
|
||||
|
||||
// GetApiKeyUsageTrend handles getting API key usage trend data
|
||||
// GET /api/v1/admin/dashboard/api-keys-trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
|
||||
func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
limitStr := c.DefaultQuery("limit", "5")
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil || limit <= 0 {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
"granularity": granularity,
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserUsageTrend handles getting user usage trend data
|
||||
// GET /api/v1/admin/dashboard/users-trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 12)
|
||||
func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
limitStr := c.DefaultQuery("limit", "12")
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil || limit <= 0 {
|
||||
limit = 12
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
"granularity": granularity,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchUsersUsageRequest represents the request body for batch user usage stats
|
||||
type BatchUsersUsageRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||
// POST /api/v1/admin/dashboard/users-usage
|
||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
var req BatchUsersUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
|
||||
// BatchApiKeysUsageRequest represents the request body for batch api key usage stats
|
||||
type BatchApiKeysUsageRequest struct {
|
||||
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// GetBatchApiKeysUsage handles getting usage stats for multiple API keys
|
||||
// POST /api/v1/admin/dashboard/api-keys-usage
|
||||
func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
|
||||
var req BatchApiKeysUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// DashboardHandler handles admin dashboard statistics
|
||||
type DashboardHandler struct {
|
||||
dashboardService *service.DashboardService
|
||||
startTime time.Time // Server start time for uptime calculation
|
||||
}
|
||||
|
||||
// NewDashboardHandler creates a new admin dashboard handler
|
||||
func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler {
|
||||
return &DashboardHandler{
|
||||
dashboardService: dashboardService,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// parseTimeRange parses start_date, end_date query parameters
|
||||
func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
now := timezone.Now()
|
||||
startDate := c.Query("start_date")
|
||||
endDate := c.Query("end_date")
|
||||
|
||||
var startTime, endTime time.Time
|
||||
|
||||
if startDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
|
||||
startTime = t
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
}
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
}
|
||||
|
||||
if endDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
|
||||
endTime = t.Add(24 * time.Hour) // Include the end date
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
}
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
}
|
||||
|
||||
return startTime, endTime
|
||||
}
|
||||
|
||||
// GetStats handles getting dashboard statistics
|
||||
// GET /api/v1/admin/dashboard/stats
|
||||
func (h *DashboardHandler) GetStats(c *gin.Context) {
|
||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate uptime in seconds
|
||||
uptime := int64(time.Since(h.startTime).Seconds())
|
||||
|
||||
response.Success(c, gin.H{
|
||||
// 用户统计
|
||||
"total_users": stats.TotalUsers,
|
||||
"today_new_users": stats.TodayNewUsers,
|
||||
"active_users": stats.ActiveUsers,
|
||||
|
||||
// API Key 统计
|
||||
"total_api_keys": stats.TotalApiKeys,
|
||||
"active_api_keys": stats.ActiveApiKeys,
|
||||
|
||||
// 账户统计
|
||||
"total_accounts": stats.TotalAccounts,
|
||||
"normal_accounts": stats.NormalAccounts,
|
||||
"error_accounts": stats.ErrorAccounts,
|
||||
"ratelimit_accounts": stats.RateLimitAccounts,
|
||||
"overload_accounts": stats.OverloadAccounts,
|
||||
|
||||
// 累计 Token 使用统计
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
"total_output_tokens": stats.TotalOutputTokens,
|
||||
"total_cache_creation_tokens": stats.TotalCacheCreationTokens,
|
||||
"total_cache_read_tokens": stats.TotalCacheReadTokens,
|
||||
"total_tokens": stats.TotalTokens,
|
||||
"total_cost": stats.TotalCost, // 标准计费
|
||||
"total_actual_cost": stats.TotalActualCost, // 实际扣除
|
||||
|
||||
// 今日 Token 使用统计
|
||||
"today_requests": stats.TodayRequests,
|
||||
"today_input_tokens": stats.TodayInputTokens,
|
||||
"today_output_tokens": stats.TodayOutputTokens,
|
||||
"today_cache_creation_tokens": stats.TodayCacheCreationTokens,
|
||||
"today_cache_read_tokens": stats.TodayCacheReadTokens,
|
||||
"today_tokens": stats.TodayTokens,
|
||||
"today_cost": stats.TodayCost, // 今日标准计费
|
||||
"today_actual_cost": stats.TodayActualCost, // 今日实际扣除
|
||||
|
||||
// 系统运行统计
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
"uptime": uptime,
|
||||
|
||||
// 性能指标
|
||||
"rpm": stats.Rpm,
|
||||
"tpm": stats.Tpm,
|
||||
})
|
||||
}
|
||||
|
||||
// GetRealtimeMetrics handles getting real-time system metrics
|
||||
// GET /api/v1/admin/dashboard/realtime
|
||||
func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
||||
// Return mock data for now
|
||||
response.Success(c, gin.H{
|
||||
"active_requests": 0,
|
||||
"requests_per_minute": 0,
|
||||
"average_response_time": 0,
|
||||
"error_rate": 0.0,
|
||||
})
|
||||
}
|
||||
|
||||
// GetUsageTrend handles getting usage trend data
|
||||
// GET /api/v1/admin/dashboard/trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
|
||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = id
|
||||
}
|
||||
}
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
|
||||
apiKeyID = id
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
"granularity": granularity,
|
||||
})
|
||||
}
|
||||
|
||||
// GetModelStats handles getting model usage statistics
|
||||
// GET /api/v1/admin/dashboard/models
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
|
||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = id
|
||||
}
|
||||
}
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
|
||||
apiKeyID = id
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"models": stats,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
})
|
||||
}
|
||||
|
||||
// GetApiKeyUsageTrend handles getting API key usage trend data
|
||||
// GET /api/v1/admin/dashboard/api-keys-trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
|
||||
func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
limitStr := c.DefaultQuery("limit", "5")
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil || limit <= 0 {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
"granularity": granularity,
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserUsageTrend handles getting user usage trend data
|
||||
// GET /api/v1/admin/dashboard/users-trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 12)
|
||||
func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
limitStr := c.DefaultQuery("limit", "12")
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil || limit <= 0 {
|
||||
limit = 12
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
"granularity": granularity,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchUsersUsageRequest represents the request body for batch user usage stats
|
||||
type BatchUsersUsageRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||
// POST /api/v1/admin/dashboard/users-usage
|
||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
var req BatchUsersUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
|
||||
// BatchApiKeysUsageRequest represents the request body for batch api key usage stats
|
||||
type BatchApiKeysUsageRequest struct {
|
||||
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// GetBatchApiKeysUsage handles getting usage stats for multiple API keys
|
||||
// POST /api/v1/admin/dashboard/api-keys-usage
|
||||
func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
|
||||
var req BatchApiKeysUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
|
||||
@@ -1,135 +1,135 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GeminiOAuthHandler struct {
|
||||
geminiOAuthService *service.GeminiOAuthService
|
||||
}
|
||||
|
||||
func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *GeminiOAuthHandler {
|
||||
return &GeminiOAuthHandler{geminiOAuthService: geminiOAuthService}
|
||||
}
|
||||
|
||||
// GET /api/v1/admin/gemini/oauth/capabilities
|
||||
func (h *GeminiOAuthHandler) GetCapabilities(c *gin.Context) {
|
||||
cfg := h.geminiOAuthService.GetOAuthConfig()
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
type GeminiGenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
ProjectID string `json:"project_id"`
|
||||
// OAuth 类型: "code_assist" (需要 project_id) 或 "ai_studio" (不需要 project_id)
|
||||
// 默认为 "code_assist" 以保持向后兼容
|
||||
OAuthType string `json:"oauth_type"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates Google OAuth authorization URL for Gemini.
|
||||
// POST /api/v1/admin/gemini/oauth/auth-url
|
||||
func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
var req GeminiGenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 默认使用 code_assist 以保持向后兼容
|
||||
oauthType := strings.TrimSpace(req.OAuthType)
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
|
||||
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
|
||||
return
|
||||
}
|
||||
|
||||
// Always pass the "hosted" callback URI; the OAuth service may override it depending on
|
||||
// oauth_type and whether the built-in Gemini CLI OAuth client is used.
|
||||
redirectURI := deriveGeminiRedirectURI(c)
|
||||
result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, redirectURI, req.ProjectID, oauthType)
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
// Treat missing/invalid OAuth client configuration as a user/config error.
|
||||
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
|
||||
response.BadRequest(c, "Failed to generate auth URL: "+msg)
|
||||
return
|
||||
}
|
||||
response.InternalError(c, "Failed to generate auth URL: "+msg)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
type GeminiExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
// OAuth 类型: "code_assist" 或 "ai_studio",需要与 GenerateAuthURL 时的类型一致
|
||||
OAuthType string `json:"oauth_type"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens.
|
||||
// POST /api/v1/admin/gemini/oauth/exchange-code
|
||||
func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
var req GeminiExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 默认使用 code_assist 以保持向后兼容
|
||||
oauthType := strings.TrimSpace(req.OAuthType)
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
|
||||
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.geminiOAuthService.ExchangeCode(c.Request.Context(), &service.GeminiExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
State: req.State,
|
||||
Code: req.Code,
|
||||
ProxyID: req.ProxyID,
|
||||
OAuthType: oauthType,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
func deriveGeminiRedirectURI(c *gin.Context) string {
|
||||
origin := strings.TrimSpace(c.GetHeader("Origin"))
|
||||
if origin != "" {
|
||||
return strings.TrimRight(origin, "/") + "/auth/callback"
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if xfProto := strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")); xfProto != "" {
|
||||
scheme = strings.TrimSpace(strings.Split(xfProto, ",")[0])
|
||||
}
|
||||
|
||||
host := strings.TrimSpace(c.Request.Host)
|
||||
if xfHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); xfHost != "" {
|
||||
host = strings.TrimSpace(strings.Split(xfHost, ",")[0])
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s/auth/callback", scheme, host)
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GeminiOAuthHandler struct {
|
||||
geminiOAuthService *service.GeminiOAuthService
|
||||
}
|
||||
|
||||
func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *GeminiOAuthHandler {
|
||||
return &GeminiOAuthHandler{geminiOAuthService: geminiOAuthService}
|
||||
}
|
||||
|
||||
// GET /api/v1/admin/gemini/oauth/capabilities
|
||||
func (h *GeminiOAuthHandler) GetCapabilities(c *gin.Context) {
|
||||
cfg := h.geminiOAuthService.GetOAuthConfig()
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
type GeminiGenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
ProjectID string `json:"project_id"`
|
||||
// OAuth 类型: "code_assist" (需要 project_id) 或 "ai_studio" (不需要 project_id)
|
||||
// 默认为 "code_assist" 以保持向后兼容
|
||||
OAuthType string `json:"oauth_type"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates Google OAuth authorization URL for Gemini.
|
||||
// POST /api/v1/admin/gemini/oauth/auth-url
|
||||
func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
var req GeminiGenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 默认使用 code_assist 以保持向后兼容
|
||||
oauthType := strings.TrimSpace(req.OAuthType)
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
|
||||
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
|
||||
return
|
||||
}
|
||||
|
||||
// Always pass the "hosted" callback URI; the OAuth service may override it depending on
|
||||
// oauth_type and whether the built-in Gemini CLI OAuth client is used.
|
||||
redirectURI := deriveGeminiRedirectURI(c)
|
||||
result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, redirectURI, req.ProjectID, oauthType)
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
// Treat missing/invalid OAuth client configuration as a user/config error.
|
||||
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
|
||||
response.BadRequest(c, "Failed to generate auth URL: "+msg)
|
||||
return
|
||||
}
|
||||
response.InternalError(c, "Failed to generate auth URL: "+msg)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
type GeminiExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
// OAuth 类型: "code_assist" 或 "ai_studio",需要与 GenerateAuthURL 时的类型一致
|
||||
OAuthType string `json:"oauth_type"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens.
|
||||
// POST /api/v1/admin/gemini/oauth/exchange-code
|
||||
func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
var req GeminiExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 默认使用 code_assist 以保持向后兼容
|
||||
oauthType := strings.TrimSpace(req.OAuthType)
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
|
||||
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.geminiOAuthService.ExchangeCode(c.Request.Context(), &service.GeminiExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
State: req.State,
|
||||
Code: req.Code,
|
||||
ProxyID: req.ProxyID,
|
||||
OAuthType: oauthType,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
func deriveGeminiRedirectURI(c *gin.Context) string {
|
||||
origin := strings.TrimSpace(c.GetHeader("Origin"))
|
||||
if origin != "" {
|
||||
return strings.TrimRight(origin, "/") + "/auth/callback"
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if xfProto := strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")); xfProto != "" {
|
||||
scheme = strings.TrimSpace(strings.Split(xfProto, ",")[0])
|
||||
}
|
||||
|
||||
host := strings.TrimSpace(c.Request.Host)
|
||||
if xfHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); xfHost != "" {
|
||||
host = strings.TrimSpace(strings.Split(xfHost, ",")[0])
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s/auth/callback", scheme, host)
|
||||
}
|
||||
|
||||
@@ -1,245 +1,245 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GroupHandler handles admin group management
|
||||
type GroupHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewGroupHandler creates a new admin group handler
|
||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
return &GroupHandler{
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateGroupRequest represents create group request
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
}
|
||||
|
||||
// List handles listing all groups with pagination
|
||||
// GET /api/v1/admin/groups
|
||||
func (h *GroupHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
platform := c.Query("platform")
|
||||
status := c.Query("status")
|
||||
isExclusiveStr := c.Query("is_exclusive")
|
||||
|
||||
var isExclusive *bool
|
||||
if isExclusiveStr != "" {
|
||||
val := isExclusiveStr == "true"
|
||||
isExclusive = &val
|
||||
}
|
||||
|
||||
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
}
|
||||
response.Paginated(c, outGroups, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetAll handles getting all active groups without pagination
|
||||
// GET /api/v1/admin/groups/all
|
||||
func (h *GroupHandler) GetAll(c *gin.Context) {
|
||||
platform := c.Query("platform")
|
||||
|
||||
var groups []service.Group
|
||||
var err error
|
||||
|
||||
if platform != "" {
|
||||
groups, err = h.adminService.GetAllGroupsByPlatform(c.Request.Context(), platform)
|
||||
} else {
|
||||
groups, err = h.adminService.GetAllGroups(c.Request.Context())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
}
|
||||
response.Success(c, outGroups)
|
||||
}
|
||||
|
||||
// GetByID handles getting a group by ID
|
||||
// GET /api/v1/admin/groups/:id
|
||||
func (h *GroupHandler) GetByID(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.adminService.GetGroup(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
}
|
||||
|
||||
// Create handles creating a new group
|
||||
// POST /api/v1/admin/groups
|
||||
func (h *GroupHandler) Create(c *gin.Context) {
|
||||
var req CreateGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
}
|
||||
|
||||
// Update handles updating a group
|
||||
// PUT /api/v1/admin/groups/:id
|
||||
func (h *GroupHandler) Update(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
}
|
||||
|
||||
// Delete handles deleting a group
|
||||
// DELETE /api/v1/admin/groups/:id
|
||||
func (h *GroupHandler) Delete(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.adminService.DeleteGroup(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Group deleted successfully"})
|
||||
}
|
||||
|
||||
// GetStats handles getting group statistics
|
||||
// GET /api/v1/admin/groups/:id/stats
|
||||
func (h *GroupHandler) GetStats(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Return mock data for now
|
||||
response.Success(c, gin.H{
|
||||
"total_api_keys": 0,
|
||||
"active_api_keys": 0,
|
||||
"total_requests": 0,
|
||||
"total_cost": 0.0,
|
||||
})
|
||||
_ = groupID // TODO: implement actual stats
|
||||
}
|
||||
|
||||
// GetGroupAPIKeys handles getting API keys in a group
|
||||
// GET /api/v1/admin/groups/:id/api-keys
|
||||
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
outKeys := make([]dto.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GroupHandler handles admin group management
|
||||
type GroupHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewGroupHandler creates a new admin group handler
|
||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
return &GroupHandler{
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateGroupRequest represents create group request
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
}
|
||||
|
||||
// List handles listing all groups with pagination
|
||||
// GET /api/v1/admin/groups
|
||||
func (h *GroupHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
platform := c.Query("platform")
|
||||
status := c.Query("status")
|
||||
isExclusiveStr := c.Query("is_exclusive")
|
||||
|
||||
var isExclusive *bool
|
||||
if isExclusiveStr != "" {
|
||||
val := isExclusiveStr == "true"
|
||||
isExclusive = &val
|
||||
}
|
||||
|
||||
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
}
|
||||
response.Paginated(c, outGroups, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetAll handles getting all active groups without pagination
|
||||
// GET /api/v1/admin/groups/all
|
||||
func (h *GroupHandler) GetAll(c *gin.Context) {
|
||||
platform := c.Query("platform")
|
||||
|
||||
var groups []service.Group
|
||||
var err error
|
||||
|
||||
if platform != "" {
|
||||
groups, err = h.adminService.GetAllGroupsByPlatform(c.Request.Context(), platform)
|
||||
} else {
|
||||
groups, err = h.adminService.GetAllGroups(c.Request.Context())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
}
|
||||
response.Success(c, outGroups)
|
||||
}
|
||||
|
||||
// GetByID handles getting a group by ID
|
||||
// GET /api/v1/admin/groups/:id
|
||||
func (h *GroupHandler) GetByID(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.adminService.GetGroup(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
}
|
||||
|
||||
// Create handles creating a new group
|
||||
// POST /api/v1/admin/groups
|
||||
func (h *GroupHandler) Create(c *gin.Context) {
|
||||
var req CreateGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
}
|
||||
|
||||
// Update handles updating a group
|
||||
// PUT /api/v1/admin/groups/:id
|
||||
func (h *GroupHandler) Update(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
}
|
||||
|
||||
// Delete handles deleting a group
|
||||
// DELETE /api/v1/admin/groups/:id
|
||||
func (h *GroupHandler) Delete(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.adminService.DeleteGroup(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Group deleted successfully"})
|
||||
}
|
||||
|
||||
// GetStats handles getting group statistics
|
||||
// GET /api/v1/admin/groups/:id/stats
|
||||
func (h *GroupHandler) GetStats(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Return mock data for now
|
||||
response.Success(c, gin.H{
|
||||
"total_api_keys": 0,
|
||||
"active_api_keys": 0,
|
||||
"total_requests": 0,
|
||||
"total_cost": 0.0,
|
||||
})
|
||||
_ = groupID // TODO: implement actual stats
|
||||
}
|
||||
|
||||
// GetGroupAPIKeys handles getting API keys in a group
|
||||
// GET /api/v1/admin/groups/:id/api-keys
|
||||
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
outKeys := make([]dto.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
|
||||
@@ -1,229 +1,229 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenAIOAuthHandler handles OpenAI OAuth-related operations
|
||||
type OpenAIOAuthHandler struct {
|
||||
openaiOAuthService *service.OpenAIOAuthService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
||||
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
||||
return &OpenAIOAuthHandler{
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIGenerateAuthURLRequest represents the request for generating OpenAI auth URL
|
||||
type OpenAIGenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates OpenAI OAuth authorization URL
|
||||
// POST /api/v1/admin/openai/generate-auth-url
|
||||
func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
var req OpenAIGenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// Allow empty body
|
||||
req = OpenAIGenerateAuthURLRequest{}
|
||||
}
|
||||
|
||||
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// OpenAIExchangeCodeRequest represents the request for exchanging OpenAI auth code
|
||||
type OpenAIExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges OpenAI authorization code for tokens
|
||||
// POST /api/v1/admin/openai/exchange-code
|
||||
func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
var req OpenAIExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
||||
type OpenAIRefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
// POST /api/v1/admin/openai/refresh-token
|
||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req OpenAIRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if req.ProxyID != nil {
|
||||
proxy, err := h.adminService.GetProxy(c.Request.Context(), *req.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI account
|
||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure account is OpenAI platform
|
||||
if !account.IsOpenAI() {
|
||||
response.BadRequest(c, "Account is not an OpenAI account")
|
||||
return
|
||||
}
|
||||
|
||||
// Only refresh OAuth-based accounts
|
||||
if !account.IsOAuth() {
|
||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
||||
return
|
||||
}
|
||||
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build new credentials from token info
|
||||
newCredentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||
// POST /api/v1/admin/openai/create-from-oauth
|
||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Name string `json:"name"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build credentials from token info
|
||||
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Use email as default name if not provided
|
||||
name := req.Name
|
||||
if name == "" && tokenInfo.Email != "" {
|
||||
name = tokenInfo.Email
|
||||
}
|
||||
if name == "" {
|
||||
name = "OpenAI OAuth Account"
|
||||
}
|
||||
|
||||
// Create account
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: name,
|
||||
Platform: "openai",
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
GroupIDs: req.GroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenAIOAuthHandler handles OpenAI OAuth-related operations
|
||||
type OpenAIOAuthHandler struct {
|
||||
openaiOAuthService *service.OpenAIOAuthService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
||||
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
||||
return &OpenAIOAuthHandler{
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIGenerateAuthURLRequest represents the request for generating OpenAI auth URL
|
||||
type OpenAIGenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates OpenAI OAuth authorization URL
|
||||
// POST /api/v1/admin/openai/generate-auth-url
|
||||
func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
var req OpenAIGenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// Allow empty body
|
||||
req = OpenAIGenerateAuthURLRequest{}
|
||||
}
|
||||
|
||||
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// OpenAIExchangeCodeRequest represents the request for exchanging OpenAI auth code
|
||||
type OpenAIExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges OpenAI authorization code for tokens
|
||||
// POST /api/v1/admin/openai/exchange-code
|
||||
func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
var req OpenAIExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
||||
type OpenAIRefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
// POST /api/v1/admin/openai/refresh-token
|
||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req OpenAIRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if req.ProxyID != nil {
|
||||
proxy, err := h.adminService.GetProxy(c.Request.Context(), *req.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI account
|
||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure account is OpenAI platform
|
||||
if !account.IsOpenAI() {
|
||||
response.BadRequest(c, "Account is not an OpenAI account")
|
||||
return
|
||||
}
|
||||
|
||||
// Only refresh OAuth-based accounts
|
||||
if !account.IsOAuth() {
|
||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
||||
return
|
||||
}
|
||||
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build new credentials from token info
|
||||
newCredentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||
// POST /api/v1/admin/openai/create-from-oauth
|
||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Name string `json:"name"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build credentials from token info
|
||||
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Use email as default name if not provided
|
||||
name := req.Name
|
||||
if name == "" && tokenInfo.Email != "" {
|
||||
name = tokenInfo.Email
|
||||
}
|
||||
if name == "" {
|
||||
name = "OpenAI OAuth Account"
|
||||
}
|
||||
|
||||
// Create account
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: name,
|
||||
Platform: "openai",
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
GroupIDs: req.GroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
@@ -1,238 +1,238 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RedeemHandler handles admin redeem code management
|
||||
type RedeemHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewRedeemHandler creates a new admin redeem handler
|
||||
func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
|
||||
return &RedeemHandler{
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRedeemCodesRequest represents generate redeem codes request
|
||||
type GenerateRedeemCodesRequest struct {
|
||||
Count int `json:"count" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription"`
|
||||
Value float64 `json:"value" binding:"min=0"`
|
||||
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
|
||||
}
|
||||
|
||||
// List handles listing all redeem codes with pagination
|
||||
// GET /api/v1/admin/redeem-codes
|
||||
func (h *RedeemHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
codeType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
|
||||
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a redeem code by ID
|
||||
// GET /api/v1/admin/redeem-codes/:id
|
||||
func (h *RedeemHandler) GetByID(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid redeem code ID")
|
||||
return
|
||||
}
|
||||
|
||||
code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
}
|
||||
|
||||
// Generate handles generating new redeem codes
|
||||
// POST /api/v1/admin/redeem-codes/generate
|
||||
func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
var req GenerateRedeemCodesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{
|
||||
Count: req.Count,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// Delete handles deleting a redeem code
|
||||
// DELETE /api/v1/admin/redeem-codes/:id
|
||||
func (h *RedeemHandler) Delete(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid redeem code ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Redeem code deleted successfully"})
|
||||
}
|
||||
|
||||
// BatchDelete handles batch deleting redeem codes
|
||||
// POST /api/v1/admin/redeem-codes/batch-delete
|
||||
func (h *RedeemHandler) BatchDelete(c *gin.Context) {
|
||||
var req struct {
|
||||
IDs []int64 `json:"ids" binding:"required,min=1"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"deleted": deleted,
|
||||
"message": "Redeem codes deleted successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// Expire handles expiring a redeem code
|
||||
// POST /api/v1/admin/redeem-codes/:id/expire
|
||||
func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid redeem code ID")
|
||||
return
|
||||
}
|
||||
|
||||
code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
}
|
||||
|
||||
// GetStats handles getting redeem code statistics
|
||||
// GET /api/v1/admin/redeem-codes/stats
|
||||
func (h *RedeemHandler) GetStats(c *gin.Context) {
|
||||
// Return mock data for now
|
||||
response.Success(c, gin.H{
|
||||
"total_codes": 0,
|
||||
"active_codes": 0,
|
||||
"used_codes": 0,
|
||||
"expired_codes": 0,
|
||||
"total_value_distributed": 0.0,
|
||||
"by_type": gin.H{
|
||||
"balance": 0,
|
||||
"concurrency": 0,
|
||||
"trial": 0,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Export handles exporting redeem codes to CSV
|
||||
// GET /api/v1/admin/redeem-codes/export
|
||||
func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
codeType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
|
||||
// Get all codes without pagination (use large page size)
|
||||
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create CSV buffer
|
||||
var buf bytes.Buffer
|
||||
writer := csv.NewWriter(&buf)
|
||||
|
||||
// Write header
|
||||
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Write data rows
|
||||
for _, code := range codes {
|
||||
usedBy := ""
|
||||
if code.UsedBy != nil {
|
||||
usedBy = fmt.Sprintf("%d", *code.UsedBy)
|
||||
}
|
||||
usedAt := ""
|
||||
if code.UsedAt != nil {
|
||||
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
if err := writer.Write([]string{
|
||||
fmt.Sprintf("%d", code.ID),
|
||||
code.Code,
|
||||
code.Type,
|
||||
fmt.Sprintf("%.2f", code.Value),
|
||||
code.Status,
|
||||
usedBy,
|
||||
usedAt,
|
||||
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
writer.Flush()
|
||||
if err := writer.Error(); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/csv")
|
||||
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
|
||||
c.Data(200, "text/csv", buf.Bytes())
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RedeemHandler handles admin redeem code management
|
||||
type RedeemHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewRedeemHandler creates a new admin redeem handler
|
||||
func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
|
||||
return &RedeemHandler{
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRedeemCodesRequest represents generate redeem codes request
|
||||
type GenerateRedeemCodesRequest struct {
|
||||
Count int `json:"count" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription"`
|
||||
Value float64 `json:"value" binding:"min=0"`
|
||||
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
|
||||
}
|
||||
|
||||
// List handles listing all redeem codes with pagination
|
||||
// GET /api/v1/admin/redeem-codes
|
||||
func (h *RedeemHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
codeType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
|
||||
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a redeem code by ID
|
||||
// GET /api/v1/admin/redeem-codes/:id
|
||||
func (h *RedeemHandler) GetByID(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid redeem code ID")
|
||||
return
|
||||
}
|
||||
|
||||
code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
}
|
||||
|
||||
// Generate handles generating new redeem codes
|
||||
// POST /api/v1/admin/redeem-codes/generate
|
||||
func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
var req GenerateRedeemCodesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{
|
||||
Count: req.Count,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// Delete handles deleting a redeem code
|
||||
// DELETE /api/v1/admin/redeem-codes/:id
|
||||
func (h *RedeemHandler) Delete(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid redeem code ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Redeem code deleted successfully"})
|
||||
}
|
||||
|
||||
// BatchDelete handles batch deleting redeem codes
|
||||
// POST /api/v1/admin/redeem-codes/batch-delete
|
||||
func (h *RedeemHandler) BatchDelete(c *gin.Context) {
|
||||
var req struct {
|
||||
IDs []int64 `json:"ids" binding:"required,min=1"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"deleted": deleted,
|
||||
"message": "Redeem codes deleted successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// Expire handles expiring a redeem code
|
||||
// POST /api/v1/admin/redeem-codes/:id/expire
|
||||
func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid redeem code ID")
|
||||
return
|
||||
}
|
||||
|
||||
code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
}
|
||||
|
||||
// GetStats handles getting redeem code statistics
|
||||
// GET /api/v1/admin/redeem-codes/stats
|
||||
func (h *RedeemHandler) GetStats(c *gin.Context) {
|
||||
// Return mock data for now
|
||||
response.Success(c, gin.H{
|
||||
"total_codes": 0,
|
||||
"active_codes": 0,
|
||||
"used_codes": 0,
|
||||
"expired_codes": 0,
|
||||
"total_value_distributed": 0.0,
|
||||
"by_type": gin.H{
|
||||
"balance": 0,
|
||||
"concurrency": 0,
|
||||
"trial": 0,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Export handles exporting redeem codes to CSV
|
||||
// GET /api/v1/admin/redeem-codes/export
|
||||
func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
codeType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
|
||||
// Get all codes without pagination (use large page size)
|
||||
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create CSV buffer
|
||||
var buf bytes.Buffer
|
||||
writer := csv.NewWriter(&buf)
|
||||
|
||||
// Write header
|
||||
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Write data rows
|
||||
for _, code := range codes {
|
||||
usedBy := ""
|
||||
if code.UsedBy != nil {
|
||||
usedBy = fmt.Sprintf("%d", *code.UsedBy)
|
||||
}
|
||||
usedAt := ""
|
||||
if code.UsedAt != nil {
|
||||
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
if err := writer.Write([]string{
|
||||
fmt.Sprintf("%d", code.ID),
|
||||
code.Code,
|
||||
code.Type,
|
||||
fmt.Sprintf("%.2f", code.Value),
|
||||
code.Status,
|
||||
usedBy,
|
||||
usedAt,
|
||||
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
writer.Flush()
|
||||
if err := writer.Error(); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/csv")
|
||||
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
|
||||
c.Data(200, "text/csv", buf.Bytes())
|
||||
}
|
||||
|
||||
@@ -1,374 +1,374 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SettingHandler 系统设置处理器
|
||||
type SettingHandler struct {
|
||||
settingService *service.SettingService
|
||||
emailService *service.EmailService
|
||||
turnstileService *service.TurnstileService
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建系统设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSettings 获取所有系统设置
|
||||
// GET /api/v1/admin/settings
|
||||
func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SmtpHost: settings.SmtpHost,
|
||||
SmtpPort: settings.SmtpPort,
|
||||
SmtpUsername: settings.SmtpUsername,
|
||||
SmtpPassword: settings.SmtpPassword,
|
||||
SmtpFrom: settings.SmtpFrom,
|
||||
SmtpFromName: settings.SmtpFromName,
|
||||
SmtpUseTLS: settings.SmtpUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: settings.TurnstileSecretKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
ApiBaseUrl: settings.ApiBaseUrl,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocUrl: settings.DocUrl,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
// 邮件服务设置
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
// PUT /api/v1/admin/settings
|
||||
func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
var req UpdateSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证参数
|
||||
if req.DefaultConcurrency < 1 {
|
||||
req.DefaultConcurrency = 1
|
||||
}
|
||||
if req.DefaultBalance < 0 {
|
||||
req.DefaultBalance = 0
|
||||
}
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
// Turnstile 参数验证
|
||||
if req.TurnstileEnabled {
|
||||
// 检查必填字段
|
||||
if req.TurnstileSiteKey == "" {
|
||||
response.BadRequest(c, "Turnstile Site Key is required when enabled")
|
||||
return
|
||||
}
|
||||
if req.TurnstileSecretKey == "" {
|
||||
response.BadRequest(c, "Turnstile Secret Key is required when enabled")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前设置,检查参数是否有变化
|
||||
currentSettings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 当 site_key 或 secret_key 任一变化时验证(避免配置错误导致无法登录)
|
||||
siteKeyChanged := currentSettings.TurnstileSiteKey != req.TurnstileSiteKey
|
||||
secretKeyChanged := currentSettings.TurnstileSecretKey != req.TurnstileSecretKey
|
||||
if siteKeyChanged || secretKeyChanged {
|
||||
if err := h.turnstileService.ValidateSecretKey(c.Request.Context(), req.TurnstileSecretKey); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SmtpHost: req.SmtpHost,
|
||||
SmtpPort: req.SmtpPort,
|
||||
SmtpUsername: req.SmtpUsername,
|
||||
SmtpPassword: req.SmtpPassword,
|
||||
SmtpFrom: req.SmtpFrom,
|
||||
SmtpFromName: req.SmtpFromName,
|
||||
SmtpUseTLS: req.SmtpUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
ApiBaseUrl: req.ApiBaseUrl,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocUrl: req.DocUrl,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
}
|
||||
|
||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 重新获取设置返回
|
||||
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SmtpHost: updatedSettings.SmtpHost,
|
||||
SmtpPort: updatedSettings.SmtpPort,
|
||||
SmtpUsername: updatedSettings.SmtpUsername,
|
||||
SmtpPassword: updatedSettings.SmtpPassword,
|
||||
SmtpFrom: updatedSettings.SmtpFrom,
|
||||
SmtpFromName: updatedSettings.SmtpFromName,
|
||||
SmtpUseTLS: updatedSettings.SmtpUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
ApiBaseUrl: updatedSettings.ApiBaseUrl,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocUrl: updatedSettings.DocUrl,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
})
|
||||
}
|
||||
|
||||
// TestSmtpRequest 测试SMTP连接请求
|
||||
type TestSmtpRequest struct {
|
||||
SmtpHost string `json:"smtp_host" binding:"required"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
}
|
||||
|
||||
// TestSmtpConnection 测试SMTP连接
|
||||
// POST /api/v1/admin/settings/test-smtp
|
||||
func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
|
||||
var req TestSmtpRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SmtpPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
}
|
||||
|
||||
config := &service.SmtpConfig{
|
||||
Host: req.SmtpHost,
|
||||
Port: req.SmtpPort,
|
||||
Username: req.SmtpUsername,
|
||||
Password: password,
|
||||
UseTLS: req.SmtpUseTLS,
|
||||
}
|
||||
|
||||
err := h.emailService.TestSmtpConnectionWithConfig(config)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "SMTP connection successful"})
|
||||
}
|
||||
|
||||
// SendTestEmailRequest 发送测试邮件请求
|
||||
type SendTestEmailRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
SmtpHost string `json:"smtp_host" binding:"required"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
}
|
||||
|
||||
// SendTestEmail 发送测试邮件
|
||||
// POST /api/v1/admin/settings/send-test-email
|
||||
func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
var req SendTestEmailRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SmtpPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
}
|
||||
|
||||
config := &service.SmtpConfig{
|
||||
Host: req.SmtpHost,
|
||||
Port: req.SmtpPort,
|
||||
Username: req.SmtpUsername,
|
||||
Password: password,
|
||||
From: req.SmtpFrom,
|
||||
FromName: req.SmtpFromName,
|
||||
UseTLS: req.SmtpUseTLS,
|
||||
}
|
||||
|
||||
siteName := h.settingService.GetSiteName(c.Request.Context())
|
||||
subject := "[" + siteName + "] Test Email"
|
||||
body := `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
|
||||
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
|
||||
.header { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 30px; text-align: center; }
|
||||
.content { padding: 40px 30px; text-align: center; }
|
||||
.success { color: #10b981; font-size: 48px; margin-bottom: 20px; }
|
||||
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>` + siteName + `</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<div class="success">✓</div>
|
||||
<h2>Email Configuration Successful!</h2>
|
||||
<p>This is a test email to verify your SMTP settings are working correctly.</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>This is an automated test message.</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
|
||||
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Test email sent successfully"})
|
||||
}
|
||||
|
||||
// GetAdminApiKey 获取管理员 API Key 状态
|
||||
// GET /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"exists": exists,
|
||||
"masked_key": maskedKey,
|
||||
})
|
||||
}
|
||||
|
||||
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
|
||||
// POST /api/v1/admin/settings/admin-api-key/regenerate
|
||||
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"key": key, // 完整 key 只在生成时返回一次
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
// DELETE /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
|
||||
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SettingHandler 系统设置处理器
|
||||
type SettingHandler struct {
|
||||
settingService *service.SettingService
|
||||
emailService *service.EmailService
|
||||
turnstileService *service.TurnstileService
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建系统设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSettings 获取所有系统设置
|
||||
// GET /api/v1/admin/settings
|
||||
func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SmtpHost: settings.SmtpHost,
|
||||
SmtpPort: settings.SmtpPort,
|
||||
SmtpUsername: settings.SmtpUsername,
|
||||
SmtpPassword: settings.SmtpPassword,
|
||||
SmtpFrom: settings.SmtpFrom,
|
||||
SmtpFromName: settings.SmtpFromName,
|
||||
SmtpUseTLS: settings.SmtpUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: settings.TurnstileSecretKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
ApiBaseUrl: settings.ApiBaseUrl,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocUrl: settings.DocUrl,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
// 邮件服务设置
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
// PUT /api/v1/admin/settings
|
||||
func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
var req UpdateSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证参数
|
||||
if req.DefaultConcurrency < 1 {
|
||||
req.DefaultConcurrency = 1
|
||||
}
|
||||
if req.DefaultBalance < 0 {
|
||||
req.DefaultBalance = 0
|
||||
}
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
// Turnstile 参数验证
|
||||
if req.TurnstileEnabled {
|
||||
// 检查必填字段
|
||||
if req.TurnstileSiteKey == "" {
|
||||
response.BadRequest(c, "Turnstile Site Key is required when enabled")
|
||||
return
|
||||
}
|
||||
if req.TurnstileSecretKey == "" {
|
||||
response.BadRequest(c, "Turnstile Secret Key is required when enabled")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前设置,检查参数是否有变化
|
||||
currentSettings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 当 site_key 或 secret_key 任一变化时验证(避免配置错误导致无法登录)
|
||||
siteKeyChanged := currentSettings.TurnstileSiteKey != req.TurnstileSiteKey
|
||||
secretKeyChanged := currentSettings.TurnstileSecretKey != req.TurnstileSecretKey
|
||||
if siteKeyChanged || secretKeyChanged {
|
||||
if err := h.turnstileService.ValidateSecretKey(c.Request.Context(), req.TurnstileSecretKey); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SmtpHost: req.SmtpHost,
|
||||
SmtpPort: req.SmtpPort,
|
||||
SmtpUsername: req.SmtpUsername,
|
||||
SmtpPassword: req.SmtpPassword,
|
||||
SmtpFrom: req.SmtpFrom,
|
||||
SmtpFromName: req.SmtpFromName,
|
||||
SmtpUseTLS: req.SmtpUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
ApiBaseUrl: req.ApiBaseUrl,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocUrl: req.DocUrl,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
}
|
||||
|
||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 重新获取设置返回
|
||||
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SmtpHost: updatedSettings.SmtpHost,
|
||||
SmtpPort: updatedSettings.SmtpPort,
|
||||
SmtpUsername: updatedSettings.SmtpUsername,
|
||||
SmtpPassword: updatedSettings.SmtpPassword,
|
||||
SmtpFrom: updatedSettings.SmtpFrom,
|
||||
SmtpFromName: updatedSettings.SmtpFromName,
|
||||
SmtpUseTLS: updatedSettings.SmtpUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
ApiBaseUrl: updatedSettings.ApiBaseUrl,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocUrl: updatedSettings.DocUrl,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
})
|
||||
}
|
||||
|
||||
// TestSmtpRequest 测试SMTP连接请求
|
||||
type TestSmtpRequest struct {
|
||||
SmtpHost string `json:"smtp_host" binding:"required"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
}
|
||||
|
||||
// TestSmtpConnection 测试SMTP连接
|
||||
// POST /api/v1/admin/settings/test-smtp
|
||||
func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
|
||||
var req TestSmtpRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SmtpPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
}
|
||||
|
||||
config := &service.SmtpConfig{
|
||||
Host: req.SmtpHost,
|
||||
Port: req.SmtpPort,
|
||||
Username: req.SmtpUsername,
|
||||
Password: password,
|
||||
UseTLS: req.SmtpUseTLS,
|
||||
}
|
||||
|
||||
err := h.emailService.TestSmtpConnectionWithConfig(config)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "SMTP connection successful"})
|
||||
}
|
||||
|
||||
// SendTestEmailRequest 发送测试邮件请求
|
||||
type SendTestEmailRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
SmtpHost string `json:"smtp_host" binding:"required"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
}
|
||||
|
||||
// SendTestEmail 发送测试邮件
|
||||
// POST /api/v1/admin/settings/send-test-email
|
||||
func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
var req SendTestEmailRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SmtpPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
}
|
||||
|
||||
config := &service.SmtpConfig{
|
||||
Host: req.SmtpHost,
|
||||
Port: req.SmtpPort,
|
||||
Username: req.SmtpUsername,
|
||||
Password: password,
|
||||
From: req.SmtpFrom,
|
||||
FromName: req.SmtpFromName,
|
||||
UseTLS: req.SmtpUseTLS,
|
||||
}
|
||||
|
||||
siteName := h.settingService.GetSiteName(c.Request.Context())
|
||||
subject := "[" + siteName + "] Test Email"
|
||||
body := `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
|
||||
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
|
||||
.header { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 30px; text-align: center; }
|
||||
.content { padding: 40px 30px; text-align: center; }
|
||||
.success { color: #10b981; font-size: 48px; margin-bottom: 20px; }
|
||||
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>` + siteName + `</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<div class="success">✓</div>
|
||||
<h2>Email Configuration Successful!</h2>
|
||||
<p>This is a test email to verify your SMTP settings are working correctly.</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>This is an automated test message.</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
|
||||
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Test email sent successfully"})
|
||||
}
|
||||
|
||||
// GetAdminApiKey 获取管理员 API Key 状态
|
||||
// GET /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"exists": exists,
|
||||
"masked_key": maskedKey,
|
||||
})
|
||||
}
|
||||
|
||||
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
|
||||
// POST /api/v1/admin/settings/admin-api-key/regenerate
|
||||
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"key": key, // 完整 key 只在生成时返回一次
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
// DELETE /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
|
||||
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
||||
}
|
||||
|
||||
@@ -1,278 +1,278 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// toResponsePagination converts pagination.PaginationResult to response.PaginationResult
|
||||
func toResponsePagination(p *pagination.PaginationResult) *response.PaginationResult {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &response.PaginationResult{
|
||||
Total: p.Total,
|
||||
Page: p.Page,
|
||||
PageSize: p.PageSize,
|
||||
Pages: p.Pages,
|
||||
}
|
||||
}
|
||||
|
||||
// SubscriptionHandler handles admin subscription management
|
||||
type SubscriptionHandler struct {
|
||||
subscriptionService *service.SubscriptionService
|
||||
}
|
||||
|
||||
// NewSubscriptionHandler creates a new admin subscription handler
|
||||
func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *SubscriptionHandler {
|
||||
return &SubscriptionHandler{
|
||||
subscriptionService: subscriptionService,
|
||||
}
|
||||
}
|
||||
|
||||
// AssignSubscriptionRequest represents assign subscription request
|
||||
type AssignSubscriptionRequest struct {
|
||||
UserID int64 `json:"user_id" binding:"required"`
|
||||
GroupID int64 `json:"group_id" binding:"required"`
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // max 100 years
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// BulkAssignSubscriptionRequest represents bulk assign subscription request
|
||||
type BulkAssignSubscriptionRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required,min=1"`
|
||||
GroupID int64 `json:"group_id" binding:"required"`
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // max 100 years
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// ExtendSubscriptionRequest represents extend subscription request
|
||||
type ExtendSubscriptionRequest struct {
|
||||
Days int `json:"days" binding:"required,min=1,max=36500"` // max 100 years
|
||||
}
|
||||
|
||||
// List handles listing all subscriptions with pagination and filters
|
||||
// GET /api/v1/admin/subscriptions
|
||||
func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
// Parse optional filters
|
||||
var userID, groupID *int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = &id
|
||||
}
|
||||
}
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
|
||||
groupID = &id
|
||||
}
|
||||
}
|
||||
status := c.Query("status")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
|
||||
// GetByID handles getting a subscription by ID
|
||||
// GET /api/v1/admin/subscriptions/:id
|
||||
func (h *SubscriptionHandler) GetByID(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
|
||||
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
}
|
||||
|
||||
// GetProgress handles getting subscription usage progress
|
||||
// GET /api/v1/admin/subscriptions/:id/progress
|
||||
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
|
||||
progress, err := h.subscriptionService.GetSubscriptionProgress(c.Request.Context(), subscriptionID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Subscription not found")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, progress)
|
||||
}
|
||||
|
||||
// Assign handles assigning a subscription to a user
|
||||
// POST /api/v1/admin/subscriptions/assign
|
||||
func (h *SubscriptionHandler) Assign(c *gin.Context) {
|
||||
var req AssignSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Get admin user ID from context
|
||||
adminID := getAdminIDFromContext(c)
|
||||
|
||||
subscription, err := h.subscriptionService.AssignSubscription(c.Request.Context(), &service.AssignSubscriptionInput{
|
||||
UserID: req.UserID,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
AssignedBy: adminID,
|
||||
Notes: req.Notes,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
}
|
||||
|
||||
// BulkAssign handles bulk assigning subscriptions to multiple users
|
||||
// POST /api/v1/admin/subscriptions/bulk-assign
|
||||
func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
|
||||
var req BulkAssignSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Get admin user ID from context
|
||||
adminID := getAdminIDFromContext(c)
|
||||
|
||||
result, err := h.subscriptionService.BulkAssignSubscription(c.Request.Context(), &service.BulkAssignSubscriptionInput{
|
||||
UserIDs: req.UserIDs,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
AssignedBy: adminID,
|
||||
Notes: req.Notes,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.BulkAssignResultFromService(result))
|
||||
}
|
||||
|
||||
// Extend handles extending a subscription
|
||||
// POST /api/v1/admin/subscriptions/:id/extend
|
||||
func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req ExtendSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
// DELETE /api/v1/admin/subscriptions/:id
|
||||
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Subscription revoked successfully"})
|
||||
}
|
||||
|
||||
// ListByGroup handles listing subscriptions for a specific group
|
||||
// GET /api/v1/admin/groups/:id/subscriptions
|
||||
func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
|
||||
// ListByUser handles listing subscriptions for a specific user
|
||||
// GET /api/v1/admin/users/:id/subscriptions
|
||||
func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// Helper function to get admin ID from context
|
||||
func getAdminIDFromContext(c *gin.Context) int64 {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return subject.UserID
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// toResponsePagination converts pagination.PaginationResult to response.PaginationResult
|
||||
func toResponsePagination(p *pagination.PaginationResult) *response.PaginationResult {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &response.PaginationResult{
|
||||
Total: p.Total,
|
||||
Page: p.Page,
|
||||
PageSize: p.PageSize,
|
||||
Pages: p.Pages,
|
||||
}
|
||||
}
|
||||
|
||||
// SubscriptionHandler handles admin subscription management
|
||||
type SubscriptionHandler struct {
|
||||
subscriptionService *service.SubscriptionService
|
||||
}
|
||||
|
||||
// NewSubscriptionHandler creates a new admin subscription handler
|
||||
func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *SubscriptionHandler {
|
||||
return &SubscriptionHandler{
|
||||
subscriptionService: subscriptionService,
|
||||
}
|
||||
}
|
||||
|
||||
// AssignSubscriptionRequest represents assign subscription request
|
||||
type AssignSubscriptionRequest struct {
|
||||
UserID int64 `json:"user_id" binding:"required"`
|
||||
GroupID int64 `json:"group_id" binding:"required"`
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // max 100 years
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// BulkAssignSubscriptionRequest represents bulk assign subscription request
|
||||
type BulkAssignSubscriptionRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required,min=1"`
|
||||
GroupID int64 `json:"group_id" binding:"required"`
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // max 100 years
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// ExtendSubscriptionRequest represents extend subscription request
|
||||
type ExtendSubscriptionRequest struct {
|
||||
Days int `json:"days" binding:"required,min=1,max=36500"` // max 100 years
|
||||
}
|
||||
|
||||
// List handles listing all subscriptions with pagination and filters
|
||||
// GET /api/v1/admin/subscriptions
|
||||
func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
// Parse optional filters
|
||||
var userID, groupID *int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = &id
|
||||
}
|
||||
}
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
|
||||
groupID = &id
|
||||
}
|
||||
}
|
||||
status := c.Query("status")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
|
||||
// GetByID handles getting a subscription by ID
|
||||
// GET /api/v1/admin/subscriptions/:id
|
||||
func (h *SubscriptionHandler) GetByID(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
|
||||
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
}
|
||||
|
||||
// GetProgress handles getting subscription usage progress
|
||||
// GET /api/v1/admin/subscriptions/:id/progress
|
||||
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
|
||||
progress, err := h.subscriptionService.GetSubscriptionProgress(c.Request.Context(), subscriptionID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Subscription not found")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, progress)
|
||||
}
|
||||
|
||||
// Assign handles assigning a subscription to a user
|
||||
// POST /api/v1/admin/subscriptions/assign
|
||||
func (h *SubscriptionHandler) Assign(c *gin.Context) {
|
||||
var req AssignSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Get admin user ID from context
|
||||
adminID := getAdminIDFromContext(c)
|
||||
|
||||
subscription, err := h.subscriptionService.AssignSubscription(c.Request.Context(), &service.AssignSubscriptionInput{
|
||||
UserID: req.UserID,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
AssignedBy: adminID,
|
||||
Notes: req.Notes,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
}
|
||||
|
||||
// BulkAssign handles bulk assigning subscriptions to multiple users
|
||||
// POST /api/v1/admin/subscriptions/bulk-assign
|
||||
func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
|
||||
var req BulkAssignSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Get admin user ID from context
|
||||
adminID := getAdminIDFromContext(c)
|
||||
|
||||
result, err := h.subscriptionService.BulkAssignSubscription(c.Request.Context(), &service.BulkAssignSubscriptionInput{
|
||||
UserIDs: req.UserIDs,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
AssignedBy: adminID,
|
||||
Notes: req.Notes,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.BulkAssignResultFromService(result))
|
||||
}
|
||||
|
||||
// Extend handles extending a subscription
|
||||
// POST /api/v1/admin/subscriptions/:id/extend
|
||||
func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req ExtendSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
// DELETE /api/v1/admin/subscriptions/:id
|
||||
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Subscription revoked successfully"})
|
||||
}
|
||||
|
||||
// ListByGroup handles listing subscriptions for a specific group
|
||||
// GET /api/v1/admin/groups/:id/subscriptions
|
||||
func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
|
||||
// ListByUser handles listing subscriptions for a specific user
|
||||
// GET /api/v1/admin/users/:id/subscriptions
|
||||
func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// Helper function to get admin ID from context
|
||||
func getAdminIDFromContext(c *gin.Context) int64 {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return subject.UserID
|
||||
}
|
||||
|
||||
@@ -1,87 +1,87 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SystemHandler handles system-related operations
|
||||
type SystemHandler struct {
|
||||
updateSvc *service.UpdateService
|
||||
}
|
||||
|
||||
// NewSystemHandler creates a new SystemHandler
|
||||
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
|
||||
return &SystemHandler{
|
||||
updateSvc: updateSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// GetVersion returns the current version
|
||||
// GET /api/v1/admin/system/version
|
||||
func (h *SystemHandler) GetVersion(c *gin.Context) {
|
||||
info, _ := h.updateSvc.CheckUpdate(c.Request.Context(), false)
|
||||
response.Success(c, gin.H{
|
||||
"version": info.CurrentVersion,
|
||||
})
|
||||
}
|
||||
|
||||
// CheckUpdates checks for available updates
|
||||
// GET /api/v1/admin/system/check-updates
|
||||
func (h *SystemHandler) CheckUpdates(c *gin.Context) {
|
||||
force := c.Query("force") == "true"
|
||||
info, err := h.updateSvc.CheckUpdate(c.Request.Context(), force)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, info)
|
||||
}
|
||||
|
||||
// PerformUpdate downloads and applies the update
|
||||
// POST /api/v1/admin/system/update
|
||||
func (h *SystemHandler) PerformUpdate(c *gin.Context) {
|
||||
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Update completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
})
|
||||
}
|
||||
|
||||
// Rollback restores the previous version
|
||||
// POST /api/v1/admin/system/rollback
|
||||
func (h *SystemHandler) Rollback(c *gin.Context) {
|
||||
if err := h.updateSvc.Rollback(); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Rollback completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
})
|
||||
}
|
||||
|
||||
// RestartService restarts the systemd service
|
||||
// POST /api/v1/admin/system/restart
|
||||
func (h *SystemHandler) RestartService(c *gin.Context) {
|
||||
// Schedule service restart in background after sending response
|
||||
// This ensures the client receives the success response before the service restarts
|
||||
go func() {
|
||||
// Wait a moment to ensure the response is sent
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
sysutil.RestartServiceAsync()
|
||||
}()
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "Service restart initiated",
|
||||
})
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SystemHandler handles system-related operations
|
||||
type SystemHandler struct {
|
||||
updateSvc *service.UpdateService
|
||||
}
|
||||
|
||||
// NewSystemHandler creates a new SystemHandler
|
||||
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
|
||||
return &SystemHandler{
|
||||
updateSvc: updateSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// GetVersion returns the current version
|
||||
// GET /api/v1/admin/system/version
|
||||
func (h *SystemHandler) GetVersion(c *gin.Context) {
|
||||
info, _ := h.updateSvc.CheckUpdate(c.Request.Context(), false)
|
||||
response.Success(c, gin.H{
|
||||
"version": info.CurrentVersion,
|
||||
})
|
||||
}
|
||||
|
||||
// CheckUpdates checks for available updates
|
||||
// GET /api/v1/admin/system/check-updates
|
||||
func (h *SystemHandler) CheckUpdates(c *gin.Context) {
|
||||
force := c.Query("force") == "true"
|
||||
info, err := h.updateSvc.CheckUpdate(c.Request.Context(), force)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, info)
|
||||
}
|
||||
|
||||
// PerformUpdate downloads and applies the update
|
||||
// POST /api/v1/admin/system/update
|
||||
func (h *SystemHandler) PerformUpdate(c *gin.Context) {
|
||||
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Update completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
})
|
||||
}
|
||||
|
||||
// Rollback restores the previous version
|
||||
// POST /api/v1/admin/system/rollback
|
||||
func (h *SystemHandler) Rollback(c *gin.Context) {
|
||||
if err := h.updateSvc.Rollback(); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Rollback completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
})
|
||||
}
|
||||
|
||||
// RestartService restarts the systemd service
|
||||
// POST /api/v1/admin/system/restart
|
||||
func (h *SystemHandler) RestartService(c *gin.Context) {
|
||||
// Schedule service restart in background after sending response
|
||||
// This ensures the client receives the success response before the service restarts
|
||||
go func() {
|
||||
// Wait a moment to ensure the response is sent
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
sysutil.RestartServiceAsync()
|
||||
}()
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "Service restart initiated",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,311 +1,311 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UsageHandler handles admin usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.ApiKeyService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new admin usage handler
|
||||
func NewUsageHandler(
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
adminService service.AdminService,
|
||||
) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// List handles listing all usage records with filters
|
||||
// GET /api/v1/admin/usage
|
||||
func (h *UsageHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
// Parse filters
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
userID = id
|
||||
}
|
||||
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid api_key_id")
|
||||
return
|
||||
}
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||
id, err := strconv.ParseInt(accountIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account_id")
|
||||
return
|
||||
}
|
||||
accountID = id
|
||||
}
|
||||
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
id, err := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group_id")
|
||||
return
|
||||
}
|
||||
groupID = id
|
||||
}
|
||||
|
||||
model := c.Query("model")
|
||||
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
stream = &val
|
||||
}
|
||||
|
||||
var billingType *int8
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
bt := int8(val)
|
||||
billingType = &bt
|
||||
}
|
||||
|
||||
// Parse date range
|
||||
var startTime, endTime *time.Time
|
||||
if startDateStr := c.Query("start_date"); startDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
startTime = &t
|
||||
}
|
||||
|
||||
if endDateStr := c.Query("end_date"); endDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// Set end time to end of day
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
}
|
||||
|
||||
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UsageLog, 0, len(records))
|
||||
for i := range records {
|
||||
out = append(out, *dto.UsageLogFromService(&records[i]))
|
||||
}
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// Stats handles getting usage statistics with filters
|
||||
// GET /api/v1/admin/usage/stats
|
||||
func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
// Parse filters
|
||||
var userID, apiKeyID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
userID = id
|
||||
}
|
||||
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid api_key_id")
|
||||
return
|
||||
}
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
// Parse date range
|
||||
now := timezone.Now()
|
||||
var startTime, endTime time.Time
|
||||
|
||||
startDateStr := c.Query("start_date")
|
||||
endDateStr := c.Query("end_date")
|
||||
|
||||
if startDateStr != "" && endDateStr != "" {
|
||||
var err error
|
||||
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
} else {
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
case "today":
|
||||
startTime = timezone.StartOfDay(now)
|
||||
case "week":
|
||||
startTime = now.AddDate(0, 0, -7)
|
||||
case "month":
|
||||
startTime = now.AddDate(0, -1, 0)
|
||||
default:
|
||||
startTime = timezone.StartOfDay(now)
|
||||
}
|
||||
endTime = now
|
||||
}
|
||||
|
||||
if apiKeyID > 0 {
|
||||
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
return
|
||||
}
|
||||
|
||||
if userID > 0 {
|
||||
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
return
|
||||
}
|
||||
|
||||
// Get global stats
|
||||
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// SearchUsers handles searching users by email keyword
|
||||
// GET /api/v1/admin/usage/search-users
|
||||
func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("q")
|
||||
if keyword == "" {
|
||||
response.Success(c, []any{})
|
||||
return
|
||||
}
|
||||
|
||||
// Limit to 30 results
|
||||
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Return simplified user list (only id and email)
|
||||
type SimpleUser struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
result := make([]SimpleUser, len(users))
|
||||
for i, u := range users {
|
||||
result[i] = SimpleUser{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// SearchApiKeys handles searching API keys by user
|
||||
// GET /api/v1/admin/usage/search-api-keys
|
||||
func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
|
||||
userIDStr := c.Query("user_id")
|
||||
keyword := c.Query("q")
|
||||
|
||||
var userID int64
|
||||
if userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
userID = id
|
||||
}
|
||||
|
||||
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Return simplified API key list (only id and name)
|
||||
type SimpleApiKey struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
UserID int64 `json:"user_id"`
|
||||
}
|
||||
|
||||
result := make([]SimpleApiKey, len(keys))
|
||||
for i, k := range keys {
|
||||
result[i] = SimpleApiKey{
|
||||
ID: k.ID,
|
||||
Name: k.Name,
|
||||
UserID: k.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UsageHandler handles admin usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.ApiKeyService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new admin usage handler
|
||||
func NewUsageHandler(
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
adminService service.AdminService,
|
||||
) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// List handles listing all usage records with filters
|
||||
// GET /api/v1/admin/usage
|
||||
func (h *UsageHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
// Parse filters
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
userID = id
|
||||
}
|
||||
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid api_key_id")
|
||||
return
|
||||
}
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||
id, err := strconv.ParseInt(accountIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account_id")
|
||||
return
|
||||
}
|
||||
accountID = id
|
||||
}
|
||||
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
id, err := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group_id")
|
||||
return
|
||||
}
|
||||
groupID = id
|
||||
}
|
||||
|
||||
model := c.Query("model")
|
||||
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
stream = &val
|
||||
}
|
||||
|
||||
var billingType *int8
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
bt := int8(val)
|
||||
billingType = &bt
|
||||
}
|
||||
|
||||
// Parse date range
|
||||
var startTime, endTime *time.Time
|
||||
if startDateStr := c.Query("start_date"); startDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
startTime = &t
|
||||
}
|
||||
|
||||
if endDateStr := c.Query("end_date"); endDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// Set end time to end of day
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
}
|
||||
|
||||
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UsageLog, 0, len(records))
|
||||
for i := range records {
|
||||
out = append(out, *dto.UsageLogFromService(&records[i]))
|
||||
}
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// Stats handles getting usage statistics with filters
|
||||
// GET /api/v1/admin/usage/stats
|
||||
func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
// Parse filters
|
||||
var userID, apiKeyID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
userID = id
|
||||
}
|
||||
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid api_key_id")
|
||||
return
|
||||
}
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
// Parse date range
|
||||
now := timezone.Now()
|
||||
var startTime, endTime time.Time
|
||||
|
||||
startDateStr := c.Query("start_date")
|
||||
endDateStr := c.Query("end_date")
|
||||
|
||||
if startDateStr != "" && endDateStr != "" {
|
||||
var err error
|
||||
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
} else {
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
case "today":
|
||||
startTime = timezone.StartOfDay(now)
|
||||
case "week":
|
||||
startTime = now.AddDate(0, 0, -7)
|
||||
case "month":
|
||||
startTime = now.AddDate(0, -1, 0)
|
||||
default:
|
||||
startTime = timezone.StartOfDay(now)
|
||||
}
|
||||
endTime = now
|
||||
}
|
||||
|
||||
if apiKeyID > 0 {
|
||||
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
return
|
||||
}
|
||||
|
||||
if userID > 0 {
|
||||
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
return
|
||||
}
|
||||
|
||||
// Get global stats
|
||||
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// SearchUsers handles searching users by email keyword
|
||||
// GET /api/v1/admin/usage/search-users
|
||||
func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("q")
|
||||
if keyword == "" {
|
||||
response.Success(c, []any{})
|
||||
return
|
||||
}
|
||||
|
||||
// Limit to 30 results
|
||||
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Return simplified user list (only id and email)
|
||||
type SimpleUser struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
result := make([]SimpleUser, len(users))
|
||||
for i, u := range users {
|
||||
result[i] = SimpleUser{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// SearchApiKeys handles searching API keys by user
|
||||
// GET /api/v1/admin/usage/search-api-keys
|
||||
func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
|
||||
userIDStr := c.Query("user_id")
|
||||
keyword := c.Query("q")
|
||||
|
||||
var userID int64
|
||||
if userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
userID = id
|
||||
}
|
||||
|
||||
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Return simplified API key list (only id and name)
|
||||
type SimpleApiKey struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
UserID int64 `json:"user_id"`
|
||||
}
|
||||
|
||||
result := make([]SimpleApiKey, len(keys))
|
||||
for i, k := range keys {
|
||||
result[i] = SimpleApiKey{
|
||||
ID: k.ID,
|
||||
Name: k.Name,
|
||||
UserID: k.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
@@ -1,342 +1,342 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UserAttributeHandler handles user attribute management
|
||||
type UserAttributeHandler struct {
|
||||
attrService *service.UserAttributeService
|
||||
}
|
||||
|
||||
// NewUserAttributeHandler creates a new handler
|
||||
func NewUserAttributeHandler(attrService *service.UserAttributeService) *UserAttributeHandler {
|
||||
return &UserAttributeHandler{attrService: attrService}
|
||||
}
|
||||
|
||||
// --- Request/Response DTOs ---
|
||||
|
||||
// CreateAttributeDefinitionRequest represents create attribute definition request
|
||||
type CreateAttributeDefinitionRequest struct {
|
||||
Key string `json:"key" binding:"required,min=1,max=100"`
|
||||
Name string `json:"name" binding:"required,min=1,max=255"`
|
||||
Description string `json:"description"`
|
||||
Type string `json:"type" binding:"required"`
|
||||
Options []service.UserAttributeOption `json:"options"`
|
||||
Required bool `json:"required"`
|
||||
Validation service.UserAttributeValidation `json:"validation"`
|
||||
Placeholder string `json:"placeholder"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// UpdateAttributeDefinitionRequest represents update attribute definition request
|
||||
type UpdateAttributeDefinitionRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
Type *string `json:"type"`
|
||||
Options *[]service.UserAttributeOption `json:"options"`
|
||||
Required *bool `json:"required"`
|
||||
Validation *service.UserAttributeValidation `json:"validation"`
|
||||
Placeholder *string `json:"placeholder"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// ReorderRequest represents reorder attribute definitions request
|
||||
type ReorderRequest struct {
|
||||
IDs []int64 `json:"ids" binding:"required"`
|
||||
}
|
||||
|
||||
// UpdateUserAttributesRequest represents update user attributes request
|
||||
type UpdateUserAttributesRequest struct {
|
||||
Values map[int64]string `json:"values" binding:"required"`
|
||||
}
|
||||
|
||||
// BatchGetUserAttributesRequest represents batch get user attributes request
|
||||
type BatchGetUserAttributesRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// BatchUserAttributesResponse represents batch user attributes response
|
||||
type BatchUserAttributesResponse struct {
|
||||
// Map of userID -> map of attributeID -> value
|
||||
Attributes map[int64]map[int64]string `json:"attributes"`
|
||||
}
|
||||
|
||||
// AttributeDefinitionResponse represents attribute definition response
|
||||
type AttributeDefinitionResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Type string `json:"type"`
|
||||
Options []service.UserAttributeOption `json:"options"`
|
||||
Required bool `json:"required"`
|
||||
Validation service.UserAttributeValidation `json:"validation"`
|
||||
Placeholder string `json:"placeholder"`
|
||||
DisplayOrder int `json:"display_order"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// AttributeValueResponse represents attribute value response
|
||||
type AttributeValueResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
AttributeID int64 `json:"attribute_id"`
|
||||
Value string `json:"value"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func defToResponse(def *service.UserAttributeDefinition) *AttributeDefinitionResponse {
|
||||
return &AttributeDefinitionResponse{
|
||||
ID: def.ID,
|
||||
Key: def.Key,
|
||||
Name: def.Name,
|
||||
Description: def.Description,
|
||||
Type: string(def.Type),
|
||||
Options: def.Options,
|
||||
Required: def.Required,
|
||||
Validation: def.Validation,
|
||||
Placeholder: def.Placeholder,
|
||||
DisplayOrder: def.DisplayOrder,
|
||||
Enabled: def.Enabled,
|
||||
CreatedAt: def.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
UpdatedAt: def.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
}
|
||||
}
|
||||
|
||||
func valueToResponse(val *service.UserAttributeValue) *AttributeValueResponse {
|
||||
return &AttributeValueResponse{
|
||||
ID: val.ID,
|
||||
UserID: val.UserID,
|
||||
AttributeID: val.AttributeID,
|
||||
Value: val.Value,
|
||||
CreatedAt: val.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
UpdatedAt: val.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Handlers ---
|
||||
|
||||
// ListDefinitions lists all attribute definitions
|
||||
// GET /admin/user-attributes
|
||||
func (h *UserAttributeHandler) ListDefinitions(c *gin.Context) {
|
||||
enabledOnly := c.Query("enabled") == "true"
|
||||
|
||||
defs, err := h.attrService.ListDefinitions(c.Request.Context(), enabledOnly)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*AttributeDefinitionResponse, 0, len(defs))
|
||||
for i := range defs {
|
||||
out = append(out, defToResponse(&defs[i]))
|
||||
}
|
||||
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// CreateDefinition creates a new attribute definition
|
||||
// POST /admin/user-attributes
|
||||
func (h *UserAttributeHandler) CreateDefinition(c *gin.Context) {
|
||||
var req CreateAttributeDefinitionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
def, err := h.attrService.CreateDefinition(c.Request.Context(), service.CreateAttributeDefinitionInput{
|
||||
Key: req.Key,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Type: service.UserAttributeType(req.Type),
|
||||
Options: req.Options,
|
||||
Required: req.Required,
|
||||
Validation: req.Validation,
|
||||
Placeholder: req.Placeholder,
|
||||
Enabled: req.Enabled,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, defToResponse(def))
|
||||
}
|
||||
|
||||
// UpdateDefinition updates an attribute definition
|
||||
// PUT /admin/user-attributes/:id
|
||||
func (h *UserAttributeHandler) UpdateDefinition(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid attribute ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateAttributeDefinitionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
input := service.UpdateAttributeDefinitionInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Options: req.Options,
|
||||
Required: req.Required,
|
||||
Validation: req.Validation,
|
||||
Placeholder: req.Placeholder,
|
||||
Enabled: req.Enabled,
|
||||
}
|
||||
if req.Type != nil {
|
||||
t := service.UserAttributeType(*req.Type)
|
||||
input.Type = &t
|
||||
}
|
||||
|
||||
def, err := h.attrService.UpdateDefinition(c.Request.Context(), id, input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, defToResponse(def))
|
||||
}
|
||||
|
||||
// DeleteDefinition deletes an attribute definition
|
||||
// DELETE /admin/user-attributes/:id
|
||||
func (h *UserAttributeHandler) DeleteDefinition(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid attribute ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.attrService.DeleteDefinition(c.Request.Context(), id); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Attribute definition deleted successfully"})
|
||||
}
|
||||
|
||||
// ReorderDefinitions reorders attribute definitions
|
||||
// PUT /admin/user-attributes/reorder
|
||||
func (h *UserAttributeHandler) ReorderDefinitions(c *gin.Context) {
|
||||
var req ReorderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Convert IDs array to orders map (position in array = display_order)
|
||||
orders := make(map[int64]int, len(req.IDs))
|
||||
for i, id := range req.IDs {
|
||||
orders[id] = i
|
||||
}
|
||||
|
||||
if err := h.attrService.ReorderDefinitions(c.Request.Context(), orders); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Reorder successful"})
|
||||
}
|
||||
|
||||
// GetUserAttributes gets a user's attribute values
|
||||
// GET /admin/users/:id/attributes
|
||||
func (h *UserAttributeHandler) GetUserAttributes(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*AttributeValueResponse, 0, len(values))
|
||||
for i := range values {
|
||||
out = append(out, valueToResponse(&values[i]))
|
||||
}
|
||||
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// UpdateUserAttributes updates a user's attribute values
|
||||
// PUT /admin/users/:id/attributes
|
||||
func (h *UserAttributeHandler) UpdateUserAttributes(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateUserAttributesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
inputs := make([]service.UpdateUserAttributeInput, 0, len(req.Values))
|
||||
for attrID, value := range req.Values {
|
||||
inputs = append(inputs, service.UpdateUserAttributeInput{
|
||||
AttributeID: attrID,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
if err := h.attrService.UpdateUserAttributes(c.Request.Context(), userID, inputs); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Return updated values
|
||||
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*AttributeValueResponse, 0, len(values))
|
||||
for i := range values {
|
||||
out = append(out, valueToResponse(&values[i]))
|
||||
}
|
||||
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// GetBatchUserAttributes gets attribute values for multiple users
|
||||
// POST /admin/user-attributes/batch
|
||||
func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
|
||||
var req BatchGetUserAttributesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
|
||||
return
|
||||
}
|
||||
|
||||
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, BatchUserAttributesResponse{Attributes: attrs})
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UserAttributeHandler handles user attribute management
|
||||
type UserAttributeHandler struct {
|
||||
attrService *service.UserAttributeService
|
||||
}
|
||||
|
||||
// NewUserAttributeHandler creates a new handler
|
||||
func NewUserAttributeHandler(attrService *service.UserAttributeService) *UserAttributeHandler {
|
||||
return &UserAttributeHandler{attrService: attrService}
|
||||
}
|
||||
|
||||
// --- Request/Response DTOs ---
|
||||
|
||||
// CreateAttributeDefinitionRequest represents create attribute definition request
|
||||
type CreateAttributeDefinitionRequest struct {
|
||||
Key string `json:"key" binding:"required,min=1,max=100"`
|
||||
Name string `json:"name" binding:"required,min=1,max=255"`
|
||||
Description string `json:"description"`
|
||||
Type string `json:"type" binding:"required"`
|
||||
Options []service.UserAttributeOption `json:"options"`
|
||||
Required bool `json:"required"`
|
||||
Validation service.UserAttributeValidation `json:"validation"`
|
||||
Placeholder string `json:"placeholder"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// UpdateAttributeDefinitionRequest represents update attribute definition request
|
||||
type UpdateAttributeDefinitionRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
Type *string `json:"type"`
|
||||
Options *[]service.UserAttributeOption `json:"options"`
|
||||
Required *bool `json:"required"`
|
||||
Validation *service.UserAttributeValidation `json:"validation"`
|
||||
Placeholder *string `json:"placeholder"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// ReorderRequest represents reorder attribute definitions request
|
||||
type ReorderRequest struct {
|
||||
IDs []int64 `json:"ids" binding:"required"`
|
||||
}
|
||||
|
||||
// UpdateUserAttributesRequest represents update user attributes request
|
||||
type UpdateUserAttributesRequest struct {
|
||||
Values map[int64]string `json:"values" binding:"required"`
|
||||
}
|
||||
|
||||
// BatchGetUserAttributesRequest represents batch get user attributes request
|
||||
type BatchGetUserAttributesRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// BatchUserAttributesResponse represents batch user attributes response
|
||||
type BatchUserAttributesResponse struct {
|
||||
// Map of userID -> map of attributeID -> value
|
||||
Attributes map[int64]map[int64]string `json:"attributes"`
|
||||
}
|
||||
|
||||
// AttributeDefinitionResponse represents attribute definition response
|
||||
type AttributeDefinitionResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Type string `json:"type"`
|
||||
Options []service.UserAttributeOption `json:"options"`
|
||||
Required bool `json:"required"`
|
||||
Validation service.UserAttributeValidation `json:"validation"`
|
||||
Placeholder string `json:"placeholder"`
|
||||
DisplayOrder int `json:"display_order"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// AttributeValueResponse represents attribute value response
|
||||
type AttributeValueResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
AttributeID int64 `json:"attribute_id"`
|
||||
Value string `json:"value"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func defToResponse(def *service.UserAttributeDefinition) *AttributeDefinitionResponse {
|
||||
return &AttributeDefinitionResponse{
|
||||
ID: def.ID,
|
||||
Key: def.Key,
|
||||
Name: def.Name,
|
||||
Description: def.Description,
|
||||
Type: string(def.Type),
|
||||
Options: def.Options,
|
||||
Required: def.Required,
|
||||
Validation: def.Validation,
|
||||
Placeholder: def.Placeholder,
|
||||
DisplayOrder: def.DisplayOrder,
|
||||
Enabled: def.Enabled,
|
||||
CreatedAt: def.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
UpdatedAt: def.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
}
|
||||
}
|
||||
|
||||
func valueToResponse(val *service.UserAttributeValue) *AttributeValueResponse {
|
||||
return &AttributeValueResponse{
|
||||
ID: val.ID,
|
||||
UserID: val.UserID,
|
||||
AttributeID: val.AttributeID,
|
||||
Value: val.Value,
|
||||
CreatedAt: val.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
UpdatedAt: val.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Handlers ---
|
||||
|
||||
// ListDefinitions lists all attribute definitions
|
||||
// GET /admin/user-attributes
|
||||
func (h *UserAttributeHandler) ListDefinitions(c *gin.Context) {
|
||||
enabledOnly := c.Query("enabled") == "true"
|
||||
|
||||
defs, err := h.attrService.ListDefinitions(c.Request.Context(), enabledOnly)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*AttributeDefinitionResponse, 0, len(defs))
|
||||
for i := range defs {
|
||||
out = append(out, defToResponse(&defs[i]))
|
||||
}
|
||||
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// CreateDefinition creates a new attribute definition
|
||||
// POST /admin/user-attributes
|
||||
func (h *UserAttributeHandler) CreateDefinition(c *gin.Context) {
|
||||
var req CreateAttributeDefinitionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
def, err := h.attrService.CreateDefinition(c.Request.Context(), service.CreateAttributeDefinitionInput{
|
||||
Key: req.Key,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Type: service.UserAttributeType(req.Type),
|
||||
Options: req.Options,
|
||||
Required: req.Required,
|
||||
Validation: req.Validation,
|
||||
Placeholder: req.Placeholder,
|
||||
Enabled: req.Enabled,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, defToResponse(def))
|
||||
}
|
||||
|
||||
// UpdateDefinition updates an attribute definition
|
||||
// PUT /admin/user-attributes/:id
|
||||
func (h *UserAttributeHandler) UpdateDefinition(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid attribute ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateAttributeDefinitionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
input := service.UpdateAttributeDefinitionInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Options: req.Options,
|
||||
Required: req.Required,
|
||||
Validation: req.Validation,
|
||||
Placeholder: req.Placeholder,
|
||||
Enabled: req.Enabled,
|
||||
}
|
||||
if req.Type != nil {
|
||||
t := service.UserAttributeType(*req.Type)
|
||||
input.Type = &t
|
||||
}
|
||||
|
||||
def, err := h.attrService.UpdateDefinition(c.Request.Context(), id, input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, defToResponse(def))
|
||||
}
|
||||
|
||||
// DeleteDefinition deletes an attribute definition
|
||||
// DELETE /admin/user-attributes/:id
|
||||
func (h *UserAttributeHandler) DeleteDefinition(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid attribute ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.attrService.DeleteDefinition(c.Request.Context(), id); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Attribute definition deleted successfully"})
|
||||
}
|
||||
|
||||
// ReorderDefinitions reorders attribute definitions
|
||||
// PUT /admin/user-attributes/reorder
|
||||
func (h *UserAttributeHandler) ReorderDefinitions(c *gin.Context) {
|
||||
var req ReorderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Convert IDs array to orders map (position in array = display_order)
|
||||
orders := make(map[int64]int, len(req.IDs))
|
||||
for i, id := range req.IDs {
|
||||
orders[id] = i
|
||||
}
|
||||
|
||||
if err := h.attrService.ReorderDefinitions(c.Request.Context(), orders); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Reorder successful"})
|
||||
}
|
||||
|
||||
// GetUserAttributes gets a user's attribute values
|
||||
// GET /admin/users/:id/attributes
|
||||
func (h *UserAttributeHandler) GetUserAttributes(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*AttributeValueResponse, 0, len(values))
|
||||
for i := range values {
|
||||
out = append(out, valueToResponse(&values[i]))
|
||||
}
|
||||
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// UpdateUserAttributes updates a user's attribute values
|
||||
// PUT /admin/users/:id/attributes
|
||||
func (h *UserAttributeHandler) UpdateUserAttributes(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateUserAttributesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
inputs := make([]service.UpdateUserAttributeInput, 0, len(req.Values))
|
||||
for attrID, value := range req.Values {
|
||||
inputs = append(inputs, service.UpdateUserAttributeInput{
|
||||
AttributeID: attrID,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
if err := h.attrService.UpdateUserAttributes(c.Request.Context(), userID, inputs); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Return updated values
|
||||
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*AttributeValueResponse, 0, len(values))
|
||||
for i := range values {
|
||||
out = append(out, valueToResponse(&values[i]))
|
||||
}
|
||||
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// GetBatchUserAttributes gets attribute values for multiple users
|
||||
// POST /admin/user-attributes/batch
|
||||
func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
|
||||
var req BatchGetUserAttributesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
|
||||
return
|
||||
}
|
||||
|
||||
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, BatchUserAttributesResponse{Attributes: attrs})
|
||||
}
|
||||
|
||||
@@ -1,271 +1,271 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UserHandler handles admin user management
|
||||
type UserHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewUserHandler creates a new admin user handler
|
||||
func NewUserHandler(adminService service.AdminService) *UserHandler {
|
||||
return &UserHandler{
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateUserRequest represents admin create user request
|
||||
type CreateUserRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
}
|
||||
|
||||
// UpdateUserRequest represents admin update user request
|
||||
// 使用指针类型来区分"未提供"和"设置为0"
|
||||
type UpdateUserRequest struct {
|
||||
Email string `json:"email" binding:"omitempty,email"`
|
||||
Password string `json:"password" binding:"omitempty,min=6"`
|
||||
Username *string `json:"username"`
|
||||
Notes *string `json:"notes"`
|
||||
Balance *float64 `json:"balance"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||
}
|
||||
|
||||
// UpdateBalanceRequest represents balance update request
|
||||
type UpdateBalanceRequest struct {
|
||||
Balance float64 `json:"balance" binding:"required,gt=0"`
|
||||
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all users with pagination
|
||||
// GET /api/v1/admin/users
|
||||
// Query params:
|
||||
// - status: filter by user status
|
||||
// - role: filter by user role
|
||||
// - search: search in email, username
|
||||
// - attr[{id}]: filter by custom attribute value, e.g. attr[1]=company
|
||||
func (h *UserHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
filters := service.UserListFilters{
|
||||
Status: c.Query("status"),
|
||||
Role: c.Query("role"),
|
||||
Search: c.Query("search"),
|
||||
Attributes: parseAttributeFilters(c),
|
||||
}
|
||||
|
||||
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.User, 0, len(users))
|
||||
for i := range users {
|
||||
out = append(out, *dto.UserFromService(&users[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
// parseAttributeFilters extracts attribute filters from query params
|
||||
// Format: attr[{attributeID}]=value, e.g. attr[1]=company&attr[2]=developer
|
||||
func parseAttributeFilters(c *gin.Context) map[int64]string {
|
||||
result := make(map[int64]string)
|
||||
|
||||
// Get all query params and look for attr[*] pattern
|
||||
for key, values := range c.Request.URL.Query() {
|
||||
if len(values) == 0 || values[0] == "" {
|
||||
continue
|
||||
}
|
||||
// Check if key matches pattern attr[{id}]
|
||||
if len(key) > 5 && key[:5] == "attr[" && key[len(key)-1] == ']' {
|
||||
idStr := key[5 : len(key)-1]
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err == nil && id > 0 {
|
||||
result[id] = values[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetByID handles getting a user by ID
|
||||
// GET /api/v1/admin/users/:id
|
||||
func (h *UserHandler) GetByID(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.GetUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
}
|
||||
|
||||
// Create handles creating a new user
|
||||
// POST /api/v1/admin/users
|
||||
func (h *UserHandler) Create(c *gin.Context) {
|
||||
var req CreateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
}
|
||||
|
||||
// Update handles updating a user
|
||||
// PUT /api/v1/admin/users/:id
|
||||
func (h *UserHandler) Update(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 使用指针类型直接传递,nil 表示未提供该字段
|
||||
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
}
|
||||
|
||||
// Delete handles deleting a user
|
||||
// DELETE /api/v1/admin/users/:id
|
||||
func (h *UserHandler) Delete(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.adminService.DeleteUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "User deleted successfully"})
|
||||
}
|
||||
|
||||
// UpdateBalance handles updating user balance
|
||||
// POST /api/v1/admin/users/:id/balance
|
||||
func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateBalanceRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
}
|
||||
|
||||
// GetUserAPIKeys handles getting user's API keys
|
||||
// GET /api/v1/admin/users/:id/api-keys
|
||||
func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
out = append(out, *dto.ApiKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetUserUsage handles getting user's usage statistics
|
||||
// GET /api/v1/admin/users/:id/usage
|
||||
func (h *UserHandler) GetUserUsage(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
period := c.DefaultQuery("period", "month")
|
||||
|
||||
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UserHandler handles admin user management
|
||||
type UserHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewUserHandler creates a new admin user handler
|
||||
func NewUserHandler(adminService service.AdminService) *UserHandler {
|
||||
return &UserHandler{
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateUserRequest represents admin create user request
|
||||
type CreateUserRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
}
|
||||
|
||||
// UpdateUserRequest represents admin update user request
|
||||
// 使用指针类型来区分"未提供"和"设置为0"
|
||||
type UpdateUserRequest struct {
|
||||
Email string `json:"email" binding:"omitempty,email"`
|
||||
Password string `json:"password" binding:"omitempty,min=6"`
|
||||
Username *string `json:"username"`
|
||||
Notes *string `json:"notes"`
|
||||
Balance *float64 `json:"balance"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||
}
|
||||
|
||||
// UpdateBalanceRequest represents balance update request
|
||||
type UpdateBalanceRequest struct {
|
||||
Balance float64 `json:"balance" binding:"required,gt=0"`
|
||||
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all users with pagination
|
||||
// GET /api/v1/admin/users
|
||||
// Query params:
|
||||
// - status: filter by user status
|
||||
// - role: filter by user role
|
||||
// - search: search in email, username
|
||||
// - attr[{id}]: filter by custom attribute value, e.g. attr[1]=company
|
||||
func (h *UserHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
filters := service.UserListFilters{
|
||||
Status: c.Query("status"),
|
||||
Role: c.Query("role"),
|
||||
Search: c.Query("search"),
|
||||
Attributes: parseAttributeFilters(c),
|
||||
}
|
||||
|
||||
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.User, 0, len(users))
|
||||
for i := range users {
|
||||
out = append(out, *dto.UserFromService(&users[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
// parseAttributeFilters extracts attribute filters from query params
|
||||
// Format: attr[{attributeID}]=value, e.g. attr[1]=company&attr[2]=developer
|
||||
func parseAttributeFilters(c *gin.Context) map[int64]string {
|
||||
result := make(map[int64]string)
|
||||
|
||||
// Get all query params and look for attr[*] pattern
|
||||
for key, values := range c.Request.URL.Query() {
|
||||
if len(values) == 0 || values[0] == "" {
|
||||
continue
|
||||
}
|
||||
// Check if key matches pattern attr[{id}]
|
||||
if len(key) > 5 && key[:5] == "attr[" && key[len(key)-1] == ']' {
|
||||
idStr := key[5 : len(key)-1]
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err == nil && id > 0 {
|
||||
result[id] = values[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetByID handles getting a user by ID
|
||||
// GET /api/v1/admin/users/:id
|
||||
func (h *UserHandler) GetByID(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.GetUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
}
|
||||
|
||||
// Create handles creating a new user
|
||||
// POST /api/v1/admin/users
|
||||
func (h *UserHandler) Create(c *gin.Context) {
|
||||
var req CreateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
}
|
||||
|
||||
// Update handles updating a user
|
||||
// PUT /api/v1/admin/users/:id
|
||||
func (h *UserHandler) Update(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 使用指针类型直接传递,nil 表示未提供该字段
|
||||
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
}
|
||||
|
||||
// Delete handles deleting a user
|
||||
// DELETE /api/v1/admin/users/:id
|
||||
func (h *UserHandler) Delete(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.adminService.DeleteUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "User deleted successfully"})
|
||||
}
|
||||
|
||||
// UpdateBalance handles updating user balance
|
||||
// POST /api/v1/admin/users/:id/balance
|
||||
func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateBalanceRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
}
|
||||
|
||||
// GetUserAPIKeys handles getting user's API keys
|
||||
// GET /api/v1/admin/users/:id/api-keys
|
||||
func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
out = append(out, *dto.ApiKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetUserUsage handles getting user's usage statistics
|
||||
// GET /api/v1/admin/users/:id/usage
|
||||
func (h *UserHandler) GetUserUsage(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
period := c.DefaultQuery("period", "month")
|
||||
|
||||
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
@@ -1,208 +1,208 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// APIKeyHandler handles API key-related requests
|
||||
type APIKeyHandler struct {
|
||||
apiKeyService *service.ApiKeyService
|
||||
}
|
||||
|
||||
// NewAPIKeyHandler creates a new APIKeyHandler
|
||||
func NewAPIKeyHandler(apiKeyService *service.ApiKeyService) *APIKeyHandler {
|
||||
return &APIKeyHandler{
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateAPIKeyRequest represents the create API key request payload
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
GroupID *int64 `json:"group_id"` // nullable
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest represents the update API key request payload
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
}
|
||||
|
||||
// List handles listing user's API keys with pagination
|
||||
// GET /api/v1/api-keys
|
||||
func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
out = append(out, *dto.ApiKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a single API key
|
||||
// GET /api/v1/api-keys/:id
|
||||
func (h *APIKeyHandler) GetByID(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid key ID")
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if key.UserID != subject.UserID {
|
||||
response.Forbidden(c, "Not authorized to access this key")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
}
|
||||
|
||||
// Create handles creating a new API key
|
||||
// POST /api/v1/api-keys
|
||||
func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateAPIKeyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.CreateApiKeyRequest{
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
CustomKey: req.CustomKey,
|
||||
}
|
||||
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
}
|
||||
|
||||
// Update handles updating an API key
|
||||
// PUT /api/v1/api-keys/:id
|
||||
func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid key ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateAPIKeyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.UpdateApiKeyRequest{}
|
||||
if req.Name != "" {
|
||||
svcReq.Name = &req.Name
|
||||
}
|
||||
svcReq.GroupID = req.GroupID
|
||||
if req.Status != "" {
|
||||
svcReq.Status = &req.Status
|
||||
}
|
||||
|
||||
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
}
|
||||
|
||||
// Delete handles deleting an API key
|
||||
// DELETE /api/v1/api-keys/:id
|
||||
func (h *APIKeyHandler) Delete(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid key ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.apiKeyService.Delete(c.Request.Context(), keyID, subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "API key deleted successfully"})
|
||||
}
|
||||
|
||||
// GetAvailableGroups 获取用户可以绑定的分组列表
|
||||
// GET /api/v1/groups/available
|
||||
func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
out = append(out, *dto.GroupFromService(&groups[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// APIKeyHandler handles API key-related requests
|
||||
type APIKeyHandler struct {
|
||||
apiKeyService *service.ApiKeyService
|
||||
}
|
||||
|
||||
// NewAPIKeyHandler creates a new APIKeyHandler
|
||||
func NewAPIKeyHandler(apiKeyService *service.ApiKeyService) *APIKeyHandler {
|
||||
return &APIKeyHandler{
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateAPIKeyRequest represents the create API key request payload
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
GroupID *int64 `json:"group_id"` // nullable
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest represents the update API key request payload
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
}
|
||||
|
||||
// List handles listing user's API keys with pagination
|
||||
// GET /api/v1/api-keys
|
||||
func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
out = append(out, *dto.ApiKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a single API key
|
||||
// GET /api/v1/api-keys/:id
|
||||
func (h *APIKeyHandler) GetByID(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid key ID")
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if key.UserID != subject.UserID {
|
||||
response.Forbidden(c, "Not authorized to access this key")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
}
|
||||
|
||||
// Create handles creating a new API key
|
||||
// POST /api/v1/api-keys
|
||||
func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateAPIKeyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.CreateApiKeyRequest{
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
CustomKey: req.CustomKey,
|
||||
}
|
||||
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
}
|
||||
|
||||
// Update handles updating an API key
|
||||
// PUT /api/v1/api-keys/:id
|
||||
func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid key ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateAPIKeyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.UpdateApiKeyRequest{}
|
||||
if req.Name != "" {
|
||||
svcReq.Name = &req.Name
|
||||
}
|
||||
svcReq.GroupID = req.GroupID
|
||||
if req.Status != "" {
|
||||
svcReq.Status = &req.Status
|
||||
}
|
||||
|
||||
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
}
|
||||
|
||||
// Delete handles deleting an API key
|
||||
// DELETE /api/v1/api-keys/:id
|
||||
func (h *APIKeyHandler) Delete(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid key ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.apiKeyService.Delete(c.Request.Context(), keyID, subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "API key deleted successfully"})
|
||||
}
|
||||
|
||||
// GetAvailableGroups 获取用户可以绑定的分组列表
|
||||
// GET /api/v1/groups/available
|
||||
func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
out = append(out, *dto.GroupFromService(&groups[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
@@ -1,174 +1,174 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication-related requests
|
||||
type AuthHandler struct {
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRequest represents the registration request payload
|
||||
type RegisterRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
VerifyCode string `json:"verify_code"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// SendVerifyCodeRequest 发送验证码请求
|
||||
type SendVerifyCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// SendVerifyCodeResponse 发送验证码响应
|
||||
type SendVerifyCodeResponse struct {
|
||||
Message string `json:"message"`
|
||||
Countdown int `json:"countdown"` // 倒计时秒数
|
||||
}
|
||||
|
||||
// LoginRequest represents the login request payload
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// AuthResponse 认证响应格式(匹配前端期望)
|
||||
type AuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
User *dto.User `json:"user"`
|
||||
}
|
||||
|
||||
// Register handles user registration
|
||||
// POST /api/v1/auth/register
|
||||
func (h *AuthHandler) Register(c *gin.Context) {
|
||||
var req RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
|
||||
if req.VerifyCode == "" {
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
}
|
||||
|
||||
// SendVerifyCode 发送邮箱验证码
|
||||
// POST /api/v1/auth/send-verify-code
|
||||
func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
|
||||
var req SendVerifyCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, SendVerifyCodeResponse{
|
||||
Message: "Verification code sent successfully",
|
||||
Countdown: result.Countdown,
|
||||
})
|
||||
}
|
||||
|
||||
// Login handles user login
|
||||
// POST /api/v1/auth/login
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
}
|
||||
|
||||
// GetCurrentUser handles getting current authenticated user
|
||||
// GET /api/v1/auth/me
|
||||
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
type UserResponse struct {
|
||||
*dto.User
|
||||
RunMode string `json:"run_mode"`
|
||||
}
|
||||
|
||||
runMode := config.RunModeStandard
|
||||
if h.cfg != nil {
|
||||
runMode = h.cfg.RunMode
|
||||
}
|
||||
|
||||
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
|
||||
}
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication-related requests
|
||||
type AuthHandler struct {
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRequest represents the registration request payload
|
||||
type RegisterRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
VerifyCode string `json:"verify_code"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// SendVerifyCodeRequest 发送验证码请求
|
||||
type SendVerifyCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// SendVerifyCodeResponse 发送验证码响应
|
||||
type SendVerifyCodeResponse struct {
|
||||
Message string `json:"message"`
|
||||
Countdown int `json:"countdown"` // 倒计时秒数
|
||||
}
|
||||
|
||||
// LoginRequest represents the login request payload
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// AuthResponse 认证响应格式(匹配前端期望)
|
||||
type AuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
User *dto.User `json:"user"`
|
||||
}
|
||||
|
||||
// Register handles user registration
|
||||
// POST /api/v1/auth/register
|
||||
func (h *AuthHandler) Register(c *gin.Context) {
|
||||
var req RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
|
||||
if req.VerifyCode == "" {
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
}
|
||||
|
||||
// SendVerifyCode 发送邮箱验证码
|
||||
// POST /api/v1/auth/send-verify-code
|
||||
func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
|
||||
var req SendVerifyCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, SendVerifyCodeResponse{
|
||||
Message: "Verification code sent successfully",
|
||||
Countdown: result.Countdown,
|
||||
})
|
||||
}
|
||||
|
||||
// Login handles user login
|
||||
// POST /api/v1/auth/login
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
}
|
||||
|
||||
// GetCurrentUser handles getting current authenticated user
|
||||
// GET /api/v1/auth/me
|
||||
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
type UserResponse struct {
|
||||
*dto.User
|
||||
RunMode string `json:"run_mode"`
|
||||
}
|
||||
|
||||
runMode := config.RunModeStandard
|
||||
if h.cfg != nil {
|
||||
runMode = h.cfg.RunMode
|
||||
}
|
||||
|
||||
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
|
||||
}
|
||||
|
||||
@@ -1,309 +1,309 @@
|
||||
package dto
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
func UserFromServiceShallow(u *service.User) *User {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
AllowedGroups: u.AllowedGroups,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func UserFromService(u *service.User) *User {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
out := UserFromServiceShallow(u)
|
||||
if len(u.ApiKeys) > 0 {
|
||||
out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys))
|
||||
for i := range u.ApiKeys {
|
||||
k := u.ApiKeys[i]
|
||||
out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k))
|
||||
}
|
||||
}
|
||||
if len(u.Subscriptions) > 0 {
|
||||
out.Subscriptions = make([]UserSubscription, 0, len(u.Subscriptions))
|
||||
for i := range u.Subscriptions {
|
||||
s := u.Subscriptions[i]
|
||||
out.Subscriptions = append(out.Subscriptions, *UserSubscriptionFromService(&s))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ApiKeyFromService(k *service.ApiKey) *ApiKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &ApiKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return &Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
AccountCount: g.AccountCount,
|
||||
}
|
||||
}
|
||||
|
||||
func GroupFromService(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
out := GroupFromServiceShallow(g)
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
for i := range g.AccountGroups {
|
||||
ag := g.AccountGroups[i]
|
||||
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
return &Account{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Platform: a.Platform,
|
||||
Type: a.Type,
|
||||
Credentials: a.Credentials,
|
||||
Extra: a.Extra,
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
Priority: a.Priority,
|
||||
Status: a.Status,
|
||||
ErrorMessage: a.ErrorMessage,
|
||||
LastUsedAt: a.LastUsedAt,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
Schedulable: a.Schedulable,
|
||||
RateLimitedAt: a.RateLimitedAt,
|
||||
RateLimitResetAt: a.RateLimitResetAt,
|
||||
OverloadUntil: a.OverloadUntil,
|
||||
SessionWindowStart: a.SessionWindowStart,
|
||||
SessionWindowEnd: a.SessionWindowEnd,
|
||||
SessionWindowStatus: a.SessionWindowStatus,
|
||||
GroupIDs: a.GroupIDs,
|
||||
}
|
||||
}
|
||||
|
||||
func AccountFromService(a *service.Account) *Account {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
out := AccountFromServiceShallow(a)
|
||||
out.Proxy = ProxyFromService(a.Proxy)
|
||||
if len(a.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(a.AccountGroups))
|
||||
for i := range a.AccountGroups {
|
||||
ag := a.AccountGroups[i]
|
||||
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
|
||||
}
|
||||
}
|
||||
if len(a.Groups) > 0 {
|
||||
out.Groups = make([]*Group, 0, len(a.Groups))
|
||||
for _, g := range a.Groups {
|
||||
out.Groups = append(out.Groups, GroupFromServiceShallow(g))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup {
|
||||
if ag == nil {
|
||||
return nil
|
||||
}
|
||||
return &AccountGroup{
|
||||
AccountID: ag.AccountID,
|
||||
GroupID: ag.GroupID,
|
||||
Priority: ag.Priority,
|
||||
CreatedAt: ag.CreatedAt,
|
||||
Account: AccountFromServiceShallow(ag.Account),
|
||||
Group: GroupFromServiceShallow(ag.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func ProxyFromService(p *service.Proxy) *Proxy {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &Proxy{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
Protocol: p.Protocol,
|
||||
Host: p.Host,
|
||||
Port: p.Port,
|
||||
Username: p.Username,
|
||||
Password: p.Password,
|
||||
Status: p.Status,
|
||||
CreatedAt: p.CreatedAt,
|
||||
UpdatedAt: p.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWithAccountCount {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &ProxyWithAccountCount{
|
||||
Proxy: *ProxyFromService(&p.Proxy),
|
||||
AccountCount: p.AccountCount,
|
||||
}
|
||||
}
|
||||
|
||||
func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
||||
if rc == nil {
|
||||
return nil
|
||||
}
|
||||
return &RedeemCode{
|
||||
ID: rc.ID,
|
||||
Code: rc.Code,
|
||||
Type: rc.Type,
|
||||
Value: rc.Value,
|
||||
Status: rc.Status,
|
||||
UsedBy: rc.UsedBy,
|
||||
UsedAt: rc.UsedAt,
|
||||
Notes: rc.Notes,
|
||||
CreatedAt: rc.CreatedAt,
|
||||
GroupID: rc.GroupID,
|
||||
ValidityDays: rc.ValidityDays,
|
||||
User: UserFromServiceShallow(rc.User),
|
||||
Group: GroupFromServiceShallow(rc.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
ApiKeyID: l.ApiKeyID,
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
InputTokens: l.InputTokens,
|
||||
OutputTokens: l.OutputTokens,
|
||||
CacheCreationTokens: l.CacheCreationTokens,
|
||||
CacheReadTokens: l.CacheReadTokens,
|
||||
CacheCreation5mTokens: l.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: l.CacheCreation1hTokens,
|
||||
InputCost: l.InputCost,
|
||||
OutputCost: l.OutputCost,
|
||||
CacheCreationCost: l.CacheCreationCost,
|
||||
CacheReadCost: l.CacheReadCost,
|
||||
TotalCost: l.TotalCost,
|
||||
ActualCost: l.ActualCost,
|
||||
RateMultiplier: l.RateMultiplier,
|
||||
BillingType: l.BillingType,
|
||||
Stream: l.Stream,
|
||||
DurationMs: l.DurationMs,
|
||||
FirstTokenMs: l.FirstTokenMs,
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
ApiKey: ApiKeyFromService(l.ApiKey),
|
||||
Account: AccountFromService(l.Account),
|
||||
Group: GroupFromServiceShallow(l.Group),
|
||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||
}
|
||||
}
|
||||
|
||||
func SettingFromService(s *service.Setting) *Setting {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &Setting{
|
||||
ID: s.ID,
|
||||
Key: s.Key,
|
||||
Value: s.Value,
|
||||
UpdatedAt: s.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscription {
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserSubscription{
|
||||
ID: sub.ID,
|
||||
UserID: sub.UserID,
|
||||
GroupID: sub.GroupID,
|
||||
StartsAt: sub.StartsAt,
|
||||
ExpiresAt: sub.ExpiresAt,
|
||||
Status: sub.Status,
|
||||
DailyWindowStart: sub.DailyWindowStart,
|
||||
WeeklyWindowStart: sub.WeeklyWindowStart,
|
||||
MonthlyWindowStart: sub.MonthlyWindowStart,
|
||||
DailyUsageUSD: sub.DailyUsageUSD,
|
||||
WeeklyUsageUSD: sub.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: sub.MonthlyUsageUSD,
|
||||
AssignedBy: sub.AssignedBy,
|
||||
AssignedAt: sub.AssignedAt,
|
||||
Notes: sub.Notes,
|
||||
CreatedAt: sub.CreatedAt,
|
||||
UpdatedAt: sub.UpdatedAt,
|
||||
User: UserFromServiceShallow(sub.User),
|
||||
Group: GroupFromServiceShallow(sub.Group),
|
||||
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
|
||||
}
|
||||
}
|
||||
|
||||
func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
subs := make([]UserSubscription, 0, len(r.Subscriptions))
|
||||
for i := range r.Subscriptions {
|
||||
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
|
||||
}
|
||||
return &BulkAssignResult{
|
||||
SuccessCount: r.SuccessCount,
|
||||
FailedCount: r.FailedCount,
|
||||
Subscriptions: subs,
|
||||
Errors: r.Errors,
|
||||
}
|
||||
}
|
||||
package dto
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
func UserFromServiceShallow(u *service.User) *User {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
AllowedGroups: u.AllowedGroups,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func UserFromService(u *service.User) *User {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
out := UserFromServiceShallow(u)
|
||||
if len(u.ApiKeys) > 0 {
|
||||
out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys))
|
||||
for i := range u.ApiKeys {
|
||||
k := u.ApiKeys[i]
|
||||
out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k))
|
||||
}
|
||||
}
|
||||
if len(u.Subscriptions) > 0 {
|
||||
out.Subscriptions = make([]UserSubscription, 0, len(u.Subscriptions))
|
||||
for i := range u.Subscriptions {
|
||||
s := u.Subscriptions[i]
|
||||
out.Subscriptions = append(out.Subscriptions, *UserSubscriptionFromService(&s))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ApiKeyFromService(k *service.ApiKey) *ApiKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &ApiKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return &Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
AccountCount: g.AccountCount,
|
||||
}
|
||||
}
|
||||
|
||||
func GroupFromService(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
out := GroupFromServiceShallow(g)
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
for i := range g.AccountGroups {
|
||||
ag := g.AccountGroups[i]
|
||||
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
return &Account{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Platform: a.Platform,
|
||||
Type: a.Type,
|
||||
Credentials: a.Credentials,
|
||||
Extra: a.Extra,
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
Priority: a.Priority,
|
||||
Status: a.Status,
|
||||
ErrorMessage: a.ErrorMessage,
|
||||
LastUsedAt: a.LastUsedAt,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
Schedulable: a.Schedulable,
|
||||
RateLimitedAt: a.RateLimitedAt,
|
||||
RateLimitResetAt: a.RateLimitResetAt,
|
||||
OverloadUntil: a.OverloadUntil,
|
||||
SessionWindowStart: a.SessionWindowStart,
|
||||
SessionWindowEnd: a.SessionWindowEnd,
|
||||
SessionWindowStatus: a.SessionWindowStatus,
|
||||
GroupIDs: a.GroupIDs,
|
||||
}
|
||||
}
|
||||
|
||||
func AccountFromService(a *service.Account) *Account {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
out := AccountFromServiceShallow(a)
|
||||
out.Proxy = ProxyFromService(a.Proxy)
|
||||
if len(a.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(a.AccountGroups))
|
||||
for i := range a.AccountGroups {
|
||||
ag := a.AccountGroups[i]
|
||||
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
|
||||
}
|
||||
}
|
||||
if len(a.Groups) > 0 {
|
||||
out.Groups = make([]*Group, 0, len(a.Groups))
|
||||
for _, g := range a.Groups {
|
||||
out.Groups = append(out.Groups, GroupFromServiceShallow(g))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup {
|
||||
if ag == nil {
|
||||
return nil
|
||||
}
|
||||
return &AccountGroup{
|
||||
AccountID: ag.AccountID,
|
||||
GroupID: ag.GroupID,
|
||||
Priority: ag.Priority,
|
||||
CreatedAt: ag.CreatedAt,
|
||||
Account: AccountFromServiceShallow(ag.Account),
|
||||
Group: GroupFromServiceShallow(ag.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func ProxyFromService(p *service.Proxy) *Proxy {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &Proxy{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
Protocol: p.Protocol,
|
||||
Host: p.Host,
|
||||
Port: p.Port,
|
||||
Username: p.Username,
|
||||
Password: p.Password,
|
||||
Status: p.Status,
|
||||
CreatedAt: p.CreatedAt,
|
||||
UpdatedAt: p.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWithAccountCount {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &ProxyWithAccountCount{
|
||||
Proxy: *ProxyFromService(&p.Proxy),
|
||||
AccountCount: p.AccountCount,
|
||||
}
|
||||
}
|
||||
|
||||
func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
||||
if rc == nil {
|
||||
return nil
|
||||
}
|
||||
return &RedeemCode{
|
||||
ID: rc.ID,
|
||||
Code: rc.Code,
|
||||
Type: rc.Type,
|
||||
Value: rc.Value,
|
||||
Status: rc.Status,
|
||||
UsedBy: rc.UsedBy,
|
||||
UsedAt: rc.UsedAt,
|
||||
Notes: rc.Notes,
|
||||
CreatedAt: rc.CreatedAt,
|
||||
GroupID: rc.GroupID,
|
||||
ValidityDays: rc.ValidityDays,
|
||||
User: UserFromServiceShallow(rc.User),
|
||||
Group: GroupFromServiceShallow(rc.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
ApiKeyID: l.ApiKeyID,
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
InputTokens: l.InputTokens,
|
||||
OutputTokens: l.OutputTokens,
|
||||
CacheCreationTokens: l.CacheCreationTokens,
|
||||
CacheReadTokens: l.CacheReadTokens,
|
||||
CacheCreation5mTokens: l.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: l.CacheCreation1hTokens,
|
||||
InputCost: l.InputCost,
|
||||
OutputCost: l.OutputCost,
|
||||
CacheCreationCost: l.CacheCreationCost,
|
||||
CacheReadCost: l.CacheReadCost,
|
||||
TotalCost: l.TotalCost,
|
||||
ActualCost: l.ActualCost,
|
||||
RateMultiplier: l.RateMultiplier,
|
||||
BillingType: l.BillingType,
|
||||
Stream: l.Stream,
|
||||
DurationMs: l.DurationMs,
|
||||
FirstTokenMs: l.FirstTokenMs,
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
ApiKey: ApiKeyFromService(l.ApiKey),
|
||||
Account: AccountFromService(l.Account),
|
||||
Group: GroupFromServiceShallow(l.Group),
|
||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||
}
|
||||
}
|
||||
|
||||
func SettingFromService(s *service.Setting) *Setting {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &Setting{
|
||||
ID: s.ID,
|
||||
Key: s.Key,
|
||||
Value: s.Value,
|
||||
UpdatedAt: s.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscription {
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserSubscription{
|
||||
ID: sub.ID,
|
||||
UserID: sub.UserID,
|
||||
GroupID: sub.GroupID,
|
||||
StartsAt: sub.StartsAt,
|
||||
ExpiresAt: sub.ExpiresAt,
|
||||
Status: sub.Status,
|
||||
DailyWindowStart: sub.DailyWindowStart,
|
||||
WeeklyWindowStart: sub.WeeklyWindowStart,
|
||||
MonthlyWindowStart: sub.MonthlyWindowStart,
|
||||
DailyUsageUSD: sub.DailyUsageUSD,
|
||||
WeeklyUsageUSD: sub.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: sub.MonthlyUsageUSD,
|
||||
AssignedBy: sub.AssignedBy,
|
||||
AssignedAt: sub.AssignedAt,
|
||||
Notes: sub.Notes,
|
||||
CreatedAt: sub.CreatedAt,
|
||||
UpdatedAt: sub.UpdatedAt,
|
||||
User: UserFromServiceShallow(sub.User),
|
||||
Group: GroupFromServiceShallow(sub.Group),
|
||||
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
|
||||
}
|
||||
}
|
||||
|
||||
func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
subs := make([]UserSubscription, 0, len(r.Subscriptions))
|
||||
for i := range r.Subscriptions {
|
||||
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
|
||||
}
|
||||
return &BulkAssignResult{
|
||||
SuccessCount: r.SuccessCount,
|
||||
FailedCount: r.FailedCount,
|
||||
Subscriptions: subs,
|
||||
Errors: r.Errors,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,43 +1,43 @@
|
||||
package dto
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password,omitempty"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
package dto
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password,omitempty"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
@@ -1,218 +1,218 @@
|
||||
package dto
|
||||
|
||||
import "time"
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Status string `json:"status"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
ApiKeys []ApiKey `json:"api_keys,omitempty"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
type ApiKey struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
Status string `json:"status"`
|
||||
|
||||
SubscriptionType string `json:"subscription_type"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
Schedulable bool `json:"schedulable"`
|
||||
|
||||
RateLimitedAt *time.Time `json:"rate_limited_at"`
|
||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||
OverloadUntil *time.Time `json:"overload_until"`
|
||||
|
||||
SessionWindowStart *time.Time `json:"session_window_start"`
|
||||
SessionWindowEnd *time.Time `json:"session_window_end"`
|
||||
SessionWindowStatus string `json:"session_window_status"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
GroupIDs []int64 `json:"group_ids,omitempty"`
|
||||
Groups []*Group `json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
type AccountGroup struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
Priority int `json:"priority"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
Account *Account `json:"account,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
type Proxy struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Protocol string `json:"protocol"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"-"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type ProxyWithAccountCount struct {
|
||||
Proxy
|
||||
AccountCount int64 `json:"account_count"`
|
||||
}
|
||||
|
||||
type RedeemCode struct {
|
||||
ID int64 `json:"id"`
|
||||
Code string `json:"code"`
|
||||
Type string `json:"type"`
|
||||
Value float64 `json:"value"`
|
||||
Status string `json:"status"`
|
||||
UsedBy *int64 `json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
Notes string `json:"notes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
type UsageLog struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
SubscriptionID *int64 `json:"subscription_id"`
|
||||
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationTokens int `json:"cache_creation_tokens"`
|
||||
CacheReadTokens int `json:"cache_read_tokens"`
|
||||
|
||||
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
|
||||
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
|
||||
|
||||
InputCost float64 `json:"input_cost"`
|
||||
OutputCost float64 `json:"output_cost"`
|
||||
CacheCreationCost float64 `json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `json:"cache_read_cost"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
|
||||
BillingType int8 `json:"billing_type"`
|
||||
Stream bool `json:"stream"`
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
FirstTokenMs *int `json:"first_token_ms"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
ApiKey *ApiKey `json:"api_key,omitempty"`
|
||||
Account *Account `json:"account,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
Subscription *UserSubscription `json:"subscription,omitempty"`
|
||||
}
|
||||
|
||||
type Setting struct {
|
||||
ID int64 `json:"id"`
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserSubscription struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
|
||||
StartsAt time.Time `json:"starts_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Status string `json:"status"`
|
||||
|
||||
DailyWindowStart *time.Time `json:"daily_window_start"`
|
||||
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
|
||||
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
|
||||
|
||||
DailyUsageUSD float64 `json:"daily_usage_usd"`
|
||||
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
|
||||
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
|
||||
|
||||
AssignedBy *int64 `json:"assigned_by"`
|
||||
AssignedAt time.Time `json:"assigned_at"`
|
||||
Notes string `json:"notes"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
AssignedByUser *User `json:"assigned_by_user,omitempty"`
|
||||
}
|
||||
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
}
|
||||
package dto
|
||||
|
||||
import "time"
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Status string `json:"status"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
ApiKeys []ApiKey `json:"api_keys,omitempty"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
type ApiKey struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
Status string `json:"status"`
|
||||
|
||||
SubscriptionType string `json:"subscription_type"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
Schedulable bool `json:"schedulable"`
|
||||
|
||||
RateLimitedAt *time.Time `json:"rate_limited_at"`
|
||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||
OverloadUntil *time.Time `json:"overload_until"`
|
||||
|
||||
SessionWindowStart *time.Time `json:"session_window_start"`
|
||||
SessionWindowEnd *time.Time `json:"session_window_end"`
|
||||
SessionWindowStatus string `json:"session_window_status"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
GroupIDs []int64 `json:"group_ids,omitempty"`
|
||||
Groups []*Group `json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
type AccountGroup struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
Priority int `json:"priority"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
Account *Account `json:"account,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
type Proxy struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Protocol string `json:"protocol"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"-"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type ProxyWithAccountCount struct {
|
||||
Proxy
|
||||
AccountCount int64 `json:"account_count"`
|
||||
}
|
||||
|
||||
type RedeemCode struct {
|
||||
ID int64 `json:"id"`
|
||||
Code string `json:"code"`
|
||||
Type string `json:"type"`
|
||||
Value float64 `json:"value"`
|
||||
Status string `json:"status"`
|
||||
UsedBy *int64 `json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
Notes string `json:"notes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
type UsageLog struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
SubscriptionID *int64 `json:"subscription_id"`
|
||||
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationTokens int `json:"cache_creation_tokens"`
|
||||
CacheReadTokens int `json:"cache_read_tokens"`
|
||||
|
||||
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
|
||||
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
|
||||
|
||||
InputCost float64 `json:"input_cost"`
|
||||
OutputCost float64 `json:"output_cost"`
|
||||
CacheCreationCost float64 `json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `json:"cache_read_cost"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
|
||||
BillingType int8 `json:"billing_type"`
|
||||
Stream bool `json:"stream"`
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
FirstTokenMs *int `json:"first_token_ms"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
ApiKey *ApiKey `json:"api_key,omitempty"`
|
||||
Account *Account `json:"account,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
Subscription *UserSubscription `json:"subscription,omitempty"`
|
||||
}
|
||||
|
||||
type Setting struct {
|
||||
ID int64 `json:"id"`
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserSubscription struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
|
||||
StartsAt time.Time `json:"starts_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Status string `json:"status"`
|
||||
|
||||
DailyWindowStart *time.Time `json:"daily_window_start"`
|
||||
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
|
||||
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
|
||||
|
||||
DailyUsageUSD float64 `json:"daily_usage_usd"`
|
||||
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
|
||||
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
|
||||
|
||||
AssignedBy *int64 `json:"assigned_by"`
|
||||
AssignedAt time.Time `json:"assigned_at"`
|
||||
Notes string `json:"notes"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
AssignedByUser *User `json:"assigned_by_user,omitempty"`
|
||||
}
|
||||
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,263 +1,263 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 并发槽位等待相关常量
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用固定间隔(100ms)轮询并发槽位,存在以下问题:
|
||||
// 1. 高并发时频繁轮询增加 Redis 压力
|
||||
// 2. 固定间隔可能导致多个请求同时重试(惊群效应)
|
||||
//
|
||||
// 新实现使用指数退避 + 抖动算法:
|
||||
// 1. 初始退避 100ms,每次乘以 1.5,最大 2s
|
||||
// 2. 添加 ±20% 的随机抖动,分散重试时间点
|
||||
// 3. 减少 Redis 压力,避免惊群效应
|
||||
const (
|
||||
// maxConcurrencyWait 等待并发槽位的最大时间
|
||||
maxConcurrencyWait = 30 * time.Second
|
||||
// pingInterval 流式响应等待时发送 ping 的间隔
|
||||
pingInterval = 15 * time.Second
|
||||
// initialBackoff 初始退避时间
|
||||
initialBackoff = 100 * time.Millisecond
|
||||
// backoffMultiplier 退避时间乘数(指数退避)
|
||||
backoffMultiplier = 1.5
|
||||
// maxBackoff 最大退避时间
|
||||
maxBackoff = 2 * time.Second
|
||||
)
|
||||
|
||||
// SSEPingFormat defines the format of SSE ping events for different platforms
|
||||
type SSEPingFormat string
|
||||
|
||||
const (
|
||||
// SSEPingFormatClaude is the Claude/Anthropic SSE ping format
|
||||
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
|
||||
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
|
||||
SSEPingFormatNone SSEPingFormat = ""
|
||||
)
|
||||
|
||||
// ConcurrencyError represents a concurrency limit error with context
|
||||
type ConcurrencyError struct {
|
||||
SlotType string
|
||||
IsTimeout bool
|
||||
}
|
||||
|
||||
func (e *ConcurrencyError) Error() string {
|
||||
if e.IsTimeout {
|
||||
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
|
||||
}
|
||||
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
|
||||
}
|
||||
|
||||
// ConcurrencyHelper provides common concurrency slot management for gateway handlers
|
||||
type ConcurrencyHelper struct {
|
||||
concurrencyService *service.ConcurrencyService
|
||||
pingFormat SSEPingFormat
|
||||
}
|
||||
|
||||
// NewConcurrencyHelper creates a new ConcurrencyHelper
|
||||
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
|
||||
return &ConcurrencyHelper{
|
||||
concurrencyService: concurrencyService,
|
||||
pingFormat: pingFormat,
|
||||
}
|
||||
}
|
||||
|
||||
// IncrementWaitCount increments the wait count for a user
|
||||
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
|
||||
}
|
||||
|
||||
// DecrementWaitCount decrements the wait count for a user
|
||||
func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||
h.concurrencyService.DecrementWaitCount(ctx, userID)
|
||||
}
|
||||
|
||||
// IncrementAccountWaitCount increments the wait count for an account
|
||||
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
|
||||
}
|
||||
|
||||
// DecrementAccountWaitCount decrements the wait count for an account
|
||||
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
|
||||
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
|
||||
}
|
||||
|
||||
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "user", userID, maxConcurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "account", accountID, maxConcurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// Try immediate acquire first (avoid unnecessary wait)
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Determine if ping is needed (streaming + ping format defined)
|
||||
needPing := isStream && h.pingFormat != ""
|
||||
|
||||
var flusher http.Flusher
|
||||
if needPing {
|
||||
var ok bool
|
||||
flusher, ok = c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("streaming not supported")
|
||||
}
|
||||
}
|
||||
|
||||
// Only create ping ticker if ping is needed
|
||||
var pingCh <-chan time.Time
|
||||
if needPing {
|
||||
pingTicker := time.NewTicker(pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
pingCh = pingTicker.C
|
||||
}
|
||||
|
||||
backoff := initialBackoff
|
||||
timer := time.NewTimer(backoff)
|
||||
defer timer.Stop()
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, &ConcurrencyError{
|
||||
SlotType: slotType,
|
||||
IsTimeout: true,
|
||||
}
|
||||
|
||||
case <-pingCh:
|
||||
// Send ping to keep connection alive
|
||||
if !*streamStarted {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
case <-timer.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
backoff = nextBackoff(backoff, rng)
|
||||
timer.Reset(backoff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// nextBackoff 计算下一次退避时间
|
||||
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
|
||||
// current: 当前退避时间
|
||||
// rng: 随机数生成器(可为 nil,此时不添加抖动)
|
||||
// 返回值:下一次退避时间(100ms ~ 2s 之间)
|
||||
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
|
||||
// 指数退避:当前时间 * 1.5
|
||||
next := time.Duration(float64(current) * backoffMultiplier)
|
||||
if next > maxBackoff {
|
||||
next = maxBackoff
|
||||
}
|
||||
if rng == nil {
|
||||
return next
|
||||
}
|
||||
// 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2)
|
||||
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
|
||||
jitter := 0.8 + rng.Float64()*0.4
|
||||
jittered := time.Duration(float64(next) * jitter)
|
||||
if jittered < initialBackoff {
|
||||
return initialBackoff
|
||||
}
|
||||
if jittered > maxBackoff {
|
||||
return maxBackoff
|
||||
}
|
||||
return jittered
|
||||
}
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 并发槽位等待相关常量
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用固定间隔(100ms)轮询并发槽位,存在以下问题:
|
||||
// 1. 高并发时频繁轮询增加 Redis 压力
|
||||
// 2. 固定间隔可能导致多个请求同时重试(惊群效应)
|
||||
//
|
||||
// 新实现使用指数退避 + 抖动算法:
|
||||
// 1. 初始退避 100ms,每次乘以 1.5,最大 2s
|
||||
// 2. 添加 ±20% 的随机抖动,分散重试时间点
|
||||
// 3. 减少 Redis 压力,避免惊群效应
|
||||
const (
|
||||
// maxConcurrencyWait 等待并发槽位的最大时间
|
||||
maxConcurrencyWait = 30 * time.Second
|
||||
// pingInterval 流式响应等待时发送 ping 的间隔
|
||||
pingInterval = 15 * time.Second
|
||||
// initialBackoff 初始退避时间
|
||||
initialBackoff = 100 * time.Millisecond
|
||||
// backoffMultiplier 退避时间乘数(指数退避)
|
||||
backoffMultiplier = 1.5
|
||||
// maxBackoff 最大退避时间
|
||||
maxBackoff = 2 * time.Second
|
||||
)
|
||||
|
||||
// SSEPingFormat defines the format of SSE ping events for different platforms
|
||||
type SSEPingFormat string
|
||||
|
||||
const (
|
||||
// SSEPingFormatClaude is the Claude/Anthropic SSE ping format
|
||||
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
|
||||
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
|
||||
SSEPingFormatNone SSEPingFormat = ""
|
||||
)
|
||||
|
||||
// ConcurrencyError represents a concurrency limit error with context
|
||||
type ConcurrencyError struct {
|
||||
SlotType string
|
||||
IsTimeout bool
|
||||
}
|
||||
|
||||
func (e *ConcurrencyError) Error() string {
|
||||
if e.IsTimeout {
|
||||
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
|
||||
}
|
||||
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
|
||||
}
|
||||
|
||||
// ConcurrencyHelper provides common concurrency slot management for gateway handlers
|
||||
type ConcurrencyHelper struct {
|
||||
concurrencyService *service.ConcurrencyService
|
||||
pingFormat SSEPingFormat
|
||||
}
|
||||
|
||||
// NewConcurrencyHelper creates a new ConcurrencyHelper
|
||||
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
|
||||
return &ConcurrencyHelper{
|
||||
concurrencyService: concurrencyService,
|
||||
pingFormat: pingFormat,
|
||||
}
|
||||
}
|
||||
|
||||
// IncrementWaitCount increments the wait count for a user
|
||||
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
|
||||
}
|
||||
|
||||
// DecrementWaitCount decrements the wait count for a user
|
||||
func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||
h.concurrencyService.DecrementWaitCount(ctx, userID)
|
||||
}
|
||||
|
||||
// IncrementAccountWaitCount increments the wait count for an account
|
||||
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
|
||||
}
|
||||
|
||||
// DecrementAccountWaitCount decrements the wait count for an account
|
||||
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
|
||||
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
|
||||
}
|
||||
|
||||
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "user", userID, maxConcurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "account", accountID, maxConcurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// Try immediate acquire first (avoid unnecessary wait)
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Determine if ping is needed (streaming + ping format defined)
|
||||
needPing := isStream && h.pingFormat != ""
|
||||
|
||||
var flusher http.Flusher
|
||||
if needPing {
|
||||
var ok bool
|
||||
flusher, ok = c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("streaming not supported")
|
||||
}
|
||||
}
|
||||
|
||||
// Only create ping ticker if ping is needed
|
||||
var pingCh <-chan time.Time
|
||||
if needPing {
|
||||
pingTicker := time.NewTicker(pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
pingCh = pingTicker.C
|
||||
}
|
||||
|
||||
backoff := initialBackoff
|
||||
timer := time.NewTimer(backoff)
|
||||
defer timer.Stop()
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, &ConcurrencyError{
|
||||
SlotType: slotType,
|
||||
IsTimeout: true,
|
||||
}
|
||||
|
||||
case <-pingCh:
|
||||
// Send ping to keep connection alive
|
||||
if !*streamStarted {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
case <-timer.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
backoff = nextBackoff(backoff, rng)
|
||||
timer.Reset(backoff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// nextBackoff 计算下一次退避时间
|
||||
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
|
||||
// current: 当前退避时间
|
||||
// rng: 随机数生成器(可为 nil,此时不添加抖动)
|
||||
// 返回值:下一次退避时间(100ms ~ 2s 之间)
|
||||
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
|
||||
// 指数退避:当前时间 * 1.5
|
||||
next := time.Duration(float64(current) * backoffMultiplier)
|
||||
if next > maxBackoff {
|
||||
next = maxBackoff
|
||||
}
|
||||
if rng == nil {
|
||||
return next
|
||||
}
|
||||
// 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2)
|
||||
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
|
||||
jitter := 0.8 + rng.Float64()*0.4
|
||||
jittered := time.Duration(float64(next) * jitter)
|
||||
if jittered < initialBackoff {
|
||||
return initialBackoff
|
||||
}
|
||||
if jittered > maxBackoff {
|
||||
return maxBackoff
|
||||
}
|
||||
return jittered
|
||||
}
|
||||
|
||||
@@ -1,407 +1,407 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GeminiV1BetaListModels proxies:
|
||||
// GET /v1beta/models
|
||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
}
|
||||
// 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
|
||||
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
|
||||
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
|
||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||
return
|
||||
}
|
||||
|
||||
// 强制 antigravity 模式:返回 antigravity 支持的模型列表
|
||||
if forcePlatform == service.PlatformAntigravity {
|
||||
c.JSON(http.StatusOK, antigravity.FallbackGeminiModelsList())
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
||||
if err != nil {
|
||||
// 没有 gemini 账户,检查是否有 antigravity 账户可用
|
||||
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
|
||||
if hasAntigravity {
|
||||
// antigravity 账户使用静态模型列表
|
||||
c.JSON(http.StatusOK, gemini.FallbackModelsList())
|
||||
return
|
||||
}
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models")
|
||||
if err != nil {
|
||||
googleError(c, http.StatusBadGateway, err.Error())
|
||||
return
|
||||
}
|
||||
if shouldFallbackGeminiModels(res) {
|
||||
c.JSON(http.StatusOK, gemini.FallbackModelsList())
|
||||
return
|
||||
}
|
||||
writeUpstreamResponse(c, res)
|
||||
}
|
||||
|
||||
// GeminiV1BetaGetModel proxies:
|
||||
// GET /v1beta/models/{model}
|
||||
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
}
|
||||
// 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
|
||||
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
|
||||
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
|
||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||
return
|
||||
}
|
||||
|
||||
modelName := strings.TrimSpace(c.Param("model"))
|
||||
if modelName == "" {
|
||||
googleError(c, http.StatusBadRequest, "Missing model in URL")
|
||||
return
|
||||
}
|
||||
|
||||
// 强制 antigravity 模式:返回 antigravity 模型信息
|
||||
if forcePlatform == service.PlatformAntigravity {
|
||||
c.JSON(http.StatusOK, antigravity.FallbackGeminiModel(modelName))
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
||||
if err != nil {
|
||||
// 没有 gemini 账户,检查是否有 antigravity 账户可用
|
||||
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
|
||||
if hasAntigravity {
|
||||
// antigravity 账户使用静态模型信息
|
||||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||||
return
|
||||
}
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models/"+modelName)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusBadGateway, err.Error())
|
||||
return
|
||||
}
|
||||
if shouldFallbackGeminiModels(res) {
|
||||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||||
return
|
||||
}
|
||||
writeUpstreamResponse(c, res)
|
||||
}
|
||||
|
||||
// GeminiV1BetaModels proxies Gemini native REST endpoints like:
|
||||
// POST /v1beta/models/{model}:generateContent
|
||||
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
|
||||
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
}
|
||||
authSubject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
googleError(c, http.StatusInternalServerError, "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组
|
||||
if !middleware.HasForcePlatform(c) {
|
||||
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
|
||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/"))
|
||||
if err != nil {
|
||||
googleError(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
stream := action == "streamGenerateContent"
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
googleError(c, http.StatusBadRequest, "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
googleError(c, http.StatusBadRequest, "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// Get subscription (may be nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
|
||||
// For Gemini native API, do not send Claude-style ping frames.
|
||||
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
|
||||
|
||||
// 0) wait queue check
|
||||
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
|
||||
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
|
||||
|
||||
// 1) user concurrency slot
|
||||
streamStarted := false
|
||||
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2) billing eligibility check (after wait)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
googleError(c, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
sessionKey := sessionHash
|
||||
if sessionHash != "" {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
|
||||
return
|
||||
}
|
||||
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
stream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 5) forward (根据平台分流)
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body)
|
||||
} else {
|
||||
result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// ForwardNative already wrote the response
|
||||
log.Printf("Gemini native forward failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6) record usage async
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func parseGeminiModelAction(rest string) (model string, action string, err error) {
|
||||
rest = strings.TrimSpace(rest)
|
||||
if rest == "" {
|
||||
return "", "", &pathParseError{"missing path"}
|
||||
}
|
||||
|
||||
// Standard: {model}:{action}
|
||||
if i := strings.Index(rest, ":"); i > 0 && i < len(rest)-1 {
|
||||
return rest[:i], rest[i+1:], nil
|
||||
}
|
||||
|
||||
// Fallback: {model}/{action}
|
||||
if i := strings.Index(rest, "/"); i > 0 && i < len(rest)-1 {
|
||||
return rest[:i], rest[i+1:], nil
|
||||
}
|
||||
|
||||
return "", "", &pathParseError{"invalid model action path"}
|
||||
}
|
||||
|
||||
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
|
||||
status, message := mapGeminiUpstreamError(statusCode)
|
||||
googleError(c, status, message)
|
||||
}
|
||||
|
||||
func mapGeminiUpstreamError(statusCode int) (int, string) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
type pathParseError struct{ msg string }
|
||||
|
||||
func (e *pathParseError) Error() string { return e.msg }
|
||||
|
||||
func googleError(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": googleapi.HTTPStatusToGoogleStatus(status),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
|
||||
if res == nil {
|
||||
googleError(c, http.StatusBadGateway, "Empty upstream response")
|
||||
return
|
||||
}
|
||||
for k, vv := range res.Headers {
|
||||
// Avoid overriding content-length and hop-by-hop headers.
|
||||
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
c.Writer.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
contentType := res.Headers.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(res.StatusCode, contentType, res.Body)
|
||||
}
|
||||
|
||||
func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
||||
if res == nil {
|
||||
return true
|
||||
}
|
||||
if res.StatusCode != http.StatusUnauthorized && res.StatusCode != http.StatusForbidden {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(strings.ToLower(res.Headers.Get("Www-Authenticate")), "insufficient_scope") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(strings.ToLower(string(res.Body)), "insufficient authentication scopes") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(strings.ToLower(string(res.Body)), "access_token_scope_insufficient") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GeminiV1BetaListModels proxies:
|
||||
// GET /v1beta/models
|
||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
}
|
||||
// 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
|
||||
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
|
||||
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
|
||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||
return
|
||||
}
|
||||
|
||||
// 强制 antigravity 模式:返回 antigravity 支持的模型列表
|
||||
if forcePlatform == service.PlatformAntigravity {
|
||||
c.JSON(http.StatusOK, antigravity.FallbackGeminiModelsList())
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
||||
if err != nil {
|
||||
// 没有 gemini 账户,检查是否有 antigravity 账户可用
|
||||
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
|
||||
if hasAntigravity {
|
||||
// antigravity 账户使用静态模型列表
|
||||
c.JSON(http.StatusOK, gemini.FallbackModelsList())
|
||||
return
|
||||
}
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models")
|
||||
if err != nil {
|
||||
googleError(c, http.StatusBadGateway, err.Error())
|
||||
return
|
||||
}
|
||||
if shouldFallbackGeminiModels(res) {
|
||||
c.JSON(http.StatusOK, gemini.FallbackModelsList())
|
||||
return
|
||||
}
|
||||
writeUpstreamResponse(c, res)
|
||||
}
|
||||
|
||||
// GeminiV1BetaGetModel proxies:
|
||||
// GET /v1beta/models/{model}
|
||||
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
}
|
||||
// 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
|
||||
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
|
||||
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
|
||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||
return
|
||||
}
|
||||
|
||||
modelName := strings.TrimSpace(c.Param("model"))
|
||||
if modelName == "" {
|
||||
googleError(c, http.StatusBadRequest, "Missing model in URL")
|
||||
return
|
||||
}
|
||||
|
||||
// 强制 antigravity 模式:返回 antigravity 模型信息
|
||||
if forcePlatform == service.PlatformAntigravity {
|
||||
c.JSON(http.StatusOK, antigravity.FallbackGeminiModel(modelName))
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
||||
if err != nil {
|
||||
// 没有 gemini 账户,检查是否有 antigravity 账户可用
|
||||
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
|
||||
if hasAntigravity {
|
||||
// antigravity 账户使用静态模型信息
|
||||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||||
return
|
||||
}
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models/"+modelName)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusBadGateway, err.Error())
|
||||
return
|
||||
}
|
||||
if shouldFallbackGeminiModels(res) {
|
||||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||||
return
|
||||
}
|
||||
writeUpstreamResponse(c, res)
|
||||
}
|
||||
|
||||
// GeminiV1BetaModels proxies Gemini native REST endpoints like:
|
||||
// POST /v1beta/models/{model}:generateContent
|
||||
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
|
||||
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
}
|
||||
authSubject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
googleError(c, http.StatusInternalServerError, "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组
|
||||
if !middleware.HasForcePlatform(c) {
|
||||
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
|
||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/"))
|
||||
if err != nil {
|
||||
googleError(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
stream := action == "streamGenerateContent"
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
googleError(c, http.StatusBadRequest, "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
googleError(c, http.StatusBadRequest, "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// Get subscription (may be nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
|
||||
// For Gemini native API, do not send Claude-style ping frames.
|
||||
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
|
||||
|
||||
// 0) wait queue check
|
||||
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
|
||||
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
|
||||
|
||||
// 1) user concurrency slot
|
||||
streamStarted := false
|
||||
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2) billing eligibility check (after wait)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
googleError(c, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
sessionKey := sessionHash
|
||||
if sessionHash != "" {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
|
||||
return
|
||||
}
|
||||
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
stream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 5) forward (根据平台分流)
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body)
|
||||
} else {
|
||||
result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// ForwardNative already wrote the response
|
||||
log.Printf("Gemini native forward failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6) record usage async
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func parseGeminiModelAction(rest string) (model string, action string, err error) {
|
||||
rest = strings.TrimSpace(rest)
|
||||
if rest == "" {
|
||||
return "", "", &pathParseError{"missing path"}
|
||||
}
|
||||
|
||||
// Standard: {model}:{action}
|
||||
if i := strings.Index(rest, ":"); i > 0 && i < len(rest)-1 {
|
||||
return rest[:i], rest[i+1:], nil
|
||||
}
|
||||
|
||||
// Fallback: {model}/{action}
|
||||
if i := strings.Index(rest, "/"); i > 0 && i < len(rest)-1 {
|
||||
return rest[:i], rest[i+1:], nil
|
||||
}
|
||||
|
||||
return "", "", &pathParseError{"invalid model action path"}
|
||||
}
|
||||
|
||||
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
|
||||
status, message := mapGeminiUpstreamError(statusCode)
|
||||
googleError(c, status, message)
|
||||
}
|
||||
|
||||
func mapGeminiUpstreamError(statusCode int) (int, string) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
type pathParseError struct{ msg string }
|
||||
|
||||
func (e *pathParseError) Error() string { return e.msg }
|
||||
|
||||
func googleError(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": googleapi.HTTPStatusToGoogleStatus(status),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
|
||||
if res == nil {
|
||||
googleError(c, http.StatusBadGateway, "Empty upstream response")
|
||||
return
|
||||
}
|
||||
for k, vv := range res.Headers {
|
||||
// Avoid overriding content-length and hop-by-hop headers.
|
||||
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
c.Writer.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
contentType := res.Headers.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(res.StatusCode, contentType, res.Body)
|
||||
}
|
||||
|
||||
func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
||||
if res == nil {
|
||||
return true
|
||||
}
|
||||
if res.StatusCode != http.StatusUnauthorized && res.StatusCode != http.StatusForbidden {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(strings.ToLower(res.Headers.Get("Www-Authenticate")), "insufficient_scope") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(strings.ToLower(string(res.Body)), "insufficient authentication scopes") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(strings.ToLower(string(res.Body)), "access_token_scope_insufficient") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,143 +1,143 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestGeminiV1BetaHandler_PlatformRoutingInvariant 文档化并验证 Handler 层的平台路由逻辑不变量
|
||||
// 该测试确保 gemini 和 antigravity 平台的路由逻辑符合预期
|
||||
func TestGeminiV1BetaHandler_PlatformRoutingInvariant(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
expectedService string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Gemini平台使用ForwardNative",
|
||||
platform: service.PlatformGemini,
|
||||
expectedService: "GeminiMessagesCompatService.ForwardNative",
|
||||
description: "Gemini OAuth 账户直接调用 Google API",
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台使用ForwardGemini",
|
||||
platform: service.PlatformAntigravity,
|
||||
expectedService: "AntigravityGatewayService.ForwardGemini",
|
||||
description: "Antigravity 账户通过 CRS 中转,支持 Gemini 协议",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 GeminiV1BetaModels 中的路由决策 (lines 199-205 in gemini_v1beta_handler.go)
|
||||
var routedService string
|
||||
if tt.platform == service.PlatformAntigravity {
|
||||
routedService = "AntigravityGatewayService.ForwardGemini"
|
||||
} else {
|
||||
routedService = "GeminiMessagesCompatService.ForwardNative"
|
||||
}
|
||||
|
||||
require.Equal(t, tt.expectedService, routedService,
|
||||
"平台 %s 应该路由到 %s: %s",
|
||||
tt.platform, tt.expectedService, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGeminiV1BetaHandler_ListModelsAntigravityFallback 验证 ListModels 的 antigravity 降级逻辑
|
||||
// 当没有 gemini 账户但有 antigravity 账户时,应返回静态模型列表
|
||||
func TestGeminiV1BetaHandler_ListModelsAntigravityFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hasGeminiAccount bool
|
||||
hasAntigravity bool
|
||||
expectedBehavior string
|
||||
}{
|
||||
{
|
||||
name: "有Gemini账户-调用ForwardAIStudioGET",
|
||||
hasGeminiAccount: true,
|
||||
hasAntigravity: false,
|
||||
expectedBehavior: "forward_to_upstream",
|
||||
},
|
||||
{
|
||||
name: "无Gemini有Antigravity-返回静态列表",
|
||||
hasGeminiAccount: false,
|
||||
hasAntigravity: true,
|
||||
expectedBehavior: "static_fallback",
|
||||
},
|
||||
{
|
||||
name: "无任何账户-返回503",
|
||||
hasGeminiAccount: false,
|
||||
hasAntigravity: false,
|
||||
expectedBehavior: "service_unavailable",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 GeminiV1BetaListModels 的逻辑 (lines 33-44 in gemini_v1beta_handler.go)
|
||||
var behavior string
|
||||
|
||||
if tt.hasGeminiAccount {
|
||||
behavior = "forward_to_upstream"
|
||||
} else if tt.hasAntigravity {
|
||||
behavior = "static_fallback"
|
||||
} else {
|
||||
behavior = "service_unavailable"
|
||||
}
|
||||
|
||||
require.Equal(t, tt.expectedBehavior, behavior)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGeminiV1BetaHandler_GetModelAntigravityFallback 验证 GetModel 的 antigravity 降级逻辑
|
||||
func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hasGeminiAccount bool
|
||||
hasAntigravity bool
|
||||
expectedBehavior string
|
||||
}{
|
||||
{
|
||||
name: "有Gemini账户-调用ForwardAIStudioGET",
|
||||
hasGeminiAccount: true,
|
||||
hasAntigravity: false,
|
||||
expectedBehavior: "forward_to_upstream",
|
||||
},
|
||||
{
|
||||
name: "无Gemini有Antigravity-返回静态模型信息",
|
||||
hasGeminiAccount: false,
|
||||
hasAntigravity: true,
|
||||
expectedBehavior: "static_model_info",
|
||||
},
|
||||
{
|
||||
name: "无任何账户-返回503",
|
||||
hasGeminiAccount: false,
|
||||
hasAntigravity: false,
|
||||
expectedBehavior: "service_unavailable",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 GeminiV1BetaGetModel 的逻辑 (lines 77-87 in gemini_v1beta_handler.go)
|
||||
var behavior string
|
||||
|
||||
if tt.hasGeminiAccount {
|
||||
behavior = "forward_to_upstream"
|
||||
} else if tt.hasAntigravity {
|
||||
behavior = "static_model_info"
|
||||
} else {
|
||||
behavior = "service_unavailable"
|
||||
}
|
||||
|
||||
require.Equal(t, tt.expectedBehavior, behavior)
|
||||
})
|
||||
}
|
||||
}
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestGeminiV1BetaHandler_PlatformRoutingInvariant 文档化并验证 Handler 层的平台路由逻辑不变量
|
||||
// 该测试确保 gemini 和 antigravity 平台的路由逻辑符合预期
|
||||
func TestGeminiV1BetaHandler_PlatformRoutingInvariant(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
expectedService string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Gemini平台使用ForwardNative",
|
||||
platform: service.PlatformGemini,
|
||||
expectedService: "GeminiMessagesCompatService.ForwardNative",
|
||||
description: "Gemini OAuth 账户直接调用 Google API",
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台使用ForwardGemini",
|
||||
platform: service.PlatformAntigravity,
|
||||
expectedService: "AntigravityGatewayService.ForwardGemini",
|
||||
description: "Antigravity 账户通过 CRS 中转,支持 Gemini 协议",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 GeminiV1BetaModels 中的路由决策 (lines 199-205 in gemini_v1beta_handler.go)
|
||||
var routedService string
|
||||
if tt.platform == service.PlatformAntigravity {
|
||||
routedService = "AntigravityGatewayService.ForwardGemini"
|
||||
} else {
|
||||
routedService = "GeminiMessagesCompatService.ForwardNative"
|
||||
}
|
||||
|
||||
require.Equal(t, tt.expectedService, routedService,
|
||||
"平台 %s 应该路由到 %s: %s",
|
||||
tt.platform, tt.expectedService, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGeminiV1BetaHandler_ListModelsAntigravityFallback 验证 ListModels 的 antigravity 降级逻辑
|
||||
// 当没有 gemini 账户但有 antigravity 账户时,应返回静态模型列表
|
||||
func TestGeminiV1BetaHandler_ListModelsAntigravityFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hasGeminiAccount bool
|
||||
hasAntigravity bool
|
||||
expectedBehavior string
|
||||
}{
|
||||
{
|
||||
name: "有Gemini账户-调用ForwardAIStudioGET",
|
||||
hasGeminiAccount: true,
|
||||
hasAntigravity: false,
|
||||
expectedBehavior: "forward_to_upstream",
|
||||
},
|
||||
{
|
||||
name: "无Gemini有Antigravity-返回静态列表",
|
||||
hasGeminiAccount: false,
|
||||
hasAntigravity: true,
|
||||
expectedBehavior: "static_fallback",
|
||||
},
|
||||
{
|
||||
name: "无任何账户-返回503",
|
||||
hasGeminiAccount: false,
|
||||
hasAntigravity: false,
|
||||
expectedBehavior: "service_unavailable",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 GeminiV1BetaListModels 的逻辑 (lines 33-44 in gemini_v1beta_handler.go)
|
||||
var behavior string
|
||||
|
||||
if tt.hasGeminiAccount {
|
||||
behavior = "forward_to_upstream"
|
||||
} else if tt.hasAntigravity {
|
||||
behavior = "static_fallback"
|
||||
} else {
|
||||
behavior = "service_unavailable"
|
||||
}
|
||||
|
||||
require.Equal(t, tt.expectedBehavior, behavior)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGeminiV1BetaHandler_GetModelAntigravityFallback 验证 GetModel 的 antigravity 降级逻辑
|
||||
func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hasGeminiAccount bool
|
||||
hasAntigravity bool
|
||||
expectedBehavior string
|
||||
}{
|
||||
{
|
||||
name: "有Gemini账户-调用ForwardAIStudioGET",
|
||||
hasGeminiAccount: true,
|
||||
hasAntigravity: false,
|
||||
expectedBehavior: "forward_to_upstream",
|
||||
},
|
||||
{
|
||||
name: "无Gemini有Antigravity-返回静态模型信息",
|
||||
hasGeminiAccount: false,
|
||||
hasAntigravity: true,
|
||||
expectedBehavior: "static_model_info",
|
||||
},
|
||||
{
|
||||
name: "无任何账户-返回503",
|
||||
hasGeminiAccount: false,
|
||||
hasAntigravity: false,
|
||||
expectedBehavior: "service_unavailable",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 GeminiV1BetaGetModel 的逻辑 (lines 77-87 in gemini_v1beta_handler.go)
|
||||
var behavior string
|
||||
|
||||
if tt.hasGeminiAccount {
|
||||
behavior = "forward_to_upstream"
|
||||
} else if tt.hasAntigravity {
|
||||
behavior = "static_model_info"
|
||||
} else {
|
||||
behavior = "service_unavailable"
|
||||
}
|
||||
|
||||
require.Equal(t, tt.expectedBehavior, behavior)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,44 +1,44 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||
)
|
||||
|
||||
// AdminHandlers contains all admin-related HTTP handlers
|
||||
type AdminHandlers struct {
|
||||
Dashboard *admin.DashboardHandler
|
||||
User *admin.UserHandler
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
AntigravityOAuth *admin.AntigravityOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Setting *admin.SettingHandler
|
||||
System *admin.SystemHandler
|
||||
Subscription *admin.SubscriptionHandler
|
||||
Usage *admin.UsageHandler
|
||||
UserAttribute *admin.UserAttributeHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
type Handlers struct {
|
||||
Auth *AuthHandler
|
||||
User *UserHandler
|
||||
APIKey *APIKeyHandler
|
||||
Usage *UsageHandler
|
||||
Redeem *RedeemHandler
|
||||
Subscription *SubscriptionHandler
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
Setting *SettingHandler
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time information
|
||||
type BuildInfo struct {
|
||||
Version string
|
||||
BuildType string // "source" for manual builds, "release" for CI builds
|
||||
}
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||
)
|
||||
|
||||
// AdminHandlers contains all admin-related HTTP handlers
|
||||
type AdminHandlers struct {
|
||||
Dashboard *admin.DashboardHandler
|
||||
User *admin.UserHandler
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
AntigravityOAuth *admin.AntigravityOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Setting *admin.SettingHandler
|
||||
System *admin.SystemHandler
|
||||
Subscription *admin.SubscriptionHandler
|
||||
Usage *admin.UsageHandler
|
||||
UserAttribute *admin.UserAttributeHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
type Handlers struct {
|
||||
Auth *AuthHandler
|
||||
User *UserHandler
|
||||
APIKey *APIKeyHandler
|
||||
Usage *UsageHandler
|
||||
Redeem *RedeemHandler
|
||||
Subscription *SubscriptionHandler
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
Setting *SettingHandler
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time information
|
||||
type BuildInfo struct {
|
||||
Version string
|
||||
BuildType string // "source" for manual builds, "release" for CI builds
|
||||
}
|
||||
|
||||
@@ -1,306 +1,306 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||
type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
func NewOpenAIGatewayHandler(
|
||||
gatewayService *service.OpenAIGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
) *OpenAIGatewayHandler {
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
|
||||
}
|
||||
}
|
||||
|
||||
// Responses handles OpenAI Responses API endpoint
|
||||
// POST /openai/v1/responses
|
||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body to map for potential modification
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model and stream
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
|
||||
// 验证 model 必填
|
||||
if reqModel == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
|
||||
// For non-Codex CLI requests, set default instructions
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if !openai.IsCodexCLIRequest(userAgent) {
|
||||
reqBody["instructions"] = openai.DefaultInstructions
|
||||
// Re-serialize body
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
// Get subscription info (may be nil)
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 0. Check if wait queue is full
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
} else if !canWait {
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
// Ensure wait count is decremented when function exits
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
|
||||
// 1. First acquire user concurrency slot
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2. Re-check billing eligibility after wait
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate session hash (from header for OpenAI)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
||||
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Forward request
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in OpenAI SSE format
|
||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Normal case: return JSON response with proper status code
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||
type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
func NewOpenAIGatewayHandler(
|
||||
gatewayService *service.OpenAIGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
) *OpenAIGatewayHandler {
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
|
||||
}
|
||||
}
|
||||
|
||||
// Responses handles OpenAI Responses API endpoint
|
||||
// POST /openai/v1/responses
|
||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body to map for potential modification
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model and stream
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
|
||||
// 验证 model 必填
|
||||
if reqModel == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
|
||||
// For non-Codex CLI requests, set default instructions
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if !openai.IsCodexCLIRequest(userAgent) {
|
||||
reqBody["instructions"] = openai.DefaultInstructions
|
||||
// Re-serialize body
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
// Get subscription info (may be nil)
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 0. Check if wait queue is full
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
} else if !canWait {
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
// Ensure wait count is decremented when function exits
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
|
||||
// 1. First acquire user concurrency slot
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2. Re-check billing eligibility after wait
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate session hash (from header for OpenAI)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
||||
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Forward request
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in OpenAI SSE format
|
||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Normal case: return JSON response with proper status code
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,85 +1,85 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RedeemHandler handles redeem code-related requests
|
||||
type RedeemHandler struct {
|
||||
redeemService *service.RedeemService
|
||||
}
|
||||
|
||||
// NewRedeemHandler creates a new RedeemHandler
|
||||
func NewRedeemHandler(redeemService *service.RedeemService) *RedeemHandler {
|
||||
return &RedeemHandler{
|
||||
redeemService: redeemService,
|
||||
}
|
||||
}
|
||||
|
||||
// RedeemRequest represents the redeem code request payload
|
||||
type RedeemRequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// RedeemResponse represents the redeem response
|
||||
type RedeemResponse struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Value float64 `json:"value"`
|
||||
NewBalance *float64 `json:"new_balance,omitempty"`
|
||||
NewConcurrency *int `json:"new_concurrency,omitempty"`
|
||||
}
|
||||
|
||||
// Redeem handles redeeming a code
|
||||
// POST /api/v1/redeem
|
||||
func (h *RedeemHandler) Redeem(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req RedeemRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.redeemService.Redeem(c.Request.Context(), subject.UserID, req.Code)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(result))
|
||||
}
|
||||
|
||||
// GetHistory returns the user's redemption history
|
||||
// GET /api/v1/redeem/history
|
||||
func (h *RedeemHandler) GetHistory(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
// Default limit is 25
|
||||
limit := 25
|
||||
|
||||
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), subject.UserID, limit)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RedeemHandler handles redeem code-related requests
|
||||
type RedeemHandler struct {
|
||||
redeemService *service.RedeemService
|
||||
}
|
||||
|
||||
// NewRedeemHandler creates a new RedeemHandler
|
||||
func NewRedeemHandler(redeemService *service.RedeemService) *RedeemHandler {
|
||||
return &RedeemHandler{
|
||||
redeemService: redeemService,
|
||||
}
|
||||
}
|
||||
|
||||
// RedeemRequest represents the redeem code request payload
|
||||
type RedeemRequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// RedeemResponse represents the redeem response
|
||||
type RedeemResponse struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Value float64 `json:"value"`
|
||||
NewBalance *float64 `json:"new_balance,omitempty"`
|
||||
NewConcurrency *int `json:"new_concurrency,omitempty"`
|
||||
}
|
||||
|
||||
// Redeem handles redeeming a code
|
||||
// POST /api/v1/redeem
|
||||
func (h *RedeemHandler) Redeem(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req RedeemRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.redeemService.Redeem(c.Request.Context(), subject.UserID, req.Code)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(result))
|
||||
}
|
||||
|
||||
// GetHistory returns the user's redemption history
|
||||
// GET /api/v1/redeem/history
|
||||
func (h *RedeemHandler) GetHistory(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
// Default limit is 25
|
||||
limit := 25
|
||||
|
||||
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), subject.UserID, limit)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
@@ -1,27 +1,27 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func extractMaxBytesError(err error) (*http.MaxBytesError, bool) {
|
||||
var maxErr *http.MaxBytesError
|
||||
if errors.As(err, &maxErr) {
|
||||
return maxErr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func formatBodyLimit(limit int64) string {
|
||||
const mb = 1024 * 1024
|
||||
if limit >= mb {
|
||||
return fmt.Sprintf("%dMB", limit/mb)
|
||||
}
|
||||
return fmt.Sprintf("%dB", limit)
|
||||
}
|
||||
|
||||
func buildBodyTooLargeMessage(limit int64) string {
|
||||
return fmt.Sprintf("Request body too large, limit is %s", formatBodyLimit(limit))
|
||||
}
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func extractMaxBytesError(err error) (*http.MaxBytesError, bool) {
|
||||
var maxErr *http.MaxBytesError
|
||||
if errors.As(err, &maxErr) {
|
||||
return maxErr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func formatBodyLimit(limit int64) string {
|
||||
const mb = 1024 * 1024
|
||||
if limit >= mb {
|
||||
return fmt.Sprintf("%dMB", limit/mb)
|
||||
}
|
||||
return fmt.Sprintf("%dB", limit)
|
||||
}
|
||||
|
||||
func buildBodyTooLargeMessage(limit int64) string {
|
||||
return fmt.Sprintf("Request body too large, limit is %s", formatBodyLimit(limit))
|
||||
}
|
||||
|
||||
@@ -1,45 +1,45 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRequestBodyLimitTooLarge(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
limit := int64(16)
|
||||
router := gin.New()
|
||||
router.Use(middleware.RequestBodyLimit(limit))
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
_, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": buildBodyTooLargeMessage(maxErr.Limit),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "read_failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
payload := bytes.Repeat([]byte("a"), int(limit+1))
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(payload))
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
|
||||
require.Contains(t, recorder.Body.String(), buildBodyTooLargeMessage(limit))
|
||||
}
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRequestBodyLimitTooLarge(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
limit := int64(16)
|
||||
router := gin.New()
|
||||
router.Use(middleware.RequestBodyLimit(limit))
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
_, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": buildBodyTooLargeMessage(maxErr.Limit),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "read_failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
payload := bytes.Repeat([]byte("a"), int(limit+1))
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(payload))
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
|
||||
require.Contains(t, recorder.Body.String(), buildBodyTooLargeMessage(limit))
|
||||
}
|
||||
|
||||
@@ -1,47 +1,47 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SettingHandler 公开设置处理器(无需认证)
|
||||
type SettingHandler struct {
|
||||
settingService *service.SettingService
|
||||
version string
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建公开设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, version string) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPublicSettings 获取公开设置
|
||||
// GET /api/v1/settings/public
|
||||
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetPublicSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PublicSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
ApiBaseUrl: settings.ApiBaseUrl,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocUrl: settings.DocUrl,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SettingHandler 公开设置处理器(无需认证)
|
||||
type SettingHandler struct {
|
||||
settingService *service.SettingService
|
||||
version string
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建公开设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, version string) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPublicSettings 获取公开设置
|
||||
// GET /api/v1/settings/public
|
||||
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetPublicSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PublicSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
ApiBaseUrl: settings.ApiBaseUrl,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocUrl: settings.DocUrl,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,188 +1,188 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SubscriptionSummaryItem represents a subscription item in summary
|
||||
type SubscriptionSummaryItem struct {
|
||||
ID int64 `json:"id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
Status string `json:"status"`
|
||||
DailyUsedUSD float64 `json:"daily_used_usd,omitempty"`
|
||||
DailyLimitUSD float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyUsedUSD float64 `json:"weekly_used_usd,omitempty"`
|
||||
WeeklyLimitUSD float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyUsedUSD float64 `json:"monthly_used_usd,omitempty"`
|
||||
MonthlyLimitUSD float64 `json:"monthly_limit_usd,omitempty"`
|
||||
ExpiresAt *string `json:"expires_at,omitempty"`
|
||||
}
|
||||
|
||||
// SubscriptionProgressInfo represents subscription with progress info
|
||||
type SubscriptionProgressInfo struct {
|
||||
Subscription *dto.UserSubscription `json:"subscription"`
|
||||
Progress *service.SubscriptionProgress `json:"progress"`
|
||||
}
|
||||
|
||||
// SubscriptionHandler handles user subscription operations
|
||||
type SubscriptionHandler struct {
|
||||
subscriptionService *service.SubscriptionService
|
||||
}
|
||||
|
||||
// NewSubscriptionHandler creates a new user subscription handler
|
||||
func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *SubscriptionHandler {
|
||||
return &SubscriptionHandler{
|
||||
subscriptionService: subscriptionService,
|
||||
}
|
||||
}
|
||||
|
||||
// List handles listing current user's subscriptions
|
||||
// GET /api/v1/subscriptions
|
||||
func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// GetActive handles getting current user's active subscriptions
|
||||
// GET /api/v1/subscriptions/active
|
||||
func (h *SubscriptionHandler) GetActive(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// GetProgress handles getting subscription progress for current user
|
||||
// GET /api/v1/subscriptions/progress
|
||||
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
// Get all active subscriptions with progress
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
result := make([]SubscriptionProgressInfo, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
sub := &subscriptions[i]
|
||||
progress, err := h.subscriptionService.GetSubscriptionProgress(c.Request.Context(), sub.ID)
|
||||
if err != nil {
|
||||
// Skip subscriptions with errors
|
||||
continue
|
||||
}
|
||||
result = append(result, SubscriptionProgressInfo{
|
||||
Subscription: dto.UserSubscriptionFromService(sub),
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetSummary handles getting a summary of current user's subscription status
|
||||
// GET /api/v1/subscriptions/summary
|
||||
func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
// Get all active subscriptions
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var totalUsed float64
|
||||
items := make([]SubscriptionSummaryItem, 0, len(subscriptions))
|
||||
|
||||
for _, sub := range subscriptions {
|
||||
item := SubscriptionSummaryItem{
|
||||
ID: sub.ID,
|
||||
GroupID: sub.GroupID,
|
||||
Status: sub.Status,
|
||||
DailyUsedUSD: sub.DailyUsageUSD,
|
||||
WeeklyUsedUSD: sub.WeeklyUsageUSD,
|
||||
MonthlyUsedUSD: sub.MonthlyUsageUSD,
|
||||
}
|
||||
|
||||
// Add group info if preloaded
|
||||
if sub.Group != nil {
|
||||
item.GroupName = sub.Group.Name
|
||||
if sub.Group.DailyLimitUSD != nil {
|
||||
item.DailyLimitUSD = *sub.Group.DailyLimitUSD
|
||||
}
|
||||
if sub.Group.WeeklyLimitUSD != nil {
|
||||
item.WeeklyLimitUSD = *sub.Group.WeeklyLimitUSD
|
||||
}
|
||||
if sub.Group.MonthlyLimitUSD != nil {
|
||||
item.MonthlyLimitUSD = *sub.Group.MonthlyLimitUSD
|
||||
}
|
||||
}
|
||||
|
||||
// Format expiration time
|
||||
if !sub.ExpiresAt.IsZero() {
|
||||
formatted := sub.ExpiresAt.Format("2006-01-02T15:04:05Z07:00")
|
||||
item.ExpiresAt = &formatted
|
||||
}
|
||||
|
||||
// Track total usage (use monthly as the most comprehensive)
|
||||
totalUsed += sub.MonthlyUsageUSD
|
||||
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
summary := struct {
|
||||
ActiveCount int `json:"active_count"`
|
||||
TotalUsedUSD float64 `json:"total_used_usd"`
|
||||
Subscriptions []SubscriptionSummaryItem `json:"subscriptions"`
|
||||
}{
|
||||
ActiveCount: len(subscriptions),
|
||||
TotalUsedUSD: totalUsed,
|
||||
Subscriptions: items,
|
||||
}
|
||||
|
||||
response.Success(c, summary)
|
||||
}
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SubscriptionSummaryItem represents a subscription item in summary
|
||||
type SubscriptionSummaryItem struct {
|
||||
ID int64 `json:"id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
Status string `json:"status"`
|
||||
DailyUsedUSD float64 `json:"daily_used_usd,omitempty"`
|
||||
DailyLimitUSD float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyUsedUSD float64 `json:"weekly_used_usd,omitempty"`
|
||||
WeeklyLimitUSD float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyUsedUSD float64 `json:"monthly_used_usd,omitempty"`
|
||||
MonthlyLimitUSD float64 `json:"monthly_limit_usd,omitempty"`
|
||||
ExpiresAt *string `json:"expires_at,omitempty"`
|
||||
}
|
||||
|
||||
// SubscriptionProgressInfo represents subscription with progress info
|
||||
type SubscriptionProgressInfo struct {
|
||||
Subscription *dto.UserSubscription `json:"subscription"`
|
||||
Progress *service.SubscriptionProgress `json:"progress"`
|
||||
}
|
||||
|
||||
// SubscriptionHandler handles user subscription operations
|
||||
type SubscriptionHandler struct {
|
||||
subscriptionService *service.SubscriptionService
|
||||
}
|
||||
|
||||
// NewSubscriptionHandler creates a new user subscription handler
|
||||
func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *SubscriptionHandler {
|
||||
return &SubscriptionHandler{
|
||||
subscriptionService: subscriptionService,
|
||||
}
|
||||
}
|
||||
|
||||
// List handles listing current user's subscriptions
|
||||
// GET /api/v1/subscriptions
|
||||
func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// GetActive handles getting current user's active subscriptions
|
||||
// GET /api/v1/subscriptions/active
|
||||
func (h *SubscriptionHandler) GetActive(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// GetProgress handles getting subscription progress for current user
|
||||
// GET /api/v1/subscriptions/progress
|
||||
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
// Get all active subscriptions with progress
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
result := make([]SubscriptionProgressInfo, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
sub := &subscriptions[i]
|
||||
progress, err := h.subscriptionService.GetSubscriptionProgress(c.Request.Context(), sub.ID)
|
||||
if err != nil {
|
||||
// Skip subscriptions with errors
|
||||
continue
|
||||
}
|
||||
result = append(result, SubscriptionProgressInfo{
|
||||
Subscription: dto.UserSubscriptionFromService(sub),
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetSummary handles getting a summary of current user's subscription status
|
||||
// GET /api/v1/subscriptions/summary
|
||||
func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
// Get all active subscriptions
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var totalUsed float64
|
||||
items := make([]SubscriptionSummaryItem, 0, len(subscriptions))
|
||||
|
||||
for _, sub := range subscriptions {
|
||||
item := SubscriptionSummaryItem{
|
||||
ID: sub.ID,
|
||||
GroupID: sub.GroupID,
|
||||
Status: sub.Status,
|
||||
DailyUsedUSD: sub.DailyUsageUSD,
|
||||
WeeklyUsedUSD: sub.WeeklyUsageUSD,
|
||||
MonthlyUsedUSD: sub.MonthlyUsageUSD,
|
||||
}
|
||||
|
||||
// Add group info if preloaded
|
||||
if sub.Group != nil {
|
||||
item.GroupName = sub.Group.Name
|
||||
if sub.Group.DailyLimitUSD != nil {
|
||||
item.DailyLimitUSD = *sub.Group.DailyLimitUSD
|
||||
}
|
||||
if sub.Group.WeeklyLimitUSD != nil {
|
||||
item.WeeklyLimitUSD = *sub.Group.WeeklyLimitUSD
|
||||
}
|
||||
if sub.Group.MonthlyLimitUSD != nil {
|
||||
item.MonthlyLimitUSD = *sub.Group.MonthlyLimitUSD
|
||||
}
|
||||
}
|
||||
|
||||
// Format expiration time
|
||||
if !sub.ExpiresAt.IsZero() {
|
||||
formatted := sub.ExpiresAt.Format("2006-01-02T15:04:05Z07:00")
|
||||
item.ExpiresAt = &formatted
|
||||
}
|
||||
|
||||
// Track total usage (use monthly as the most comprehensive)
|
||||
totalUsed += sub.MonthlyUsageUSD
|
||||
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
summary := struct {
|
||||
ActiveCount int `json:"active_count"`
|
||||
TotalUsedUSD float64 `json:"total_used_usd"`
|
||||
Subscriptions []SubscriptionSummaryItem `json:"subscriptions"`
|
||||
}{
|
||||
ActiveCount: len(subscriptions),
|
||||
TotalUsedUSD: totalUsed,
|
||||
Subscriptions: items,
|
||||
}
|
||||
|
||||
response.Success(c, summary)
|
||||
}
|
||||
|
||||
@@ -1,398 +1,398 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UsageHandler handles usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.ApiKeyService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new UsageHandler
|
||||
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
}
|
||||
|
||||
// List handles listing usage records with pagination
|
||||
// GET /api/v1/usage
|
||||
func (h *UsageHandler) List(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
var apiKeyID int64
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid api_key_id")
|
||||
return
|
||||
}
|
||||
|
||||
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
|
||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != subject.UserID {
|
||||
response.Forbidden(c, "Not authorized to access this API key's usage records")
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
// Parse additional filters
|
||||
model := c.Query("model")
|
||||
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
stream = &val
|
||||
}
|
||||
|
||||
var billingType *int8
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
bt := int8(val)
|
||||
billingType = &bt
|
||||
}
|
||||
|
||||
// Parse date range
|
||||
var startTime, endTime *time.Time
|
||||
if startDateStr := c.Query("start_date"); startDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
startTime = &t
|
||||
}
|
||||
|
||||
if endDateStr := c.Query("end_date"); endDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// Set end time to end of day
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: subject.UserID, // Always filter by current user for security
|
||||
ApiKeyID: apiKeyID,
|
||||
Model: model,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
}
|
||||
|
||||
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UsageLog, 0, len(records))
|
||||
for i := range records {
|
||||
out = append(out, *dto.UsageLogFromService(&records[i]))
|
||||
}
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a single usage record
|
||||
// GET /api/v1/usage/:id
|
||||
func (h *UsageHandler) GetByID(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
usageID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid usage ID")
|
||||
return
|
||||
}
|
||||
|
||||
record, err := h.usageService.GetByID(c.Request.Context(), usageID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if record.UserID != subject.UserID {
|
||||
response.Forbidden(c, "Not authorized to access this record")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UsageLogFromService(record))
|
||||
}
|
||||
|
||||
// Stats handles getting usage statistics
|
||||
// GET /api/v1/usage/stats
|
||||
func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var apiKeyID int64
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid api_key_id")
|
||||
return
|
||||
}
|
||||
|
||||
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
|
||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "API key not found")
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != subject.UserID {
|
||||
response.Forbidden(c, "Not authorized to access this API key's statistics")
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
// 获取时间范围参数
|
||||
now := timezone.Now()
|
||||
var startTime, endTime time.Time
|
||||
|
||||
// 优先使用 start_date 和 end_date 参数
|
||||
startDateStr := c.Query("start_date")
|
||||
endDateStr := c.Query("end_date")
|
||||
|
||||
if startDateStr != "" && endDateStr != "" {
|
||||
// 使用自定义日期范围
|
||||
var err error
|
||||
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// 设置结束时间为当天结束
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
} else {
|
||||
// 使用 period 参数
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
case "today":
|
||||
startTime = timezone.StartOfDay(now)
|
||||
case "week":
|
||||
startTime = now.AddDate(0, 0, -7)
|
||||
case "month":
|
||||
startTime = now.AddDate(0, -1, 0)
|
||||
default:
|
||||
startTime = timezone.StartOfDay(now)
|
||||
}
|
||||
endTime = now
|
||||
}
|
||||
|
||||
var stats *service.UsageStats
|
||||
var err error
|
||||
if apiKeyID > 0 {
|
||||
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
} else {
|
||||
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
|
||||
}
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// parseUserTimeRange parses start_date, end_date query parameters for user dashboard
|
||||
func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
now := timezone.Now()
|
||||
startDate := c.Query("start_date")
|
||||
endDate := c.Query("end_date")
|
||||
|
||||
var startTime, endTime time.Time
|
||||
|
||||
if startDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
|
||||
startTime = t
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
}
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
}
|
||||
|
||||
if endDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
|
||||
endTime = t.Add(24 * time.Hour) // Include the end date
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
}
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
}
|
||||
|
||||
return startTime, endTime
|
||||
}
|
||||
|
||||
// DashboardStats handles getting user dashboard statistics
|
||||
// GET /api/v1/usage/dashboard/stats
|
||||
func (h *UsageHandler) DashboardStats(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// DashboardTrend handles getting user usage trend data
|
||||
// GET /api/v1/usage/dashboard/trend
|
||||
func (h *UsageHandler) DashboardTrend(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
startTime, endTime := parseUserTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
|
||||
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), subject.UserID, startTime, endTime, granularity)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
"granularity": granularity,
|
||||
})
|
||||
}
|
||||
|
||||
// DashboardModels handles getting user model usage statistics
|
||||
// GET /api/v1/usage/dashboard/models
|
||||
func (h *UsageHandler) DashboardModels(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
startTime, endTime := parseUserTimeRange(c)
|
||||
|
||||
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), subject.UserID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"models": stats,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
})
|
||||
}
|
||||
|
||||
// BatchApiKeysUsageRequest represents the request for batch API keys usage
|
||||
type BatchApiKeysUsageRequest struct {
|
||||
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
|
||||
// POST /api/v1/usage/dashboard/api-keys-usage
|
||||
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req BatchApiKeysUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
// Limit the number of API key IDs to prevent SQL parameter overflow
|
||||
if len(req.ApiKeyIDs) > 100 {
|
||||
response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)")
|
||||
return
|
||||
}
|
||||
|
||||
validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(validApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UsageHandler handles usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.ApiKeyService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new UsageHandler
|
||||
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
}
|
||||
|
||||
// List handles listing usage records with pagination
|
||||
// GET /api/v1/usage
|
||||
func (h *UsageHandler) List(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
var apiKeyID int64
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid api_key_id")
|
||||
return
|
||||
}
|
||||
|
||||
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
|
||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != subject.UserID {
|
||||
response.Forbidden(c, "Not authorized to access this API key's usage records")
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
// Parse additional filters
|
||||
model := c.Query("model")
|
||||
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
stream = &val
|
||||
}
|
||||
|
||||
var billingType *int8
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
bt := int8(val)
|
||||
billingType = &bt
|
||||
}
|
||||
|
||||
// Parse date range
|
||||
var startTime, endTime *time.Time
|
||||
if startDateStr := c.Query("start_date"); startDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
startTime = &t
|
||||
}
|
||||
|
||||
if endDateStr := c.Query("end_date"); endDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// Set end time to end of day
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: subject.UserID, // Always filter by current user for security
|
||||
ApiKeyID: apiKeyID,
|
||||
Model: model,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
}
|
||||
|
||||
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UsageLog, 0, len(records))
|
||||
for i := range records {
|
||||
out = append(out, *dto.UsageLogFromService(&records[i]))
|
||||
}
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a single usage record
|
||||
// GET /api/v1/usage/:id
|
||||
func (h *UsageHandler) GetByID(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
usageID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid usage ID")
|
||||
return
|
||||
}
|
||||
|
||||
record, err := h.usageService.GetByID(c.Request.Context(), usageID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if record.UserID != subject.UserID {
|
||||
response.Forbidden(c, "Not authorized to access this record")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UsageLogFromService(record))
|
||||
}
|
||||
|
||||
// Stats handles getting usage statistics
|
||||
// GET /api/v1/usage/stats
|
||||
func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var apiKeyID int64
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid api_key_id")
|
||||
return
|
||||
}
|
||||
|
||||
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
|
||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "API key not found")
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != subject.UserID {
|
||||
response.Forbidden(c, "Not authorized to access this API key's statistics")
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
// 获取时间范围参数
|
||||
now := timezone.Now()
|
||||
var startTime, endTime time.Time
|
||||
|
||||
// 优先使用 start_date 和 end_date 参数
|
||||
startDateStr := c.Query("start_date")
|
||||
endDateStr := c.Query("end_date")
|
||||
|
||||
if startDateStr != "" && endDateStr != "" {
|
||||
// 使用自定义日期范围
|
||||
var err error
|
||||
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// 设置结束时间为当天结束
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
} else {
|
||||
// 使用 period 参数
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
case "today":
|
||||
startTime = timezone.StartOfDay(now)
|
||||
case "week":
|
||||
startTime = now.AddDate(0, 0, -7)
|
||||
case "month":
|
||||
startTime = now.AddDate(0, -1, 0)
|
||||
default:
|
||||
startTime = timezone.StartOfDay(now)
|
||||
}
|
||||
endTime = now
|
||||
}
|
||||
|
||||
var stats *service.UsageStats
|
||||
var err error
|
||||
if apiKeyID > 0 {
|
||||
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
} else {
|
||||
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
|
||||
}
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// parseUserTimeRange parses start_date, end_date query parameters for user dashboard
|
||||
func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
now := timezone.Now()
|
||||
startDate := c.Query("start_date")
|
||||
endDate := c.Query("end_date")
|
||||
|
||||
var startTime, endTime time.Time
|
||||
|
||||
if startDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
|
||||
startTime = t
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
}
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
}
|
||||
|
||||
if endDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
|
||||
endTime = t.Add(24 * time.Hour) // Include the end date
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
}
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
}
|
||||
|
||||
return startTime, endTime
|
||||
}
|
||||
|
||||
// DashboardStats handles getting user dashboard statistics
|
||||
// GET /api/v1/usage/dashboard/stats
|
||||
func (h *UsageHandler) DashboardStats(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// DashboardTrend handles getting user usage trend data
|
||||
// GET /api/v1/usage/dashboard/trend
|
||||
func (h *UsageHandler) DashboardTrend(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
startTime, endTime := parseUserTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
|
||||
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), subject.UserID, startTime, endTime, granularity)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
"granularity": granularity,
|
||||
})
|
||||
}
|
||||
|
||||
// DashboardModels handles getting user model usage statistics
|
||||
// GET /api/v1/usage/dashboard/models
|
||||
func (h *UsageHandler) DashboardModels(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
startTime, endTime := parseUserTimeRange(c)
|
||||
|
||||
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), subject.UserID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"models": stats,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
})
|
||||
}
|
||||
|
||||
// BatchApiKeysUsageRequest represents the request for batch API keys usage
|
||||
type BatchApiKeysUsageRequest struct {
|
||||
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
|
||||
// POST /api/v1/usage/dashboard/api-keys-usage
|
||||
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req BatchApiKeysUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
// Limit the number of API key IDs to prevent SQL parameter overflow
|
||||
if len(req.ApiKeyIDs) > 100 {
|
||||
response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)")
|
||||
return
|
||||
}
|
||||
|
||||
validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(validApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
|
||||
@@ -1,112 +1,112 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UserHandler handles user-related requests
|
||||
type UserHandler struct {
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
// NewUserHandler creates a new UserHandler
|
||||
func NewUserHandler(userService *service.UserService) *UserHandler {
|
||||
return &UserHandler{
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// ChangePasswordRequest represents the change password request payload
|
||||
type ChangePasswordRequest struct {
|
||||
OldPassword string `json:"old_password" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// UpdateProfileRequest represents the update profile request payload
|
||||
type UpdateProfileRequest struct {
|
||||
Username *string `json:"username"`
|
||||
}
|
||||
|
||||
// GetProfile handles getting user profile
|
||||
// GET /api/v1/users/me
|
||||
func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
userData.Notes = ""
|
||||
|
||||
response.Success(c, dto.UserFromService(userData))
|
||||
}
|
||||
|
||||
// ChangePassword handles changing user password
|
||||
// POST /api/v1/users/me/password
|
||||
func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req ChangePasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.ChangePasswordRequest{
|
||||
CurrentPassword: req.OldPassword,
|
||||
NewPassword: req.NewPassword,
|
||||
}
|
||||
err := h.userService.ChangePassword(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Password changed successfully"})
|
||||
}
|
||||
|
||||
// UpdateProfile handles updating user profile
|
||||
// PUT /api/v1/users/me
|
||||
func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.UpdateProfileRequest{
|
||||
Username: req.Username,
|
||||
}
|
||||
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
updatedUser.Notes = ""
|
||||
|
||||
response.Success(c, dto.UserFromService(updatedUser))
|
||||
}
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UserHandler handles user-related requests
|
||||
type UserHandler struct {
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
// NewUserHandler creates a new UserHandler
|
||||
func NewUserHandler(userService *service.UserService) *UserHandler {
|
||||
return &UserHandler{
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// ChangePasswordRequest represents the change password request payload
|
||||
type ChangePasswordRequest struct {
|
||||
OldPassword string `json:"old_password" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// UpdateProfileRequest represents the update profile request payload
|
||||
type UpdateProfileRequest struct {
|
||||
Username *string `json:"username"`
|
||||
}
|
||||
|
||||
// GetProfile handles getting user profile
|
||||
// GET /api/v1/users/me
|
||||
func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
userData.Notes = ""
|
||||
|
||||
response.Success(c, dto.UserFromService(userData))
|
||||
}
|
||||
|
||||
// ChangePassword handles changing user password
|
||||
// POST /api/v1/users/me/password
|
||||
func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req ChangePasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.ChangePasswordRequest{
|
||||
CurrentPassword: req.OldPassword,
|
||||
NewPassword: req.NewPassword,
|
||||
}
|
||||
err := h.userService.ChangePassword(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Password changed successfully"})
|
||||
}
|
||||
|
||||
// UpdateProfile handles updating user profile
|
||||
// PUT /api/v1/users/me
|
||||
func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.UpdateProfileRequest{
|
||||
Username: req.Username,
|
||||
}
|
||||
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
updatedUser.Notes = ""
|
||||
|
||||
response.Success(c, dto.UserFromService(updatedUser))
|
||||
}
|
||||
|
||||
@@ -1,117 +1,117 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// ProvideAdminHandlers creates the AdminHandlers struct
|
||||
func ProvideAdminHandlers(
|
||||
dashboardHandler *admin.DashboardHandler,
|
||||
userHandler *admin.UserHandler,
|
||||
groupHandler *admin.GroupHandler,
|
||||
accountHandler *admin.AccountHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
settingHandler *admin.SettingHandler,
|
||||
systemHandler *admin.SystemHandler,
|
||||
subscriptionHandler *admin.SubscriptionHandler,
|
||||
usageHandler *admin.UsageHandler,
|
||||
userAttributeHandler *admin.UserAttributeHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
User: userHandler,
|
||||
Group: groupHandler,
|
||||
Account: accountHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
AntigravityOAuth: antigravityOAuthHandler,
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Setting: settingHandler,
|
||||
System: systemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Usage: usageHandler,
|
||||
UserAttribute: userAttributeHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// ProvideSystemHandler creates admin.SystemHandler with UpdateService
|
||||
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
|
||||
return admin.NewSystemHandler(updateService)
|
||||
}
|
||||
|
||||
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
|
||||
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler {
|
||||
return NewSettingHandler(settingService, buildInfo.Version)
|
||||
}
|
||||
|
||||
// ProvideHandlers creates the Handlers struct
|
||||
func ProvideHandlers(
|
||||
authHandler *AuthHandler,
|
||||
userHandler *UserHandler,
|
||||
apiKeyHandler *APIKeyHandler,
|
||||
usageHandler *UsageHandler,
|
||||
redeemHandler *RedeemHandler,
|
||||
subscriptionHandler *SubscriptionHandler,
|
||||
adminHandlers *AdminHandlers,
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
User: userHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
Usage: usageHandler,
|
||||
Redeem: redeemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all handlers
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Top-level handlers
|
||||
NewAuthHandler,
|
||||
NewUserHandler,
|
||||
NewAPIKeyHandler,
|
||||
NewUsageHandler,
|
||||
NewRedeemHandler,
|
||||
NewSubscriptionHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
// Admin handlers
|
||||
admin.NewDashboardHandler,
|
||||
admin.NewUserHandler,
|
||||
admin.NewGroupHandler,
|
||||
admin.NewAccountHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
admin.NewAntigravityOAuthHandler,
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewSettingHandler,
|
||||
ProvideSystemHandler,
|
||||
admin.NewSubscriptionHandler,
|
||||
admin.NewUsageHandler,
|
||||
admin.NewUserAttributeHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
ProvideHandlers,
|
||||
)
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// ProvideAdminHandlers creates the AdminHandlers struct
|
||||
func ProvideAdminHandlers(
|
||||
dashboardHandler *admin.DashboardHandler,
|
||||
userHandler *admin.UserHandler,
|
||||
groupHandler *admin.GroupHandler,
|
||||
accountHandler *admin.AccountHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
settingHandler *admin.SettingHandler,
|
||||
systemHandler *admin.SystemHandler,
|
||||
subscriptionHandler *admin.SubscriptionHandler,
|
||||
usageHandler *admin.UsageHandler,
|
||||
userAttributeHandler *admin.UserAttributeHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
User: userHandler,
|
||||
Group: groupHandler,
|
||||
Account: accountHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
AntigravityOAuth: antigravityOAuthHandler,
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Setting: settingHandler,
|
||||
System: systemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Usage: usageHandler,
|
||||
UserAttribute: userAttributeHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// ProvideSystemHandler creates admin.SystemHandler with UpdateService
|
||||
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
|
||||
return admin.NewSystemHandler(updateService)
|
||||
}
|
||||
|
||||
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
|
||||
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler {
|
||||
return NewSettingHandler(settingService, buildInfo.Version)
|
||||
}
|
||||
|
||||
// ProvideHandlers creates the Handlers struct
|
||||
func ProvideHandlers(
|
||||
authHandler *AuthHandler,
|
||||
userHandler *UserHandler,
|
||||
apiKeyHandler *APIKeyHandler,
|
||||
usageHandler *UsageHandler,
|
||||
redeemHandler *RedeemHandler,
|
||||
subscriptionHandler *SubscriptionHandler,
|
||||
adminHandlers *AdminHandlers,
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
User: userHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
Usage: usageHandler,
|
||||
Redeem: redeemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all handlers
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Top-level handlers
|
||||
NewAuthHandler,
|
||||
NewUserHandler,
|
||||
NewAPIKeyHandler,
|
||||
NewUsageHandler,
|
||||
NewRedeemHandler,
|
||||
NewSubscriptionHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
// Admin handlers
|
||||
admin.NewDashboardHandler,
|
||||
admin.NewUserHandler,
|
||||
admin.NewGroupHandler,
|
||||
admin.NewAccountHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
admin.NewAntigravityOAuthHandler,
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewSettingHandler,
|
||||
ProvideSystemHandler,
|
||||
admin.NewSubscriptionHandler,
|
||||
admin.NewUsageHandler,
|
||||
admin.NewUserAttributeHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
ProvideHandlers,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,228 +1,228 @@
|
||||
package antigravity
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Claude 请求/响应类型定义
|
||||
|
||||
// ClaudeRequest Claude Messages API 请求
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ClaudeMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string 或 []SystemBlock
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
Tools []ClaudeTool `json:"tools,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Metadata *ClaudeMetadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeMessage Claude 消息
|
||||
type ClaudeMessage struct {
|
||||
Role string `json:"role"` // user, assistant
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
// ThinkingConfig Thinking 配置
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"` // "enabled" or "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"` // thinking budget
|
||||
}
|
||||
|
||||
// ClaudeMetadata 请求元数据
|
||||
type ClaudeMetadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeTool Claude 工具定义
|
||||
// 支持两种格式:
|
||||
// 1. 标准格式: { "name": "...", "description": "...", "input_schema": {...} }
|
||||
// 2. Custom 格式 (MCP): { "type": "custom", "name": "...", "custom": { "description": "...", "input_schema": {...} } }
|
||||
type ClaudeTool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" 或空(标准格式)
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"` // 标准格式使用
|
||||
InputSchema map[string]any `json:"input_schema,omitempty"` // 标准格式使用
|
||||
Custom *CustomToolSpec `json:"custom,omitempty"` // custom 格式使用
|
||||
}
|
||||
|
||||
// CustomToolSpec MCP custom 工具规格
|
||||
type CustomToolSpec struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]any `json:"input_schema"`
|
||||
}
|
||||
|
||||
// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格)
|
||||
type ClaudeCustomToolSpec = CustomToolSpec
|
||||
|
||||
// SystemBlock system prompt 数组形式的元素
|
||||
type SystemBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// ContentBlock Claude 消息内容块(解析后)
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
// text
|
||||
Text string `json:"text,omitempty"`
|
||||
// thinking
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
// tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
// tool_result
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
// image
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
}
|
||||
|
||||
// ImageSource Claude 图片来源
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64"
|
||||
MediaType string `json:"media_type"` // "image/png", "image/jpeg" 等
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// ClaudeResponse Claude Messages API 响应
|
||||
type ClaudeResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Model string `json:"model"`
|
||||
Content []ClaudeContentItem `json:"content"`
|
||||
StopReason string `json:"stop_reason,omitempty"` // end_turn, tool_use, max_tokens
|
||||
StopSequence *string `json:"stop_sequence,omitempty"` // null 或具体值
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// ClaudeContentItem Claude 响应内容项
|
||||
type ClaudeContentItem struct {
|
||||
Type string `json:"type"` // text, thinking, tool_use
|
||||
|
||||
// text
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// thinking
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
|
||||
// tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeUsage Claude 用量统计
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeError Claude 错误响应
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail 错误详情
|
||||
type ErrorDetail struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// modelDef Antigravity 模型定义(内部使用)
|
||||
type modelDef struct {
|
||||
ID string
|
||||
DisplayName string
|
||||
CreatedAt string // 仅 Claude API 格式使用
|
||||
}
|
||||
|
||||
// Antigravity 支持的 Claude 模型
|
||||
var claudeModels = []modelDef{
|
||||
{ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
}
|
||||
|
||||
// Antigravity 支持的 Gemini 模型
|
||||
var geminiModels = []modelDef{
|
||||
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
}
|
||||
|
||||
// ========== Claude API 格式 (/v1/models) ==========
|
||||
|
||||
// ClaudeModel Claude API 模型格式
|
||||
type ClaudeModel struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// DefaultModels 返回 Claude API 格式的模型列表(Claude + Gemini)
|
||||
func DefaultModels() []ClaudeModel {
|
||||
all := append(claudeModels, geminiModels...)
|
||||
result := make([]ClaudeModel, len(all))
|
||||
for i, m := range all {
|
||||
result[i] = ClaudeModel{ID: m.ID, Type: "model", DisplayName: m.DisplayName, CreatedAt: m.CreatedAt}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ========== Gemini v1beta 格式 (/v1beta/models) ==========
|
||||
|
||||
// GeminiModel Gemini v1beta 模型格式
|
||||
type GeminiModel struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiModelsListResponse Gemini v1beta 模型列表响应
|
||||
type GeminiModelsListResponse struct {
|
||||
Models []GeminiModel `json:"models"`
|
||||
}
|
||||
|
||||
var defaultGeminiMethods = []string{"generateContent", "streamGenerateContent"}
|
||||
|
||||
// DefaultGeminiModels 返回 Gemini v1beta 格式的模型列表(仅 Gemini 模型)
|
||||
func DefaultGeminiModels() []GeminiModel {
|
||||
result := make([]GeminiModel, len(geminiModels))
|
||||
for i, m := range geminiModels {
|
||||
result[i] = GeminiModel{Name: "models/" + m.ID, DisplayName: m.DisplayName, SupportedGenerationMethods: defaultGeminiMethods}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// FallbackGeminiModelsList 返回 Gemini v1beta 格式的模型列表响应
|
||||
func FallbackGeminiModelsList() GeminiModelsListResponse {
|
||||
return GeminiModelsListResponse{Models: DefaultGeminiModels()}
|
||||
}
|
||||
|
||||
// FallbackGeminiModel 返回单个模型信息(v1beta 格式)
|
||||
func FallbackGeminiModel(model string) GeminiModel {
|
||||
if model == "" {
|
||||
return GeminiModel{Name: "models/unknown", SupportedGenerationMethods: defaultGeminiMethods}
|
||||
}
|
||||
name := model
|
||||
if len(model) < 7 || model[:7] != "models/" {
|
||||
name = "models/" + model
|
||||
}
|
||||
return GeminiModel{Name: name, SupportedGenerationMethods: defaultGeminiMethods}
|
||||
}
|
||||
package antigravity
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Claude 请求/响应类型定义
|
||||
|
||||
// ClaudeRequest Claude Messages API 请求
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ClaudeMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string 或 []SystemBlock
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
Tools []ClaudeTool `json:"tools,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Metadata *ClaudeMetadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeMessage Claude 消息
|
||||
type ClaudeMessage struct {
|
||||
Role string `json:"role"` // user, assistant
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
// ThinkingConfig Thinking 配置
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"` // "enabled" or "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"` // thinking budget
|
||||
}
|
||||
|
||||
// ClaudeMetadata 请求元数据
|
||||
type ClaudeMetadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeTool Claude 工具定义
|
||||
// 支持两种格式:
|
||||
// 1. 标准格式: { "name": "...", "description": "...", "input_schema": {...} }
|
||||
// 2. Custom 格式 (MCP): { "type": "custom", "name": "...", "custom": { "description": "...", "input_schema": {...} } }
|
||||
type ClaudeTool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" 或空(标准格式)
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"` // 标准格式使用
|
||||
InputSchema map[string]any `json:"input_schema,omitempty"` // 标准格式使用
|
||||
Custom *CustomToolSpec `json:"custom,omitempty"` // custom 格式使用
|
||||
}
|
||||
|
||||
// CustomToolSpec MCP custom 工具规格
|
||||
type CustomToolSpec struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]any `json:"input_schema"`
|
||||
}
|
||||
|
||||
// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格)
|
||||
type ClaudeCustomToolSpec = CustomToolSpec
|
||||
|
||||
// SystemBlock system prompt 数组形式的元素
|
||||
type SystemBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// ContentBlock Claude 消息内容块(解析后)
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
// text
|
||||
Text string `json:"text,omitempty"`
|
||||
// thinking
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
// tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
// tool_result
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
// image
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
}
|
||||
|
||||
// ImageSource Claude 图片来源
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64"
|
||||
MediaType string `json:"media_type"` // "image/png", "image/jpeg" 等
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// ClaudeResponse Claude Messages API 响应
|
||||
type ClaudeResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Model string `json:"model"`
|
||||
Content []ClaudeContentItem `json:"content"`
|
||||
StopReason string `json:"stop_reason,omitempty"` // end_turn, tool_use, max_tokens
|
||||
StopSequence *string `json:"stop_sequence,omitempty"` // null 或具体值
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// ClaudeContentItem Claude 响应内容项
|
||||
type ClaudeContentItem struct {
|
||||
Type string `json:"type"` // text, thinking, tool_use
|
||||
|
||||
// text
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// thinking
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
|
||||
// tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeUsage Claude 用量统计
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeError Claude 错误响应
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail 错误详情
|
||||
type ErrorDetail struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// modelDef Antigravity 模型定义(内部使用)
|
||||
type modelDef struct {
|
||||
ID string
|
||||
DisplayName string
|
||||
CreatedAt string // 仅 Claude API 格式使用
|
||||
}
|
||||
|
||||
// Antigravity 支持的 Claude 模型
|
||||
var claudeModels = []modelDef{
|
||||
{ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
}
|
||||
|
||||
// Antigravity 支持的 Gemini 模型
|
||||
var geminiModels = []modelDef{
|
||||
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
}
|
||||
|
||||
// ========== Claude API 格式 (/v1/models) ==========
|
||||
|
||||
// ClaudeModel Claude API 模型格式
|
||||
type ClaudeModel struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// DefaultModels 返回 Claude API 格式的模型列表(Claude + Gemini)
|
||||
func DefaultModels() []ClaudeModel {
|
||||
all := append(claudeModels, geminiModels...)
|
||||
result := make([]ClaudeModel, len(all))
|
||||
for i, m := range all {
|
||||
result[i] = ClaudeModel{ID: m.ID, Type: "model", DisplayName: m.DisplayName, CreatedAt: m.CreatedAt}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ========== Gemini v1beta 格式 (/v1beta/models) ==========
|
||||
|
||||
// GeminiModel Gemini v1beta 模型格式
|
||||
type GeminiModel struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiModelsListResponse Gemini v1beta 模型列表响应
|
||||
type GeminiModelsListResponse struct {
|
||||
Models []GeminiModel `json:"models"`
|
||||
}
|
||||
|
||||
var defaultGeminiMethods = []string{"generateContent", "streamGenerateContent"}
|
||||
|
||||
// DefaultGeminiModels 返回 Gemini v1beta 格式的模型列表(仅 Gemini 模型)
|
||||
func DefaultGeminiModels() []GeminiModel {
|
||||
result := make([]GeminiModel, len(geminiModels))
|
||||
for i, m := range geminiModels {
|
||||
result[i] = GeminiModel{Name: "models/" + m.ID, DisplayName: m.DisplayName, SupportedGenerationMethods: defaultGeminiMethods}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// FallbackGeminiModelsList 返回 Gemini v1beta 格式的模型列表响应
|
||||
func FallbackGeminiModelsList() GeminiModelsListResponse {
|
||||
return GeminiModelsListResponse{Models: DefaultGeminiModels()}
|
||||
}
|
||||
|
||||
// FallbackGeminiModel 返回单个模型信息(v1beta 格式)
|
||||
func FallbackGeminiModel(model string) GeminiModel {
|
||||
if model == "" {
|
||||
return GeminiModel{Name: "models/unknown", SupportedGenerationMethods: defaultGeminiMethods}
|
||||
}
|
||||
name := model
|
||||
if len(model) < 7 || model[:7] != "models/" {
|
||||
name = "models/" + model
|
||||
}
|
||||
return GeminiModel{Name: name, SupportedGenerationMethods: defaultGeminiMethods}
|
||||
}
|
||||
|
||||
@@ -1,327 +1,327 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewAPIRequest 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// TokenResponse Google OAuth token 响应
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo Google 用户信息
|
||||
type UserInfo struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name,omitempty"`
|
||||
GivenName string `json:"given_name,omitempty"`
|
||||
FamilyName string `json:"family_name,omitempty"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
}
|
||||
|
||||
// LoadCodeAssistRequest loadCodeAssist 请求
|
||||
type LoadCodeAssistRequest struct {
|
||||
Metadata struct {
|
||||
IDEType string `json:"ideType"`
|
||||
} `json:"metadata"`
|
||||
}
|
||||
|
||||
// TierInfo 账户类型信息
|
||||
type TierInfo struct {
|
||||
ID string `json:"id"` // free-tier, g1-pro-tier, g1-ultra-tier
|
||||
Name string `json:"name"` // 显示名称
|
||||
Description string `json:"description"` // 描述
|
||||
}
|
||||
|
||||
// IneligibleTier 不符合条件的层级信息
|
||||
type IneligibleTier struct {
|
||||
Tier *TierInfo `json:"tier,omitempty"`
|
||||
// ReasonCode 不符合条件的原因代码,如 INELIGIBLE_ACCOUNT
|
||||
ReasonCode string `json:"reasonCode,omitempty"`
|
||||
ReasonMessage string `json:"reasonMessage,omitempty"`
|
||||
}
|
||||
|
||||
// LoadCodeAssistResponse loadCodeAssist 响应
|
||||
type LoadCodeAssistResponse struct {
|
||||
CloudAICompanionProject string `json:"cloudaicompanionProject"`
|
||||
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
||||
PaidTier *TierInfo `json:"paidTier,omitempty"`
|
||||
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
||||
}
|
||||
|
||||
// GetTier 获取账户类型
|
||||
// 优先返回 paidTier(付费订阅级别),否则返回 currentTier
|
||||
func (r *LoadCodeAssistResponse) GetTier() string {
|
||||
if r.PaidTier != nil && r.PaidTier.ID != "" {
|
||||
return r.PaidTier.ID
|
||||
}
|
||||
if r.CurrentTier != nil {
|
||||
return r.CurrentTier.ID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Client Antigravity API 客户端
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewClient(proxyURL string) *Client {
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
if proxyURLParsed, err := url.Parse(proxyURL); err == nil {
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURLParsed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", ClientSecret)
|
||||
params.Set("code", code)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("grant_type", "authorization_code")
|
||||
params.Set("code_verifier", codeVerifier)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenURL, strings.NewReader(params.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token 交换请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token 交换失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("token 解析失败: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// RefreshToken 刷新 access_token
|
||||
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", ClientSecret)
|
||||
params.Set("refresh_token", refreshToken)
|
||||
params.Set("grant_type", "refresh_token")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenURL, strings.NewReader(params.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token 刷新请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token 刷新失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("token 解析失败: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// GetUserInfo 获取用户信息
|
||||
func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("用户信息请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("获取用户信息失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var userInfo UserInfo
|
||||
if err := json.Unmarshal(bodyBytes, &userInfo); err != nil {
|
||||
return nil, fmt.Errorf("用户信息解析失败: %w", err)
|
||||
}
|
||||
|
||||
return &userInfo, nil
|
||||
}
|
||||
|
||||
// LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON
|
||||
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
|
||||
reqBody := LoadCodeAssistRequest{}
|
||||
reqBody.Metadata.IDEType = "ANTIGRAVITY"
|
||||
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := BaseURL + "/v1internal:loadCodeAssist"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("loadCodeAssist 请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
var loadResp LoadCodeAssistResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &loadResp, rawResp, nil
|
||||
}
|
||||
|
||||
// ModelQuotaInfo 模型配额信息
|
||||
type ModelQuotaInfo struct {
|
||||
RemainingFraction float64 `json:"remainingFraction"`
|
||||
ResetTime string `json:"resetTime,omitempty"`
|
||||
}
|
||||
|
||||
// ModelInfo 模型信息
|
||||
type ModelInfo struct {
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||
type FetchAvailableModelsRequest struct {
|
||||
Project string `json:"project"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||
type FetchAvailableModelsResponse struct {
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
}
|
||||
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
|
||||
reqBody := FetchAvailableModelsRequest{Project: projectID}
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
apiURL := BaseURL + "/v1internal:fetchAvailableModels"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
var modelsResp FetchAvailableModelsResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &modelsResp, rawResp, nil
|
||||
}
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewAPIRequest 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// TokenResponse Google OAuth token 响应
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo Google 用户信息
|
||||
type UserInfo struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name,omitempty"`
|
||||
GivenName string `json:"given_name,omitempty"`
|
||||
FamilyName string `json:"family_name,omitempty"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
}
|
||||
|
||||
// LoadCodeAssistRequest loadCodeAssist 请求
|
||||
type LoadCodeAssistRequest struct {
|
||||
Metadata struct {
|
||||
IDEType string `json:"ideType"`
|
||||
} `json:"metadata"`
|
||||
}
|
||||
|
||||
// TierInfo 账户类型信息
|
||||
type TierInfo struct {
|
||||
ID string `json:"id"` // free-tier, g1-pro-tier, g1-ultra-tier
|
||||
Name string `json:"name"` // 显示名称
|
||||
Description string `json:"description"` // 描述
|
||||
}
|
||||
|
||||
// IneligibleTier 不符合条件的层级信息
|
||||
type IneligibleTier struct {
|
||||
Tier *TierInfo `json:"tier,omitempty"`
|
||||
// ReasonCode 不符合条件的原因代码,如 INELIGIBLE_ACCOUNT
|
||||
ReasonCode string `json:"reasonCode,omitempty"`
|
||||
ReasonMessage string `json:"reasonMessage,omitempty"`
|
||||
}
|
||||
|
||||
// LoadCodeAssistResponse loadCodeAssist 响应
|
||||
type LoadCodeAssistResponse struct {
|
||||
CloudAICompanionProject string `json:"cloudaicompanionProject"`
|
||||
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
||||
PaidTier *TierInfo `json:"paidTier,omitempty"`
|
||||
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
||||
}
|
||||
|
||||
// GetTier 获取账户类型
|
||||
// 优先返回 paidTier(付费订阅级别),否则返回 currentTier
|
||||
func (r *LoadCodeAssistResponse) GetTier() string {
|
||||
if r.PaidTier != nil && r.PaidTier.ID != "" {
|
||||
return r.PaidTier.ID
|
||||
}
|
||||
if r.CurrentTier != nil {
|
||||
return r.CurrentTier.ID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Client Antigravity API 客户端
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewClient(proxyURL string) *Client {
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
if proxyURLParsed, err := url.Parse(proxyURL); err == nil {
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURLParsed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", ClientSecret)
|
||||
params.Set("code", code)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("grant_type", "authorization_code")
|
||||
params.Set("code_verifier", codeVerifier)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenURL, strings.NewReader(params.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token 交换请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token 交换失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("token 解析失败: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// RefreshToken 刷新 access_token
|
||||
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", ClientSecret)
|
||||
params.Set("refresh_token", refreshToken)
|
||||
params.Set("grant_type", "refresh_token")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenURL, strings.NewReader(params.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token 刷新请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token 刷新失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("token 解析失败: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// GetUserInfo 获取用户信息
|
||||
func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("用户信息请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("获取用户信息失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var userInfo UserInfo
|
||||
if err := json.Unmarshal(bodyBytes, &userInfo); err != nil {
|
||||
return nil, fmt.Errorf("用户信息解析失败: %w", err)
|
||||
}
|
||||
|
||||
return &userInfo, nil
|
||||
}
|
||||
|
||||
// LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON
|
||||
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
|
||||
reqBody := LoadCodeAssistRequest{}
|
||||
reqBody.Metadata.IDEType = "ANTIGRAVITY"
|
||||
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := BaseURL + "/v1internal:loadCodeAssist"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("loadCodeAssist 请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
var loadResp LoadCodeAssistResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &loadResp, rawResp, nil
|
||||
}
|
||||
|
||||
// ModelQuotaInfo 模型配额信息
|
||||
type ModelQuotaInfo struct {
|
||||
RemainingFraction float64 `json:"remainingFraction"`
|
||||
ResetTime string `json:"resetTime,omitempty"`
|
||||
}
|
||||
|
||||
// ModelInfo 模型信息
|
||||
type ModelInfo struct {
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||
type FetchAvailableModelsRequest struct {
|
||||
Project string `json:"project"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||
type FetchAvailableModelsResponse struct {
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
}
|
||||
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
|
||||
reqBody := FetchAvailableModelsRequest{Project: projectID}
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
apiURL := BaseURL + "/v1internal:fetchAvailableModels"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
var modelsResp FetchAvailableModelsResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &modelsResp, rawResp, nil
|
||||
}
|
||||
|
||||
@@ -1,168 +1,168 @@
|
||||
package antigravity
|
||||
|
||||
// Gemini v1internal 请求/响应类型定义
|
||||
|
||||
// V1InternalRequest v1internal 请求包装
|
||||
type V1InternalRequest struct {
|
||||
Project string `json:"project"`
|
||||
RequestID string `json:"requestId"`
|
||||
UserAgent string `json:"userAgent"`
|
||||
RequestType string `json:"requestType,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Request GeminiRequest `json:"request"`
|
||||
}
|
||||
|
||||
// GeminiRequest Gemini 请求内容
|
||||
type GeminiRequest struct {
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
SystemInstruction *GeminiContent `json:"systemInstruction,omitempty"`
|
||||
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools []GeminiToolDeclaration `json:"tools,omitempty"`
|
||||
ToolConfig *GeminiToolConfig `json:"toolConfig,omitempty"`
|
||||
SafetySettings []GeminiSafetySetting `json:"safetySettings,omitempty"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiContent Gemini 内容
|
||||
type GeminiContent struct {
|
||||
Role string `json:"role"` // user, model
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
// GeminiPart Gemini 内容部分
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
ThoughtSignature string `json:"thoughtSignature,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiInlineData Gemini 内联数据(图片等)
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// GeminiFunctionCall Gemini 函数调用
|
||||
type GeminiFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args any `json:"args,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiFunctionResponse Gemini 函数响应
|
||||
type GeminiFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response map[string]any `json:"response"`
|
||||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiGenerationConfig Gemini 生成配置
|
||||
type GeminiGenerationConfig struct {
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"topP,omitempty"`
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiThinkingConfig Gemini thinking 配置
|
||||
type GeminiThinkingConfig struct {
|
||||
IncludeThoughts bool `json:"includeThoughts"`
|
||||
ThinkingBudget int `json:"thinkingBudget,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiToolDeclaration Gemini 工具声明
|
||||
type GeminiToolDeclaration struct {
|
||||
FunctionDeclarations []GeminiFunctionDecl `json:"functionDeclarations,omitempty"`
|
||||
GoogleSearch *GeminiGoogleSearch `json:"googleSearch,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiFunctionDecl Gemini 函数声明
|
||||
type GeminiFunctionDecl struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiGoogleSearch Gemini Google 搜索工具
|
||||
type GeminiGoogleSearch struct {
|
||||
EnhancedContent *GeminiEnhancedContent `json:"enhancedContent,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiEnhancedContent 增强内容配置
|
||||
type GeminiEnhancedContent struct {
|
||||
ImageSearch *GeminiImageSearch `json:"imageSearch,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiImageSearch 图片搜索配置
|
||||
type GeminiImageSearch struct {
|
||||
MaxResultCount int `json:"maxResultCount,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiToolConfig Gemini 工具配置
|
||||
type GeminiToolConfig struct {
|
||||
FunctionCallingConfig *GeminiFunctionCallingConfig `json:"functionCallingConfig,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiFunctionCallingConfig 函数调用配置
|
||||
type GeminiFunctionCallingConfig struct {
|
||||
Mode string `json:"mode,omitempty"` // VALIDATED, AUTO, NONE
|
||||
}
|
||||
|
||||
// GeminiSafetySetting Gemini 安全设置
|
||||
type GeminiSafetySetting struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
// V1InternalResponse v1internal 响应包装
|
||||
type V1InternalResponse struct {
|
||||
Response GeminiResponse `json:"response"`
|
||||
ResponseID string `json:"responseId,omitempty"`
|
||||
ModelVersion string `json:"modelVersion,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiResponse Gemini 响应
|
||||
type GeminiResponse struct {
|
||||
Candidates []GeminiCandidate `json:"candidates,omitempty"`
|
||||
UsageMetadata *GeminiUsageMetadata `json:"usageMetadata,omitempty"`
|
||||
ResponseID string `json:"responseId,omitempty"`
|
||||
ModelVersion string `json:"modelVersion,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiCandidate Gemini 候选响应
|
||||
type GeminiCandidate struct {
|
||||
Content *GeminiContent `json:"content,omitempty"`
|
||||
FinishReason string `json:"finishReason,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiUsageMetadata Gemini 用量元数据
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount,omitempty"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
}
|
||||
|
||||
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
|
||||
var DefaultSafetySettings = []GeminiSafetySetting{
|
||||
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},
|
||||
{Category: "HARM_CATEGORY_HATE_SPEECH", Threshold: "OFF"},
|
||||
{Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", Threshold: "OFF"},
|
||||
{Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Threshold: "OFF"},
|
||||
{Category: "HARM_CATEGORY_CIVIC_INTEGRITY", Threshold: "OFF"},
|
||||
}
|
||||
|
||||
// DefaultStopSequences 默认停止序列
|
||||
var DefaultStopSequences = []string{
|
||||
"<|user|>",
|
||||
"<|endoftext|>",
|
||||
"<|end_of_turn|>",
|
||||
"[DONE]",
|
||||
"\n\nHuman:",
|
||||
}
|
||||
package antigravity
|
||||
|
||||
// Gemini v1internal 请求/响应类型定义
|
||||
|
||||
// V1InternalRequest v1internal 请求包装
|
||||
type V1InternalRequest struct {
|
||||
Project string `json:"project"`
|
||||
RequestID string `json:"requestId"`
|
||||
UserAgent string `json:"userAgent"`
|
||||
RequestType string `json:"requestType,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Request GeminiRequest `json:"request"`
|
||||
}
|
||||
|
||||
// GeminiRequest Gemini 请求内容
|
||||
type GeminiRequest struct {
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
SystemInstruction *GeminiContent `json:"systemInstruction,omitempty"`
|
||||
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools []GeminiToolDeclaration `json:"tools,omitempty"`
|
||||
ToolConfig *GeminiToolConfig `json:"toolConfig,omitempty"`
|
||||
SafetySettings []GeminiSafetySetting `json:"safetySettings,omitempty"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiContent Gemini 内容
|
||||
type GeminiContent struct {
|
||||
Role string `json:"role"` // user, model
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
// GeminiPart Gemini 内容部分
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
ThoughtSignature string `json:"thoughtSignature,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiInlineData Gemini 内联数据(图片等)
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// GeminiFunctionCall Gemini 函数调用
|
||||
type GeminiFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args any `json:"args,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiFunctionResponse Gemini 函数响应
|
||||
type GeminiFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response map[string]any `json:"response"`
|
||||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiGenerationConfig Gemini 生成配置
|
||||
type GeminiGenerationConfig struct {
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"topP,omitempty"`
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiThinkingConfig Gemini thinking 配置
|
||||
type GeminiThinkingConfig struct {
|
||||
IncludeThoughts bool `json:"includeThoughts"`
|
||||
ThinkingBudget int `json:"thinkingBudget,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiToolDeclaration Gemini 工具声明
|
||||
type GeminiToolDeclaration struct {
|
||||
FunctionDeclarations []GeminiFunctionDecl `json:"functionDeclarations,omitempty"`
|
||||
GoogleSearch *GeminiGoogleSearch `json:"googleSearch,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiFunctionDecl Gemini 函数声明
|
||||
type GeminiFunctionDecl struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiGoogleSearch Gemini Google 搜索工具
|
||||
type GeminiGoogleSearch struct {
|
||||
EnhancedContent *GeminiEnhancedContent `json:"enhancedContent,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiEnhancedContent 增强内容配置
|
||||
type GeminiEnhancedContent struct {
|
||||
ImageSearch *GeminiImageSearch `json:"imageSearch,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiImageSearch 图片搜索配置
|
||||
type GeminiImageSearch struct {
|
||||
MaxResultCount int `json:"maxResultCount,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiToolConfig Gemini 工具配置
|
||||
type GeminiToolConfig struct {
|
||||
FunctionCallingConfig *GeminiFunctionCallingConfig `json:"functionCallingConfig,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiFunctionCallingConfig 函数调用配置
|
||||
type GeminiFunctionCallingConfig struct {
|
||||
Mode string `json:"mode,omitempty"` // VALIDATED, AUTO, NONE
|
||||
}
|
||||
|
||||
// GeminiSafetySetting Gemini 安全设置
|
||||
type GeminiSafetySetting struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
// V1InternalResponse v1internal 响应包装
|
||||
type V1InternalResponse struct {
|
||||
Response GeminiResponse `json:"response"`
|
||||
ResponseID string `json:"responseId,omitempty"`
|
||||
ModelVersion string `json:"modelVersion,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiResponse Gemini 响应
|
||||
type GeminiResponse struct {
|
||||
Candidates []GeminiCandidate `json:"candidates,omitempty"`
|
||||
UsageMetadata *GeminiUsageMetadata `json:"usageMetadata,omitempty"`
|
||||
ResponseID string `json:"responseId,omitempty"`
|
||||
ModelVersion string `json:"modelVersion,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiCandidate Gemini 候选响应
|
||||
type GeminiCandidate struct {
|
||||
Content *GeminiContent `json:"content,omitempty"`
|
||||
FinishReason string `json:"finishReason,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiUsageMetadata Gemini 用量元数据
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount,omitempty"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
}
|
||||
|
||||
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
|
||||
var DefaultSafetySettings = []GeminiSafetySetting{
|
||||
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},
|
||||
{Category: "HARM_CATEGORY_HATE_SPEECH", Threshold: "OFF"},
|
||||
{Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", Threshold: "OFF"},
|
||||
{Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Threshold: "OFF"},
|
||||
{Category: "HARM_CATEGORY_CIVIC_INTEGRITY", Threshold: "OFF"},
|
||||
}
|
||||
|
||||
// DefaultStopSequences 默认停止序列
|
||||
var DefaultStopSequences = []string{
|
||||
"<|user|>",
|
||||
"<|endoftext|>",
|
||||
"<|end_of_turn|>",
|
||||
"[DONE]",
|
||||
"\n\nHuman:",
|
||||
}
|
||||
|
||||
@@ -1,200 +1,200 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// Google OAuth 端点
|
||||
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
TokenURL = "https://oauth2.googleapis.com/token"
|
||||
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
// Antigravity OAuth 客户端凭证
|
||||
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
ClientSecret = "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"
|
||||
|
||||
// 固定的 redirect_uri(用户需手动复制 code)
|
||||
RedirectURI = "http://localhost:8085/callback"
|
||||
|
||||
// OAuth scopes
|
||||
Scopes = "https://www.googleapis.com/auth/cloud-platform " +
|
||||
"https://www.googleapis.com/auth/userinfo.email " +
|
||||
"https://www.googleapis.com/auth/userinfo.profile " +
|
||||
"https://www.googleapis.com/auth/cclog " +
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||
|
||||
// API 端点
|
||||
BaseURL = "https://cloudcode-pa.googleapis.com"
|
||||
|
||||
// User-Agent
|
||||
UserAgent = "antigravity/1.11.9 windows/amd64"
|
||||
|
||||
// Session 过期时间
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// OAuthSession 保存 OAuth 授权流程的临时状态
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SessionStore OAuth session 存储
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
func (s *SessionStore) Stop() {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
default:
|
||||
close(s.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
func base64URLEncode(data []byte) string {
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL 构建 Google OAuth 授权 URL
|
||||
func BuildAuthorizationURL(state, codeChallenge string) string {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("scope", Scopes)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
params.Set("access_type", "offline")
|
||||
params.Set("prompt", "consent")
|
||||
params.Set("include_granted_scopes", "true")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// GenerateMockProjectID 生成随机 project_id(当 API 不返回时使用)
|
||||
// 格式:{形容词}-{名词}-{5位随机字符}
|
||||
func GenerateMockProjectID() string {
|
||||
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
||||
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
||||
|
||||
randBytes, _ := GenerateRandomBytes(7)
|
||||
|
||||
adj := adjectives[int(randBytes[0])%len(adjectives)]
|
||||
noun := nouns[int(randBytes[1])%len(nouns)]
|
||||
|
||||
// 生成 5 位随机字符(a-z0-9)
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
suffix := make([]byte, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
suffix[i] = charset[int(randBytes[i+2])%len(charset)]
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix))
|
||||
}
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// Google OAuth 端点
|
||||
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
TokenURL = "https://oauth2.googleapis.com/token"
|
||||
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
// Antigravity OAuth 客户端凭证
|
||||
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
ClientSecret = "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"
|
||||
|
||||
// 固定的 redirect_uri(用户需手动复制 code)
|
||||
RedirectURI = "http://localhost:8085/callback"
|
||||
|
||||
// OAuth scopes
|
||||
Scopes = "https://www.googleapis.com/auth/cloud-platform " +
|
||||
"https://www.googleapis.com/auth/userinfo.email " +
|
||||
"https://www.googleapis.com/auth/userinfo.profile " +
|
||||
"https://www.googleapis.com/auth/cclog " +
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||
|
||||
// API 端点
|
||||
BaseURL = "https://cloudcode-pa.googleapis.com"
|
||||
|
||||
// User-Agent
|
||||
UserAgent = "antigravity/1.11.9 windows/amd64"
|
||||
|
||||
// Session 过期时间
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// OAuthSession 保存 OAuth 授权流程的临时状态
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SessionStore OAuth session 存储
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
func (s *SessionStore) Stop() {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
default:
|
||||
close(s.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
func base64URLEncode(data []byte) string {
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL 构建 Google OAuth 授权 URL
|
||||
func BuildAuthorizationURL(state, codeChallenge string) string {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("scope", Scopes)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
params.Set("access_type", "offline")
|
||||
params.Set("prompt", "consent")
|
||||
params.Set("include_granted_scopes", "true")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// GenerateMockProjectID 生成随机 project_id(当 API 不返回时使用)
|
||||
// 格式:{形容词}-{名词}-{5位随机字符}
|
||||
func GenerateMockProjectID() string {
|
||||
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
||||
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
||||
|
||||
randBytes, _ := GenerateRandomBytes(7)
|
||||
|
||||
adj := adjectives[int(randBytes[0])%len(adjectives)]
|
||||
noun := nouns[int(randBytes[1])%len(nouns)]
|
||||
|
||||
// 生成 5 位随机字符(a-z0-9)
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
suffix := make([]byte, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
suffix[i] = charset[int(randBytes[i+2])%len(charset)]
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix))
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,179 +1,179 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
allowDummyThought bool
|
||||
expectedParts int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Claude model - skip thinking block without signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: false,
|
||||
expectedParts: 2, // 只有两个text block
|
||||
description: "Claude模型应该跳过无signature的thinking block",
|
||||
},
|
||||
{
|
||||
name: "Claude model - keep thinking block with signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: false,
|
||||
expectedParts: 3, // 三个block都保留
|
||||
description: "Claude模型应该保留有signature的thinking block",
|
||||
},
|
||||
{
|
||||
name: "Gemini model - use dummy signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: true,
|
||||
expectedParts: 3, // 三个block都保留,thinking使用dummy signature
|
||||
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("buildParts() error = %v", err)
|
||||
}
|
||||
|
||||
if len(parts) != tt.expectedParts {
|
||||
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
|
||||
func TestBuildTools_CustomTypeTools(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tools []ClaudeTool
|
||||
expectedLen int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Standard tool format",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather information",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"location": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "标准工具格式应该正常转换",
|
||||
},
|
||||
{
|
||||
name: "Custom type tool (MCP format)",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "mcp_tool",
|
||||
Custom: &ClaudeCustomToolSpec{
|
||||
Description: "MCP tool description",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"param": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "Custom类型工具应该从Custom字段读取description和input_schema",
|
||||
},
|
||||
{
|
||||
name: "Mixed standard and custom tools",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Name: "standard_tool",
|
||||
Description: "Standard tool",
|
||||
InputSchema: map[string]any{"type": "object"},
|
||||
},
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "custom_tool",
|
||||
Custom: &ClaudeCustomToolSpec{
|
||||
Description: "Custom tool",
|
||||
InputSchema: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations
|
||||
description: "混合标准和custom工具应该都能正确转换",
|
||||
},
|
||||
{
|
||||
name: "Invalid custom tool - nil Custom field",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "invalid_custom",
|
||||
// Custom 为 nil
|
||||
},
|
||||
},
|
||||
expectedLen: 0, // 应该被跳过
|
||||
description: "Custom字段为nil的custom工具应该被跳过",
|
||||
},
|
||||
{
|
||||
name: "Invalid custom tool - nil InputSchema",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "invalid_custom",
|
||||
Custom: &ClaudeCustomToolSpec{
|
||||
Description: "Invalid",
|
||||
// InputSchema 为 nil
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 0, // 应该被跳过
|
||||
description: "InputSchema为nil的custom工具应该被跳过",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildTools(tt.tools)
|
||||
|
||||
if len(result) != tt.expectedLen {
|
||||
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
|
||||
}
|
||||
|
||||
// 验证function declarations存在
|
||||
if len(result) > 0 && result[0].FunctionDeclarations != nil {
|
||||
if len(result[0].FunctionDeclarations) != len(tt.tools) {
|
||||
t.Errorf("%s: got %d function declarations, want %d",
|
||||
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
allowDummyThought bool
|
||||
expectedParts int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Claude model - skip thinking block without signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: false,
|
||||
expectedParts: 2, // 只有两个text block
|
||||
description: "Claude模型应该跳过无signature的thinking block",
|
||||
},
|
||||
{
|
||||
name: "Claude model - keep thinking block with signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: false,
|
||||
expectedParts: 3, // 三个block都保留
|
||||
description: "Claude模型应该保留有signature的thinking block",
|
||||
},
|
||||
{
|
||||
name: "Gemini model - use dummy signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: true,
|
||||
expectedParts: 3, // 三个block都保留,thinking使用dummy signature
|
||||
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("buildParts() error = %v", err)
|
||||
}
|
||||
|
||||
if len(parts) != tt.expectedParts {
|
||||
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
|
||||
func TestBuildTools_CustomTypeTools(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tools []ClaudeTool
|
||||
expectedLen int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Standard tool format",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather information",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"location": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "标准工具格式应该正常转换",
|
||||
},
|
||||
{
|
||||
name: "Custom type tool (MCP format)",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "mcp_tool",
|
||||
Custom: &ClaudeCustomToolSpec{
|
||||
Description: "MCP tool description",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"param": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "Custom类型工具应该从Custom字段读取description和input_schema",
|
||||
},
|
||||
{
|
||||
name: "Mixed standard and custom tools",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Name: "standard_tool",
|
||||
Description: "Standard tool",
|
||||
InputSchema: map[string]any{"type": "object"},
|
||||
},
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "custom_tool",
|
||||
Custom: &ClaudeCustomToolSpec{
|
||||
Description: "Custom tool",
|
||||
InputSchema: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations
|
||||
description: "混合标准和custom工具应该都能正确转换",
|
||||
},
|
||||
{
|
||||
name: "Invalid custom tool - nil Custom field",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "invalid_custom",
|
||||
// Custom 为 nil
|
||||
},
|
||||
},
|
||||
expectedLen: 0, // 应该被跳过
|
||||
description: "Custom字段为nil的custom工具应该被跳过",
|
||||
},
|
||||
{
|
||||
name: "Invalid custom tool - nil InputSchema",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "invalid_custom",
|
||||
Custom: &ClaudeCustomToolSpec{
|
||||
Description: "Invalid",
|
||||
// InputSchema 为 nil
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 0, // 应该被跳过
|
||||
description: "InputSchema为nil的custom工具应该被跳过",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildTools(tt.tools)
|
||||
|
||||
if len(result) != tt.expectedLen {
|
||||
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
|
||||
}
|
||||
|
||||
// 验证function declarations存在
|
||||
if len(result) > 0 && result[0].FunctionDeclarations != nil {
|
||||
if len(result[0].FunctionDeclarations) != len(tt.tools) {
|
||||
t.Errorf("%s: got %d function declarations, want %d",
|
||||
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,273 +1,273 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
|
||||
func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *ClaudeUsage, error) {
|
||||
// 解包 v1internal 响应
|
||||
var v1Resp V1InternalResponse
|
||||
if err := json.Unmarshal(geminiResp, &v1Resp); err != nil {
|
||||
// 尝试直接解析为 GeminiResponse
|
||||
var directResp GeminiResponse
|
||||
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
|
||||
return nil, nil, fmt.Errorf("parse gemini response: %w", err)
|
||||
}
|
||||
v1Resp.Response = directResp
|
||||
v1Resp.ResponseID = directResp.ResponseID
|
||||
v1Resp.ModelVersion = directResp.ModelVersion
|
||||
}
|
||||
|
||||
// 使用处理器转换
|
||||
processor := NewNonStreamingProcessor()
|
||||
claudeResp := processor.Process(&v1Resp.Response, v1Resp.ResponseID, originalModel)
|
||||
|
||||
// 序列化
|
||||
respBytes, err := json.Marshal(claudeResp)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal claude response: %w", err)
|
||||
}
|
||||
|
||||
return respBytes, &claudeResp.Usage, nil
|
||||
}
|
||||
|
||||
// NonStreamingProcessor 非流式响应处理器
|
||||
type NonStreamingProcessor struct {
|
||||
contentBlocks []ClaudeContentItem
|
||||
textBuilder string
|
||||
thinkingBuilder string
|
||||
thinkingSignature string
|
||||
trailingSignature string
|
||||
hasToolCall bool
|
||||
}
|
||||
|
||||
// NewNonStreamingProcessor 创建非流式响应处理器
|
||||
func NewNonStreamingProcessor() *NonStreamingProcessor {
|
||||
return &NonStreamingProcessor{
|
||||
contentBlocks: make([]ClaudeContentItem, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Process 处理 Gemini 响应
|
||||
func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID, originalModel string) *ClaudeResponse {
|
||||
// 获取 parts
|
||||
var parts []GeminiPart
|
||||
if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil {
|
||||
parts = geminiResp.Candidates[0].Content.Parts
|
||||
}
|
||||
|
||||
// 处理所有 parts
|
||||
for _, part := range parts {
|
||||
p.processPart(&part)
|
||||
}
|
||||
|
||||
// 刷新剩余内容
|
||||
p.flushThinking()
|
||||
p.flushText()
|
||||
|
||||
// 处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: p.trailingSignature,
|
||||
})
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
return p.buildResponse(geminiResp, responseID, originalModel)
|
||||
}
|
||||
|
||||
// processPart 处理单个 part
|
||||
func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
|
||||
signature := part.ThoughtSignature
|
||||
|
||||
// 1. FunctionCall 处理
|
||||
if part.FunctionCall != nil {
|
||||
p.flushThinking()
|
||||
p.flushText()
|
||||
|
||||
// 处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: p.trailingSignature,
|
||||
})
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
p.hasToolCall = true
|
||||
|
||||
// 生成 tool_use id
|
||||
toolID := part.FunctionCall.ID
|
||||
if toolID == "" {
|
||||
toolID = fmt.Sprintf("%s-%s", part.FunctionCall.Name, generateRandomID())
|
||||
}
|
||||
|
||||
item := ClaudeContentItem{
|
||||
Type: "tool_use",
|
||||
ID: toolID,
|
||||
Name: part.FunctionCall.Name,
|
||||
Input: part.FunctionCall.Args,
|
||||
}
|
||||
|
||||
if signature != "" {
|
||||
item.Signature = signature
|
||||
}
|
||||
|
||||
p.contentBlocks = append(p.contentBlocks, item)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Text 处理
|
||||
if part.Text != "" || part.Thought {
|
||||
if part.Thought {
|
||||
// Thinking part
|
||||
p.flushText()
|
||||
|
||||
// 处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
p.flushThinking()
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: p.trailingSignature,
|
||||
})
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
p.thinkingBuilder += part.Text
|
||||
if signature != "" {
|
||||
p.thinkingSignature = signature
|
||||
}
|
||||
} else {
|
||||
// 普通 Text
|
||||
if part.Text == "" {
|
||||
// 空 text 带签名 - 暂存
|
||||
if signature != "" {
|
||||
p.trailingSignature = signature
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
p.flushThinking()
|
||||
|
||||
// 处理之前的 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
p.flushText()
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: p.trailingSignature,
|
||||
})
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
p.textBuilder += part.Text
|
||||
|
||||
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
|
||||
if signature != "" {
|
||||
p.flushText()
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: signature,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. InlineData (Image) 处理
|
||||
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||
p.flushThinking()
|
||||
markdownImg := fmt.Sprintf("",
|
||||
part.InlineData.MimeType, part.InlineData.Data)
|
||||
p.textBuilder += markdownImg
|
||||
p.flushText()
|
||||
}
|
||||
}
|
||||
|
||||
// flushText 刷新 text builder
|
||||
func (p *NonStreamingProcessor) flushText() {
|
||||
if p.textBuilder == "" {
|
||||
return
|
||||
}
|
||||
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "text",
|
||||
Text: p.textBuilder,
|
||||
})
|
||||
p.textBuilder = ""
|
||||
}
|
||||
|
||||
// flushThinking 刷新 thinking builder
|
||||
func (p *NonStreamingProcessor) flushThinking() {
|
||||
if p.thinkingBuilder == "" && p.thinkingSignature == "" {
|
||||
return
|
||||
}
|
||||
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: p.thinkingBuilder,
|
||||
Signature: p.thinkingSignature,
|
||||
})
|
||||
p.thinkingBuilder = ""
|
||||
p.thinkingSignature = ""
|
||||
}
|
||||
|
||||
// buildResponse 构建最终响应
|
||||
func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, responseID, originalModel string) *ClaudeResponse {
|
||||
var finishReason string
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
finishReason = geminiResp.Candidates[0].FinishReason
|
||||
}
|
||||
|
||||
stopReason := "end_turn"
|
||||
if p.hasToolCall {
|
||||
stopReason = "tool_use"
|
||||
} else if finishReason == "MAX_TOKENS" {
|
||||
stopReason = "max_tokens"
|
||||
}
|
||||
|
||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||
usage := ClaudeUsage{}
|
||||
if geminiResp.UsageMetadata != nil {
|
||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
}
|
||||
|
||||
// 生成响应 ID
|
||||
respID := responseID
|
||||
if respID == "" {
|
||||
respID = geminiResp.ResponseID
|
||||
}
|
||||
if respID == "" {
|
||||
respID = "msg_" + generateRandomID()
|
||||
}
|
||||
|
||||
return &ClaudeResponse{
|
||||
ID: respID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: originalModel,
|
||||
Content: p.contentBlocks,
|
||||
StopReason: stopReason,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRandomID 生成随机 ID
|
||||
func generateRandomID() string {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, 12)
|
||||
for i := range result {
|
||||
result[i] = chars[i%len(chars)]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
|
||||
func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *ClaudeUsage, error) {
|
||||
// 解包 v1internal 响应
|
||||
var v1Resp V1InternalResponse
|
||||
if err := json.Unmarshal(geminiResp, &v1Resp); err != nil {
|
||||
// 尝试直接解析为 GeminiResponse
|
||||
var directResp GeminiResponse
|
||||
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
|
||||
return nil, nil, fmt.Errorf("parse gemini response: %w", err)
|
||||
}
|
||||
v1Resp.Response = directResp
|
||||
v1Resp.ResponseID = directResp.ResponseID
|
||||
v1Resp.ModelVersion = directResp.ModelVersion
|
||||
}
|
||||
|
||||
// 使用处理器转换
|
||||
processor := NewNonStreamingProcessor()
|
||||
claudeResp := processor.Process(&v1Resp.Response, v1Resp.ResponseID, originalModel)
|
||||
|
||||
// 序列化
|
||||
respBytes, err := json.Marshal(claudeResp)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal claude response: %w", err)
|
||||
}
|
||||
|
||||
return respBytes, &claudeResp.Usage, nil
|
||||
}
|
||||
|
||||
// NonStreamingProcessor 非流式响应处理器
|
||||
type NonStreamingProcessor struct {
|
||||
contentBlocks []ClaudeContentItem
|
||||
textBuilder string
|
||||
thinkingBuilder string
|
||||
thinkingSignature string
|
||||
trailingSignature string
|
||||
hasToolCall bool
|
||||
}
|
||||
|
||||
// NewNonStreamingProcessor 创建非流式响应处理器
|
||||
func NewNonStreamingProcessor() *NonStreamingProcessor {
|
||||
return &NonStreamingProcessor{
|
||||
contentBlocks: make([]ClaudeContentItem, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Process 处理 Gemini 响应
|
||||
func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID, originalModel string) *ClaudeResponse {
|
||||
// 获取 parts
|
||||
var parts []GeminiPart
|
||||
if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil {
|
||||
parts = geminiResp.Candidates[0].Content.Parts
|
||||
}
|
||||
|
||||
// 处理所有 parts
|
||||
for _, part := range parts {
|
||||
p.processPart(&part)
|
||||
}
|
||||
|
||||
// 刷新剩余内容
|
||||
p.flushThinking()
|
||||
p.flushText()
|
||||
|
||||
// 处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: p.trailingSignature,
|
||||
})
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
return p.buildResponse(geminiResp, responseID, originalModel)
|
||||
}
|
||||
|
||||
// processPart 处理单个 part
|
||||
func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
|
||||
signature := part.ThoughtSignature
|
||||
|
||||
// 1. FunctionCall 处理
|
||||
if part.FunctionCall != nil {
|
||||
p.flushThinking()
|
||||
p.flushText()
|
||||
|
||||
// 处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: p.trailingSignature,
|
||||
})
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
p.hasToolCall = true
|
||||
|
||||
// 生成 tool_use id
|
||||
toolID := part.FunctionCall.ID
|
||||
if toolID == "" {
|
||||
toolID = fmt.Sprintf("%s-%s", part.FunctionCall.Name, generateRandomID())
|
||||
}
|
||||
|
||||
item := ClaudeContentItem{
|
||||
Type: "tool_use",
|
||||
ID: toolID,
|
||||
Name: part.FunctionCall.Name,
|
||||
Input: part.FunctionCall.Args,
|
||||
}
|
||||
|
||||
if signature != "" {
|
||||
item.Signature = signature
|
||||
}
|
||||
|
||||
p.contentBlocks = append(p.contentBlocks, item)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Text 处理
|
||||
if part.Text != "" || part.Thought {
|
||||
if part.Thought {
|
||||
// Thinking part
|
||||
p.flushText()
|
||||
|
||||
// 处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
p.flushThinking()
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: p.trailingSignature,
|
||||
})
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
p.thinkingBuilder += part.Text
|
||||
if signature != "" {
|
||||
p.thinkingSignature = signature
|
||||
}
|
||||
} else {
|
||||
// 普通 Text
|
||||
if part.Text == "" {
|
||||
// 空 text 带签名 - 暂存
|
||||
if signature != "" {
|
||||
p.trailingSignature = signature
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
p.flushThinking()
|
||||
|
||||
// 处理之前的 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
p.flushText()
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: p.trailingSignature,
|
||||
})
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
p.textBuilder += part.Text
|
||||
|
||||
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
|
||||
if signature != "" {
|
||||
p.flushText()
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: signature,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. InlineData (Image) 处理
|
||||
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||
p.flushThinking()
|
||||
markdownImg := fmt.Sprintf("",
|
||||
part.InlineData.MimeType, part.InlineData.Data)
|
||||
p.textBuilder += markdownImg
|
||||
p.flushText()
|
||||
}
|
||||
}
|
||||
|
||||
// flushText 刷新 text builder
|
||||
func (p *NonStreamingProcessor) flushText() {
|
||||
if p.textBuilder == "" {
|
||||
return
|
||||
}
|
||||
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "text",
|
||||
Text: p.textBuilder,
|
||||
})
|
||||
p.textBuilder = ""
|
||||
}
|
||||
|
||||
// flushThinking 刷新 thinking builder
|
||||
func (p *NonStreamingProcessor) flushThinking() {
|
||||
if p.thinkingBuilder == "" && p.thinkingSignature == "" {
|
||||
return
|
||||
}
|
||||
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: p.thinkingBuilder,
|
||||
Signature: p.thinkingSignature,
|
||||
})
|
||||
p.thinkingBuilder = ""
|
||||
p.thinkingSignature = ""
|
||||
}
|
||||
|
||||
// buildResponse 构建最终响应
|
||||
func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, responseID, originalModel string) *ClaudeResponse {
|
||||
var finishReason string
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
finishReason = geminiResp.Candidates[0].FinishReason
|
||||
}
|
||||
|
||||
stopReason := "end_turn"
|
||||
if p.hasToolCall {
|
||||
stopReason = "tool_use"
|
||||
} else if finishReason == "MAX_TOKENS" {
|
||||
stopReason = "max_tokens"
|
||||
}
|
||||
|
||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||
usage := ClaudeUsage{}
|
||||
if geminiResp.UsageMetadata != nil {
|
||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
}
|
||||
|
||||
// 生成响应 ID
|
||||
respID := responseID
|
||||
if respID == "" {
|
||||
respID = geminiResp.ResponseID
|
||||
}
|
||||
if respID == "" {
|
||||
respID = "msg_" + generateRandomID()
|
||||
}
|
||||
|
||||
return &ClaudeResponse{
|
||||
ID: respID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: originalModel,
|
||||
Content: p.contentBlocks,
|
||||
StopReason: stopReason,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRandomID 生成随机 ID
|
||||
func generateRandomID() string {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, 12)
|
||||
for i := range result {
|
||||
result[i] = chars[i%len(chars)]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
@@ -1,464 +1,464 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// BlockType 内容块类型
|
||||
type BlockType int
|
||||
|
||||
const (
|
||||
BlockTypeNone BlockType = iota
|
||||
BlockTypeText
|
||||
BlockTypeThinking
|
||||
BlockTypeFunction
|
||||
)
|
||||
|
||||
// StreamingProcessor 流式响应处理器
|
||||
type StreamingProcessor struct {
|
||||
blockType BlockType
|
||||
blockIndex int
|
||||
messageStartSent bool
|
||||
messageStopSent bool
|
||||
usedTool bool
|
||||
pendingSignature string
|
||||
trailingSignature string
|
||||
originalModel string
|
||||
|
||||
// 累计 usage
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
cacheReadTokens int
|
||||
}
|
||||
|
||||
// NewStreamingProcessor 创建流式响应处理器
|
||||
func NewStreamingProcessor(originalModel string) *StreamingProcessor {
|
||||
return &StreamingProcessor{
|
||||
blockType: BlockTypeNone,
|
||||
originalModel: originalModel,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
||||
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "data:") {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if data == "" || data == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解包 v1internal 响应
|
||||
var v1Resp V1InternalResponse
|
||||
if err := json.Unmarshal([]byte(data), &v1Resp); err != nil {
|
||||
// 尝试直接解析为 GeminiResponse
|
||||
var directResp GeminiResponse
|
||||
if err2 := json.Unmarshal([]byte(data), &directResp); err2 != nil {
|
||||
return nil
|
||||
}
|
||||
v1Resp.Response = directResp
|
||||
v1Resp.ResponseID = directResp.ResponseID
|
||||
v1Resp.ModelVersion = directResp.ModelVersion
|
||||
}
|
||||
|
||||
geminiResp := &v1Resp.Response
|
||||
|
||||
var result bytes.Buffer
|
||||
|
||||
// 发送 message_start
|
||||
if !p.messageStartSent {
|
||||
_, _ = result.Write(p.emitMessageStart(&v1Resp))
|
||||
}
|
||||
|
||||
// 更新 usage
|
||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||
if geminiResp.UsageMetadata != nil {
|
||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
||||
p.cacheReadTokens = cached
|
||||
}
|
||||
|
||||
// 处理 parts
|
||||
if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil {
|
||||
for _, part := range geminiResp.Candidates[0].Content.Parts {
|
||||
_, _ = result.Write(p.processPart(&part))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否结束
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
finishReason := geminiResp.Candidates[0].FinishReason
|
||||
if finishReason != "" {
|
||||
_, _ = result.Write(p.emitFinish(finishReason))
|
||||
}
|
||||
}
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// Finish 结束处理,返回最终事件和用量
|
||||
func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
|
||||
var result bytes.Buffer
|
||||
|
||||
if !p.messageStopSent {
|
||||
_, _ = result.Write(p.emitFinish(""))
|
||||
}
|
||||
|
||||
usage := &ClaudeUsage{
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
return result.Bytes(), usage
|
||||
}
|
||||
|
||||
// emitMessageStart 发送 message_start 事件
|
||||
func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte {
|
||||
if p.messageStartSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
usage := ClaudeUsage{}
|
||||
if v1Resp.Response.UsageMetadata != nil {
|
||||
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
|
||||
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
}
|
||||
|
||||
responseID := v1Resp.ResponseID
|
||||
if responseID == "" {
|
||||
responseID = v1Resp.Response.ResponseID
|
||||
}
|
||||
if responseID == "" {
|
||||
responseID = "msg_" + generateRandomID()
|
||||
}
|
||||
|
||||
message := map[string]any{
|
||||
"id": responseID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []any{},
|
||||
"model": p.originalModel,
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": message,
|
||||
}
|
||||
|
||||
p.messageStartSent = true
|
||||
return p.formatSSE("message_start", event)
|
||||
}
|
||||
|
||||
// processPart 处理单个 part
|
||||
func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
|
||||
var result bytes.Buffer
|
||||
signature := part.ThoughtSignature
|
||||
|
||||
// 1. FunctionCall 处理
|
||||
if part.FunctionCall != nil {
|
||||
// 先处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.processFunctionCall(part.FunctionCall, signature))
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// 2. Text 处理
|
||||
if part.Text != "" || part.Thought {
|
||||
if part.Thought {
|
||||
_, _ = result.Write(p.processThinking(part.Text, signature))
|
||||
} else {
|
||||
_, _ = result.Write(p.processText(part.Text, signature))
|
||||
}
|
||||
}
|
||||
|
||||
// 3. InlineData (Image) 处理
|
||||
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||
markdownImg := fmt.Sprintf("",
|
||||
part.InlineData.MimeType, part.InlineData.Data)
|
||||
_, _ = result.Write(p.processText(markdownImg, ""))
|
||||
}
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// processThinking 处理 thinking
|
||||
func (p *StreamingProcessor) processThinking(text, signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
// 处理之前的 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
// 开始或继续 thinking 块
|
||||
if p.blockType != BlockTypeThinking {
|
||||
_, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
}))
|
||||
}
|
||||
|
||||
if text != "" {
|
||||
_, _ = result.Write(p.emitDelta("thinking_delta", map[string]any{
|
||||
"thinking": text,
|
||||
}))
|
||||
}
|
||||
|
||||
// 暂存签名
|
||||
if signature != "" {
|
||||
p.pendingSignature = signature
|
||||
}
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// processText 处理普通 text
|
||||
func (p *StreamingProcessor) processText(text, signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
// 空 text 带签名 - 暂存
|
||||
if text == "" {
|
||||
if signature != "" {
|
||||
p.trailingSignature = signature
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理之前的 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
// 非空 text 带签名 - 特殊处理
|
||||
if signature != "" {
|
||||
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
}))
|
||||
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
|
||||
"text": text,
|
||||
}))
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(signature))
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// 普通 text (无签名)
|
||||
if p.blockType != BlockTypeText {
|
||||
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
}))
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
|
||||
"text": text,
|
||||
}))
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// processFunctionCall 处理 function call
|
||||
func (p *StreamingProcessor) processFunctionCall(fc *GeminiFunctionCall, signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
p.usedTool = true
|
||||
|
||||
toolID := fc.ID
|
||||
if toolID == "" {
|
||||
toolID = fmt.Sprintf("%s-%s", fc.Name, generateRandomID())
|
||||
}
|
||||
|
||||
toolUse := map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": toolID,
|
||||
"name": fc.Name,
|
||||
"input": map[string]any{},
|
||||
}
|
||||
|
||||
if signature != "" {
|
||||
toolUse["signature"] = signature
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.startBlock(BlockTypeFunction, toolUse))
|
||||
|
||||
// 发送 input_json_delta
|
||||
if fc.Args != nil {
|
||||
argsJSON, _ := json.Marshal(fc.Args)
|
||||
_, _ = result.Write(p.emitDelta("input_json_delta", map[string]any{
|
||||
"partial_json": string(argsJSON),
|
||||
}))
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.endBlock())
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// startBlock 开始新的内容块
|
||||
func (p *StreamingProcessor) startBlock(blockType BlockType, contentBlock map[string]any) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
if p.blockType != BlockTypeNone {
|
||||
_, _ = result.Write(p.endBlock())
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": p.blockIndex,
|
||||
"content_block": contentBlock,
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.formatSSE("content_block_start", event))
|
||||
p.blockType = blockType
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// endBlock 结束当前内容块
|
||||
func (p *StreamingProcessor) endBlock() []byte {
|
||||
if p.blockType == BlockTypeNone {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result bytes.Buffer
|
||||
|
||||
// Thinking 块结束时发送暂存的签名
|
||||
if p.blockType == BlockTypeThinking && p.pendingSignature != "" {
|
||||
_, _ = result.Write(p.emitDelta("signature_delta", map[string]any{
|
||||
"signature": p.pendingSignature,
|
||||
}))
|
||||
p.pendingSignature = ""
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": p.blockIndex,
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.formatSSE("content_block_stop", event))
|
||||
|
||||
p.blockIndex++
|
||||
p.blockType = BlockTypeNone
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// emitDelta 发送 delta 事件
|
||||
func (p *StreamingProcessor) emitDelta(deltaType string, deltaContent map[string]any) []byte {
|
||||
delta := map[string]any{
|
||||
"type": deltaType,
|
||||
}
|
||||
for k, v := range deltaContent {
|
||||
delta[k] = v
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": p.blockIndex,
|
||||
"delta": delta,
|
||||
}
|
||||
|
||||
return p.formatSSE("content_block_delta", event)
|
||||
}
|
||||
|
||||
// emitEmptyThinkingWithSignature 发送空 thinking 块承载签名
|
||||
func (p *StreamingProcessor) emitEmptyThinkingWithSignature(signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
_, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
}))
|
||||
_, _ = result.Write(p.emitDelta("thinking_delta", map[string]any{
|
||||
"thinking": "",
|
||||
}))
|
||||
_, _ = result.Write(p.emitDelta("signature_delta", map[string]any{
|
||||
"signature": signature,
|
||||
}))
|
||||
_, _ = result.Write(p.endBlock())
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// emitFinish 发送结束事件
|
||||
func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
// 关闭最后一个块
|
||||
_, _ = result.Write(p.endBlock())
|
||||
|
||||
// 处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
// 确定 stop_reason
|
||||
stopReason := "end_turn"
|
||||
if p.usedTool {
|
||||
stopReason = "tool_use"
|
||||
} else if finishReason == "MAX_TOKENS" {
|
||||
stopReason = "max_tokens"
|
||||
}
|
||||
|
||||
usage := ClaudeUsage{
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
deltaEvent := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
||||
|
||||
if !p.messageStopSent {
|
||||
stopEvent := map[string]any{
|
||||
"type": "message_stop",
|
||||
}
|
||||
_, _ = result.Write(p.formatSSE("message_stop", stopEvent))
|
||||
p.messageStopSent = true
|
||||
}
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// formatSSE 格式化 SSE 事件
|
||||
func (p *StreamingProcessor) formatSSE(eventType string, data any) []byte {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, string(jsonData)))
|
||||
}
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// BlockType 内容块类型
|
||||
type BlockType int
|
||||
|
||||
const (
|
||||
BlockTypeNone BlockType = iota
|
||||
BlockTypeText
|
||||
BlockTypeThinking
|
||||
BlockTypeFunction
|
||||
)
|
||||
|
||||
// StreamingProcessor 流式响应处理器
|
||||
type StreamingProcessor struct {
|
||||
blockType BlockType
|
||||
blockIndex int
|
||||
messageStartSent bool
|
||||
messageStopSent bool
|
||||
usedTool bool
|
||||
pendingSignature string
|
||||
trailingSignature string
|
||||
originalModel string
|
||||
|
||||
// 累计 usage
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
cacheReadTokens int
|
||||
}
|
||||
|
||||
// NewStreamingProcessor 创建流式响应处理器
|
||||
func NewStreamingProcessor(originalModel string) *StreamingProcessor {
|
||||
return &StreamingProcessor{
|
||||
blockType: BlockTypeNone,
|
||||
originalModel: originalModel,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
||||
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "data:") {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if data == "" || data == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解包 v1internal 响应
|
||||
var v1Resp V1InternalResponse
|
||||
if err := json.Unmarshal([]byte(data), &v1Resp); err != nil {
|
||||
// 尝试直接解析为 GeminiResponse
|
||||
var directResp GeminiResponse
|
||||
if err2 := json.Unmarshal([]byte(data), &directResp); err2 != nil {
|
||||
return nil
|
||||
}
|
||||
v1Resp.Response = directResp
|
||||
v1Resp.ResponseID = directResp.ResponseID
|
||||
v1Resp.ModelVersion = directResp.ModelVersion
|
||||
}
|
||||
|
||||
geminiResp := &v1Resp.Response
|
||||
|
||||
var result bytes.Buffer
|
||||
|
||||
// 发送 message_start
|
||||
if !p.messageStartSent {
|
||||
_, _ = result.Write(p.emitMessageStart(&v1Resp))
|
||||
}
|
||||
|
||||
// 更新 usage
|
||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||
if geminiResp.UsageMetadata != nil {
|
||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
||||
p.cacheReadTokens = cached
|
||||
}
|
||||
|
||||
// 处理 parts
|
||||
if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil {
|
||||
for _, part := range geminiResp.Candidates[0].Content.Parts {
|
||||
_, _ = result.Write(p.processPart(&part))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否结束
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
finishReason := geminiResp.Candidates[0].FinishReason
|
||||
if finishReason != "" {
|
||||
_, _ = result.Write(p.emitFinish(finishReason))
|
||||
}
|
||||
}
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// Finish 结束处理,返回最终事件和用量
|
||||
func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
|
||||
var result bytes.Buffer
|
||||
|
||||
if !p.messageStopSent {
|
||||
_, _ = result.Write(p.emitFinish(""))
|
||||
}
|
||||
|
||||
usage := &ClaudeUsage{
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
return result.Bytes(), usage
|
||||
}
|
||||
|
||||
// emitMessageStart 发送 message_start 事件
|
||||
func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte {
|
||||
if p.messageStartSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
usage := ClaudeUsage{}
|
||||
if v1Resp.Response.UsageMetadata != nil {
|
||||
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
|
||||
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
}
|
||||
|
||||
responseID := v1Resp.ResponseID
|
||||
if responseID == "" {
|
||||
responseID = v1Resp.Response.ResponseID
|
||||
}
|
||||
if responseID == "" {
|
||||
responseID = "msg_" + generateRandomID()
|
||||
}
|
||||
|
||||
message := map[string]any{
|
||||
"id": responseID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []any{},
|
||||
"model": p.originalModel,
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": message,
|
||||
}
|
||||
|
||||
p.messageStartSent = true
|
||||
return p.formatSSE("message_start", event)
|
||||
}
|
||||
|
||||
// processPart 处理单个 part
|
||||
func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
|
||||
var result bytes.Buffer
|
||||
signature := part.ThoughtSignature
|
||||
|
||||
// 1. FunctionCall 处理
|
||||
if part.FunctionCall != nil {
|
||||
// 先处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.processFunctionCall(part.FunctionCall, signature))
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// 2. Text 处理
|
||||
if part.Text != "" || part.Thought {
|
||||
if part.Thought {
|
||||
_, _ = result.Write(p.processThinking(part.Text, signature))
|
||||
} else {
|
||||
_, _ = result.Write(p.processText(part.Text, signature))
|
||||
}
|
||||
}
|
||||
|
||||
// 3. InlineData (Image) 处理
|
||||
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||
markdownImg := fmt.Sprintf("",
|
||||
part.InlineData.MimeType, part.InlineData.Data)
|
||||
_, _ = result.Write(p.processText(markdownImg, ""))
|
||||
}
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// processThinking 处理 thinking
|
||||
func (p *StreamingProcessor) processThinking(text, signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
// 处理之前的 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
// 开始或继续 thinking 块
|
||||
if p.blockType != BlockTypeThinking {
|
||||
_, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
}))
|
||||
}
|
||||
|
||||
if text != "" {
|
||||
_, _ = result.Write(p.emitDelta("thinking_delta", map[string]any{
|
||||
"thinking": text,
|
||||
}))
|
||||
}
|
||||
|
||||
// 暂存签名
|
||||
if signature != "" {
|
||||
p.pendingSignature = signature
|
||||
}
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// processText 处理普通 text
|
||||
func (p *StreamingProcessor) processText(text, signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
// 空 text 带签名 - 暂存
|
||||
if text == "" {
|
||||
if signature != "" {
|
||||
p.trailingSignature = signature
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理之前的 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
// 非空 text 带签名 - 特殊处理
|
||||
if signature != "" {
|
||||
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
}))
|
||||
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
|
||||
"text": text,
|
||||
}))
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(signature))
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// 普通 text (无签名)
|
||||
if p.blockType != BlockTypeText {
|
||||
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
}))
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
|
||||
"text": text,
|
||||
}))
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// processFunctionCall 处理 function call
|
||||
func (p *StreamingProcessor) processFunctionCall(fc *GeminiFunctionCall, signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
p.usedTool = true
|
||||
|
||||
toolID := fc.ID
|
||||
if toolID == "" {
|
||||
toolID = fmt.Sprintf("%s-%s", fc.Name, generateRandomID())
|
||||
}
|
||||
|
||||
toolUse := map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": toolID,
|
||||
"name": fc.Name,
|
||||
"input": map[string]any{},
|
||||
}
|
||||
|
||||
if signature != "" {
|
||||
toolUse["signature"] = signature
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.startBlock(BlockTypeFunction, toolUse))
|
||||
|
||||
// 发送 input_json_delta
|
||||
if fc.Args != nil {
|
||||
argsJSON, _ := json.Marshal(fc.Args)
|
||||
_, _ = result.Write(p.emitDelta("input_json_delta", map[string]any{
|
||||
"partial_json": string(argsJSON),
|
||||
}))
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.endBlock())
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// startBlock 开始新的内容块
|
||||
func (p *StreamingProcessor) startBlock(blockType BlockType, contentBlock map[string]any) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
if p.blockType != BlockTypeNone {
|
||||
_, _ = result.Write(p.endBlock())
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": p.blockIndex,
|
||||
"content_block": contentBlock,
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.formatSSE("content_block_start", event))
|
||||
p.blockType = blockType
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// endBlock 结束当前内容块
|
||||
func (p *StreamingProcessor) endBlock() []byte {
|
||||
if p.blockType == BlockTypeNone {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result bytes.Buffer
|
||||
|
||||
// Thinking 块结束时发送暂存的签名
|
||||
if p.blockType == BlockTypeThinking && p.pendingSignature != "" {
|
||||
_, _ = result.Write(p.emitDelta("signature_delta", map[string]any{
|
||||
"signature": p.pendingSignature,
|
||||
}))
|
||||
p.pendingSignature = ""
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": p.blockIndex,
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.formatSSE("content_block_stop", event))
|
||||
|
||||
p.blockIndex++
|
||||
p.blockType = BlockTypeNone
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// emitDelta 发送 delta 事件
|
||||
func (p *StreamingProcessor) emitDelta(deltaType string, deltaContent map[string]any) []byte {
|
||||
delta := map[string]any{
|
||||
"type": deltaType,
|
||||
}
|
||||
for k, v := range deltaContent {
|
||||
delta[k] = v
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": p.blockIndex,
|
||||
"delta": delta,
|
||||
}
|
||||
|
||||
return p.formatSSE("content_block_delta", event)
|
||||
}
|
||||
|
||||
// emitEmptyThinkingWithSignature 发送空 thinking 块承载签名
|
||||
func (p *StreamingProcessor) emitEmptyThinkingWithSignature(signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
_, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
}))
|
||||
_, _ = result.Write(p.emitDelta("thinking_delta", map[string]any{
|
||||
"thinking": "",
|
||||
}))
|
||||
_, _ = result.Write(p.emitDelta("signature_delta", map[string]any{
|
||||
"signature": signature,
|
||||
}))
|
||||
_, _ = result.Write(p.endBlock())
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// emitFinish 发送结束事件
|
||||
func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
// 关闭最后一个块
|
||||
_, _ = result.Write(p.endBlock())
|
||||
|
||||
// 处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
// 确定 stop_reason
|
||||
stopReason := "end_turn"
|
||||
if p.usedTool {
|
||||
stopReason = "tool_use"
|
||||
} else if finishReason == "MAX_TOKENS" {
|
||||
stopReason = "max_tokens"
|
||||
}
|
||||
|
||||
usage := ClaudeUsage{
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
deltaEvent := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
||||
|
||||
if !p.messageStopSent {
|
||||
stopEvent := map[string]any{
|
||||
"type": "message_stop",
|
||||
}
|
||||
_, _ = result.Write(p.formatSSE("message_stop", stopEvent))
|
||||
p.messageStopSent = true
|
||||
}
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// formatSSE 格式化 SSE 事件
|
||||
func (p *StreamingProcessor) formatSSE(eventType string, data any) []byte {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, string(jsonData)))
|
||||
}
|
||||
|
||||
@@ -1,80 +1,80 @@
|
||||
package claude
|
||||
|
||||
// Claude Code 客户端相关常量
|
||||
|
||||
// Beta header 常量
|
||||
const (
|
||||
BetaOAuth = "oauth-2025-04-20"
|
||||
BetaClaudeCode = "claude-code-20250219"
|
||||
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||
)
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||
|
||||
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
|
||||
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
||||
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||
|
||||
// Claude Code 客户端默认请求头
|
||||
var DefaultHeaders = map[string]string{
|
||||
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
||||
"X-Stainless-Lang": "js",
|
||||
"X-Stainless-Package-Version": "0.52.0",
|
||||
"X-Stainless-OS": "Linux",
|
||||
"X-Stainless-Arch": "x64",
|
||||
"X-Stainless-Runtime": "node",
|
||||
"X-Stainless-Runtime-Version": "v22.14.0",
|
||||
"X-Stainless-Retry-Count": "0",
|
||||
"X-Stainless-Timeout": "60",
|
||||
"X-App": "cli",
|
||||
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
||||
}
|
||||
|
||||
// Model 表示一个 Claude 模型
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// DefaultModels Claude Code 客户端支持的默认模型列表
|
||||
var DefaultModels = []Model{
|
||||
{
|
||||
ID: "claude-opus-4-5-20251101",
|
||||
Type: "model",
|
||||
DisplayName: "Claude Opus 4.5",
|
||||
CreatedAt: "2025-11-01T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-5-20250929",
|
||||
Type: "model",
|
||||
DisplayName: "Claude Sonnet 4.5",
|
||||
CreatedAt: "2025-09-29T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-haiku-4-5-20251001",
|
||||
Type: "model",
|
||||
DisplayName: "Claude Haiku 4.5",
|
||||
CreatedAt: "2025-10-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
// DefaultModelIDs 返回默认模型的 ID 列表
|
||||
func DefaultModelIDs() []string {
|
||||
ids := make([]string, len(DefaultModels))
|
||||
for i, m := range DefaultModels {
|
||||
ids[i] = m.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// DefaultTestModel 测试时使用的默认模型
|
||||
const DefaultTestModel = "claude-sonnet-4-5-20250929"
|
||||
package claude
|
||||
|
||||
// Claude Code 客户端相关常量
|
||||
|
||||
// Beta header 常量
|
||||
const (
|
||||
BetaOAuth = "oauth-2025-04-20"
|
||||
BetaClaudeCode = "claude-code-20250219"
|
||||
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||
)
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||
|
||||
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
|
||||
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
||||
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||
|
||||
// Claude Code 客户端默认请求头
|
||||
var DefaultHeaders = map[string]string{
|
||||
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
||||
"X-Stainless-Lang": "js",
|
||||
"X-Stainless-Package-Version": "0.52.0",
|
||||
"X-Stainless-OS": "Linux",
|
||||
"X-Stainless-Arch": "x64",
|
||||
"X-Stainless-Runtime": "node",
|
||||
"X-Stainless-Runtime-Version": "v22.14.0",
|
||||
"X-Stainless-Retry-Count": "0",
|
||||
"X-Stainless-Timeout": "60",
|
||||
"X-App": "cli",
|
||||
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
||||
}
|
||||
|
||||
// Model 表示一个 Claude 模型
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// DefaultModels Claude Code 客户端支持的默认模型列表
|
||||
var DefaultModels = []Model{
|
||||
{
|
||||
ID: "claude-opus-4-5-20251101",
|
||||
Type: "model",
|
||||
DisplayName: "Claude Opus 4.5",
|
||||
CreatedAt: "2025-11-01T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-5-20250929",
|
||||
Type: "model",
|
||||
DisplayName: "Claude Sonnet 4.5",
|
||||
CreatedAt: "2025-09-29T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-haiku-4-5-20251001",
|
||||
Type: "model",
|
||||
DisplayName: "Claude Haiku 4.5",
|
||||
CreatedAt: "2025-10-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
// DefaultModelIDs 返回默认模型的 ID 列表
|
||||
func DefaultModelIDs() []string {
|
||||
ids := make([]string, len(DefaultModels))
|
||||
for i, m := range DefaultModels {
|
||||
ids[i] = m.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// DefaultTestModel 测试时使用的默认模型
|
||||
const DefaultTestModel = "claude-sonnet-4-5-20250929"
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// Package ctxkey 定义用于 context.Value 的类型安全 key
|
||||
package ctxkey
|
||||
|
||||
// Key 定义 context key 的类型,避免使用内置 string 类型(staticcheck SA1029)
|
||||
type Key string
|
||||
|
||||
const (
|
||||
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
|
||||
ForcePlatform Key = "ctx_force_platform"
|
||||
)
|
||||
// Package ctxkey 定义用于 context.Value 的类型安全 key
|
||||
package ctxkey
|
||||
|
||||
// Key 定义 context key 的类型,避免使用内置 string 类型(staticcheck SA1029)
|
||||
type Key string
|
||||
|
||||
const (
|
||||
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
|
||||
ForcePlatform Key = "ctx_force_platform"
|
||||
)
|
||||
|
||||
@@ -1,158 +1,158 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
UnknownCode = http.StatusInternalServerError
|
||||
UnknownReason = ""
|
||||
UnknownMessage = "internal error"
|
||||
)
|
||||
|
||||
type Status struct {
|
||||
Code int32 `json:"code"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ApplicationError is the standard error type used to control HTTP responses.
|
||||
//
|
||||
// Code is expected to be an HTTP status code (e.g. 400/401/403/404/409/500).
|
||||
type ApplicationError struct {
|
||||
Status
|
||||
cause error
|
||||
}
|
||||
|
||||
// Error is kept for backwards compatibility within this package.
|
||||
type Error = ApplicationError
|
||||
|
||||
func (e *ApplicationError) Error() string {
|
||||
if e == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
if e.cause == nil {
|
||||
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v", e.Code, e.Reason, e.Message, e.Metadata)
|
||||
}
|
||||
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v cause=%v", e.Code, e.Reason, e.Message, e.Metadata, e.cause)
|
||||
}
|
||||
|
||||
// Unwrap provides compatibility for Go 1.13 error chains.
|
||||
func (e *ApplicationError) Unwrap() error { return e.cause }
|
||||
|
||||
// Is matches each error in the chain with the target value.
|
||||
func (e *ApplicationError) Is(err error) bool {
|
||||
if se := new(ApplicationError); errors.As(err, &se) {
|
||||
return se.Code == e.Code && se.Reason == e.Reason
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WithCause attaches the underlying cause of the error.
|
||||
func (e *ApplicationError) WithCause(cause error) *ApplicationError {
|
||||
err := Clone(e)
|
||||
err.cause = cause
|
||||
return err
|
||||
}
|
||||
|
||||
// WithMetadata deep-copies the given metadata map.
|
||||
func (e *ApplicationError) WithMetadata(md map[string]string) *ApplicationError {
|
||||
err := Clone(e)
|
||||
if md == nil {
|
||||
err.Metadata = nil
|
||||
return err
|
||||
}
|
||||
err.Metadata = make(map[string]string, len(md))
|
||||
for k, v := range md {
|
||||
err.Metadata[k] = v
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// New returns an error object for the code, message.
|
||||
func New(code int, reason, message string) *ApplicationError {
|
||||
return &ApplicationError{
|
||||
Status: Status{
|
||||
Code: int32(code),
|
||||
Message: message,
|
||||
Reason: reason,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Newf New(code fmt.Sprintf(format, a...))
|
||||
func Newf(code int, reason, format string, a ...any) *ApplicationError {
|
||||
return New(code, reason, fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
// Errorf returns an error object for the code, message and error info.
|
||||
func Errorf(code int, reason, format string, a ...any) error {
|
||||
return New(code, reason, fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
// Code returns the http code for an error.
|
||||
// It supports wrapped errors.
|
||||
func Code(err error) int {
|
||||
if err == nil {
|
||||
return http.StatusOK
|
||||
}
|
||||
return int(FromError(err).Code)
|
||||
}
|
||||
|
||||
// Reason returns the reason for a particular error.
|
||||
// It supports wrapped errors.
|
||||
func Reason(err error) string {
|
||||
if err == nil {
|
||||
return UnknownReason
|
||||
}
|
||||
return FromError(err).Reason
|
||||
}
|
||||
|
||||
// Message returns the message for a particular error.
|
||||
// It supports wrapped errors.
|
||||
func Message(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return FromError(err).Message
|
||||
}
|
||||
|
||||
// Clone deep clone error to a new error.
|
||||
func Clone(err *ApplicationError) *ApplicationError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var metadata map[string]string
|
||||
if err.Metadata != nil {
|
||||
metadata = make(map[string]string, len(err.Metadata))
|
||||
for k, v := range err.Metadata {
|
||||
metadata[k] = v
|
||||
}
|
||||
}
|
||||
return &ApplicationError{
|
||||
cause: err.cause,
|
||||
Status: Status{
|
||||
Code: err.Code,
|
||||
Reason: err.Reason,
|
||||
Message: err.Message,
|
||||
Metadata: metadata,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// FromError tries to convert an error to *ApplicationError.
|
||||
// It supports wrapped errors.
|
||||
func FromError(err error) *ApplicationError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if se := new(ApplicationError); errors.As(err, &se) {
|
||||
return se
|
||||
}
|
||||
|
||||
// Fall back to a generic internal error.
|
||||
return New(UnknownCode, UnknownReason, UnknownMessage).WithCause(err)
|
||||
}
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
UnknownCode = http.StatusInternalServerError
|
||||
UnknownReason = ""
|
||||
UnknownMessage = "internal error"
|
||||
)
|
||||
|
||||
type Status struct {
|
||||
Code int32 `json:"code"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ApplicationError is the standard error type used to control HTTP responses.
|
||||
//
|
||||
// Code is expected to be an HTTP status code (e.g. 400/401/403/404/409/500).
|
||||
type ApplicationError struct {
|
||||
Status
|
||||
cause error
|
||||
}
|
||||
|
||||
// Error is kept for backwards compatibility within this package.
|
||||
type Error = ApplicationError
|
||||
|
||||
func (e *ApplicationError) Error() string {
|
||||
if e == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
if e.cause == nil {
|
||||
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v", e.Code, e.Reason, e.Message, e.Metadata)
|
||||
}
|
||||
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v cause=%v", e.Code, e.Reason, e.Message, e.Metadata, e.cause)
|
||||
}
|
||||
|
||||
// Unwrap provides compatibility for Go 1.13 error chains.
|
||||
func (e *ApplicationError) Unwrap() error { return e.cause }
|
||||
|
||||
// Is matches each error in the chain with the target value.
|
||||
func (e *ApplicationError) Is(err error) bool {
|
||||
if se := new(ApplicationError); errors.As(err, &se) {
|
||||
return se.Code == e.Code && se.Reason == e.Reason
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WithCause attaches the underlying cause of the error.
|
||||
func (e *ApplicationError) WithCause(cause error) *ApplicationError {
|
||||
err := Clone(e)
|
||||
err.cause = cause
|
||||
return err
|
||||
}
|
||||
|
||||
// WithMetadata deep-copies the given metadata map.
|
||||
func (e *ApplicationError) WithMetadata(md map[string]string) *ApplicationError {
|
||||
err := Clone(e)
|
||||
if md == nil {
|
||||
err.Metadata = nil
|
||||
return err
|
||||
}
|
||||
err.Metadata = make(map[string]string, len(md))
|
||||
for k, v := range md {
|
||||
err.Metadata[k] = v
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// New returns an error object for the code, message.
|
||||
func New(code int, reason, message string) *ApplicationError {
|
||||
return &ApplicationError{
|
||||
Status: Status{
|
||||
Code: int32(code),
|
||||
Message: message,
|
||||
Reason: reason,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Newf New(code fmt.Sprintf(format, a...))
|
||||
func Newf(code int, reason, format string, a ...any) *ApplicationError {
|
||||
return New(code, reason, fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
// Errorf returns an error object for the code, message and error info.
|
||||
func Errorf(code int, reason, format string, a ...any) error {
|
||||
return New(code, reason, fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
// Code returns the http code for an error.
|
||||
// It supports wrapped errors.
|
||||
func Code(err error) int {
|
||||
if err == nil {
|
||||
return http.StatusOK
|
||||
}
|
||||
return int(FromError(err).Code)
|
||||
}
|
||||
|
||||
// Reason returns the reason for a particular error.
|
||||
// It supports wrapped errors.
|
||||
func Reason(err error) string {
|
||||
if err == nil {
|
||||
return UnknownReason
|
||||
}
|
||||
return FromError(err).Reason
|
||||
}
|
||||
|
||||
// Message returns the message for a particular error.
|
||||
// It supports wrapped errors.
|
||||
func Message(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return FromError(err).Message
|
||||
}
|
||||
|
||||
// Clone deep clone error to a new error.
|
||||
func Clone(err *ApplicationError) *ApplicationError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var metadata map[string]string
|
||||
if err.Metadata != nil {
|
||||
metadata = make(map[string]string, len(err.Metadata))
|
||||
for k, v := range err.Metadata {
|
||||
metadata[k] = v
|
||||
}
|
||||
}
|
||||
return &ApplicationError{
|
||||
cause: err.cause,
|
||||
Status: Status{
|
||||
Code: err.Code,
|
||||
Reason: err.Reason,
|
||||
Message: err.Message,
|
||||
Metadata: metadata,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// FromError tries to convert an error to *ApplicationError.
|
||||
// It supports wrapped errors.
|
||||
func FromError(err error) *ApplicationError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if se := new(ApplicationError); errors.As(err, &se) {
|
||||
return se
|
||||
}
|
||||
|
||||
// Fall back to a generic internal error.
|
||||
return New(UnknownCode, UnknownReason, UnknownMessage).WithCause(err)
|
||||
}
|
||||
|
||||
@@ -1,168 +1,168 @@
|
||||
//go:build unit
|
||||
|
||||
package errors
|
||||
|
||||
import (
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplicationError_Basics(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *ApplicationError
|
||||
want Status
|
||||
wantIs bool
|
||||
target error
|
||||
wrapped error
|
||||
}{
|
||||
{
|
||||
name: "new",
|
||||
err: New(400, "BAD_REQUEST", "invalid input"),
|
||||
want: Status{
|
||||
Code: 400,
|
||||
Reason: "BAD_REQUEST",
|
||||
Message: "invalid input",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "is_matches_code_and_reason",
|
||||
err: New(401, "UNAUTHORIZED", "nope"),
|
||||
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
|
||||
target: New(401, "UNAUTHORIZED", "ignored message"),
|
||||
wantIs: true,
|
||||
},
|
||||
{
|
||||
name: "is_does_not_match_reason",
|
||||
err: New(401, "UNAUTHORIZED", "nope"),
|
||||
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
|
||||
target: New(401, "DIFFERENT", "ignored message"),
|
||||
wantIs: false,
|
||||
},
|
||||
{
|
||||
name: "from_error_unwraps_wrapped_application_error",
|
||||
err: New(404, "NOT_FOUND", "missing"),
|
||||
wrapped: fmt.Errorf("wrap: %w", New(404, "NOT_FOUND", "missing")),
|
||||
want: Status{
|
||||
Code: 404,
|
||||
Reason: "NOT_FOUND",
|
||||
Message: "missing",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.err != nil {
|
||||
require.Equal(t, tt.want, tt.err.Status)
|
||||
}
|
||||
|
||||
if tt.target != nil {
|
||||
require.Equal(t, tt.wantIs, stderrors.Is(tt.err, tt.target))
|
||||
}
|
||||
|
||||
if tt.wrapped != nil {
|
||||
got := FromError(tt.wrapped)
|
||||
require.Equal(t, tt.want, got.Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplicationError_WithMetadataDeepCopy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
md map[string]string
|
||||
}{
|
||||
{name: "non_nil", md: map[string]string{"a": "1"}},
|
||||
{name: "nil", md: nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
appErr := BadRequest("BAD_REQUEST", "invalid input").WithMetadata(tt.md)
|
||||
|
||||
if tt.md == nil {
|
||||
require.Nil(t, appErr.Metadata)
|
||||
return
|
||||
}
|
||||
|
||||
tt.md["a"] = "changed"
|
||||
require.Equal(t, "1", appErr.Metadata["a"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromError_Generic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantCode int32
|
||||
wantReason string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "plain_error",
|
||||
err: stderrors.New("boom"),
|
||||
wantCode: UnknownCode,
|
||||
wantReason: UnknownReason,
|
||||
wantMsg: UnknownMessage,
|
||||
},
|
||||
{
|
||||
name: "wrapped_plain_error",
|
||||
err: fmt.Errorf("wrap: %w", io.EOF),
|
||||
wantCode: UnknownCode,
|
||||
wantReason: UnknownReason,
|
||||
wantMsg: UnknownMessage,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := FromError(tt.err)
|
||||
require.Equal(t, tt.wantCode, got.Code)
|
||||
require.Equal(t, tt.wantReason, got.Reason)
|
||||
require.Equal(t, tt.wantMsg, got.Message)
|
||||
require.Equal(t, tt.err, got.Unwrap())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantStatusCode int
|
||||
wantBody Status
|
||||
}{
|
||||
{
|
||||
name: "nil_error",
|
||||
err: nil,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: Status{Code: int32(http.StatusOK)},
|
||||
},
|
||||
{
|
||||
name: "application_error",
|
||||
err: Forbidden("FORBIDDEN", "no access"),
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
wantBody: Status{
|
||||
Code: int32(http.StatusForbidden),
|
||||
Reason: "FORBIDDEN",
|
||||
Message: "no access",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
code, body := ToHTTP(tt.err)
|
||||
require.Equal(t, tt.wantStatusCode, code)
|
||||
require.Equal(t, tt.wantBody, body)
|
||||
})
|
||||
}
|
||||
}
|
||||
//go:build unit
|
||||
|
||||
package errors
|
||||
|
||||
import (
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplicationError_Basics(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *ApplicationError
|
||||
want Status
|
||||
wantIs bool
|
||||
target error
|
||||
wrapped error
|
||||
}{
|
||||
{
|
||||
name: "new",
|
||||
err: New(400, "BAD_REQUEST", "invalid input"),
|
||||
want: Status{
|
||||
Code: 400,
|
||||
Reason: "BAD_REQUEST",
|
||||
Message: "invalid input",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "is_matches_code_and_reason",
|
||||
err: New(401, "UNAUTHORIZED", "nope"),
|
||||
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
|
||||
target: New(401, "UNAUTHORIZED", "ignored message"),
|
||||
wantIs: true,
|
||||
},
|
||||
{
|
||||
name: "is_does_not_match_reason",
|
||||
err: New(401, "UNAUTHORIZED", "nope"),
|
||||
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
|
||||
target: New(401, "DIFFERENT", "ignored message"),
|
||||
wantIs: false,
|
||||
},
|
||||
{
|
||||
name: "from_error_unwraps_wrapped_application_error",
|
||||
err: New(404, "NOT_FOUND", "missing"),
|
||||
wrapped: fmt.Errorf("wrap: %w", New(404, "NOT_FOUND", "missing")),
|
||||
want: Status{
|
||||
Code: 404,
|
||||
Reason: "NOT_FOUND",
|
||||
Message: "missing",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.err != nil {
|
||||
require.Equal(t, tt.want, tt.err.Status)
|
||||
}
|
||||
|
||||
if tt.target != nil {
|
||||
require.Equal(t, tt.wantIs, stderrors.Is(tt.err, tt.target))
|
||||
}
|
||||
|
||||
if tt.wrapped != nil {
|
||||
got := FromError(tt.wrapped)
|
||||
require.Equal(t, tt.want, got.Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplicationError_WithMetadataDeepCopy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
md map[string]string
|
||||
}{
|
||||
{name: "non_nil", md: map[string]string{"a": "1"}},
|
||||
{name: "nil", md: nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
appErr := BadRequest("BAD_REQUEST", "invalid input").WithMetadata(tt.md)
|
||||
|
||||
if tt.md == nil {
|
||||
require.Nil(t, appErr.Metadata)
|
||||
return
|
||||
}
|
||||
|
||||
tt.md["a"] = "changed"
|
||||
require.Equal(t, "1", appErr.Metadata["a"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromError_Generic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantCode int32
|
||||
wantReason string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "plain_error",
|
||||
err: stderrors.New("boom"),
|
||||
wantCode: UnknownCode,
|
||||
wantReason: UnknownReason,
|
||||
wantMsg: UnknownMessage,
|
||||
},
|
||||
{
|
||||
name: "wrapped_plain_error",
|
||||
err: fmt.Errorf("wrap: %w", io.EOF),
|
||||
wantCode: UnknownCode,
|
||||
wantReason: UnknownReason,
|
||||
wantMsg: UnknownMessage,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := FromError(tt.err)
|
||||
require.Equal(t, tt.wantCode, got.Code)
|
||||
require.Equal(t, tt.wantReason, got.Reason)
|
||||
require.Equal(t, tt.wantMsg, got.Message)
|
||||
require.Equal(t, tt.err, got.Unwrap())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantStatusCode int
|
||||
wantBody Status
|
||||
}{
|
||||
{
|
||||
name: "nil_error",
|
||||
err: nil,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: Status{Code: int32(http.StatusOK)},
|
||||
},
|
||||
{
|
||||
name: "application_error",
|
||||
err: Forbidden("FORBIDDEN", "no access"),
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
wantBody: Status{
|
||||
Code: int32(http.StatusForbidden),
|
||||
Reason: "FORBIDDEN",
|
||||
Message: "no access",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
code, body := ToHTTP(tt.err)
|
||||
require.Equal(t, tt.wantStatusCode, code)
|
||||
require.Equal(t, tt.wantBody, body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
package errors
|
||||
|
||||
import "net/http"
|
||||
|
||||
// ToHTTP converts an error into an HTTP status code and a JSON-serializable body.
|
||||
//
|
||||
// The returned body matches the project's Status shape:
|
||||
// { code, reason, message, metadata }.
|
||||
func ToHTTP(err error) (statusCode int, body Status) {
|
||||
if err == nil {
|
||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||
}
|
||||
|
||||
appErr := FromError(err)
|
||||
if appErr == nil {
|
||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||
}
|
||||
|
||||
cloned := Clone(appErr)
|
||||
return int(cloned.Code), cloned.Status
|
||||
}
|
||||
package errors
|
||||
|
||||
import "net/http"
|
||||
|
||||
// ToHTTP converts an error into an HTTP status code and a JSON-serializable body.
|
||||
//
|
||||
// The returned body matches the project's Status shape:
|
||||
// { code, reason, message, metadata }.
|
||||
func ToHTTP(err error) (statusCode int, body Status) {
|
||||
if err == nil {
|
||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||
}
|
||||
|
||||
appErr := FromError(err)
|
||||
if appErr == nil {
|
||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||
}
|
||||
|
||||
cloned := Clone(appErr)
|
||||
return int(cloned.Code), cloned.Status
|
||||
}
|
||||
|
||||
@@ -1,114 +1,114 @@
|
||||
// nolint:mnd
|
||||
package errors
|
||||
|
||||
import "net/http"
|
||||
|
||||
// BadRequest new BadRequest error that is mapped to a 400 response.
|
||||
func BadRequest(reason, message string) *ApplicationError {
|
||||
return New(http.StatusBadRequest, reason, message)
|
||||
}
|
||||
|
||||
// IsBadRequest determines if err is an error which indicates a BadRequest error.
|
||||
// It supports wrapped errors.
|
||||
func IsBadRequest(err error) bool {
|
||||
return Code(err) == http.StatusBadRequest
|
||||
}
|
||||
|
||||
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
|
||||
func TooManyRequests(reason, message string) *ApplicationError {
|
||||
return New(http.StatusTooManyRequests, reason, message)
|
||||
}
|
||||
|
||||
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
|
||||
// It supports wrapped errors.
|
||||
func IsTooManyRequests(err error) bool {
|
||||
return Code(err) == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
// Unauthorized new Unauthorized error that is mapped to a 401 response.
|
||||
func Unauthorized(reason, message string) *ApplicationError {
|
||||
return New(http.StatusUnauthorized, reason, message)
|
||||
}
|
||||
|
||||
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
|
||||
// It supports wrapped errors.
|
||||
func IsUnauthorized(err error) bool {
|
||||
return Code(err) == http.StatusUnauthorized
|
||||
}
|
||||
|
||||
// Forbidden new Forbidden error that is mapped to a 403 response.
|
||||
func Forbidden(reason, message string) *ApplicationError {
|
||||
return New(http.StatusForbidden, reason, message)
|
||||
}
|
||||
|
||||
// IsForbidden determines if err is an error which indicates a Forbidden error.
|
||||
// It supports wrapped errors.
|
||||
func IsForbidden(err error) bool {
|
||||
return Code(err) == http.StatusForbidden
|
||||
}
|
||||
|
||||
// NotFound new NotFound error that is mapped to a 404 response.
|
||||
func NotFound(reason, message string) *ApplicationError {
|
||||
return New(http.StatusNotFound, reason, message)
|
||||
}
|
||||
|
||||
// IsNotFound determines if err is an error which indicates an NotFound error.
|
||||
// It supports wrapped errors.
|
||||
func IsNotFound(err error) bool {
|
||||
return Code(err) == http.StatusNotFound
|
||||
}
|
||||
|
||||
// Conflict new Conflict error that is mapped to a 409 response.
|
||||
func Conflict(reason, message string) *ApplicationError {
|
||||
return New(http.StatusConflict, reason, message)
|
||||
}
|
||||
|
||||
// IsConflict determines if err is an error which indicates a Conflict error.
|
||||
// It supports wrapped errors.
|
||||
func IsConflict(err error) bool {
|
||||
return Code(err) == http.StatusConflict
|
||||
}
|
||||
|
||||
// InternalServer new InternalServer error that is mapped to a 500 response.
|
||||
func InternalServer(reason, message string) *ApplicationError {
|
||||
return New(http.StatusInternalServerError, reason, message)
|
||||
}
|
||||
|
||||
// IsInternalServer determines if err is an error which indicates an Internal error.
|
||||
// It supports wrapped errors.
|
||||
func IsInternalServer(err error) bool {
|
||||
return Code(err) == http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
|
||||
func ServiceUnavailable(reason, message string) *ApplicationError {
|
||||
return New(http.StatusServiceUnavailable, reason, message)
|
||||
}
|
||||
|
||||
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
|
||||
// It supports wrapped errors.
|
||||
func IsServiceUnavailable(err error) bool {
|
||||
return Code(err) == http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
|
||||
func GatewayTimeout(reason, message string) *ApplicationError {
|
||||
return New(http.StatusGatewayTimeout, reason, message)
|
||||
}
|
||||
|
||||
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
|
||||
// It supports wrapped errors.
|
||||
func IsGatewayTimeout(err error) bool {
|
||||
return Code(err) == http.StatusGatewayTimeout
|
||||
}
|
||||
|
||||
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
|
||||
func ClientClosed(reason, message string) *ApplicationError {
|
||||
return New(499, reason, message)
|
||||
}
|
||||
|
||||
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
|
||||
// It supports wrapped errors.
|
||||
func IsClientClosed(err error) bool {
|
||||
return Code(err) == 499
|
||||
}
|
||||
// nolint:mnd
|
||||
package errors
|
||||
|
||||
import "net/http"
|
||||
|
||||
// BadRequest new BadRequest error that is mapped to a 400 response.
|
||||
func BadRequest(reason, message string) *ApplicationError {
|
||||
return New(http.StatusBadRequest, reason, message)
|
||||
}
|
||||
|
||||
// IsBadRequest determines if err is an error which indicates a BadRequest error.
|
||||
// It supports wrapped errors.
|
||||
func IsBadRequest(err error) bool {
|
||||
return Code(err) == http.StatusBadRequest
|
||||
}
|
||||
|
||||
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
|
||||
func TooManyRequests(reason, message string) *ApplicationError {
|
||||
return New(http.StatusTooManyRequests, reason, message)
|
||||
}
|
||||
|
||||
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
|
||||
// It supports wrapped errors.
|
||||
func IsTooManyRequests(err error) bool {
|
||||
return Code(err) == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
// Unauthorized new Unauthorized error that is mapped to a 401 response.
|
||||
func Unauthorized(reason, message string) *ApplicationError {
|
||||
return New(http.StatusUnauthorized, reason, message)
|
||||
}
|
||||
|
||||
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
|
||||
// It supports wrapped errors.
|
||||
func IsUnauthorized(err error) bool {
|
||||
return Code(err) == http.StatusUnauthorized
|
||||
}
|
||||
|
||||
// Forbidden new Forbidden error that is mapped to a 403 response.
|
||||
func Forbidden(reason, message string) *ApplicationError {
|
||||
return New(http.StatusForbidden, reason, message)
|
||||
}
|
||||
|
||||
// IsForbidden determines if err is an error which indicates a Forbidden error.
|
||||
// It supports wrapped errors.
|
||||
func IsForbidden(err error) bool {
|
||||
return Code(err) == http.StatusForbidden
|
||||
}
|
||||
|
||||
// NotFound new NotFound error that is mapped to a 404 response.
|
||||
func NotFound(reason, message string) *ApplicationError {
|
||||
return New(http.StatusNotFound, reason, message)
|
||||
}
|
||||
|
||||
// IsNotFound determines if err is an error which indicates an NotFound error.
|
||||
// It supports wrapped errors.
|
||||
func IsNotFound(err error) bool {
|
||||
return Code(err) == http.StatusNotFound
|
||||
}
|
||||
|
||||
// Conflict new Conflict error that is mapped to a 409 response.
|
||||
func Conflict(reason, message string) *ApplicationError {
|
||||
return New(http.StatusConflict, reason, message)
|
||||
}
|
||||
|
||||
// IsConflict determines if err is an error which indicates a Conflict error.
|
||||
// It supports wrapped errors.
|
||||
func IsConflict(err error) bool {
|
||||
return Code(err) == http.StatusConflict
|
||||
}
|
||||
|
||||
// InternalServer new InternalServer error that is mapped to a 500 response.
|
||||
func InternalServer(reason, message string) *ApplicationError {
|
||||
return New(http.StatusInternalServerError, reason, message)
|
||||
}
|
||||
|
||||
// IsInternalServer determines if err is an error which indicates an Internal error.
|
||||
// It supports wrapped errors.
|
||||
func IsInternalServer(err error) bool {
|
||||
return Code(err) == http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
|
||||
func ServiceUnavailable(reason, message string) *ApplicationError {
|
||||
return New(http.StatusServiceUnavailable, reason, message)
|
||||
}
|
||||
|
||||
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
|
||||
// It supports wrapped errors.
|
||||
func IsServiceUnavailable(err error) bool {
|
||||
return Code(err) == http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
|
||||
func GatewayTimeout(reason, message string) *ApplicationError {
|
||||
return New(http.StatusGatewayTimeout, reason, message)
|
||||
}
|
||||
|
||||
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
|
||||
// It supports wrapped errors.
|
||||
func IsGatewayTimeout(err error) bool {
|
||||
return Code(err) == http.StatusGatewayTimeout
|
||||
}
|
||||
|
||||
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
|
||||
func ClientClosed(reason, message string) *ApplicationError {
|
||||
return New(499, reason, message)
|
||||
}
|
||||
|
||||
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
|
||||
// It supports wrapped errors.
|
||||
func IsClientClosed(err error) bool {
|
||||
return Code(err) == 499
|
||||
}
|
||||
|
||||
@@ -1,44 +1,44 @@
|
||||
package gemini
|
||||
|
||||
// This package provides minimal fallback model metadata for Gemini native endpoints.
|
||||
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
|
||||
|
||||
type Model struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
|
||||
}
|
||||
|
||||
type ModelsListResponse struct {
|
||||
Models []Model `json:"models"`
|
||||
}
|
||||
|
||||
func DefaultModels() []Model {
|
||||
methods := []string{"generateContent", "streamGenerateContent"}
|
||||
return []Model{
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
|
||||
}
|
||||
}
|
||||
|
||||
func FallbackModelsList() ModelsListResponse {
|
||||
return ModelsListResponse{Models: DefaultModels()}
|
||||
}
|
||||
|
||||
func FallbackModel(model string) Model {
|
||||
methods := []string{"generateContent", "streamGenerateContent"}
|
||||
if model == "" {
|
||||
return Model{Name: "models/unknown", SupportedGenerationMethods: methods}
|
||||
}
|
||||
if len(model) >= 7 && model[:7] == "models/" {
|
||||
return Model{Name: model, SupportedGenerationMethods: methods}
|
||||
}
|
||||
return Model{Name: "models/" + model, SupportedGenerationMethods: methods}
|
||||
}
|
||||
package gemini
|
||||
|
||||
// This package provides minimal fallback model metadata for Gemini native endpoints.
|
||||
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
|
||||
|
||||
type Model struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
|
||||
}
|
||||
|
||||
type ModelsListResponse struct {
|
||||
Models []Model `json:"models"`
|
||||
}
|
||||
|
||||
func DefaultModels() []Model {
|
||||
methods := []string{"generateContent", "streamGenerateContent"}
|
||||
return []Model{
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
|
||||
}
|
||||
}
|
||||
|
||||
func FallbackModelsList() ModelsListResponse {
|
||||
return ModelsListResponse{Models: DefaultModels()}
|
||||
}
|
||||
|
||||
func FallbackModel(model string) Model {
|
||||
methods := []string{"generateContent", "streamGenerateContent"}
|
||||
if model == "" {
|
||||
return Model{Name: "models/unknown", SupportedGenerationMethods: methods}
|
||||
}
|
||||
if len(model) >= 7 && model[:7] == "models/" {
|
||||
return Model{Name: model, SupportedGenerationMethods: methods}
|
||||
}
|
||||
return Model{Name: "models/" + model, SupportedGenerationMethods: methods}
|
||||
}
|
||||
|
||||
@@ -1,38 +1,38 @@
|
||||
package geminicli
|
||||
|
||||
// LoadCodeAssistRequest matches done-hub's internal Code Assist call.
|
||||
type LoadCodeAssistRequest struct {
|
||||
Metadata LoadCodeAssistMetadata `json:"metadata"`
|
||||
}
|
||||
|
||||
type LoadCodeAssistMetadata struct {
|
||||
IDEType string `json:"ideType"`
|
||||
Platform string `json:"platform"`
|
||||
PluginType string `json:"pluginType"`
|
||||
}
|
||||
|
||||
type LoadCodeAssistResponse struct {
|
||||
CurrentTier string `json:"currentTier,omitempty"`
|
||||
CloudAICompanionProject string `json:"cloudaicompanionProject,omitempty"`
|
||||
AllowedTiers []AllowedTier `json:"allowedTiers,omitempty"`
|
||||
}
|
||||
|
||||
type AllowedTier struct {
|
||||
ID string `json:"id"`
|
||||
IsDefault bool `json:"isDefault,omitempty"`
|
||||
}
|
||||
|
||||
type OnboardUserRequest struct {
|
||||
TierID string `json:"tierId"`
|
||||
Metadata LoadCodeAssistMetadata `json:"metadata"`
|
||||
}
|
||||
|
||||
type OnboardUserResponse struct {
|
||||
Done bool `json:"done"`
|
||||
Response *OnboardUserResultData `json:"response,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
type OnboardUserResultData struct {
|
||||
CloudAICompanionProject any `json:"cloudaicompanionProject,omitempty"`
|
||||
}
|
||||
package geminicli
|
||||
|
||||
// LoadCodeAssistRequest matches done-hub's internal Code Assist call.
|
||||
type LoadCodeAssistRequest struct {
|
||||
Metadata LoadCodeAssistMetadata `json:"metadata"`
|
||||
}
|
||||
|
||||
type LoadCodeAssistMetadata struct {
|
||||
IDEType string `json:"ideType"`
|
||||
Platform string `json:"platform"`
|
||||
PluginType string `json:"pluginType"`
|
||||
}
|
||||
|
||||
type LoadCodeAssistResponse struct {
|
||||
CurrentTier string `json:"currentTier,omitempty"`
|
||||
CloudAICompanionProject string `json:"cloudaicompanionProject,omitempty"`
|
||||
AllowedTiers []AllowedTier `json:"allowedTiers,omitempty"`
|
||||
}
|
||||
|
||||
type AllowedTier struct {
|
||||
ID string `json:"id"`
|
||||
IsDefault bool `json:"isDefault,omitempty"`
|
||||
}
|
||||
|
||||
type OnboardUserRequest struct {
|
||||
TierID string `json:"tierId"`
|
||||
Metadata LoadCodeAssistMetadata `json:"metadata"`
|
||||
}
|
||||
|
||||
type OnboardUserResponse struct {
|
||||
Done bool `json:"done"`
|
||||
Response *OnboardUserResultData `json:"response,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
type OnboardUserResultData struct {
|
||||
CloudAICompanionProject any `json:"cloudaicompanionProject,omitempty"`
|
||||
}
|
||||
|
||||
@@ -1,42 +1,42 @@
|
||||
package geminicli
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
AIStudioBaseURL = "https://generativelanguage.googleapis.com"
|
||||
GeminiCliBaseURL = "https://cloudcode-pa.googleapis.com"
|
||||
|
||||
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
TokenURL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
// AIStudioOAuthRedirectURI is the default redirect URI used for AI Studio OAuth.
|
||||
// This matches the "copy/paste callback URL" flow used by OpenAI OAuth in this project.
|
||||
// Note: You still need to register this redirect URI in your Google OAuth client
|
||||
// unless you use an OAuth client type that permits localhost redirect URIs.
|
||||
AIStudioOAuthRedirectURI = "http://localhost:1455/auth/callback"
|
||||
|
||||
// DefaultScopes for Code Assist (includes cloud-platform for API access plus userinfo scopes)
|
||||
// Required by Google's Code Assist API.
|
||||
DefaultCodeAssistScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
|
||||
|
||||
// DefaultScopes for AI Studio (uses generativelanguage API with OAuth)
|
||||
// Reference: https://ai.google.dev/gemini-api/docs/oauth
|
||||
// For regular Google accounts, supports API calls to generativelanguage.googleapis.com
|
||||
// Note: Google Auth platform currently documents the OAuth scope as
|
||||
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
|
||||
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
|
||||
|
||||
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
|
||||
GeminiCLIRedirectURI = "https://codeassist.google.com/authcode"
|
||||
|
||||
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
|
||||
// They enable the "login without creating your own OAuth client" experience, but Google may
|
||||
// restrict which scopes are allowed for this client.
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
GeminiCLIOAuthClientSecret = "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"
|
||||
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
// GeminiCLIUserAgent mimics Gemini CLI to maximize compatibility with internal endpoints.
|
||||
GeminiCLIUserAgent = "GeminiCLI/0.1.5 (Windows; AMD64)"
|
||||
)
|
||||
package geminicli
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
AIStudioBaseURL = "https://generativelanguage.googleapis.com"
|
||||
GeminiCliBaseURL = "https://cloudcode-pa.googleapis.com"
|
||||
|
||||
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
TokenURL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
// AIStudioOAuthRedirectURI is the default redirect URI used for AI Studio OAuth.
|
||||
// This matches the "copy/paste callback URL" flow used by OpenAI OAuth in this project.
|
||||
// Note: You still need to register this redirect URI in your Google OAuth client
|
||||
// unless you use an OAuth client type that permits localhost redirect URIs.
|
||||
AIStudioOAuthRedirectURI = "http://localhost:1455/auth/callback"
|
||||
|
||||
// DefaultScopes for Code Assist (includes cloud-platform for API access plus userinfo scopes)
|
||||
// Required by Google's Code Assist API.
|
||||
DefaultCodeAssistScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
|
||||
|
||||
// DefaultScopes for AI Studio (uses generativelanguage API with OAuth)
|
||||
// Reference: https://ai.google.dev/gemini-api/docs/oauth
|
||||
// For regular Google accounts, supports API calls to generativelanguage.googleapis.com
|
||||
// Note: Google Auth platform currently documents the OAuth scope as
|
||||
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
|
||||
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
|
||||
|
||||
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
|
||||
GeminiCLIRedirectURI = "https://codeassist.google.com/authcode"
|
||||
|
||||
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
|
||||
// They enable the "login without creating your own OAuth client" experience, but Google may
|
||||
// restrict which scopes are allowed for this client.
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
GeminiCLIOAuthClientSecret = "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"
|
||||
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
// GeminiCLIUserAgent mimics Gemini CLI to maximize compatibility with internal endpoints.
|
||||
GeminiCLIUserAgent = "GeminiCLI/0.1.5 (Windows; AMD64)"
|
||||
)
|
||||
|
||||
@@ -1,157 +1,157 @@
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
)
|
||||
|
||||
// DriveStorageInfo represents Google Drive storage quota information
|
||||
type DriveStorageInfo struct {
|
||||
Limit int64 `json:"limit"` // Storage limit in bytes
|
||||
Usage int64 `json:"usage"` // Current usage in bytes
|
||||
}
|
||||
|
||||
// DriveClient interface for Google Drive API operations
|
||||
type DriveClient interface {
|
||||
GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error)
|
||||
}
|
||||
|
||||
type driveClient struct{}
|
||||
|
||||
// NewDriveClient creates a new Drive API client
|
||||
func NewDriveClient() DriveClient {
|
||||
return &driveClient{}
|
||||
}
|
||||
|
||||
// GetStorageQuota fetches storage quota from Google Drive API
|
||||
func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error) {
|
||||
const driveAPIURL = "https://www.googleapis.com/drive/v3/about?fields=storageQuota"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", driveAPIURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
// Get HTTP client with proxy support
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 10 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
|
||||
sleepWithContext := func(d time.Duration) error {
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Retry logic with exponential backoff (+ jitter) for rate limits and transient failures
|
||||
var resp *http.Response
|
||||
maxRetries := 3
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if ctx.Err() != nil {
|
||||
return nil, fmt.Errorf("request cancelled: %w", ctx.Err())
|
||||
}
|
||||
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
// Network error retry
|
||||
if attempt < maxRetries-1 {
|
||||
backoff := time.Duration(1<<uint(attempt)) * time.Second
|
||||
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
|
||||
if err := sleepWithContext(backoff + jitter); err != nil {
|
||||
return nil, fmt.Errorf("request cancelled: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("network error after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
// Success
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
break
|
||||
}
|
||||
|
||||
// Retry 429, 500, 502, 503 with exponential backoff + jitter
|
||||
if (resp.StatusCode == http.StatusTooManyRequests ||
|
||||
resp.StatusCode == http.StatusInternalServerError ||
|
||||
resp.StatusCode == http.StatusBadGateway ||
|
||||
resp.StatusCode == http.StatusServiceUnavailable) && attempt < maxRetries-1 {
|
||||
if err := func() error {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
backoff := time.Duration(1<<uint(attempt)) * time.Second
|
||||
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
|
||||
return sleepWithContext(backoff + jitter)
|
||||
}(); err != nil {
|
||||
return nil, fmt.Errorf("request cancelled: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("request failed: no response received")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_ = resp.Body.Close()
|
||||
statusText := http.StatusText(resp.StatusCode)
|
||||
if statusText == "" {
|
||||
statusText = resp.Status
|
||||
}
|
||||
fmt.Printf("[DriveClient] Drive API error: status=%d, msg=%s\n", resp.StatusCode, statusText)
|
||||
// 只返回通用错误
|
||||
return nil, fmt.Errorf("drive API error: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// Parse response
|
||||
var result struct {
|
||||
StorageQuota struct {
|
||||
Limit string `json:"limit"` // Can be string or number
|
||||
Usage string `json:"usage"`
|
||||
} `json:"storageQuota"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
// Parse limit and usage (handle both string and number formats)
|
||||
var limit, usage int64
|
||||
if result.StorageQuota.Limit != "" {
|
||||
if val, err := strconv.ParseInt(result.StorageQuota.Limit, 10, 64); err == nil {
|
||||
limit = val
|
||||
}
|
||||
}
|
||||
if result.StorageQuota.Usage != "" {
|
||||
if val, err := strconv.ParseInt(result.StorageQuota.Usage, 10, 64); err == nil {
|
||||
usage = val
|
||||
}
|
||||
}
|
||||
|
||||
return &DriveStorageInfo{
|
||||
Limit: limit,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
)
|
||||
|
||||
// DriveStorageInfo represents Google Drive storage quota information
|
||||
type DriveStorageInfo struct {
|
||||
Limit int64 `json:"limit"` // Storage limit in bytes
|
||||
Usage int64 `json:"usage"` // Current usage in bytes
|
||||
}
|
||||
|
||||
// DriveClient interface for Google Drive API operations
|
||||
type DriveClient interface {
|
||||
GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error)
|
||||
}
|
||||
|
||||
type driveClient struct{}
|
||||
|
||||
// NewDriveClient creates a new Drive API client
|
||||
func NewDriveClient() DriveClient {
|
||||
return &driveClient{}
|
||||
}
|
||||
|
||||
// GetStorageQuota fetches storage quota from Google Drive API
|
||||
func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error) {
|
||||
const driveAPIURL = "https://www.googleapis.com/drive/v3/about?fields=storageQuota"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", driveAPIURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
// Get HTTP client with proxy support
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 10 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
|
||||
sleepWithContext := func(d time.Duration) error {
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Retry logic with exponential backoff (+ jitter) for rate limits and transient failures
|
||||
var resp *http.Response
|
||||
maxRetries := 3
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if ctx.Err() != nil {
|
||||
return nil, fmt.Errorf("request cancelled: %w", ctx.Err())
|
||||
}
|
||||
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
// Network error retry
|
||||
if attempt < maxRetries-1 {
|
||||
backoff := time.Duration(1<<uint(attempt)) * time.Second
|
||||
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
|
||||
if err := sleepWithContext(backoff + jitter); err != nil {
|
||||
return nil, fmt.Errorf("request cancelled: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("network error after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
// Success
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
break
|
||||
}
|
||||
|
||||
// Retry 429, 500, 502, 503 with exponential backoff + jitter
|
||||
if (resp.StatusCode == http.StatusTooManyRequests ||
|
||||
resp.StatusCode == http.StatusInternalServerError ||
|
||||
resp.StatusCode == http.StatusBadGateway ||
|
||||
resp.StatusCode == http.StatusServiceUnavailable) && attempt < maxRetries-1 {
|
||||
if err := func() error {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
backoff := time.Duration(1<<uint(attempt)) * time.Second
|
||||
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
|
||||
return sleepWithContext(backoff + jitter)
|
||||
}(); err != nil {
|
||||
return nil, fmt.Errorf("request cancelled: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("request failed: no response received")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_ = resp.Body.Close()
|
||||
statusText := http.StatusText(resp.StatusCode)
|
||||
if statusText == "" {
|
||||
statusText = resp.Status
|
||||
}
|
||||
fmt.Printf("[DriveClient] Drive API error: status=%d, msg=%s\n", resp.StatusCode, statusText)
|
||||
// 只返回通用错误
|
||||
return nil, fmt.Errorf("drive API error: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// Parse response
|
||||
var result struct {
|
||||
StorageQuota struct {
|
||||
Limit string `json:"limit"` // Can be string or number
|
||||
Usage string `json:"usage"`
|
||||
} `json:"storageQuota"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
// Parse limit and usage (handle both string and number formats)
|
||||
var limit, usage int64
|
||||
if result.StorageQuota.Limit != "" {
|
||||
if val, err := strconv.ParseInt(result.StorageQuota.Limit, 10, 64); err == nil {
|
||||
limit = val
|
||||
}
|
||||
}
|
||||
if result.StorageQuota.Usage != "" {
|
||||
if val, err := strconv.ParseInt(result.StorageQuota.Usage, 10, 64); err == nil {
|
||||
usage = val
|
||||
}
|
||||
}
|
||||
|
||||
return &DriveStorageInfo{
|
||||
Limit: limit,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
package geminicli
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDriveStorageInfo(t *testing.T) {
|
||||
// 测试 DriveStorageInfo 结构体
|
||||
info := &DriveStorageInfo{
|
||||
Limit: 100 * 1024 * 1024 * 1024, // 100GB
|
||||
Usage: 50 * 1024 * 1024 * 1024, // 50GB
|
||||
}
|
||||
|
||||
if info.Limit != 100*1024*1024*1024 {
|
||||
t.Errorf("Expected limit 100GB, got %d", info.Limit)
|
||||
}
|
||||
if info.Usage != 50*1024*1024*1024 {
|
||||
t.Errorf("Expected usage 50GB, got %d", info.Usage)
|
||||
}
|
||||
}
|
||||
package geminicli
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDriveStorageInfo(t *testing.T) {
|
||||
// 测试 DriveStorageInfo 结构体
|
||||
info := &DriveStorageInfo{
|
||||
Limit: 100 * 1024 * 1024 * 1024, // 100GB
|
||||
Usage: 50 * 1024 * 1024 * 1024, // 50GB
|
||||
}
|
||||
|
||||
if info.Limit != 100*1024*1024*1024 {
|
||||
t.Errorf("Expected limit 100GB, got %d", info.Limit)
|
||||
}
|
||||
if info.Usage != 50*1024*1024*1024 {
|
||||
t.Errorf("Expected usage 50GB, got %d", info.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
package geminicli
|
||||
|
||||
// Model represents a selectable Gemini model for UI/testing purposes.
|
||||
// Keep JSON fields consistent with existing frontend expectations.
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||
}
|
||||
|
||||
// DefaultTestModel is the default model to preselect in test flows.
|
||||
const DefaultTestModel = "gemini-3-pro-preview"
|
||||
package geminicli
|
||||
|
||||
// Model represents a selectable Gemini model for UI/testing purposes.
|
||||
// Keep JSON fields consistent with existing frontend expectations.
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||
}
|
||||
|
||||
// DefaultTestModel is the default model to preselect in test flows.
|
||||
const DefaultTestModel = "gemini-3-pro-preview"
|
||||
|
||||
@@ -1,243 +1,243 @@
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OAuthConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
Scopes string
|
||||
}
|
||||
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type"` // "code_assist" 或 "ai_studio"
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
func (s *SessionStore) Stop() {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
default:
|
||||
close(s.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier returns an RFC 7636 compatible code verifier (43+ chars).
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
func base64URLEncode(data []byte) string {
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
|
||||
}
|
||||
|
||||
// EffectiveOAuthConfig returns the effective OAuth configuration.
|
||||
// oauthType: "code_assist" or "ai_studio" (defaults to "code_assist" if empty).
|
||||
//
|
||||
// If ClientID/ClientSecret is not provided, this falls back to the built-in Gemini CLI OAuth client.
|
||||
//
|
||||
// Note: The built-in Gemini CLI OAuth client is restricted and may reject some scopes (e.g.
|
||||
// https://www.googleapis.com/auth/generative-language), which will surface as
|
||||
// "restricted_client" / "Unregistered scope(s)" errors during browser authorization.
|
||||
func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error) {
|
||||
effective := OAuthConfig{
|
||||
ClientID: strings.TrimSpace(cfg.ClientID),
|
||||
ClientSecret: strings.TrimSpace(cfg.ClientSecret),
|
||||
Scopes: strings.TrimSpace(cfg.Scopes),
|
||||
}
|
||||
|
||||
// Normalize scopes: allow comma-separated input but send space-delimited scopes to Google.
|
||||
if effective.Scopes != "" {
|
||||
effective.Scopes = strings.Join(strings.Fields(strings.ReplaceAll(effective.Scopes, ",", " ")), " ")
|
||||
}
|
||||
|
||||
// Fall back to built-in Gemini CLI OAuth client when not configured.
|
||||
if effective.ClientID == "" && effective.ClientSecret == "" {
|
||||
effective.ClientID = GeminiCLIOAuthClientID
|
||||
effective.ClientSecret = GeminiCLIOAuthClientSecret
|
||||
} else if effective.ClientID == "" || effective.ClientSecret == "" {
|
||||
return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
|
||||
}
|
||||
|
||||
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID &&
|
||||
effective.ClientSecret == GeminiCLIOAuthClientSecret
|
||||
|
||||
if effective.Scopes == "" {
|
||||
// Use different default scopes based on OAuth type
|
||||
if oauthType == "ai_studio" {
|
||||
// Built-in client can't request some AI Studio scopes (notably generative-language).
|
||||
if isBuiltinClient {
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
} else {
|
||||
effective.Scopes = DefaultAIStudioScopes
|
||||
}
|
||||
} else {
|
||||
// Default to Code Assist scopes
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
}
|
||||
} else if oauthType == "ai_studio" && isBuiltinClient {
|
||||
// If user overrides scopes while still using the built-in client, strip restricted scopes.
|
||||
parts := strings.Fields(effective.Scopes)
|
||||
filtered := make([]string, 0, len(parts))
|
||||
for _, s := range parts {
|
||||
if strings.Contains(s, "generative-language") {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, s)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
} else {
|
||||
effective.Scopes = strings.Join(filtered, " ")
|
||||
}
|
||||
}
|
||||
|
||||
// Backward compatibility: normalize older AI Studio scope to the currently documented one.
|
||||
if oauthType == "ai_studio" && effective.Scopes != "" {
|
||||
parts := strings.Fields(effective.Scopes)
|
||||
for i := range parts {
|
||||
if parts[i] == "https://www.googleapis.com/auth/generative-language" {
|
||||
parts[i] = "https://www.googleapis.com/auth/generative-language.retriever"
|
||||
}
|
||||
}
|
||||
effective.Scopes = strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
return effective, nil
|
||||
}
|
||||
|
||||
func BuildAuthorizationURL(cfg OAuthConfig, state, codeChallenge, redirectURI, projectID, oauthType string) (string, error) {
|
||||
effectiveCfg, err := EffectiveOAuthConfig(cfg, oauthType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
redirectURI = strings.TrimSpace(redirectURI)
|
||||
if redirectURI == "" {
|
||||
return "", fmt.Errorf("redirect_uri is required")
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", effectiveCfg.ClientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scope", effectiveCfg.Scopes)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
params.Set("access_type", "offline")
|
||||
params.Set("prompt", "consent")
|
||||
params.Set("include_granted_scopes", "true")
|
||||
if strings.TrimSpace(projectID) != "" {
|
||||
params.Set("project_id", strings.TrimSpace(projectID))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()), nil
|
||||
}
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OAuthConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
Scopes string
|
||||
}
|
||||
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type"` // "code_assist" 或 "ai_studio"
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
func (s *SessionStore) Stop() {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
default:
|
||||
close(s.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier returns an RFC 7636 compatible code verifier (43+ chars).
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
func base64URLEncode(data []byte) string {
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
|
||||
}
|
||||
|
||||
// EffectiveOAuthConfig returns the effective OAuth configuration.
|
||||
// oauthType: "code_assist" or "ai_studio" (defaults to "code_assist" if empty).
|
||||
//
|
||||
// If ClientID/ClientSecret is not provided, this falls back to the built-in Gemini CLI OAuth client.
|
||||
//
|
||||
// Note: The built-in Gemini CLI OAuth client is restricted and may reject some scopes (e.g.
|
||||
// https://www.googleapis.com/auth/generative-language), which will surface as
|
||||
// "restricted_client" / "Unregistered scope(s)" errors during browser authorization.
|
||||
func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error) {
|
||||
effective := OAuthConfig{
|
||||
ClientID: strings.TrimSpace(cfg.ClientID),
|
||||
ClientSecret: strings.TrimSpace(cfg.ClientSecret),
|
||||
Scopes: strings.TrimSpace(cfg.Scopes),
|
||||
}
|
||||
|
||||
// Normalize scopes: allow comma-separated input but send space-delimited scopes to Google.
|
||||
if effective.Scopes != "" {
|
||||
effective.Scopes = strings.Join(strings.Fields(strings.ReplaceAll(effective.Scopes, ",", " ")), " ")
|
||||
}
|
||||
|
||||
// Fall back to built-in Gemini CLI OAuth client when not configured.
|
||||
if effective.ClientID == "" && effective.ClientSecret == "" {
|
||||
effective.ClientID = GeminiCLIOAuthClientID
|
||||
effective.ClientSecret = GeminiCLIOAuthClientSecret
|
||||
} else if effective.ClientID == "" || effective.ClientSecret == "" {
|
||||
return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
|
||||
}
|
||||
|
||||
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID &&
|
||||
effective.ClientSecret == GeminiCLIOAuthClientSecret
|
||||
|
||||
if effective.Scopes == "" {
|
||||
// Use different default scopes based on OAuth type
|
||||
if oauthType == "ai_studio" {
|
||||
// Built-in client can't request some AI Studio scopes (notably generative-language).
|
||||
if isBuiltinClient {
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
} else {
|
||||
effective.Scopes = DefaultAIStudioScopes
|
||||
}
|
||||
} else {
|
||||
// Default to Code Assist scopes
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
}
|
||||
} else if oauthType == "ai_studio" && isBuiltinClient {
|
||||
// If user overrides scopes while still using the built-in client, strip restricted scopes.
|
||||
parts := strings.Fields(effective.Scopes)
|
||||
filtered := make([]string, 0, len(parts))
|
||||
for _, s := range parts {
|
||||
if strings.Contains(s, "generative-language") {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, s)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
} else {
|
||||
effective.Scopes = strings.Join(filtered, " ")
|
||||
}
|
||||
}
|
||||
|
||||
// Backward compatibility: normalize older AI Studio scope to the currently documented one.
|
||||
if oauthType == "ai_studio" && effective.Scopes != "" {
|
||||
parts := strings.Fields(effective.Scopes)
|
||||
for i := range parts {
|
||||
if parts[i] == "https://www.googleapis.com/auth/generative-language" {
|
||||
parts[i] = "https://www.googleapis.com/auth/generative-language.retriever"
|
||||
}
|
||||
}
|
||||
effective.Scopes = strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
return effective, nil
|
||||
}
|
||||
|
||||
func BuildAuthorizationURL(cfg OAuthConfig, state, codeChallenge, redirectURI, projectID, oauthType string) (string, error) {
|
||||
effectiveCfg, err := EffectiveOAuthConfig(cfg, oauthType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
redirectURI = strings.TrimSpace(redirectURI)
|
||||
if redirectURI == "" {
|
||||
return "", fmt.Errorf("redirect_uri is required")
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", effectiveCfg.ClientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scope", effectiveCfg.Scopes)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
params.Set("access_type", "offline")
|
||||
params.Set("prompt", "consent")
|
||||
params.Set("include_granted_scopes", "true")
|
||||
if strings.TrimSpace(projectID) != "" {
|
||||
params.Set("project_id", strings.TrimSpace(projectID))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()), nil
|
||||
}
|
||||
|
||||
@@ -1,46 +1,46 @@
|
||||
package geminicli
|
||||
|
||||
import "strings"
|
||||
|
||||
const maxLogBodyLen = 2048
|
||||
|
||||
func SanitizeBodyForLogs(body string) string {
|
||||
body = truncateBase64InMessage(body)
|
||||
if len(body) > maxLogBodyLen {
|
||||
body = body[:maxLogBodyLen] + "...[truncated]"
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func truncateBase64InMessage(message string) string {
|
||||
const maxBase64Length = 50
|
||||
|
||||
result := message
|
||||
offset := 0
|
||||
for {
|
||||
idx := strings.Index(result[offset:], ";base64,")
|
||||
if idx == -1 {
|
||||
break
|
||||
}
|
||||
actualIdx := offset + idx
|
||||
start := actualIdx + len(";base64,")
|
||||
|
||||
end := start
|
||||
for end < len(result) && isBase64Char(result[end]) {
|
||||
end++
|
||||
}
|
||||
|
||||
if end-start > maxBase64Length {
|
||||
result = result[:start+maxBase64Length] + "...[truncated]" + result[end:]
|
||||
offset = start + maxBase64Length + len("...[truncated]")
|
||||
continue
|
||||
}
|
||||
offset = end
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func isBase64Char(c byte) bool {
|
||||
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '='
|
||||
}
|
||||
package geminicli
|
||||
|
||||
import "strings"
|
||||
|
||||
const maxLogBodyLen = 2048
|
||||
|
||||
func SanitizeBodyForLogs(body string) string {
|
||||
body = truncateBase64InMessage(body)
|
||||
if len(body) > maxLogBodyLen {
|
||||
body = body[:maxLogBodyLen] + "...[truncated]"
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func truncateBase64InMessage(message string) string {
|
||||
const maxBase64Length = 50
|
||||
|
||||
result := message
|
||||
offset := 0
|
||||
for {
|
||||
idx := strings.Index(result[offset:], ";base64,")
|
||||
if idx == -1 {
|
||||
break
|
||||
}
|
||||
actualIdx := offset + idx
|
||||
start := actualIdx + len(";base64,")
|
||||
|
||||
end := start
|
||||
for end < len(result) && isBase64Char(result[end]) {
|
||||
end++
|
||||
}
|
||||
|
||||
if end-start > maxBase64Length {
|
||||
result = result[:start+maxBase64Length] + "...[truncated]" + result[end:]
|
||||
offset = start + maxBase64Length + len("...[truncated]")
|
||||
continue
|
||||
}
|
||||
offset = end
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func isBase64Char(c byte) bool {
|
||||
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '='
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package geminicli
|
||||
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
package geminicli
|
||||
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
package googleapi
|
||||
|
||||
import "net/http"
|
||||
|
||||
// HTTPStatusToGoogleStatus maps HTTP status codes to Google-style error status strings.
|
||||
func HTTPStatusToGoogleStatus(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return "INVALID_ARGUMENT"
|
||||
case http.StatusUnauthorized:
|
||||
return "UNAUTHENTICATED"
|
||||
case http.StatusForbidden:
|
||||
return "PERMISSION_DENIED"
|
||||
case http.StatusNotFound:
|
||||
return "NOT_FOUND"
|
||||
case http.StatusTooManyRequests:
|
||||
return "RESOURCE_EXHAUSTED"
|
||||
default:
|
||||
if status >= 500 {
|
||||
return "INTERNAL"
|
||||
}
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
package googleapi
|
||||
|
||||
import "net/http"
|
||||
|
||||
// HTTPStatusToGoogleStatus maps HTTP status codes to Google-style error status strings.
|
||||
func HTTPStatusToGoogleStatus(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return "INVALID_ARGUMENT"
|
||||
case http.StatusUnauthorized:
|
||||
return "UNAUTHENTICATED"
|
||||
case http.StatusForbidden:
|
||||
return "PERMISSION_DENIED"
|
||||
case http.StatusNotFound:
|
||||
return "NOT_FOUND"
|
||||
case http.StatusTooManyRequests:
|
||||
return "RESOURCE_EXHAUSTED"
|
||||
default:
|
||||
if status >= 500 {
|
||||
return "INTERNAL"
|
||||
}
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,236 +1,236 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Claude OAuth Constants (from CRS project)
|
||||
const (
|
||||
// OAuth Client ID for Claude
|
||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://claude.ai/oauth/authorize"
|
||||
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
|
||||
|
||||
// Scopes
|
||||
ScopeProfile = "user:profile"
|
||||
ScopeInference = "user:inference"
|
||||
|
||||
// Session TTL
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
Scope string `json:"scope"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SessionStore manages OAuth sessions in memory
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
// Get retrieves a session
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Check if expired
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
// Delete removes a session
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
// cleanup removes expired sessions periodically
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRandomBytes generates cryptographically secure random bytes
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// GenerateState generates a random state string for OAuth
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a unique session ID
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
// base64URLEncode encodes bytes to base64url without padding
|
||||
func base64URLEncode(data []byte) string {
|
||||
encoded := base64.URLEncoding.EncodeToString(data)
|
||||
// Remove padding
|
||||
return strings.TrimRight(encoded, "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL builds the OAuth authorization URL
|
||||
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("scope", scope)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
ClientID string `json:"client_id"`
|
||||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// TokenResponse represents the token response from OAuth provider
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// Organization and Account info from OAuth response
|
||||
Organization *OrgInfo `json:"organization,omitempty"`
|
||||
Account *AccountInfo `json:"account,omitempty"`
|
||||
}
|
||||
|
||||
// OrgInfo represents organization info from OAuth response
|
||||
type OrgInfo struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
// AccountInfo represents account info from OAuth response
|
||||
type AccountInfo struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
// RefreshTokenRequest represents the refresh token request
|
||||
type RefreshTokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
|
||||
// BuildTokenRequest creates a token exchange request
|
||||
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
|
||||
return &TokenRequest{
|
||||
GrantType: "authorization_code",
|
||||
ClientID: ClientID,
|
||||
Code: code,
|
||||
RedirectURI: RedirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
State: state,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRefreshTokenRequest creates a refresh token request
|
||||
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
|
||||
return &RefreshTokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: ClientID,
|
||||
}
|
||||
}
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Claude OAuth Constants (from CRS project)
|
||||
const (
|
||||
// OAuth Client ID for Claude
|
||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://claude.ai/oauth/authorize"
|
||||
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
|
||||
|
||||
// Scopes
|
||||
ScopeProfile = "user:profile"
|
||||
ScopeInference = "user:inference"
|
||||
|
||||
// Session TTL
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
Scope string `json:"scope"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SessionStore manages OAuth sessions in memory
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
// Get retrieves a session
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Check if expired
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
// Delete removes a session
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
// cleanup removes expired sessions periodically
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRandomBytes generates cryptographically secure random bytes
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// GenerateState generates a random state string for OAuth
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a unique session ID
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
// base64URLEncode encodes bytes to base64url without padding
|
||||
func base64URLEncode(data []byte) string {
|
||||
encoded := base64.URLEncoding.EncodeToString(data)
|
||||
// Remove padding
|
||||
return strings.TrimRight(encoded, "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL builds the OAuth authorization URL
|
||||
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("scope", scope)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
ClientID string `json:"client_id"`
|
||||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// TokenResponse represents the token response from OAuth provider
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// Organization and Account info from OAuth response
|
||||
Organization *OrgInfo `json:"organization,omitempty"`
|
||||
Account *AccountInfo `json:"account,omitempty"`
|
||||
}
|
||||
|
||||
// OrgInfo represents organization info from OAuth response
|
||||
type OrgInfo struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
// AccountInfo represents account info from OAuth response
|
||||
type AccountInfo struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
// RefreshTokenRequest represents the refresh token request
|
||||
type RefreshTokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
|
||||
// BuildTokenRequest creates a token exchange request
|
||||
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
|
||||
return &TokenRequest{
|
||||
GrantType: "authorization_code",
|
||||
ClientID: ClientID,
|
||||
Code: code,
|
||||
RedirectURI: RedirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
State: state,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRefreshTokenRequest creates a refresh token request
|
||||
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
|
||||
return &RefreshTokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: ClientID,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,42 +1,42 @@
|
||||
package openai
|
||||
|
||||
import _ "embed"
|
||||
|
||||
// Model represents an OpenAI model
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
// DefaultModels OpenAI models list
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
|
||||
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
|
||||
{ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
|
||||
{ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
|
||||
{ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
|
||||
{ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
|
||||
}
|
||||
|
||||
// DefaultModelIDs returns the default model ID list
|
||||
func DefaultModelIDs() []string {
|
||||
ids := make([]string, len(DefaultModels))
|
||||
for i, m := range DefaultModels {
|
||||
ids[i] = m.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// DefaultTestModel default model for testing OpenAI accounts
|
||||
const DefaultTestModel = "gpt-5.1-codex"
|
||||
|
||||
// DefaultInstructions default instructions for non-Codex CLI requests
|
||||
// Content loaded from instructions.txt at compile time
|
||||
//
|
||||
//go:embed instructions.txt
|
||||
var DefaultInstructions string
|
||||
package openai
|
||||
|
||||
import _ "embed"
|
||||
|
||||
// Model represents an OpenAI model
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
// DefaultModels OpenAI models list
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
|
||||
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
|
||||
{ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
|
||||
{ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
|
||||
{ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
|
||||
{ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
|
||||
}
|
||||
|
||||
// DefaultModelIDs returns the default model ID list
|
||||
func DefaultModelIDs() []string {
|
||||
ids := make([]string, len(DefaultModels))
|
||||
for i, m := range DefaultModels {
|
||||
ids[i] = m.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// DefaultTestModel default model for testing OpenAI accounts
|
||||
const DefaultTestModel = "gpt-5.1-codex"
|
||||
|
||||
// DefaultInstructions default instructions for non-Codex CLI requests
|
||||
// Content loaded from instructions.txt at compile time
|
||||
//
|
||||
//go:embed instructions.txt
|
||||
var DefaultInstructions string
|
||||
|
||||
@@ -1,118 +1,118 @@
|
||||
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
|
||||
|
||||
## General
|
||||
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
|
||||
## Editing constraints
|
||||
|
||||
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like \"Assigns the value to the variable\", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||
- You may be in a dirty git worktree.
|
||||
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||
- Do not amend a commit unless explicitly requested to do so.
|
||||
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||
|
||||
## Plan tool
|
||||
|
||||
When using the planning tool:
|
||||
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
|
||||
- Do not make single-step plans.
|
||||
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
|
||||
|
||||
## Codex CLI harness, sandboxing, and approvals
|
||||
|
||||
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
|
||||
|
||||
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
|
||||
- **read-only**: The sandbox only permits reading files.
|
||||
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
|
||||
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
|
||||
|
||||
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
|
||||
- **restricted**: Requires approval
|
||||
- **enabled**: No approval needed
|
||||
|
||||
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
|
||||
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe \"read\" commands.
|
||||
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
||||
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
||||
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
||||
|
||||
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
||||
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
|
||||
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
||||
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
||||
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
|
||||
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
||||
- (for all of these, you should weigh alternative paths that do not require approval)
|
||||
|
||||
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
|
||||
|
||||
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
|
||||
|
||||
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to \"never\", in which case never ask for approvals.
|
||||
|
||||
When requesting approval to execute a command that will require escalated privileges:
|
||||
- Provide the `sandbox_permissions` parameter with the value `\"require_escalated\"`
|
||||
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
|
||||
|
||||
## Special user requests
|
||||
|
||||
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
|
||||
- If the user asks for a \"review\", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
|
||||
|
||||
## Frontend tasks
|
||||
When doing frontend design tasks, avoid collapsing into \"AI slop\" or safe, average-looking layouts.
|
||||
Aim for interfaces that feel intentional, bold, and a bit surprising.
|
||||
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
|
||||
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
|
||||
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
|
||||
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
|
||||
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
|
||||
- Ensure the page loads properly on both desktop and mobile
|
||||
|
||||
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
|
||||
|
||||
## Presenting your work and final message
|
||||
|
||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||
|
||||
- Default: be very concise; friendly coding teammate tone.
|
||||
- Ask only when needed; suggest ideas; mirror the user's style.
|
||||
- For substantial work, summarize clearly; follow final‑answer formatting.
|
||||
- Skip heavy formatting for simple confirmations.
|
||||
- Don't dump large files you've written; reference paths only.
|
||||
- No \"save/copy this file\" - User is on the same machine.
|
||||
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
|
||||
- For code changes:
|
||||
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with \"summary\", just jump right in.
|
||||
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
|
||||
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
|
||||
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
|
||||
|
||||
### Final answer structure and style guidelines
|
||||
|
||||
- Plain text; CLI handles styling. Use structure only when it helps scanability.
|
||||
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
|
||||
- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
|
||||
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
|
||||
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
|
||||
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
|
||||
- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no \"above/below\"; parallel wording.
|
||||
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
|
||||
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
|
||||
- File References: When referencing files in your response follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a stand alone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\\repo\\project\\main.rs:12:5
|
||||
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
|
||||
|
||||
## General
|
||||
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
|
||||
## Editing constraints
|
||||
|
||||
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like \"Assigns the value to the variable\", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||
- You may be in a dirty git worktree.
|
||||
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||
- Do not amend a commit unless explicitly requested to do so.
|
||||
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||
|
||||
## Plan tool
|
||||
|
||||
When using the planning tool:
|
||||
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
|
||||
- Do not make single-step plans.
|
||||
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
|
||||
|
||||
## Codex CLI harness, sandboxing, and approvals
|
||||
|
||||
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
|
||||
|
||||
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
|
||||
- **read-only**: The sandbox only permits reading files.
|
||||
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
|
||||
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
|
||||
|
||||
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
|
||||
- **restricted**: Requires approval
|
||||
- **enabled**: No approval needed
|
||||
|
||||
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
|
||||
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe \"read\" commands.
|
||||
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
||||
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
||||
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
||||
|
||||
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
||||
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
|
||||
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
||||
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
||||
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
|
||||
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
||||
- (for all of these, you should weigh alternative paths that do not require approval)
|
||||
|
||||
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
|
||||
|
||||
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
|
||||
|
||||
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to \"never\", in which case never ask for approvals.
|
||||
|
||||
When requesting approval to execute a command that will require escalated privileges:
|
||||
- Provide the `sandbox_permissions` parameter with the value `\"require_escalated\"`
|
||||
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
|
||||
|
||||
## Special user requests
|
||||
|
||||
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
|
||||
- If the user asks for a \"review\", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
|
||||
|
||||
## Frontend tasks
|
||||
When doing frontend design tasks, avoid collapsing into \"AI slop\" or safe, average-looking layouts.
|
||||
Aim for interfaces that feel intentional, bold, and a bit surprising.
|
||||
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
|
||||
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
|
||||
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
|
||||
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
|
||||
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
|
||||
- Ensure the page loads properly on both desktop and mobile
|
||||
|
||||
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
|
||||
|
||||
## Presenting your work and final message
|
||||
|
||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||
|
||||
- Default: be very concise; friendly coding teammate tone.
|
||||
- Ask only when needed; suggest ideas; mirror the user's style.
|
||||
- For substantial work, summarize clearly; follow final‑answer formatting.
|
||||
- Skip heavy formatting for simple confirmations.
|
||||
- Don't dump large files you've written; reference paths only.
|
||||
- No \"save/copy this file\" - User is on the same machine.
|
||||
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
|
||||
- For code changes:
|
||||
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with \"summary\", just jump right in.
|
||||
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
|
||||
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
|
||||
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
|
||||
|
||||
### Final answer structure and style guidelines
|
||||
|
||||
- Plain text; CLI handles styling. Use structure only when it helps scanability.
|
||||
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
|
||||
- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
|
||||
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
|
||||
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
|
||||
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
|
||||
- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no \"above/below\"; parallel wording.
|
||||
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
|
||||
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
|
||||
- File References: When referencing files in your response follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a stand alone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\\repo\\project\\main.rs:12:5
|
||||
|
||||
@@ -1,366 +1,366 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OpenAI OAuth Constants (from CRS project - Codex CLI client)
|
||||
const (
|
||||
// OAuth Client ID for OpenAI (Codex CLI official)
|
||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||
TokenURL = "https://auth.openai.com/oauth/token"
|
||||
|
||||
// Default redirect URI (can be customized)
|
||||
DefaultRedirectURI = "http://localhost:1455/auth/callback"
|
||||
|
||||
// Scopes
|
||||
DefaultScopes = "openid profile email offline_access"
|
||||
// RefreshScopes - scope for token refresh (without offline_access, aligned with CRS project)
|
||||
RefreshScopes = "openid profile email"
|
||||
|
||||
// Session TTL
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state for OpenAI
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SessionStore manages OAuth sessions in memory
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
// Get retrieves a session
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Check if expired
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
// Delete removes a session
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
}
|
||||
|
||||
// cleanup removes expired sessions periodically
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRandomBytes generates cryptographically secure random bytes
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// GenerateState generates a random state string for OAuth
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a unique session ID
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier generates a PKCE code verifier (64 bytes -> hex for OpenAI)
|
||||
// OpenAI uses hex encoding instead of base64url
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||
// Uses base64url encoding as per RFC 7636
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
// base64URLEncode encodes bytes to base64url without padding
|
||||
func base64URLEncode(data []byte) string {
|
||||
encoded := base64.URLEncoding.EncodeToString(data)
|
||||
// Remove padding
|
||||
return strings.TrimRight(encoded, "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
|
||||
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
|
||||
if redirectURI == "" {
|
||||
redirectURI = DefaultRedirectURI
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scope", DefaultScopes)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
// OpenAI specific parameters
|
||||
params.Set("id_token_add_organizations", "true")
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
ClientID string `json:"client_id"`
|
||||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
}
|
||||
|
||||
// TokenResponse represents the token response from OpenAI OAuth
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// RefreshTokenRequest represents the refresh token request
|
||||
type RefreshTokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// IDTokenClaims represents the claims from OpenAI ID Token
|
||||
type IDTokenClaims struct {
|
||||
// Standard claims
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Iss string `json:"iss"`
|
||||
Aud []string `json:"aud"` // OpenAI returns aud as an array
|
||||
Exp int64 `json:"exp"`
|
||||
Iat int64 `json:"iat"`
|
||||
|
||||
// OpenAI specific claims (nested under https://api.openai.com/auth)
|
||||
OpenAIAuth *OpenAIAuthClaims `json:"https://api.openai.com/auth,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIAuthClaims represents the OpenAI specific auth claims
|
||||
type OpenAIAuthClaims struct {
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Organizations []OrganizationClaim `json:"organizations"`
|
||||
}
|
||||
|
||||
// OrganizationClaim represents an organization in the ID Token
|
||||
type OrganizationClaim struct {
|
||||
ID string `json:"id"`
|
||||
Role string `json:"role"`
|
||||
Title string `json:"title"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
}
|
||||
|
||||
// BuildTokenRequest creates a token exchange request for OpenAI
|
||||
func BuildTokenRequest(code, codeVerifier, redirectURI string) *TokenRequest {
|
||||
if redirectURI == "" {
|
||||
redirectURI = DefaultRedirectURI
|
||||
}
|
||||
return &TokenRequest{
|
||||
GrantType: "authorization_code",
|
||||
ClientID: ClientID,
|
||||
Code: code,
|
||||
RedirectURI: redirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRefreshTokenRequest creates a refresh token request for OpenAI
|
||||
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
|
||||
return &RefreshTokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: ClientID,
|
||||
Scope: RefreshScopes,
|
||||
}
|
||||
}
|
||||
|
||||
// ToFormData converts TokenRequest to URL-encoded form data
|
||||
func (r *TokenRequest) ToFormData() string {
|
||||
params := url.Values{}
|
||||
params.Set("grant_type", r.GrantType)
|
||||
params.Set("client_id", r.ClientID)
|
||||
params.Set("code", r.Code)
|
||||
params.Set("redirect_uri", r.RedirectURI)
|
||||
params.Set("code_verifier", r.CodeVerifier)
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ToFormData converts RefreshTokenRequest to URL-encoded form data
|
||||
func (r *RefreshTokenRequest) ToFormData() string {
|
||||
params := url.Values{}
|
||||
params.Set("grant_type", r.GrantType)
|
||||
params.Set("client_id", r.ClientID)
|
||||
params.Set("refresh_token", r.RefreshToken)
|
||||
params.Set("scope", r.Scope)
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims
|
||||
// Note: This does NOT verify the signature - it only decodes the payload
|
||||
// For production, you should verify the token signature using OpenAI's public keys
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Decode payload (second part)
|
||||
payload := parts[1]
|
||||
// Add padding if necessary
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
// Try standard encoding
|
||||
decoded, err = base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var claims IDTokenClaims
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// ExtractUserInfo extracts user information from ID Token claims
|
||||
type UserInfo struct {
|
||||
Email string
|
||||
ChatGPTAccountID string
|
||||
ChatGPTUserID string
|
||||
UserID string
|
||||
OrganizationID string
|
||||
Organizations []OrganizationClaim
|
||||
}
|
||||
|
||||
// GetUserInfo extracts user info from ID Token claims
|
||||
func (c *IDTokenClaims) GetUserInfo() *UserInfo {
|
||||
info := &UserInfo{
|
||||
Email: c.Email,
|
||||
}
|
||||
|
||||
if c.OpenAIAuth != nil {
|
||||
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
||||
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
||||
info.UserID = c.OpenAIAuth.UserID
|
||||
info.Organizations = c.OpenAIAuth.Organizations
|
||||
|
||||
// Get default organization ID
|
||||
for _, org := range c.OpenAIAuth.Organizations {
|
||||
if org.IsDefault {
|
||||
info.OrganizationID = org.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no default, use first org
|
||||
if info.OrganizationID == "" && len(c.OpenAIAuth.Organizations) > 0 {
|
||||
info.OrganizationID = c.OpenAIAuth.Organizations[0].ID
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
package openai
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OpenAI OAuth Constants (from CRS project - Codex CLI client)
|
||||
const (
|
||||
// OAuth Client ID for OpenAI (Codex CLI official)
|
||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||
TokenURL = "https://auth.openai.com/oauth/token"
|
||||
|
||||
// Default redirect URI (can be customized)
|
||||
DefaultRedirectURI = "http://localhost:1455/auth/callback"
|
||||
|
||||
// Scopes
|
||||
DefaultScopes = "openid profile email offline_access"
|
||||
// RefreshScopes - scope for token refresh (without offline_access, aligned with CRS project)
|
||||
RefreshScopes = "openid profile email"
|
||||
|
||||
// Session TTL
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state for OpenAI
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SessionStore manages OAuth sessions in memory
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
// Get retrieves a session
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Check if expired
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
// Delete removes a session
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
}
|
||||
|
||||
// cleanup removes expired sessions periodically
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRandomBytes generates cryptographically secure random bytes
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// GenerateState generates a random state string for OAuth
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a unique session ID
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier generates a PKCE code verifier (64 bytes -> hex for OpenAI)
|
||||
// OpenAI uses hex encoding instead of base64url
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||
// Uses base64url encoding as per RFC 7636
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
// base64URLEncode encodes bytes to base64url without padding
|
||||
func base64URLEncode(data []byte) string {
|
||||
encoded := base64.URLEncoding.EncodeToString(data)
|
||||
// Remove padding
|
||||
return strings.TrimRight(encoded, "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
|
||||
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
|
||||
if redirectURI == "" {
|
||||
redirectURI = DefaultRedirectURI
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scope", DefaultScopes)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
// OpenAI specific parameters
|
||||
params.Set("id_token_add_organizations", "true")
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
ClientID string `json:"client_id"`
|
||||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
}
|
||||
|
||||
// TokenResponse represents the token response from OpenAI OAuth
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// RefreshTokenRequest represents the refresh token request
|
||||
type RefreshTokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// IDTokenClaims represents the claims from OpenAI ID Token
|
||||
type IDTokenClaims struct {
|
||||
// Standard claims
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Iss string `json:"iss"`
|
||||
Aud []string `json:"aud"` // OpenAI returns aud as an array
|
||||
Exp int64 `json:"exp"`
|
||||
Iat int64 `json:"iat"`
|
||||
|
||||
// OpenAI specific claims (nested under https://api.openai.com/auth)
|
||||
OpenAIAuth *OpenAIAuthClaims `json:"https://api.openai.com/auth,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIAuthClaims represents the OpenAI specific auth claims
|
||||
type OpenAIAuthClaims struct {
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Organizations []OrganizationClaim `json:"organizations"`
|
||||
}
|
||||
|
||||
// OrganizationClaim represents an organization in the ID Token
|
||||
type OrganizationClaim struct {
|
||||
ID string `json:"id"`
|
||||
Role string `json:"role"`
|
||||
Title string `json:"title"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
}
|
||||
|
||||
// BuildTokenRequest creates a token exchange request for OpenAI
|
||||
func BuildTokenRequest(code, codeVerifier, redirectURI string) *TokenRequest {
|
||||
if redirectURI == "" {
|
||||
redirectURI = DefaultRedirectURI
|
||||
}
|
||||
return &TokenRequest{
|
||||
GrantType: "authorization_code",
|
||||
ClientID: ClientID,
|
||||
Code: code,
|
||||
RedirectURI: redirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRefreshTokenRequest creates a refresh token request for OpenAI
|
||||
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
|
||||
return &RefreshTokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: ClientID,
|
||||
Scope: RefreshScopes,
|
||||
}
|
||||
}
|
||||
|
||||
// ToFormData converts TokenRequest to URL-encoded form data
|
||||
func (r *TokenRequest) ToFormData() string {
|
||||
params := url.Values{}
|
||||
params.Set("grant_type", r.GrantType)
|
||||
params.Set("client_id", r.ClientID)
|
||||
params.Set("code", r.Code)
|
||||
params.Set("redirect_uri", r.RedirectURI)
|
||||
params.Set("code_verifier", r.CodeVerifier)
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ToFormData converts RefreshTokenRequest to URL-encoded form data
|
||||
func (r *RefreshTokenRequest) ToFormData() string {
|
||||
params := url.Values{}
|
||||
params.Set("grant_type", r.GrantType)
|
||||
params.Set("client_id", r.ClientID)
|
||||
params.Set("refresh_token", r.RefreshToken)
|
||||
params.Set("scope", r.Scope)
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims
|
||||
// Note: This does NOT verify the signature - it only decodes the payload
|
||||
// For production, you should verify the token signature using OpenAI's public keys
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Decode payload (second part)
|
||||
payload := parts[1]
|
||||
// Add padding if necessary
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
// Try standard encoding
|
||||
decoded, err = base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var claims IDTokenClaims
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// ExtractUserInfo extracts user information from ID Token claims
|
||||
type UserInfo struct {
|
||||
Email string
|
||||
ChatGPTAccountID string
|
||||
ChatGPTUserID string
|
||||
UserID string
|
||||
OrganizationID string
|
||||
Organizations []OrganizationClaim
|
||||
}
|
||||
|
||||
// GetUserInfo extracts user info from ID Token claims
|
||||
func (c *IDTokenClaims) GetUserInfo() *UserInfo {
|
||||
info := &UserInfo{
|
||||
Email: c.Email,
|
||||
}
|
||||
|
||||
if c.OpenAIAuth != nil {
|
||||
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
||||
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
||||
info.UserID = c.OpenAIAuth.UserID
|
||||
info.Organizations = c.OpenAIAuth.Organizations
|
||||
|
||||
// Get default organization ID
|
||||
for _, org := range c.OpenAIAuth.Organizations {
|
||||
if org.IsDefault {
|
||||
info.OrganizationID = org.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no default, use first org
|
||||
if info.OrganizationID == "" && len(c.OpenAIAuth.Organizations) > 0 {
|
||||
info.OrganizationID = c.OpenAIAuth.Organizations[0].ID
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
package openai
|
||||
|
||||
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
|
||||
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
|
||||
var CodexCLIUserAgentPrefixes = []string{
|
||||
"codex_vscode/",
|
||||
"codex_cli_rs/",
|
||||
}
|
||||
|
||||
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
||||
func IsCodexCLIRequest(userAgent string) bool {
|
||||
for _, prefix := range CodexCLIUserAgentPrefixes {
|
||||
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
package openai
|
||||
|
||||
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
|
||||
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
|
||||
var CodexCLIUserAgentPrefixes = []string{
|
||||
"codex_vscode/",
|
||||
"codex_cli_rs/",
|
||||
}
|
||||
|
||||
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
||||
func IsCodexCLIRequest(userAgent string) bool {
|
||||
for _, prefix := range CodexCLIUserAgentPrefixes {
|
||||
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,42 +1,42 @@
|
||||
package pagination
|
||||
|
||||
// PaginationParams 分页参数
|
||||
type PaginationParams struct {
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
// PaginationResult 分页结果
|
||||
type PaginationResult struct {
|
||||
Total int64
|
||||
Page int
|
||||
PageSize int
|
||||
Pages int
|
||||
}
|
||||
|
||||
// DefaultPagination 默认分页参数
|
||||
func DefaultPagination() PaginationParams {
|
||||
return PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}
|
||||
}
|
||||
|
||||
// Offset 计算偏移量
|
||||
func (p PaginationParams) Offset() int {
|
||||
if p.Page < 1 {
|
||||
p.Page = 1
|
||||
}
|
||||
return (p.Page - 1) * p.PageSize
|
||||
}
|
||||
|
||||
// Limit 获取限制数
|
||||
func (p PaginationParams) Limit() int {
|
||||
if p.PageSize < 1 {
|
||||
return 20
|
||||
}
|
||||
if p.PageSize > 100 {
|
||||
return 100
|
||||
}
|
||||
return p.PageSize
|
||||
}
|
||||
package pagination
|
||||
|
||||
// PaginationParams 分页参数
|
||||
type PaginationParams struct {
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
// PaginationResult 分页结果
|
||||
type PaginationResult struct {
|
||||
Total int64
|
||||
Page int
|
||||
PageSize int
|
||||
Pages int
|
||||
}
|
||||
|
||||
// DefaultPagination 默认分页参数
|
||||
func DefaultPagination() PaginationParams {
|
||||
return PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}
|
||||
}
|
||||
|
||||
// Offset 计算偏移量
|
||||
func (p PaginationParams) Offset() int {
|
||||
if p.Page < 1 {
|
||||
p.Page = 1
|
||||
}
|
||||
return (p.Page - 1) * p.PageSize
|
||||
}
|
||||
|
||||
// Limit 获取限制数
|
||||
func (p PaginationParams) Limit() int {
|
||||
if p.PageSize < 1 {
|
||||
return 20
|
||||
}
|
||||
if p.PageSize > 100 {
|
||||
return 100
|
||||
}
|
||||
return p.PageSize
|
||||
}
|
||||
|
||||
@@ -1,185 +1,185 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Response 标准API响应格式
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// PaginatedData 分页数据格式(匹配前端期望)
|
||||
type PaginatedData struct {
|
||||
Items any `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Pages int `json:"pages"`
|
||||
}
|
||||
|
||||
// Success 返回成功响应
|
||||
func Success(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Created 返回创建成功响应
|
||||
func Created(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusCreated, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Error 返回错误响应
|
||||
func Error(c *gin.Context, statusCode int, message string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
Reason: "",
|
||||
Metadata: nil,
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorWithDetails returns an error response compatible with the existing envelope while
|
||||
// optionally providing structured error fields (reason/metadata).
|
||||
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
Reason: reason,
|
||||
Metadata: metadata,
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
|
||||
// It returns true if an error was written.
|
||||
func ErrorFrom(c *gin.Context, err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
statusCode, status := infraerrors.ToHTTP(err)
|
||||
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
|
||||
return true
|
||||
}
|
||||
|
||||
// BadRequest 返回400错误
|
||||
func BadRequest(c *gin.Context, message string) {
|
||||
Error(c, http.StatusBadRequest, message)
|
||||
}
|
||||
|
||||
// Unauthorized 返回401错误
|
||||
func Unauthorized(c *gin.Context, message string) {
|
||||
Error(c, http.StatusUnauthorized, message)
|
||||
}
|
||||
|
||||
// Forbidden 返回403错误
|
||||
func Forbidden(c *gin.Context, message string) {
|
||||
Error(c, http.StatusForbidden, message)
|
||||
}
|
||||
|
||||
// NotFound 返回404错误
|
||||
func NotFound(c *gin.Context, message string) {
|
||||
Error(c, http.StatusNotFound, message)
|
||||
}
|
||||
|
||||
// InternalError 返回500错误
|
||||
func InternalError(c *gin.Context, message string) {
|
||||
Error(c, http.StatusInternalServerError, message)
|
||||
}
|
||||
|
||||
// Paginated 返回分页数据
|
||||
func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
|
||||
pages := int(math.Ceil(float64(total) / float64(pageSize)))
|
||||
if pages < 1 {
|
||||
pages = 1
|
||||
}
|
||||
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Pages: pages,
|
||||
})
|
||||
}
|
||||
|
||||
// PaginationResult 分页结果(与pagination.PaginationResult兼容)
|
||||
type PaginationResult struct {
|
||||
Total int64
|
||||
Page int
|
||||
PageSize int
|
||||
Pages int
|
||||
}
|
||||
|
||||
// PaginatedWithResult 使用PaginationResult返回分页数据
|
||||
func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
|
||||
if pagination == nil {
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
Total: 0,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
Pages: 1,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
Total: pagination.Total,
|
||||
Page: pagination.Page,
|
||||
PageSize: pagination.PageSize,
|
||||
Pages: pagination.Pages,
|
||||
})
|
||||
}
|
||||
|
||||
// ParsePagination 解析分页参数
|
||||
func ParsePagination(c *gin.Context) (page, pageSize int) {
|
||||
page = 1
|
||||
pageSize = 20
|
||||
|
||||
if p := c.Query("page"); p != "" {
|
||||
if val, err := parseInt(p); err == nil && val > 0 {
|
||||
page = val
|
||||
}
|
||||
}
|
||||
|
||||
// 支持 page_size 和 limit 两种参数名
|
||||
if ps := c.Query("page_size"); ps != "" {
|
||||
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
|
||||
pageSize = val
|
||||
}
|
||||
} else if l := c.Query("limit"); l != "" {
|
||||
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
|
||||
pageSize = val
|
||||
}
|
||||
}
|
||||
|
||||
return page, pageSize
|
||||
}
|
||||
|
||||
func parseInt(s string) (int, error) {
|
||||
var result int
|
||||
for _, c := range s {
|
||||
if c < '0' || c > '9' {
|
||||
return 0, nil
|
||||
}
|
||||
result = result*10 + int(c-'0')
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
package response
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Response 标准API响应格式
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// PaginatedData 分页数据格式(匹配前端期望)
|
||||
type PaginatedData struct {
|
||||
Items any `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Pages int `json:"pages"`
|
||||
}
|
||||
|
||||
// Success 返回成功响应
|
||||
func Success(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Created 返回创建成功响应
|
||||
func Created(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusCreated, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Error 返回错误响应
|
||||
func Error(c *gin.Context, statusCode int, message string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
Reason: "",
|
||||
Metadata: nil,
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorWithDetails returns an error response compatible with the existing envelope while
|
||||
// optionally providing structured error fields (reason/metadata).
|
||||
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
Reason: reason,
|
||||
Metadata: metadata,
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
|
||||
// It returns true if an error was written.
|
||||
func ErrorFrom(c *gin.Context, err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
statusCode, status := infraerrors.ToHTTP(err)
|
||||
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
|
||||
return true
|
||||
}
|
||||
|
||||
// BadRequest 返回400错误
|
||||
func BadRequest(c *gin.Context, message string) {
|
||||
Error(c, http.StatusBadRequest, message)
|
||||
}
|
||||
|
||||
// Unauthorized 返回401错误
|
||||
func Unauthorized(c *gin.Context, message string) {
|
||||
Error(c, http.StatusUnauthorized, message)
|
||||
}
|
||||
|
||||
// Forbidden 返回403错误
|
||||
func Forbidden(c *gin.Context, message string) {
|
||||
Error(c, http.StatusForbidden, message)
|
||||
}
|
||||
|
||||
// NotFound 返回404错误
|
||||
func NotFound(c *gin.Context, message string) {
|
||||
Error(c, http.StatusNotFound, message)
|
||||
}
|
||||
|
||||
// InternalError 返回500错误
|
||||
func InternalError(c *gin.Context, message string) {
|
||||
Error(c, http.StatusInternalServerError, message)
|
||||
}
|
||||
|
||||
// Paginated 返回分页数据
|
||||
func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
|
||||
pages := int(math.Ceil(float64(total) / float64(pageSize)))
|
||||
if pages < 1 {
|
||||
pages = 1
|
||||
}
|
||||
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Pages: pages,
|
||||
})
|
||||
}
|
||||
|
||||
// PaginationResult 分页结果(与pagination.PaginationResult兼容)
|
||||
type PaginationResult struct {
|
||||
Total int64
|
||||
Page int
|
||||
PageSize int
|
||||
Pages int
|
||||
}
|
||||
|
||||
// PaginatedWithResult 使用PaginationResult返回分页数据
|
||||
func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
|
||||
if pagination == nil {
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
Total: 0,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
Pages: 1,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
Total: pagination.Total,
|
||||
Page: pagination.Page,
|
||||
PageSize: pagination.PageSize,
|
||||
Pages: pagination.Pages,
|
||||
})
|
||||
}
|
||||
|
||||
// ParsePagination 解析分页参数
|
||||
func ParsePagination(c *gin.Context) (page, pageSize int) {
|
||||
page = 1
|
||||
pageSize = 20
|
||||
|
||||
if p := c.Query("page"); p != "" {
|
||||
if val, err := parseInt(p); err == nil && val > 0 {
|
||||
page = val
|
||||
}
|
||||
}
|
||||
|
||||
// 支持 page_size 和 limit 两种参数名
|
||||
if ps := c.Query("page_size"); ps != "" {
|
||||
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
|
||||
pageSize = val
|
||||
}
|
||||
} else if l := c.Query("limit"); l != "" {
|
||||
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
|
||||
pageSize = val
|
||||
}
|
||||
}
|
||||
|
||||
return page, pageSize
|
||||
}
|
||||
|
||||
func parseInt(s string) (int, error) {
|
||||
var result int
|
||||
for _, c := range s {
|
||||
if c < '0' || c > '9' {
|
||||
return 0, nil
|
||||
}
|
||||
result = result*10 + int(c-'0')
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -1,171 +1,171 @@
|
||||
//go:build unit
|
||||
|
||||
package response
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
errors2 "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestErrorWithDetails(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
message string
|
||||
reason string
|
||||
metadata map[string]string
|
||||
want Response
|
||||
}{
|
||||
{
|
||||
name: "plain_error",
|
||||
statusCode: http.StatusBadRequest,
|
||||
message: "invalid request",
|
||||
want: Response{
|
||||
Code: http.StatusBadRequest,
|
||||
Message: "invalid request",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "structured_error",
|
||||
statusCode: http.StatusForbidden,
|
||||
message: "no access",
|
||||
reason: "FORBIDDEN",
|
||||
metadata: map[string]string{"k": "v"},
|
||||
want: Response{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "no access",
|
||||
Reason: "FORBIDDEN",
|
||||
Metadata: map[string]string{"k": "v"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
|
||||
|
||||
require.Equal(t, tt.statusCode, w.Code)
|
||||
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorFrom(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantWritten bool
|
||||
wantHTTPCode int
|
||||
wantBody Response
|
||||
}{
|
||||
{
|
||||
name: "nil_error",
|
||||
err: nil,
|
||||
wantWritten: false,
|
||||
},
|
||||
{
|
||||
name: "application_error",
|
||||
err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusForbidden,
|
||||
wantBody: Response{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "no access",
|
||||
Reason: "FORBIDDEN",
|
||||
Metadata: map[string]string{"scope": "admin"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad_request_error",
|
||||
err: errors2.BadRequest("INVALID_REQUEST", "invalid request"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusBadRequest,
|
||||
wantBody: Response{
|
||||
Code: http.StatusBadRequest,
|
||||
Message: "invalid request",
|
||||
Reason: "INVALID_REQUEST",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unauthorized_error",
|
||||
err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusUnauthorized,
|
||||
wantBody: Response{
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "unauthorized",
|
||||
Reason: "UNAUTHORIZED",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not_found_error",
|
||||
err: errors2.NotFound("NOT_FOUND", "not found"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusNotFound,
|
||||
wantBody: Response{
|
||||
Code: http.StatusNotFound,
|
||||
Message: "not found",
|
||||
Reason: "NOT_FOUND",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "conflict_error",
|
||||
err: errors2.Conflict("CONFLICT", "conflict"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusConflict,
|
||||
wantBody: Response{
|
||||
Code: http.StatusConflict,
|
||||
Message: "conflict",
|
||||
Reason: "CONFLICT",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown_error_defaults_to_500",
|
||||
err: errors.New("boom"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusInternalServerError,
|
||||
wantBody: Response{
|
||||
Code: http.StatusInternalServerError,
|
||||
Message: errors2.UnknownMessage,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
written := ErrorFrom(c, tt.err)
|
||||
require.Equal(t, tt.wantWritten, written)
|
||||
|
||||
if !tt.wantWritten {
|
||||
require.Equal(t, 200, w.Code)
|
||||
require.Empty(t, w.Body.String())
|
||||
return
|
||||
}
|
||||
|
||||
require.Equal(t, tt.wantHTTPCode, w.Code)
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.wantBody, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
//go:build unit
|
||||
|
||||
package response
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
errors2 "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestErrorWithDetails(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
message string
|
||||
reason string
|
||||
metadata map[string]string
|
||||
want Response
|
||||
}{
|
||||
{
|
||||
name: "plain_error",
|
||||
statusCode: http.StatusBadRequest,
|
||||
message: "invalid request",
|
||||
want: Response{
|
||||
Code: http.StatusBadRequest,
|
||||
Message: "invalid request",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "structured_error",
|
||||
statusCode: http.StatusForbidden,
|
||||
message: "no access",
|
||||
reason: "FORBIDDEN",
|
||||
metadata: map[string]string{"k": "v"},
|
||||
want: Response{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "no access",
|
||||
Reason: "FORBIDDEN",
|
||||
Metadata: map[string]string{"k": "v"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
|
||||
|
||||
require.Equal(t, tt.statusCode, w.Code)
|
||||
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorFrom(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantWritten bool
|
||||
wantHTTPCode int
|
||||
wantBody Response
|
||||
}{
|
||||
{
|
||||
name: "nil_error",
|
||||
err: nil,
|
||||
wantWritten: false,
|
||||
},
|
||||
{
|
||||
name: "application_error",
|
||||
err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusForbidden,
|
||||
wantBody: Response{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "no access",
|
||||
Reason: "FORBIDDEN",
|
||||
Metadata: map[string]string{"scope": "admin"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad_request_error",
|
||||
err: errors2.BadRequest("INVALID_REQUEST", "invalid request"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusBadRequest,
|
||||
wantBody: Response{
|
||||
Code: http.StatusBadRequest,
|
||||
Message: "invalid request",
|
||||
Reason: "INVALID_REQUEST",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unauthorized_error",
|
||||
err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusUnauthorized,
|
||||
wantBody: Response{
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "unauthorized",
|
||||
Reason: "UNAUTHORIZED",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not_found_error",
|
||||
err: errors2.NotFound("NOT_FOUND", "not found"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusNotFound,
|
||||
wantBody: Response{
|
||||
Code: http.StatusNotFound,
|
||||
Message: "not found",
|
||||
Reason: "NOT_FOUND",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "conflict_error",
|
||||
err: errors2.Conflict("CONFLICT", "conflict"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusConflict,
|
||||
wantBody: Response{
|
||||
Code: http.StatusConflict,
|
||||
Message: "conflict",
|
||||
Reason: "CONFLICT",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown_error_defaults_to_500",
|
||||
err: errors.New("boom"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusInternalServerError,
|
||||
wantBody: Response{
|
||||
Code: http.StatusInternalServerError,
|
||||
Message: errors2.UnknownMessage,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
written := ErrorFrom(c, tt.err)
|
||||
require.Equal(t, tt.wantWritten, written)
|
||||
|
||||
if !tt.wantWritten {
|
||||
require.Equal(t, 200, w.Code)
|
||||
require.Empty(t, w.Body.String())
|
||||
return
|
||||
}
|
||||
|
||||
require.Equal(t, tt.wantHTTPCode, w.Code)
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.wantBody, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,47 +1,47 @@
|
||||
package sysutil
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RestartService triggers a service restart by gracefully exiting.
|
||||
//
|
||||
// This relies on systemd's Restart=always configuration to automatically
|
||||
// restart the service after it exits. This is the industry-standard approach:
|
||||
// - Simple and reliable
|
||||
// - No sudo permissions needed
|
||||
// - No complex process management
|
||||
// - Leverages systemd's native restart capability
|
||||
//
|
||||
// Prerequisites:
|
||||
// - Linux OS with systemd
|
||||
// - Service configured with Restart=always in systemd unit file
|
||||
func RestartService() error {
|
||||
if runtime.GOOS != "linux" {
|
||||
log.Println("Service restart via exit only works on Linux with systemd")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Println("Initiating service restart by graceful exit...")
|
||||
log.Println("systemd will automatically restart the service (Restart=always)")
|
||||
|
||||
// Give a moment for logs to flush and response to be sent
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestartServiceAsync is a fire-and-forget version of RestartService.
|
||||
// It logs errors instead of returning them, suitable for goroutine usage.
|
||||
func RestartServiceAsync() {
|
||||
if err := RestartService(); err != nil {
|
||||
log.Printf("Service restart failed: %v", err)
|
||||
log.Println("Please restart the service manually: sudo systemctl restart sub2api")
|
||||
}
|
||||
}
|
||||
package sysutil
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RestartService triggers a service restart by gracefully exiting.
|
||||
//
|
||||
// This relies on systemd's Restart=always configuration to automatically
|
||||
// restart the service after it exits. This is the industry-standard approach:
|
||||
// - Simple and reliable
|
||||
// - No sudo permissions needed
|
||||
// - No complex process management
|
||||
// - Leverages systemd's native restart capability
|
||||
//
|
||||
// Prerequisites:
|
||||
// - Linux OS with systemd
|
||||
// - Service configured with Restart=always in systemd unit file
|
||||
func RestartService() error {
|
||||
if runtime.GOOS != "linux" {
|
||||
log.Println("Service restart via exit only works on Linux with systemd")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Println("Initiating service restart by graceful exit...")
|
||||
log.Println("systemd will automatically restart the service (Restart=always)")
|
||||
|
||||
// Give a moment for logs to flush and response to be sent
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestartServiceAsync is a fire-and-forget version of RestartService.
|
||||
// It logs errors instead of returning them, suitable for goroutine usage.
|
||||
func RestartServiceAsync() {
|
||||
if err := RestartService(); err != nil {
|
||||
log.Printf("Service restart failed: %v", err)
|
||||
log.Println("Please restart the service manually: sudo systemctl restart sub2api")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,124 +1,124 @@
|
||||
// Package timezone provides global timezone management for the application.
|
||||
// Similar to PHP's date_default_timezone_set, this package allows setting
|
||||
// a global timezone that affects all time.Now() calls.
|
||||
package timezone
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// location is the global timezone location
|
||||
location *time.Location
|
||||
// tzName stores the timezone name for logging/debugging
|
||||
tzName string
|
||||
)
|
||||
|
||||
// Init initializes the global timezone setting.
|
||||
// This should be called once at application startup.
|
||||
// Example timezone values: "Asia/Shanghai", "America/New_York", "UTC"
|
||||
func Init(tz string) error {
|
||||
if tz == "" {
|
||||
tz = "Asia/Shanghai" // Default timezone
|
||||
}
|
||||
|
||||
loc, err := time.LoadLocation(tz)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid timezone %q: %w", tz, err)
|
||||
}
|
||||
|
||||
// Set the global Go time.Local to our timezone
|
||||
// This affects time.Now() throughout the application
|
||||
time.Local = loc
|
||||
location = loc
|
||||
tzName = tz
|
||||
|
||||
log.Printf("Timezone initialized: %s (UTC offset: %s)", tz, getUTCOffset(loc))
|
||||
return nil
|
||||
}
|
||||
|
||||
// getUTCOffset returns the current UTC offset for a location
|
||||
func getUTCOffset(loc *time.Location) string {
|
||||
_, offset := time.Now().In(loc).Zone()
|
||||
hours := offset / 3600
|
||||
minutes := (offset % 3600) / 60
|
||||
if minutes < 0 {
|
||||
minutes = -minutes
|
||||
}
|
||||
sign := "+"
|
||||
if hours < 0 {
|
||||
sign = "-"
|
||||
hours = -hours
|
||||
}
|
||||
return fmt.Sprintf("%s%02d:%02d", sign, hours, minutes)
|
||||
}
|
||||
|
||||
// Now returns the current time in the configured timezone.
|
||||
// This is equivalent to time.Now() after Init() is called,
|
||||
// but provided for explicit timezone-aware code.
|
||||
func Now() time.Time {
|
||||
if location == nil {
|
||||
return time.Now()
|
||||
}
|
||||
return time.Now().In(location)
|
||||
}
|
||||
|
||||
// Location returns the configured timezone location.
|
||||
func Location() *time.Location {
|
||||
if location == nil {
|
||||
return time.Local
|
||||
}
|
||||
return location
|
||||
}
|
||||
|
||||
// Name returns the configured timezone name.
|
||||
func Name() string {
|
||||
if tzName == "" {
|
||||
return "Local"
|
||||
}
|
||||
return tzName
|
||||
}
|
||||
|
||||
// StartOfDay returns the start of the given day (00:00:00) in the configured timezone.
|
||||
func StartOfDay(t time.Time) time.Time {
|
||||
loc := Location()
|
||||
t = t.In(loc)
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
// Today returns the start of today (00:00:00) in the configured timezone.
|
||||
func Today() time.Time {
|
||||
return StartOfDay(Now())
|
||||
}
|
||||
|
||||
// EndOfDay returns the end of the given day (23:59:59.999999999) in the configured timezone.
|
||||
func EndOfDay(t time.Time) time.Time {
|
||||
loc := Location()
|
||||
t = t.In(loc)
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, loc)
|
||||
}
|
||||
|
||||
// StartOfWeek returns the start of the week (Monday 00:00:00) for the given time.
|
||||
func StartOfWeek(t time.Time) time.Time {
|
||||
loc := Location()
|
||||
t = t.In(loc)
|
||||
weekday := int(t.Weekday())
|
||||
if weekday == 0 {
|
||||
weekday = 7 // Sunday is day 7
|
||||
}
|
||||
return time.Date(t.Year(), t.Month(), t.Day()-weekday+1, 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
// StartOfMonth returns the start of the month (1st day 00:00:00) for the given time.
|
||||
func StartOfMonth(t time.Time) time.Time {
|
||||
loc := Location()
|
||||
t = t.In(loc)
|
||||
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
// ParseInLocation parses a time string in the configured timezone.
|
||||
func ParseInLocation(layout, value string) (time.Time, error) {
|
||||
return time.ParseInLocation(layout, value, Location())
|
||||
}
|
||||
// Package timezone provides global timezone management for the application.
|
||||
// Similar to PHP's date_default_timezone_set, this package allows setting
|
||||
// a global timezone that affects all time.Now() calls.
|
||||
package timezone
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// location is the global timezone location
|
||||
location *time.Location
|
||||
// tzName stores the timezone name for logging/debugging
|
||||
tzName string
|
||||
)
|
||||
|
||||
// Init initializes the global timezone setting.
|
||||
// This should be called once at application startup.
|
||||
// Example timezone values: "Asia/Shanghai", "America/New_York", "UTC"
|
||||
func Init(tz string) error {
|
||||
if tz == "" {
|
||||
tz = "Asia/Shanghai" // Default timezone
|
||||
}
|
||||
|
||||
loc, err := time.LoadLocation(tz)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid timezone %q: %w", tz, err)
|
||||
}
|
||||
|
||||
// Set the global Go time.Local to our timezone
|
||||
// This affects time.Now() throughout the application
|
||||
time.Local = loc
|
||||
location = loc
|
||||
tzName = tz
|
||||
|
||||
log.Printf("Timezone initialized: %s (UTC offset: %s)", tz, getUTCOffset(loc))
|
||||
return nil
|
||||
}
|
||||
|
||||
// getUTCOffset returns the current UTC offset for a location
|
||||
func getUTCOffset(loc *time.Location) string {
|
||||
_, offset := time.Now().In(loc).Zone()
|
||||
hours := offset / 3600
|
||||
minutes := (offset % 3600) / 60
|
||||
if minutes < 0 {
|
||||
minutes = -minutes
|
||||
}
|
||||
sign := "+"
|
||||
if hours < 0 {
|
||||
sign = "-"
|
||||
hours = -hours
|
||||
}
|
||||
return fmt.Sprintf("%s%02d:%02d", sign, hours, minutes)
|
||||
}
|
||||
|
||||
// Now returns the current time in the configured timezone.
|
||||
// This is equivalent to time.Now() after Init() is called,
|
||||
// but provided for explicit timezone-aware code.
|
||||
func Now() time.Time {
|
||||
if location == nil {
|
||||
return time.Now()
|
||||
}
|
||||
return time.Now().In(location)
|
||||
}
|
||||
|
||||
// Location returns the configured timezone location.
|
||||
func Location() *time.Location {
|
||||
if location == nil {
|
||||
return time.Local
|
||||
}
|
||||
return location
|
||||
}
|
||||
|
||||
// Name returns the configured timezone name.
|
||||
func Name() string {
|
||||
if tzName == "" {
|
||||
return "Local"
|
||||
}
|
||||
return tzName
|
||||
}
|
||||
|
||||
// StartOfDay returns the start of the given day (00:00:00) in the configured timezone.
|
||||
func StartOfDay(t time.Time) time.Time {
|
||||
loc := Location()
|
||||
t = t.In(loc)
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
// Today returns the start of today (00:00:00) in the configured timezone.
|
||||
func Today() time.Time {
|
||||
return StartOfDay(Now())
|
||||
}
|
||||
|
||||
// EndOfDay returns the end of the given day (23:59:59.999999999) in the configured timezone.
|
||||
func EndOfDay(t time.Time) time.Time {
|
||||
loc := Location()
|
||||
t = t.In(loc)
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, loc)
|
||||
}
|
||||
|
||||
// StartOfWeek returns the start of the week (Monday 00:00:00) for the given time.
|
||||
func StartOfWeek(t time.Time) time.Time {
|
||||
loc := Location()
|
||||
t = t.In(loc)
|
||||
weekday := int(t.Weekday())
|
||||
if weekday == 0 {
|
||||
weekday = 7 // Sunday is day 7
|
||||
}
|
||||
return time.Date(t.Year(), t.Month(), t.Day()-weekday+1, 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
// StartOfMonth returns the start of the month (1st day 00:00:00) for the given time.
|
||||
func StartOfMonth(t time.Time) time.Time {
|
||||
loc := Location()
|
||||
t = t.In(loc)
|
||||
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
// ParseInLocation parses a time string in the configured timezone.
|
||||
func ParseInLocation(layout, value string) (time.Time, error) {
|
||||
return time.ParseInLocation(layout, value, Location())
|
||||
}
|
||||
|
||||
@@ -1,137 +1,137 @@
|
||||
package timezone
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
// Test with valid timezone
|
||||
err := Init("Asia/Shanghai")
|
||||
if err != nil {
|
||||
t.Fatalf("Init failed with valid timezone: %v", err)
|
||||
}
|
||||
|
||||
// Verify time.Local was set
|
||||
if time.Local.String() != "Asia/Shanghai" {
|
||||
t.Errorf("time.Local not set correctly, got %s", time.Local.String())
|
||||
}
|
||||
|
||||
// Verify our location variable
|
||||
if Location().String() != "Asia/Shanghai" {
|
||||
t.Errorf("Location() not set correctly, got %s", Location().String())
|
||||
}
|
||||
|
||||
// Test Name()
|
||||
if Name() != "Asia/Shanghai" {
|
||||
t.Errorf("Name() not set correctly, got %s", Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitInvalidTimezone(t *testing.T) {
|
||||
err := Init("Invalid/Timezone")
|
||||
if err == nil {
|
||||
t.Error("Init should fail with invalid timezone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeNowAffected(t *testing.T) {
|
||||
// Reset to UTC first
|
||||
if err := Init("UTC"); err != nil {
|
||||
t.Fatalf("Init failed with UTC: %v", err)
|
||||
}
|
||||
utcNow := time.Now()
|
||||
|
||||
// Switch to Shanghai (UTC+8)
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
shanghaiNow := time.Now()
|
||||
|
||||
// The times should be the same instant, but different timezone representation
|
||||
// Shanghai should be 8 hours ahead in display
|
||||
_, utcOffset := utcNow.Zone()
|
||||
_, shanghaiOffset := shanghaiNow.Zone()
|
||||
|
||||
expectedDiff := 8 * 3600 // 8 hours in seconds
|
||||
actualDiff := shanghaiOffset - utcOffset
|
||||
|
||||
if actualDiff != expectedDiff {
|
||||
t.Errorf("Timezone offset difference incorrect: expected %d, got %d", expectedDiff, actualDiff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToday(t *testing.T) {
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
today := Today()
|
||||
now := Now()
|
||||
|
||||
// Today should be at 00:00:00
|
||||
if today.Hour() != 0 || today.Minute() != 0 || today.Second() != 0 {
|
||||
t.Errorf("Today() not at start of day: %v", today)
|
||||
}
|
||||
|
||||
// Today should be same date as now
|
||||
if today.Year() != now.Year() || today.Month() != now.Month() || today.Day() != now.Day() {
|
||||
t.Errorf("Today() date mismatch: today=%v, now=%v", today, now)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartOfDay(t *testing.T) {
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
// Create a time at 15:30:45
|
||||
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
|
||||
startOfDay := StartOfDay(testTime)
|
||||
|
||||
expected := time.Date(2024, 6, 15, 0, 0, 0, 0, Location())
|
||||
if !startOfDay.Equal(expected) {
|
||||
t.Errorf("StartOfDay incorrect: expected %v, got %v", expected, startOfDay)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateVsStartOfDay(t *testing.T) {
|
||||
// This test demonstrates why Truncate(24*time.Hour) can be problematic
|
||||
// and why StartOfDay is more reliable for timezone-aware code
|
||||
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
now := Now()
|
||||
|
||||
// Truncate operates on UTC, not local time
|
||||
truncated := now.Truncate(24 * time.Hour)
|
||||
|
||||
// StartOfDay operates on local time
|
||||
startOfDay := StartOfDay(now)
|
||||
|
||||
// These will likely be different for non-UTC timezones
|
||||
t.Logf("Now: %v", now)
|
||||
t.Logf("Truncate(24h): %v", truncated)
|
||||
t.Logf("StartOfDay: %v", startOfDay)
|
||||
|
||||
// The truncated time may not be at local midnight
|
||||
// StartOfDay is always at local midnight
|
||||
if startOfDay.Hour() != 0 {
|
||||
t.Errorf("StartOfDay should be at hour 0, got %d", startOfDay.Hour())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSTAwareness(t *testing.T) {
|
||||
// Test with a timezone that has DST (America/New_York)
|
||||
err := Init("America/New_York")
|
||||
if err != nil {
|
||||
t.Skipf("America/New_York timezone not available: %v", err)
|
||||
}
|
||||
|
||||
// Just verify it doesn't crash
|
||||
_ = Today()
|
||||
_ = Now()
|
||||
_ = StartOfDay(Now())
|
||||
}
|
||||
package timezone
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
// Test with valid timezone
|
||||
err := Init("Asia/Shanghai")
|
||||
if err != nil {
|
||||
t.Fatalf("Init failed with valid timezone: %v", err)
|
||||
}
|
||||
|
||||
// Verify time.Local was set
|
||||
if time.Local.String() != "Asia/Shanghai" {
|
||||
t.Errorf("time.Local not set correctly, got %s", time.Local.String())
|
||||
}
|
||||
|
||||
// Verify our location variable
|
||||
if Location().String() != "Asia/Shanghai" {
|
||||
t.Errorf("Location() not set correctly, got %s", Location().String())
|
||||
}
|
||||
|
||||
// Test Name()
|
||||
if Name() != "Asia/Shanghai" {
|
||||
t.Errorf("Name() not set correctly, got %s", Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitInvalidTimezone(t *testing.T) {
|
||||
err := Init("Invalid/Timezone")
|
||||
if err == nil {
|
||||
t.Error("Init should fail with invalid timezone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeNowAffected(t *testing.T) {
|
||||
// Reset to UTC first
|
||||
if err := Init("UTC"); err != nil {
|
||||
t.Fatalf("Init failed with UTC: %v", err)
|
||||
}
|
||||
utcNow := time.Now()
|
||||
|
||||
// Switch to Shanghai (UTC+8)
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
shanghaiNow := time.Now()
|
||||
|
||||
// The times should be the same instant, but different timezone representation
|
||||
// Shanghai should be 8 hours ahead in display
|
||||
_, utcOffset := utcNow.Zone()
|
||||
_, shanghaiOffset := shanghaiNow.Zone()
|
||||
|
||||
expectedDiff := 8 * 3600 // 8 hours in seconds
|
||||
actualDiff := shanghaiOffset - utcOffset
|
||||
|
||||
if actualDiff != expectedDiff {
|
||||
t.Errorf("Timezone offset difference incorrect: expected %d, got %d", expectedDiff, actualDiff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToday(t *testing.T) {
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
today := Today()
|
||||
now := Now()
|
||||
|
||||
// Today should be at 00:00:00
|
||||
if today.Hour() != 0 || today.Minute() != 0 || today.Second() != 0 {
|
||||
t.Errorf("Today() not at start of day: %v", today)
|
||||
}
|
||||
|
||||
// Today should be same date as now
|
||||
if today.Year() != now.Year() || today.Month() != now.Month() || today.Day() != now.Day() {
|
||||
t.Errorf("Today() date mismatch: today=%v, now=%v", today, now)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartOfDay(t *testing.T) {
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
// Create a time at 15:30:45
|
||||
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
|
||||
startOfDay := StartOfDay(testTime)
|
||||
|
||||
expected := time.Date(2024, 6, 15, 0, 0, 0, 0, Location())
|
||||
if !startOfDay.Equal(expected) {
|
||||
t.Errorf("StartOfDay incorrect: expected %v, got %v", expected, startOfDay)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateVsStartOfDay(t *testing.T) {
|
||||
// This test demonstrates why Truncate(24*time.Hour) can be problematic
|
||||
// and why StartOfDay is more reliable for timezone-aware code
|
||||
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||
}
|
||||
|
||||
now := Now()
|
||||
|
||||
// Truncate operates on UTC, not local time
|
||||
truncated := now.Truncate(24 * time.Hour)
|
||||
|
||||
// StartOfDay operates on local time
|
||||
startOfDay := StartOfDay(now)
|
||||
|
||||
// These will likely be different for non-UTC timezones
|
||||
t.Logf("Now: %v", now)
|
||||
t.Logf("Truncate(24h): %v", truncated)
|
||||
t.Logf("StartOfDay: %v", startOfDay)
|
||||
|
||||
// The truncated time may not be at local midnight
|
||||
// StartOfDay is always at local midnight
|
||||
if startOfDay.Hour() != 0 {
|
||||
t.Errorf("StartOfDay should be at hour 0, got %d", startOfDay.Hour())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSTAwareness(t *testing.T) {
|
||||
// Test with a timezone that has DST (America/New_York)
|
||||
err := Init("America/New_York")
|
||||
if err != nil {
|
||||
t.Skipf("America/New_York timezone not available: %v", err)
|
||||
}
|
||||
|
||||
// Just verify it doesn't crash
|
||||
_ = Today()
|
||||
_ = Now()
|
||||
_ = StartOfDay(Now())
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package usagestats
|
||||
|
||||
// AccountStats 账号使用统计
|
||||
type AccountStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
}
|
||||
package usagestats
|
||||
|
||||
// AccountStats 账号使用统计
|
||||
type AccountStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
}
|
||||
|
||||
@@ -1,214 +1,214 @@
|
||||
package usagestats
|
||||
|
||||
import "time"
|
||||
|
||||
// DashboardStats 仪表盘统计
|
||||
type DashboardStats struct {
|
||||
// 用户统计
|
||||
TotalUsers int64 `json:"total_users"`
|
||||
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
|
||||
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
||||
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
|
||||
|
||||
// 账户统计
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
|
||||
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
|
||||
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
|
||||
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
|
||||
|
||||
// 累计 Token 使用统计
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
||||
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
||||
|
||||
// 今日 Token 使用统计
|
||||
TodayRequests int64 `json:"today_requests"`
|
||||
TodayInputTokens int64 `json:"today_input_tokens"`
|
||||
TodayOutputTokens int64 `json:"today_output_tokens"`
|
||||
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
||||
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
||||
TodayTokens int64 `json:"today_tokens"`
|
||||
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
||||
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
||||
|
||||
// 系统运行统计
|
||||
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
|
||||
|
||||
// 性能指标
|
||||
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
|
||||
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
|
||||
}
|
||||
|
||||
// TrendDataPoint represents a single point in trend data
|
||||
type TrendDataPoint struct {
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheTokens int64 `json:"cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ModelStat represents usage statistics for a single model
|
||||
type ModelStat struct {
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// UserUsageTrendPoint represents user usage trend data point
|
||||
type UserUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
KeyName string `json:"key_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}
|
||||
|
||||
// UserDashboardStats 用户仪表盘统计
|
||||
type UserDashboardStats struct {
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"`
|
||||
|
||||
// 累计 Token 使用统计
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
||||
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
||||
|
||||
// 今日 Token 使用统计
|
||||
TodayRequests int64 `json:"today_requests"`
|
||||
TodayInputTokens int64 `json:"today_input_tokens"`
|
||||
TodayOutputTokens int64 `json:"today_output_tokens"`
|
||||
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
||||
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
||||
TodayTokens int64 `json:"today_tokens"`
|
||||
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
||||
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
||||
|
||||
// 性能统计
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
|
||||
// 性能指标
|
||||
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
|
||||
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
|
||||
}
|
||||
|
||||
// UsageLogFilters represents filters for usage log queries
|
||||
type UsageLogFilters struct {
|
||||
UserID int64
|
||||
ApiKeyID int64
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Model string
|
||||
Stream *bool
|
||||
BillingType *int8
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
}
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
type UsageStats struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
}
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||
type BatchApiKeyUsageStats struct {
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// AccountUsageHistory represents daily usage history for an account
|
||||
type AccountUsageHistory struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
}
|
||||
|
||||
// AccountUsageSummary represents summary statistics for an account
|
||||
type AccountUsageSummary struct {
|
||||
Days int `json:"days"`
|
||||
ActualDaysUsed int `json:"actual_days_used"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalStandardCost float64 `json:"total_standard_cost"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
AvgDailyCost float64 `json:"avg_daily_cost"`
|
||||
AvgDailyRequests float64 `json:"avg_daily_requests"`
|
||||
AvgDailyTokens float64 `json:"avg_daily_tokens"`
|
||||
AvgDurationMs float64 `json:"avg_duration_ms"`
|
||||
Today *struct {
|
||||
Date string `json:"date"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
} `json:"today"`
|
||||
HighestCostDay *struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
} `json:"highest_cost_day"`
|
||||
HighestRequestDay *struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Requests int64 `json:"requests"`
|
||||
Cost float64 `json:"cost"`
|
||||
} `json:"highest_request_day"`
|
||||
}
|
||||
|
||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||
type AccountUsageStatsResponse struct {
|
||||
History []AccountUsageHistory `json:"history"`
|
||||
Summary AccountUsageSummary `json:"summary"`
|
||||
Models []ModelStat `json:"models"`
|
||||
}
|
||||
package usagestats
|
||||
|
||||
import "time"
|
||||
|
||||
// DashboardStats 仪表盘统计
|
||||
type DashboardStats struct {
|
||||
// 用户统计
|
||||
TotalUsers int64 `json:"total_users"`
|
||||
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
|
||||
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
||||
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
|
||||
|
||||
// 账户统计
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
|
||||
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
|
||||
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
|
||||
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
|
||||
|
||||
// 累计 Token 使用统计
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
||||
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
||||
|
||||
// 今日 Token 使用统计
|
||||
TodayRequests int64 `json:"today_requests"`
|
||||
TodayInputTokens int64 `json:"today_input_tokens"`
|
||||
TodayOutputTokens int64 `json:"today_output_tokens"`
|
||||
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
||||
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
||||
TodayTokens int64 `json:"today_tokens"`
|
||||
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
||||
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
||||
|
||||
// 系统运行统计
|
||||
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
|
||||
|
||||
// 性能指标
|
||||
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
|
||||
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
|
||||
}
|
||||
|
||||
// TrendDataPoint represents a single point in trend data
|
||||
type TrendDataPoint struct {
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheTokens int64 `json:"cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ModelStat represents usage statistics for a single model
|
||||
type ModelStat struct {
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// UserUsageTrendPoint represents user usage trend data point
|
||||
type UserUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
KeyName string `json:"key_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}
|
||||
|
||||
// UserDashboardStats 用户仪表盘统计
|
||||
type UserDashboardStats struct {
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"`
|
||||
|
||||
// 累计 Token 使用统计
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
||||
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
||||
|
||||
// 今日 Token 使用统计
|
||||
TodayRequests int64 `json:"today_requests"`
|
||||
TodayInputTokens int64 `json:"today_input_tokens"`
|
||||
TodayOutputTokens int64 `json:"today_output_tokens"`
|
||||
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
||||
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
||||
TodayTokens int64 `json:"today_tokens"`
|
||||
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
||||
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
||||
|
||||
// 性能统计
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
|
||||
// 性能指标
|
||||
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
|
||||
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
|
||||
}
|
||||
|
||||
// UsageLogFilters represents filters for usage log queries
|
||||
type UsageLogFilters struct {
|
||||
UserID int64
|
||||
ApiKeyID int64
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Model string
|
||||
Stream *bool
|
||||
BillingType *int8
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
}
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
type UsageStats struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
}
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||
type BatchApiKeyUsageStats struct {
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// AccountUsageHistory represents daily usage history for an account
|
||||
type AccountUsageHistory struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
}
|
||||
|
||||
// AccountUsageSummary represents summary statistics for an account
|
||||
type AccountUsageSummary struct {
|
||||
Days int `json:"days"`
|
||||
ActualDaysUsed int `json:"actual_days_used"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalStandardCost float64 `json:"total_standard_cost"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
AvgDailyCost float64 `json:"avg_daily_cost"`
|
||||
AvgDailyRequests float64 `json:"avg_daily_requests"`
|
||||
AvgDailyTokens float64 `json:"avg_daily_tokens"`
|
||||
AvgDurationMs float64 `json:"avg_duration_ms"`
|
||||
Today *struct {
|
||||
Date string `json:"date"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
} `json:"today"`
|
||||
HighestCostDay *struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
} `json:"highest_cost_day"`
|
||||
HighestRequestDay *struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Requests int64 `json:"requests"`
|
||||
Cost float64 `json:"cost"`
|
||||
} `json:"highest_request_day"`
|
||||
}
|
||||
|
||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||
type AccountUsageStatsResponse struct {
|
||||
History []AccountUsageHistory `json:"history"`
|
||||
Summary AccountUsageSummary `json:"summary"`
|
||||
Models []ModelStat `json:"models"`
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,145 +1,145 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func uniqueTestValue(t *testing.T, prefix string) string {
|
||||
t.Helper()
|
||||
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
|
||||
return fmt.Sprintf("%s-%s", prefix, safeName)
|
||||
}
|
||||
|
||||
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "target-group")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
otherGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "other-group")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo := newUserRepositoryWithSQL(entClient, tx)
|
||||
|
||||
u1 := &service.User{
|
||||
Email: uniqueTestValue(t, "u1") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u1))
|
||||
|
||||
u2 := &service.User{
|
||||
Email: uniqueTestValue(t, "u2") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u2))
|
||||
|
||||
u3 := &service.User{
|
||||
Email: uniqueTestValue(t, "u3") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u3))
|
||||
|
||||
affected, err := repo.RemoveGroupFromAllowedGroups(ctx, targetGroup.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), affected)
|
||||
|
||||
u1After, err := repo.GetByID(ctx, u1.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, u1After.AllowedGroups, targetGroup.ID)
|
||||
require.Contains(t, u1After.AllowedGroups, otherGroup.ID)
|
||||
|
||||
u2After, err := repo.GetByID(ctx, u2.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
|
||||
}
|
||||
|
||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "delete-cascade-target")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
otherGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "delete-cascade-other")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
userRepo := newUserRepositoryWithSQL(entClient, tx)
|
||||
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
|
||||
apiKeyRepo := NewApiKeyRepository(entClient)
|
||||
|
||||
u := &service.User{
|
||||
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, userRepo.Create(ctx, u))
|
||||
|
||||
key := &service.ApiKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
|
||||
Name: "test key",
|
||||
GroupID: &targetGroup.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, apiKeyRepo.Create(ctx, key))
|
||||
|
||||
_, err = groupRepo.DeleteCascade(ctx, targetGroup.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Deleted group should be hidden by default queries (soft-delete semantics).
|
||||
_, err = groupRepo.GetByID(ctx, targetGroup.ID)
|
||||
require.ErrorIs(t, err, service.ErrGroupNotFound)
|
||||
|
||||
activeGroups, err := groupRepo.ListActive(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, g := range activeGroups {
|
||||
require.NotEqual(t, targetGroup.ID, g.ID)
|
||||
}
|
||||
|
||||
// User.allowed_groups should no longer include the deleted group.
|
||||
uAfter, err := userRepo.GetByID(ctx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
|
||||
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
|
||||
|
||||
// API keys bound to the deleted group should have group_id cleared.
|
||||
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, keyAfter.GroupID)
|
||||
}
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func uniqueTestValue(t *testing.T, prefix string) string {
|
||||
t.Helper()
|
||||
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
|
||||
return fmt.Sprintf("%s-%s", prefix, safeName)
|
||||
}
|
||||
|
||||
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "target-group")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
otherGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "other-group")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo := newUserRepositoryWithSQL(entClient, tx)
|
||||
|
||||
u1 := &service.User{
|
||||
Email: uniqueTestValue(t, "u1") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u1))
|
||||
|
||||
u2 := &service.User{
|
||||
Email: uniqueTestValue(t, "u2") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u2))
|
||||
|
||||
u3 := &service.User{
|
||||
Email: uniqueTestValue(t, "u3") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u3))
|
||||
|
||||
affected, err := repo.RemoveGroupFromAllowedGroups(ctx, targetGroup.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), affected)
|
||||
|
||||
u1After, err := repo.GetByID(ctx, u1.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, u1After.AllowedGroups, targetGroup.ID)
|
||||
require.Contains(t, u1After.AllowedGroups, otherGroup.ID)
|
||||
|
||||
u2After, err := repo.GetByID(ctx, u2.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
|
||||
}
|
||||
|
||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "delete-cascade-target")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
otherGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "delete-cascade-other")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
userRepo := newUserRepositoryWithSQL(entClient, tx)
|
||||
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
|
||||
apiKeyRepo := NewApiKeyRepository(entClient)
|
||||
|
||||
u := &service.User{
|
||||
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, userRepo.Create(ctx, u))
|
||||
|
||||
key := &service.ApiKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
|
||||
Name: "test key",
|
||||
GroupID: &targetGroup.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, apiKeyRepo.Create(ctx, key))
|
||||
|
||||
_, err = groupRepo.DeleteCascade(ctx, targetGroup.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Deleted group should be hidden by default queries (soft-delete semantics).
|
||||
_, err = groupRepo.GetByID(ctx, targetGroup.ID)
|
||||
require.ErrorIs(t, err, service.ErrGroupNotFound)
|
||||
|
||||
activeGroups, err := groupRepo.ListActive(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, g := range activeGroups {
|
||||
require.NotEqual(t, targetGroup.ID, g.ID)
|
||||
}
|
||||
|
||||
// User.allowed_groups should no longer include the deleted group.
|
||||
uAfter, err := userRepo.GetByID(ctx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
|
||||
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
|
||||
|
||||
// API keys bound to the deleted group should have group_id cleared.
|
||||
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, keyAfter.GroupID)
|
||||
}
|
||||
|
||||
@@ -1,60 +1,60 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||
func apiKeyRateLimitKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
}
|
||||
|
||||
type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
|
||||
return &apiKeyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
count, err := c.rdb.Get(ctx, key).Int()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return c.rdb.Incr(ctx, apiKey).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||
}
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||
func apiKeyRateLimitKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
}
|
||||
|
||||
type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
|
||||
return &apiKeyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
count, err := c.rdb.Get(ctx, key).Int()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return c.rdb.Incr(ctx, apiKey).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||
}
|
||||
|
||||
@@ -1,127 +1,127 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ApiKeyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_zero_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
|
||||
require.NoError(s.T(), err, "expected nil error for missing key")
|
||||
require.Equal(s.T(), 0, count, "expected zero count for missing key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "increment_increases_count_and_sets_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount")
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount 2")
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetCreateAttemptCount")
|
||||
require.Equal(s.T(), 2, count, "count mismatch")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, key).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, apiKeyRateLimitDuration)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
|
||||
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.NoError(s.T(), err, "expected nil error after delete")
|
||||
require.Equal(s.T(), 0, count, "expected zero count after delete")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
rdb := testRedis(s.T())
|
||||
cache := &apiKeyCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestDailyUsage() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
}{
|
||||
{
|
||||
name: "increment_increases_count",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
dailyKey := "daily:sk-test"
|
||||
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage")
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage 2")
|
||||
|
||||
n, err := rdb.Get(ctx, dailyKey).Int()
|
||||
require.NoError(s.T(), err, "Get dailyKey")
|
||||
require.Equal(s.T(), 2, n, "expected daily usage=2")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_expiry_sets_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
dailyKey := "daily:sk-test-expiry"
|
||||
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey))
|
||||
require.NoError(s.T(), cache.SetDailyUsageExpiry(ctx, dailyKey, 1*time.Hour), "SetDailyUsageExpiry")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, dailyKey).Result()
|
||||
require.NoError(s.T(), err, "TTL dailyKey")
|
||||
require.Greater(s.T(), ttl, time.Duration(0), "expected ttl > 0")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := &apiKeyCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApiKeyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyCacheSuite))
|
||||
}
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ApiKeyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_zero_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
|
||||
require.NoError(s.T(), err, "expected nil error for missing key")
|
||||
require.Equal(s.T(), 0, count, "expected zero count for missing key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "increment_increases_count_and_sets_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount")
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount 2")
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetCreateAttemptCount")
|
||||
require.Equal(s.T(), 2, count, "count mismatch")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, key).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, apiKeyRateLimitDuration)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
|
||||
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.NoError(s.T(), err, "expected nil error after delete")
|
||||
require.Equal(s.T(), 0, count, "expected zero count after delete")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
rdb := testRedis(s.T())
|
||||
cache := &apiKeyCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestDailyUsage() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
}{
|
||||
{
|
||||
name: "increment_increases_count",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
dailyKey := "daily:sk-test"
|
||||
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage")
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage 2")
|
||||
|
||||
n, err := rdb.Get(ctx, dailyKey).Int()
|
||||
require.NoError(s.T(), err, "Get dailyKey")
|
||||
require.Equal(s.T(), 2, n, "expected daily usage=2")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_expiry_sets_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
dailyKey := "daily:sk-test-expiry"
|
||||
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey))
|
||||
require.NoError(s.T(), cache.SetDailyUsageExpiry(ctx, dailyKey, 1*time.Hour), "SetDailyUsageExpiry")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, dailyKey).Result()
|
||||
require.NoError(s.T(), err, "TTL dailyKey")
|
||||
require.Greater(s.T(), ttl, time.Duration(0), "expected ttl > 0")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := &apiKeyCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApiKeyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyCacheSuite))
|
||||
}
|
||||
|
||||
@@ -1,46 +1,46 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApiKeyRateLimitKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "apikey:ratelimit:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "apikey:ratelimit:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "apikey:ratelimit:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "apikey:ratelimit:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := apiKeyRateLimitKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApiKeyRateLimitKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "apikey:ratelimit:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "apikey:ratelimit:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "apikey:ratelimit:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "apikey:ratelimit:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := apiKeyRateLimitKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,335 +1,335 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type apiKeyRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
|
||||
return &apiKeyRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
|
||||
// 默认过滤已软删除记录,避免删除后仍被查询到。
|
||||
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
created, err := r.client.ApiKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetNillableGroupID(key.GroupID).
|
||||
Save(ctx)
|
||||
if err == nil {
|
||||
key.ID = created.ID
|
||||
key.CreatedAt = created.CreatedAt
|
||||
key.UpdatedAt = created.UpdatedAt
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
|
||||
// 相比 GetByID,此方法性能更优,因为:
|
||||
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
|
||||
// - 不加载完整的 ApiKey 实体及其关联数据(User、Group 等)
|
||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Select(apikey.FieldUserID).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrApiKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return m.UserID, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
||||
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
||||
// 则会更新已删除的记录。
|
||||
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
|
||||
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
|
||||
now := time.Now()
|
||||
builder := r.client.ApiKey.Update().
|
||||
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetUpdatedAt(now)
|
||||
if key.GroupID != nil {
|
||||
builder.SetGroupID(*key.GroupID)
|
||||
} else {
|
||||
builder.ClearGroupID()
|
||||
}
|
||||
|
||||
affected, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
// 更新影响行数为 0,说明记录不存在或已被软删除。
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
|
||||
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
|
||||
key.UpdatedAt = now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
|
||||
affected, err := r.client.ApiKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetDeletedAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
exists, err := r.client.ApiKey.Query().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Exist(mixins.SkipSoftDelete(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keys, err := q.
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
ids, err := r.client.ApiKey.Query().
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
count, err := r.activeQuery().Where(apikey.UserIDEQ(userID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
count, err := r.activeQuery().Where(apikey.KeyEQ(key)).Count(ctx)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keys, err := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
q := r.activeQuery()
|
||||
if userID > 0 {
|
||||
q = q.Where(apikey.UserIDEQ(userID))
|
||||
}
|
||||
|
||||
if keyword != "" {
|
||||
q = q.Where(apikey.NameContainsFold(keyword))
|
||||
}
|
||||
|
||||
keys, err := q.Limit(limit).Order(dbent.Desc(apikey.FieldID)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
return outKeys, nil
|
||||
}
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
n, err := r.client.ApiKey.Update().
|
||||
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// CountByGroupID 获取分组的 API Key 数量
|
||||
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.ApiKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
if m.Edges.Group != nil {
|
||||
out.Group = groupEntityToService(m.Edges.Group)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func userEntityToService(u *dbent.User) *service.User {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: derefString(g.Description),
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func derefString(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type apiKeyRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
|
||||
return &apiKeyRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
|
||||
// 默认过滤已软删除记录,避免删除后仍被查询到。
|
||||
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
created, err := r.client.ApiKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetNillableGroupID(key.GroupID).
|
||||
Save(ctx)
|
||||
if err == nil {
|
||||
key.ID = created.ID
|
||||
key.CreatedAt = created.CreatedAt
|
||||
key.UpdatedAt = created.UpdatedAt
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
|
||||
// 相比 GetByID,此方法性能更优,因为:
|
||||
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
|
||||
// - 不加载完整的 ApiKey 实体及其关联数据(User、Group 等)
|
||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Select(apikey.FieldUserID).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrApiKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return m.UserID, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
||||
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
||||
// 则会更新已删除的记录。
|
||||
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
|
||||
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
|
||||
now := time.Now()
|
||||
builder := r.client.ApiKey.Update().
|
||||
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetUpdatedAt(now)
|
||||
if key.GroupID != nil {
|
||||
builder.SetGroupID(*key.GroupID)
|
||||
} else {
|
||||
builder.ClearGroupID()
|
||||
}
|
||||
|
||||
affected, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
// 更新影响行数为 0,说明记录不存在或已被软删除。
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
|
||||
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
|
||||
key.UpdatedAt = now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
|
||||
affected, err := r.client.ApiKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetDeletedAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
exists, err := r.client.ApiKey.Query().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Exist(mixins.SkipSoftDelete(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keys, err := q.
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
ids, err := r.client.ApiKey.Query().
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
count, err := r.activeQuery().Where(apikey.UserIDEQ(userID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
count, err := r.activeQuery().Where(apikey.KeyEQ(key)).Count(ctx)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keys, err := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
q := r.activeQuery()
|
||||
if userID > 0 {
|
||||
q = q.Where(apikey.UserIDEQ(userID))
|
||||
}
|
||||
|
||||
if keyword != "" {
|
||||
q = q.Where(apikey.NameContainsFold(keyword))
|
||||
}
|
||||
|
||||
keys, err := q.Limit(limit).Order(dbent.Desc(apikey.FieldID)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
return outKeys, nil
|
||||
}
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
n, err := r.client.ApiKey.Update().
|
||||
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// CountByGroupID 获取分组的 API Key 数量
|
||||
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.ApiKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
if m.Edges.Group != nil {
|
||||
out.Group = groupEntityToService(m.Edges.Group)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func userEntityToService(u *dbent.User) *service.User {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: derefString(g.Description),
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func derefString(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
||||
@@ -1,385 +1,385 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ApiKeyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *apiKeyRepository
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
|
||||
}
|
||||
|
||||
func TestApiKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCreate() {
|
||||
user := s.mustCreateUser("create@test.com")
|
||||
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-test",
|
||||
Name: "Test Key",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, key)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(key.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("sk-create-test", got.Key)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||
user := s.mustCreateUser("getbykey@test.com")
|
||||
group := s.mustCreateGroup("g-key")
|
||||
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-getbykey",
|
||||
Name: "My Key",
|
||||
GroupID: &group.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
s.Require().Equal(key.ID, got.ID)
|
||||
s.Require().NotNil(got.User, "expected User preload")
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().NotNil(got.Group, "expected Group preload")
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
|
||||
s.Require().Error(err, "expected error for non-existent key")
|
||||
}
|
||||
|
||||
// --- Update ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
user := s.mustCreateUser("update@test.com")
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update",
|
||||
Name: "Original",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = service.StatusDisabled
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("sk-update", got.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got.Name)
|
||||
s.Require().Equal(service.StatusDisabled, got.Status)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
user := s.mustCreateUser("cleargroup@test.com")
|
||||
group := s.mustCreateGroup("g-clear")
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-clear-group",
|
||||
Name: "Group Key",
|
||||
GroupID: &group.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
key.GroupID = nil
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got.GroupID, "expected GroupID to be cleared")
|
||||
}
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
user := s.mustCreateUser("delete@test.com")
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-delete",
|
||||
Name: "Delete Me",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
err := s.repo.Delete(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
user := s.mustCreateUser("listbyuser@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
user := s.mustCreateUser("paging@test.com")
|
||||
for i := 0; i < 5; i++ {
|
||||
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
||||
}
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(5), page.Total)
|
||||
s.Require().Equal(3, page.Pages)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
user := s.mustCreateUser("count@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
|
||||
|
||||
count, err := s.repo.CountByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "CountByUserID")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
// --- ListByGroupID / CountByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
user := s.mustCreateUser("listbygroup@test.com")
|
||||
group := s.mustCreateGroup("g-list")
|
||||
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
|
||||
|
||||
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByGroupID")
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
// User preloaded
|
||||
s.Require().NotNil(keys[0].User)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
user := s.mustCreateUser("countgroup@test.com")
|
||||
group := s.mustCreateGroup("g-count")
|
||||
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
|
||||
|
||||
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(1), count)
|
||||
}
|
||||
|
||||
// --- ExistsByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
user := s.mustCreateUser("exists@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByKey(s.ctx, "sk-not-exists")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- SearchApiKeys ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
|
||||
user := s.mustCreateUser("search@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Contains(found[0].Name, "Production")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
|
||||
user := s.mustCreateUser("searchnokw@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 2)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
|
||||
user := s.mustCreateUser("searchnouid@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 1)
|
||||
}
|
||||
|
||||
// --- ClearGroupIDByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
user := s.mustCreateUser("cleargrp@test.com")
|
||||
group := s.mustCreateGroup("g-clear-bulk")
|
||||
|
||||
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
s.Require().Equal(int64(2), affected)
|
||||
|
||||
got1, _ := s.repo.GetByID(s.ctx, k1.ID)
|
||||
got2, _ := s.repo.GetByID(s.ctx, k2.ID)
|
||||
s.Require().Nil(got1.GroupID)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
user := s.mustCreateUser("k@example.com")
|
||||
group := s.mustCreateGroup("g-k")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
|
||||
key.GroupID = &group.ID
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
s.Require().Equal(key.ID, got.ID)
|
||||
s.Require().NotNil(got.User)
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().NotNil(got.Group)
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = service.StatusDisabled
|
||||
key.GroupID = nil
|
||||
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
|
||||
|
||||
got2, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got2.Name)
|
||||
s.Require().Equal(service.StatusDisabled, got2.Status)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(keys, 1)
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-test-1")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists, "expected key to exist")
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Equal(key.ID, found[0].ID)
|
||||
|
||||
// ClearGroupIDByGroupID
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
|
||||
k2.GroupID = &group.ID
|
||||
|
||||
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(1), countBefore, "expected 1 key in group before clear")
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
got3, err := s.repo.GetByID(s.ctx, k2.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Nil(got3.GroupID, "expected GroupID cleared")
|
||||
|
||||
countAfter, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID after clear")
|
||||
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
u, err := s.client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
SetStatus(service.StatusActive).
|
||||
SetRole(service.RoleUser).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create user")
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
SetName(name).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create group")
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
|
||||
s.T().Helper()
|
||||
|
||||
k := &service.ApiKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: name,
|
||||
GroupID: groupID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
|
||||
return k
|
||||
}
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ApiKeyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *apiKeyRepository
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
|
||||
}
|
||||
|
||||
func TestApiKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCreate() {
|
||||
user := s.mustCreateUser("create@test.com")
|
||||
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-test",
|
||||
Name: "Test Key",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, key)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(key.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("sk-create-test", got.Key)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||
user := s.mustCreateUser("getbykey@test.com")
|
||||
group := s.mustCreateGroup("g-key")
|
||||
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-getbykey",
|
||||
Name: "My Key",
|
||||
GroupID: &group.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
s.Require().Equal(key.ID, got.ID)
|
||||
s.Require().NotNil(got.User, "expected User preload")
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().NotNil(got.Group, "expected Group preload")
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
|
||||
s.Require().Error(err, "expected error for non-existent key")
|
||||
}
|
||||
|
||||
// --- Update ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
user := s.mustCreateUser("update@test.com")
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update",
|
||||
Name: "Original",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = service.StatusDisabled
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("sk-update", got.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got.Name)
|
||||
s.Require().Equal(service.StatusDisabled, got.Status)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
user := s.mustCreateUser("cleargroup@test.com")
|
||||
group := s.mustCreateGroup("g-clear")
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-clear-group",
|
||||
Name: "Group Key",
|
||||
GroupID: &group.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
key.GroupID = nil
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got.GroupID, "expected GroupID to be cleared")
|
||||
}
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
user := s.mustCreateUser("delete@test.com")
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-delete",
|
||||
Name: "Delete Me",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
err := s.repo.Delete(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
user := s.mustCreateUser("listbyuser@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
user := s.mustCreateUser("paging@test.com")
|
||||
for i := 0; i < 5; i++ {
|
||||
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
||||
}
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(5), page.Total)
|
||||
s.Require().Equal(3, page.Pages)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
user := s.mustCreateUser("count@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
|
||||
|
||||
count, err := s.repo.CountByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "CountByUserID")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
// --- ListByGroupID / CountByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
user := s.mustCreateUser("listbygroup@test.com")
|
||||
group := s.mustCreateGroup("g-list")
|
||||
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
|
||||
|
||||
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByGroupID")
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
// User preloaded
|
||||
s.Require().NotNil(keys[0].User)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
user := s.mustCreateUser("countgroup@test.com")
|
||||
group := s.mustCreateGroup("g-count")
|
||||
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
|
||||
|
||||
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(1), count)
|
||||
}
|
||||
|
||||
// --- ExistsByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
user := s.mustCreateUser("exists@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByKey(s.ctx, "sk-not-exists")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- SearchApiKeys ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
|
||||
user := s.mustCreateUser("search@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Contains(found[0].Name, "Production")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
|
||||
user := s.mustCreateUser("searchnokw@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 2)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
|
||||
user := s.mustCreateUser("searchnouid@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 1)
|
||||
}
|
||||
|
||||
// --- ClearGroupIDByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
user := s.mustCreateUser("cleargrp@test.com")
|
||||
group := s.mustCreateGroup("g-clear-bulk")
|
||||
|
||||
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
s.Require().Equal(int64(2), affected)
|
||||
|
||||
got1, _ := s.repo.GetByID(s.ctx, k1.ID)
|
||||
got2, _ := s.repo.GetByID(s.ctx, k2.ID)
|
||||
s.Require().Nil(got1.GroupID)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
user := s.mustCreateUser("k@example.com")
|
||||
group := s.mustCreateGroup("g-k")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
|
||||
key.GroupID = &group.ID
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
s.Require().Equal(key.ID, got.ID)
|
||||
s.Require().NotNil(got.User)
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().NotNil(got.Group)
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = service.StatusDisabled
|
||||
key.GroupID = nil
|
||||
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
|
||||
|
||||
got2, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got2.Name)
|
||||
s.Require().Equal(service.StatusDisabled, got2.Status)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(keys, 1)
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-test-1")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists, "expected key to exist")
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Equal(key.ID, found[0].ID)
|
||||
|
||||
// ClearGroupIDByGroupID
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
|
||||
k2.GroupID = &group.ID
|
||||
|
||||
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(1), countBefore, "expected 1 key in group before clear")
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
got3, err := s.repo.GetByID(s.ctx, k2.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Nil(got3.GroupID, "expected GroupID cleared")
|
||||
|
||||
countAfter, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID after clear")
|
||||
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
u, err := s.client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
SetStatus(service.StatusActive).
|
||||
SetRole(service.RoleUser).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create user")
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
SetName(name).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create group")
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
|
||||
s.T().Helper()
|
||||
|
||||
k := &service.ApiKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: name,
|
||||
GroupID: groupID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
|
||||
return k
|
||||
}
|
||||
|
||||
@@ -1,183 +1,183 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// billingBalanceKey generates the Redis key for user balance cache.
|
||||
func billingBalanceKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
}
|
||||
|
||||
// billingSubKey generates the Redis key for subscription cache.
|
||||
func billingSubKey(userID, groupID int64) string {
|
||||
return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
}
|
||||
|
||||
const (
|
||||
subFieldStatus = "status"
|
||||
subFieldExpiresAt = "expires_at"
|
||||
subFieldDailyUsage = "daily_usage"
|
||||
subFieldWeeklyUsage = "weekly_usage"
|
||||
subFieldMonthlyUsage = "monthly_usage"
|
||||
subFieldVersion = "version"
|
||||
)
|
||||
|
||||
var (
|
||||
deductBalanceScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
return 0
|
||||
end
|
||||
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||||
redis.call('SET', KEYS[1], newVal)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
updateSubUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type billingCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewBillingCache(rdb *redis.Client) service.BillingCache {
|
||||
return &billingCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
key := billingBalanceKey(userID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseFloat(val, 64)
|
||||
}
|
||||
|
||||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
|
||||
key := billingSubKey(userID, groupID)
|
||||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
return c.parseSubscriptionCache(result)
|
||||
}
|
||||
|
||||
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*service.SubscriptionCacheData, error) {
|
||||
result := &service.SubscriptionCacheData{}
|
||||
|
||||
result.Status = data[subFieldStatus]
|
||||
if result.Status == "" {
|
||||
return nil, errors.New("invalid cache: missing status")
|
||||
}
|
||||
|
||||
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||||
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err == nil {
|
||||
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||||
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||||
}
|
||||
|
||||
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||||
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||||
}
|
||||
|
||||
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||||
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||||
}
|
||||
|
||||
if versionStr, ok := data[subFieldVersion]; ok {
|
||||
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *service.SubscriptionCacheData) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := billingSubKey(userID, groupID)
|
||||
|
||||
fields := map[string]any{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
subFieldWeeklyUsage: data.WeeklyUsage,
|
||||
subFieldMonthlyUsage: data.MonthlyUsage,
|
||||
subFieldVersion: data.Version,
|
||||
}
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
key := billingSubKey(userID, groupID)
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
key := billingSubKey(userID, groupID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// billingBalanceKey generates the Redis key for user balance cache.
|
||||
func billingBalanceKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
}
|
||||
|
||||
// billingSubKey generates the Redis key for subscription cache.
|
||||
func billingSubKey(userID, groupID int64) string {
|
||||
return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
}
|
||||
|
||||
const (
|
||||
subFieldStatus = "status"
|
||||
subFieldExpiresAt = "expires_at"
|
||||
subFieldDailyUsage = "daily_usage"
|
||||
subFieldWeeklyUsage = "weekly_usage"
|
||||
subFieldMonthlyUsage = "monthly_usage"
|
||||
subFieldVersion = "version"
|
||||
)
|
||||
|
||||
var (
|
||||
deductBalanceScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
return 0
|
||||
end
|
||||
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||||
redis.call('SET', KEYS[1], newVal)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
updateSubUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type billingCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewBillingCache(rdb *redis.Client) service.BillingCache {
|
||||
return &billingCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
key := billingBalanceKey(userID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseFloat(val, 64)
|
||||
}
|
||||
|
||||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
|
||||
key := billingSubKey(userID, groupID)
|
||||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
return c.parseSubscriptionCache(result)
|
||||
}
|
||||
|
||||
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*service.SubscriptionCacheData, error) {
|
||||
result := &service.SubscriptionCacheData{}
|
||||
|
||||
result.Status = data[subFieldStatus]
|
||||
if result.Status == "" {
|
||||
return nil, errors.New("invalid cache: missing status")
|
||||
}
|
||||
|
||||
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||||
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err == nil {
|
||||
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||||
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||||
}
|
||||
|
||||
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||||
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||||
}
|
||||
|
||||
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||||
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||||
}
|
||||
|
||||
if versionStr, ok := data[subFieldVersion]; ok {
|
||||
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *service.SubscriptionCacheData) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := billingSubKey(userID, groupID)
|
||||
|
||||
fields := map[string]any{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
subFieldWeeklyUsage: data.WeeklyUsage,
|
||||
subFieldMonthlyUsage: data.MonthlyUsage,
|
||||
subFieldVersion: data.Version,
|
||||
}
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
key := billingSubKey(userID, groupID)
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
key := billingSubKey(userID, groupID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
@@ -1,283 +1,283 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type BillingCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestUserBalance() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
_, err := cache.GetUserBalance(ctx, 1)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(1)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
|
||||
|
||||
_, err := rdb.Get(ctx, balanceKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(2)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 10.5, got, "balance mismatch")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_reduces_balance",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(3)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance after deduct")
|
||||
require.Equal(s.T(), 8.25, got, "deduct mismatch")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(100)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
|
||||
|
||||
exists, err := rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
|
||||
|
||||
exists, err = rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetUserBalance(ctx, userID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_refreshes_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(103)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
|
||||
|
||||
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL before deduct")
|
||||
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
|
||||
|
||||
balance, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
|
||||
|
||||
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL after deduct")
|
||||
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(10)
|
||||
groupID := int64(20)
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(11)
|
||||
groupID := int64(21)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(12)
|
||||
groupID := int64(22)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 7,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache")
|
||||
require.Equal(s.T(), "active", gotSub.Status)
|
||||
require.Equal(s.T(), int64(7), gotSub.Version)
|
||||
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
|
||||
|
||||
ttl, err := rdb.TTL(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "TTL subKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_increments_all_fields",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(13)
|
||||
groupID := int64(23)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache after update")
|
||||
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
|
||||
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
|
||||
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(101)
|
||||
groupID := int64(10)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
|
||||
|
||||
exists, err = rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_status_returns_parsing_error",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(102)
|
||||
groupID := int64(11)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]any{
|
||||
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"daily_usage": 1.0,
|
||||
"weekly_usage": 2.0,
|
||||
"monthly_usage": 3.0,
|
||||
"version": 1,
|
||||
}
|
||||
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.Error(s.T(), err, "expected error for missing status field")
|
||||
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
|
||||
require.Equal(s.T(), "invalid cache: missing status", err.Error())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(BillingCacheSuite))
|
||||
}
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type BillingCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestUserBalance() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
_, err := cache.GetUserBalance(ctx, 1)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(1)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
|
||||
|
||||
_, err := rdb.Get(ctx, balanceKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(2)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 10.5, got, "balance mismatch")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_reduces_balance",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(3)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance after deduct")
|
||||
require.Equal(s.T(), 8.25, got, "deduct mismatch")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(100)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
|
||||
|
||||
exists, err := rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
|
||||
|
||||
exists, err = rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetUserBalance(ctx, userID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_refreshes_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(103)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
|
||||
|
||||
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL before deduct")
|
||||
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
|
||||
|
||||
balance, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
|
||||
|
||||
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL after deduct")
|
||||
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(10)
|
||||
groupID := int64(20)
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(11)
|
||||
groupID := int64(21)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(12)
|
||||
groupID := int64(22)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 7,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache")
|
||||
require.Equal(s.T(), "active", gotSub.Status)
|
||||
require.Equal(s.T(), int64(7), gotSub.Version)
|
||||
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
|
||||
|
||||
ttl, err := rdb.TTL(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "TTL subKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_increments_all_fields",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(13)
|
||||
groupID := int64(23)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache after update")
|
||||
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
|
||||
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
|
||||
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(101)
|
||||
groupID := int64(10)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
|
||||
|
||||
exists, err = rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_status_returns_parsing_error",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(102)
|
||||
groupID := int64(11)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]any{
|
||||
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"daily_usage": 1.0,
|
||||
"weekly_usage": 2.0,
|
||||
"monthly_usage": 3.0,
|
||||
"version": 1,
|
||||
}
|
||||
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.Error(s.T(), err, "expected error for missing status field")
|
||||
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
|
||||
require.Equal(s.T(), "invalid cache: missing status", err.Error())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(BillingCacheSuite))
|
||||
}
|
||||
|
||||
@@ -1,87 +1,87 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBillingBalanceKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "billing:balance:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "billing:balance:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "billing:balance:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "billing:balance:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := billingBalanceKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingSubKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
groupID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_ids",
|
||||
userID: 123,
|
||||
groupID: 456,
|
||||
expected: "billing:sub:123:456",
|
||||
},
|
||||
{
|
||||
name: "zero_ids",
|
||||
userID: 0,
|
||||
groupID: 0,
|
||||
expected: "billing:sub:0:0",
|
||||
},
|
||||
{
|
||||
name: "negative_ids",
|
||||
userID: -1,
|
||||
groupID: -2,
|
||||
expected: "billing:sub:-1:-2",
|
||||
},
|
||||
{
|
||||
name: "max_int64_ids",
|
||||
userID: math.MaxInt64,
|
||||
groupID: math.MaxInt64,
|
||||
expected: "billing:sub:9223372036854775807:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := billingSubKey(tc.userID, tc.groupID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBillingBalanceKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "billing:balance:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "billing:balance:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "billing:balance:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "billing:balance:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := billingBalanceKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingSubKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
groupID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_ids",
|
||||
userID: 123,
|
||||
groupID: 456,
|
||||
expected: "billing:sub:123:456",
|
||||
},
|
||||
{
|
||||
name: "zero_ids",
|
||||
userID: 0,
|
||||
groupID: 0,
|
||||
expected: "billing:sub:0:0",
|
||||
},
|
||||
{
|
||||
name: "negative_ids",
|
||||
userID: -1,
|
||||
groupID: -2,
|
||||
expected: "billing:sub:-1:-2",
|
||||
},
|
||||
{
|
||||
name: "max_int64_ids",
|
||||
userID: math.MaxInt64,
|
||||
groupID: math.MaxInt64,
|
||||
expected: "billing:sub:9223372036854775807:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := billingSubKey(tc.userID, tc.groupID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,398 +1,398 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeOAuthServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
client *claudeOAuthService
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
// requestCapture holds captured request data for assertions in the main goroutine.
|
||||
type requestCapture struct {
|
||||
path string
|
||||
method string
|
||||
cookies []*http.Cookie
|
||||
body []byte
|
||||
bodyJSON map[string]any
|
||||
contentType string
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
errContain string
|
||||
wantUUID string
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`[{"uuid":"org-1"}]`))
|
||||
},
|
||||
wantUUID: "org-1",
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), "/api/organizations", captured.path, "unexpected path")
|
||||
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||
require.Equal(s.T(), "sessionKey", captured.cookies[0].Name)
|
||||
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("unauthorized"))
|
||||
},
|
||||
wantErr: true,
|
||||
errContain: "401",
|
||||
},
|
||||
{
|
||||
name: "invalid_json_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte("not-json"))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.cookies = r.Cookies()
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = s.srv.URL
|
||||
|
||||
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
if tt.errContain != "" {
|
||||
require.ErrorContains(s.T(), err, tt.errContain)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantUUID, got)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
wantCode string
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "parses_redirect_uri",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"redirect_uri": oauth.RedirectURI + "?code=AUTH&state=STATE",
|
||||
})
|
||||
},
|
||||
wantCode: "AUTH#STATE",
|
||||
validate: func(captured requestCapture) {
|
||||
require.True(s.T(), strings.HasPrefix(captured.path, "/v1/oauth/") && strings.HasSuffix(captured.path, "/authorize"), "unexpected path: %s", captured.path)
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||
require.Equal(s.T(), "org-1", captured.bodyJSON["organization_uuid"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||
require.Equal(s.T(), "st", captured.bodyJSON["state"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_code_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"redirect_uri": oauth.RedirectURI + "?state=STATE", // no code
|
||||
})
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.method = r.Method
|
||||
captured.cookies = r.Cookies()
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = s.srv.URL
|
||||
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantCode, code)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
code string
|
||||
isSetupToken bool
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_state_when_embedded",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 3600,
|
||||
RefreshToken: "rt",
|
||||
Scope: "s",
|
||||
})
|
||||
},
|
||||
code: "AUTH#STATE2",
|
||||
isSetupToken: false,
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rt",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), "unexpected content-type")
|
||||
require.Equal(s.T(), "AUTH", captured.bodyJSON["code"])
|
||||
require.Equal(s.T(), "STATE2", captured.bodyJSON["state"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
|
||||
// Regular OAuth should not include expires_in
|
||||
require.Nil(s.T(), captured.bodyJSON["expires_in"], "regular OAuth should not include expires_in")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "setup_token_includes_expires_in",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 31536000,
|
||||
})
|
||||
},
|
||||
code: "AUTH",
|
||||
isSetupToken: true,
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
// Setup token should include expires_in with 1 year value
|
||||
require.Equal(s.T(), float64(31536000), captured.bodyJSON["expires_in"],
|
||||
"setup token should include expires_in: 31536000")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("bad request"))
|
||||
},
|
||||
code: "AUTH",
|
||||
isSetupToken: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = s.srv.URL
|
||||
|
||||
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_json_format",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "new_access_token",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 28800,
|
||||
RefreshToken: "new_refresh_token",
|
||||
Scope: "user:profile user:inference",
|
||||
})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "new_access_token",
|
||||
RefreshToken: "new_refresh_token",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
// 验证使用 JSON 格式(不是 form 格式)
|
||||
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"),
|
||||
"expected JSON content-type, got: %s", captured.contentType)
|
||||
// 验证 JSON body 内容
|
||||
require.Equal(s.T(), "refresh_token", captured.bodyJSON["grant_type"])
|
||||
require.Equal(s.T(), "rt", captured.bodyJSON["refresh_token"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "returns_new_refresh_token",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 28800,
|
||||
RefreshToken: "rotated_rt", // Anthropic rotates refresh tokens
|
||||
})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rotated_rt",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid_grant"}`))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = s.srv.URL
|
||||
|
||||
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClaudeOAuthServiceSuite))
|
||||
}
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeOAuthServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
client *claudeOAuthService
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
// requestCapture holds captured request data for assertions in the main goroutine.
|
||||
type requestCapture struct {
|
||||
path string
|
||||
method string
|
||||
cookies []*http.Cookie
|
||||
body []byte
|
||||
bodyJSON map[string]any
|
||||
contentType string
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
errContain string
|
||||
wantUUID string
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`[{"uuid":"org-1"}]`))
|
||||
},
|
||||
wantUUID: "org-1",
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), "/api/organizations", captured.path, "unexpected path")
|
||||
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||
require.Equal(s.T(), "sessionKey", captured.cookies[0].Name)
|
||||
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("unauthorized"))
|
||||
},
|
||||
wantErr: true,
|
||||
errContain: "401",
|
||||
},
|
||||
{
|
||||
name: "invalid_json_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte("not-json"))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.cookies = r.Cookies()
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = s.srv.URL
|
||||
|
||||
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
if tt.errContain != "" {
|
||||
require.ErrorContains(s.T(), err, tt.errContain)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantUUID, got)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
wantCode string
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "parses_redirect_uri",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"redirect_uri": oauth.RedirectURI + "?code=AUTH&state=STATE",
|
||||
})
|
||||
},
|
||||
wantCode: "AUTH#STATE",
|
||||
validate: func(captured requestCapture) {
|
||||
require.True(s.T(), strings.HasPrefix(captured.path, "/v1/oauth/") && strings.HasSuffix(captured.path, "/authorize"), "unexpected path: %s", captured.path)
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||
require.Equal(s.T(), "org-1", captured.bodyJSON["organization_uuid"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||
require.Equal(s.T(), "st", captured.bodyJSON["state"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_code_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"redirect_uri": oauth.RedirectURI + "?state=STATE", // no code
|
||||
})
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.method = r.Method
|
||||
captured.cookies = r.Cookies()
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = s.srv.URL
|
||||
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantCode, code)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
code string
|
||||
isSetupToken bool
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_state_when_embedded",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 3600,
|
||||
RefreshToken: "rt",
|
||||
Scope: "s",
|
||||
})
|
||||
},
|
||||
code: "AUTH#STATE2",
|
||||
isSetupToken: false,
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rt",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), "unexpected content-type")
|
||||
require.Equal(s.T(), "AUTH", captured.bodyJSON["code"])
|
||||
require.Equal(s.T(), "STATE2", captured.bodyJSON["state"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
|
||||
// Regular OAuth should not include expires_in
|
||||
require.Nil(s.T(), captured.bodyJSON["expires_in"], "regular OAuth should not include expires_in")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "setup_token_includes_expires_in",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 31536000,
|
||||
})
|
||||
},
|
||||
code: "AUTH",
|
||||
isSetupToken: true,
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
// Setup token should include expires_in with 1 year value
|
||||
require.Equal(s.T(), float64(31536000), captured.bodyJSON["expires_in"],
|
||||
"setup token should include expires_in: 31536000")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("bad request"))
|
||||
},
|
||||
code: "AUTH",
|
||||
isSetupToken: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = s.srv.URL
|
||||
|
||||
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_json_format",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "new_access_token",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 28800,
|
||||
RefreshToken: "new_refresh_token",
|
||||
Scope: "user:profile user:inference",
|
||||
})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "new_access_token",
|
||||
RefreshToken: "new_refresh_token",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
// 验证使用 JSON 格式(不是 form 格式)
|
||||
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"),
|
||||
"expected JSON content-type, got: %s", captured.contentType)
|
||||
// 验证 JSON body 内容
|
||||
require.Equal(s.T(), "refresh_token", captured.bodyJSON["grant_type"])
|
||||
require.Equal(s.T(), "rt", captured.bodyJSON["refresh_token"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "returns_new_refresh_token",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 28800,
|
||||
RefreshToken: "rotated_rt", // Anthropic rotates refresh tokens
|
||||
})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rotated_rt",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid_grant"}`))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = s.srv.URL
|
||||
|
||||
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClaudeOAuthServiceSuite))
|
||||
}
|
||||
|
||||
@@ -1,59 +1,59 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||
|
||||
type claudeUsageService struct {
|
||||
usageURL string
|
||||
}
|
||||
|
||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
|
||||
}
|
||||
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var usageResp service.ClaudeUsageResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
|
||||
return nil, fmt.Errorf("decode response failed: %w", err)
|
||||
}
|
||||
|
||||
return &usageResp, nil
|
||||
}
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||
|
||||
type claudeUsageService struct {
|
||||
usageURL string
|
||||
}
|
||||
|
||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
|
||||
}
|
||||
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var usageResp service.ClaudeUsageResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
|
||||
return nil, fmt.Errorf("decode response failed: %w", err)
|
||||
}
|
||||
|
||||
return &usageResp, nil
|
||||
}
|
||||
|
||||
@@ -1,105 +1,105 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeUsageServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
fetcher *claudeUsageService
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
// usageRequestCapture holds captured request data for assertions in the main goroutine.
|
||||
type usageRequestCapture struct {
|
||||
authorization string
|
||||
anthropicBeta string
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
|
||||
var captured usageRequestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.authorization = r.Header.Get("Authorization")
|
||||
captured.anthropicBeta = r.Header.Get("anthropic-beta")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{
|
||||
"five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"},
|
||||
"seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"},
|
||||
"seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"}
|
||||
}`)
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
|
||||
require.NoError(s.T(), err, "FetchUsage")
|
||||
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
|
||||
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
|
||||
require.Equal(s.T(), 56.0, resp.SevenDaySonnet.Utilization, "SevenDaySonnet utilization mismatch")
|
||||
|
||||
// Assertions on captured request data
|
||||
require.Equal(s.T(), "Bearer at", captured.authorization, "Authorization header mismatch")
|
||||
require.Equal(s.T(), "oauth-2025-04-20", captured.anthropicBeta, "anthropic-beta header mismatch")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = io.WriteString(w, "nope")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
require.ErrorContains(s.T(), err, "nope")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "decode response failed")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Never respond - simulate slow server
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
_, err := s.fetcher.FetchUsage(ctx, "at", "")
|
||||
require.Error(s.T(), err, "expected error for cancelled context")
|
||||
}
|
||||
|
||||
func TestClaudeUsageServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClaudeUsageServiceSuite))
|
||||
}
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeUsageServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
fetcher *claudeUsageService
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
// usageRequestCapture holds captured request data for assertions in the main goroutine.
|
||||
type usageRequestCapture struct {
|
||||
authorization string
|
||||
anthropicBeta string
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
|
||||
var captured usageRequestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.authorization = r.Header.Get("Authorization")
|
||||
captured.anthropicBeta = r.Header.Get("anthropic-beta")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{
|
||||
"five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"},
|
||||
"seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"},
|
||||
"seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"}
|
||||
}`)
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
|
||||
require.NoError(s.T(), err, "FetchUsage")
|
||||
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
|
||||
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
|
||||
require.Equal(s.T(), 56.0, resp.SevenDaySonnet.Utilization, "SevenDaySonnet utilization mismatch")
|
||||
|
||||
// Assertions on captured request data
|
||||
require.Equal(s.T(), "Bearer at", captured.authorization, "Authorization header mismatch")
|
||||
require.Equal(s.T(), "oauth-2025-04-20", captured.anthropicBeta, "anthropic-beta header mismatch")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = io.WriteString(w, "nope")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
require.ErrorContains(s.T(), err, "nope")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "decode response failed")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Never respond - simulate slow server
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
_, err := s.fetcher.FetchUsage(ctx, "at", "")
|
||||
require.Error(s.T(), err, "expected error for cancelled context")
|
||||
}
|
||||
|
||||
func TestClaudeUsageServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClaudeUsageServiceSuite))
|
||||
}
|
||||
|
||||
@@ -1,395 +1,395 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 并发控制缓存常量定义
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用 SCAN 命令遍历独立的槽位键(concurrency:account:{id}:{requestID}),
|
||||
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
|
||||
//
|
||||
// 新实现改用 Redis 有序集合(Sorted Set):
|
||||
// 1. 每个账号/用户只有一个键,成员为 requestID,分数为时间戳
|
||||
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
|
||||
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
|
||||
// 4. 单次 Redis 调用完成计数,减少网络往返
|
||||
const (
|
||||
// 并发槽位键前缀(有序集合)
|
||||
// 格式: concurrency:account:{accountID}
|
||||
accountSlotKeyPrefix = "concurrency:account:"
|
||||
// 格式: concurrency:user:{userID}
|
||||
userSlotKeyPrefix = "concurrency:user:"
|
||||
// 等待队列计数器格式: concurrency:wait:{userID}
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
// 账号级等待队列计数器格式: wait:account:{accountID}
|
||||
accountWaitKeyPrefix = "wait:account:"
|
||||
|
||||
// 默认槽位过期时间(分钟),可通过配置覆盖
|
||||
defaultSlotTTLMinutes = 15
|
||||
)
|
||||
|
||||
var (
|
||||
// acquireScript 使用有序集合计数并在未达上限时添加槽位
|
||||
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
|
||||
// KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
|
||||
// ARGV[1] = maxConcurrency
|
||||
// ARGV[2] = TTL(秒)
|
||||
// ARGV[3] = requestID
|
||||
acquireScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local maxConcurrency = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local requestID = ARGV[3]
|
||||
|
||||
-- 使用 Redis 服务器时间,确保多实例时钟一致
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- 清理过期槽位
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
|
||||
-- 检查是否已存在(支持重试场景刷新时间戳)
|
||||
local exists = redis.call('ZSCORE', key, requestID)
|
||||
if exists ~= false then
|
||||
redis.call('ZADD', key, now, requestID)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
-- 检查是否达到并发上限
|
||||
local count = redis.call('ZCARD', key)
|
||||
if count < maxConcurrency then
|
||||
redis.call('ZADD', key, now, requestID)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
return 0
|
||||
`)
|
||||
|
||||
// getCountScript 统计有序集合中的槽位数量并清理过期条目
|
||||
// 使用 Redis TIME 命令获取服务器时间
|
||||
// KEYS[1] = 有序集合键
|
||||
// ARGV[1] = TTL(秒)
|
||||
getCountScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
-- 使用 Redis 服务器时间
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
return redis.call('ZCARD', key)
|
||||
`)
|
||||
|
||||
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
||||
// KEYS[1] = wait queue key
|
||||
// ARGV[1] = maxWait
|
||||
// ARGV[2] = TTL in seconds
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
|
||||
if current >= tonumber(ARGV[1]) then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||
if newVal == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
end
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// incrementAccountWaitScript - account-level wait queue count
|
||||
incrementAccountWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
|
||||
if current >= tonumber(ARGV[1]) then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||
if newVal == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
end
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// decrementWaitScript - same as before
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// getAccountsLoadBatchScript - batch load query with expired slot cleanup
|
||||
// ARGV[1] = slot TTL (seconds)
|
||||
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
|
||||
getAccountsLoadBatchScript = redis.NewScript(`
|
||||
local result = {}
|
||||
local slotTTL = tonumber(ARGV[1])
|
||||
|
||||
-- Get current server time
|
||||
local timeResult = redis.call('TIME')
|
||||
local nowSeconds = tonumber(timeResult[1])
|
||||
local cutoffTime = nowSeconds - slotTTL
|
||||
|
||||
local i = 2
|
||||
while i <= #ARGV do
|
||||
local accountID = ARGV[i]
|
||||
local maxConcurrency = tonumber(ARGV[i + 1])
|
||||
|
||||
local slotKey = 'concurrency:account:' .. accountID
|
||||
|
||||
-- Clean up expired slots before counting
|
||||
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
|
||||
local currentConcurrency = redis.call('ZCARD', slotKey)
|
||||
|
||||
local waitKey = 'wait:account:' .. accountID
|
||||
local waitingCount = redis.call('GET', waitKey)
|
||||
if waitingCount == false then
|
||||
waitingCount = 0
|
||||
else
|
||||
waitingCount = tonumber(waitingCount)
|
||||
end
|
||||
|
||||
local loadRate = 0
|
||||
if maxConcurrency > 0 then
|
||||
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
|
||||
end
|
||||
|
||||
table.insert(result, accountID)
|
||||
table.insert(result, currentConcurrency)
|
||||
table.insert(result, waitingCount)
|
||||
table.insert(result, loadRate)
|
||||
|
||||
i = i + 2
|
||||
end
|
||||
|
||||
return result
|
||||
`)
|
||||
|
||||
// cleanupExpiredSlotsScript - remove expired slots
|
||||
// KEYS[1] = concurrency:account:{accountID}
|
||||
// ARGV[1] = TTL (seconds)
|
||||
cleanupExpiredSlotsScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
`)
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
rdb *redis.Client
|
||||
slotTTLSeconds int // 槽位过期时间(秒)
|
||||
waitQueueTTLSeconds int // 等待队列过期时间(秒)
|
||||
}
|
||||
|
||||
// NewConcurrencyCache 创建并发控制缓存
|
||||
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
|
||||
// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL
|
||||
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
|
||||
if slotTTLMinutes <= 0 {
|
||||
slotTTLMinutes = defaultSlotTTLMinutes
|
||||
}
|
||||
if waitQueueTTLSeconds <= 0 {
|
||||
waitQueueTTLSeconds = slotTTLMinutes * 60
|
||||
}
|
||||
return &concurrencyCache{
|
||||
rdb: rdb,
|
||||
slotTTLSeconds: slotTTLMinutes * 60,
|
||||
waitQueueTTLSeconds: waitQueueTTLSeconds,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for key generation
|
||||
func accountSlotKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
func userSlotKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func waitQueueKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func accountWaitKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
// Account slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
key := accountSlotKey(accountID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
key := accountSlotKey(accountID)
|
||||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
key := accountSlotKey(accountID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// User slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
key := userSlotKey(userID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
key := userSlotKey(userID)
|
||||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
key := userSlotKey(userID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
key := waitQueueKey(userID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// Account wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
key := accountWaitKey(accountID)
|
||||
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
key := accountWaitKey(accountID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
key := accountWaitKey(accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Int()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return 0, err
|
||||
}
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return 0, nil
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||
if len(accounts) == 0 {
|
||||
return map[int64]*service.AccountLoadInfo{}, nil
|
||||
}
|
||||
|
||||
args := []any{c.slotTTLSeconds}
|
||||
for _, acc := range accounts {
|
||||
args = append(args, acc.ID, acc.MaxConcurrency)
|
||||
}
|
||||
|
||||
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loadMap := make(map[int64]*service.AccountLoadInfo)
|
||||
for i := 0; i < len(result); i += 4 {
|
||||
if i+3 >= len(result) {
|
||||
break
|
||||
}
|
||||
|
||||
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
|
||||
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
|
||||
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
|
||||
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
|
||||
|
||||
loadMap[accountID] = &service.AccountLoadInfo{
|
||||
AccountID: accountID,
|
||||
CurrentConcurrency: currentConcurrency,
|
||||
WaitingCount: waitingCount,
|
||||
LoadRate: loadRate,
|
||||
}
|
||||
}
|
||||
|
||||
return loadMap, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
key := accountSlotKey(accountID)
|
||||
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||||
return err
|
||||
}
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 并发控制缓存常量定义
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用 SCAN 命令遍历独立的槽位键(concurrency:account:{id}:{requestID}),
|
||||
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
|
||||
//
|
||||
// 新实现改用 Redis 有序集合(Sorted Set):
|
||||
// 1. 每个账号/用户只有一个键,成员为 requestID,分数为时间戳
|
||||
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
|
||||
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
|
||||
// 4. 单次 Redis 调用完成计数,减少网络往返
|
||||
const (
|
||||
// 并发槽位键前缀(有序集合)
|
||||
// 格式: concurrency:account:{accountID}
|
||||
accountSlotKeyPrefix = "concurrency:account:"
|
||||
// 格式: concurrency:user:{userID}
|
||||
userSlotKeyPrefix = "concurrency:user:"
|
||||
// 等待队列计数器格式: concurrency:wait:{userID}
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
// 账号级等待队列计数器格式: wait:account:{accountID}
|
||||
accountWaitKeyPrefix = "wait:account:"
|
||||
|
||||
// 默认槽位过期时间(分钟),可通过配置覆盖
|
||||
defaultSlotTTLMinutes = 15
|
||||
)
|
||||
|
||||
var (
|
||||
// acquireScript 使用有序集合计数并在未达上限时添加槽位
|
||||
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
|
||||
// KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
|
||||
// ARGV[1] = maxConcurrency
|
||||
// ARGV[2] = TTL(秒)
|
||||
// ARGV[3] = requestID
|
||||
acquireScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local maxConcurrency = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local requestID = ARGV[3]
|
||||
|
||||
-- 使用 Redis 服务器时间,确保多实例时钟一致
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- 清理过期槽位
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
|
||||
-- 检查是否已存在(支持重试场景刷新时间戳)
|
||||
local exists = redis.call('ZSCORE', key, requestID)
|
||||
if exists ~= false then
|
||||
redis.call('ZADD', key, now, requestID)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
-- 检查是否达到并发上限
|
||||
local count = redis.call('ZCARD', key)
|
||||
if count < maxConcurrency then
|
||||
redis.call('ZADD', key, now, requestID)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
return 0
|
||||
`)
|
||||
|
||||
// getCountScript 统计有序集合中的槽位数量并清理过期条目
|
||||
// 使用 Redis TIME 命令获取服务器时间
|
||||
// KEYS[1] = 有序集合键
|
||||
// ARGV[1] = TTL(秒)
|
||||
getCountScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
-- 使用 Redis 服务器时间
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
return redis.call('ZCARD', key)
|
||||
`)
|
||||
|
||||
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
||||
// KEYS[1] = wait queue key
|
||||
// ARGV[1] = maxWait
|
||||
// ARGV[2] = TTL in seconds
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
|
||||
if current >= tonumber(ARGV[1]) then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||
if newVal == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
end
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// incrementAccountWaitScript - account-level wait queue count
|
||||
incrementAccountWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
|
||||
if current >= tonumber(ARGV[1]) then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||
if newVal == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
end
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// decrementWaitScript - same as before
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// getAccountsLoadBatchScript - batch load query with expired slot cleanup
|
||||
// ARGV[1] = slot TTL (seconds)
|
||||
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
|
||||
getAccountsLoadBatchScript = redis.NewScript(`
|
||||
local result = {}
|
||||
local slotTTL = tonumber(ARGV[1])
|
||||
|
||||
-- Get current server time
|
||||
local timeResult = redis.call('TIME')
|
||||
local nowSeconds = tonumber(timeResult[1])
|
||||
local cutoffTime = nowSeconds - slotTTL
|
||||
|
||||
local i = 2
|
||||
while i <= #ARGV do
|
||||
local accountID = ARGV[i]
|
||||
local maxConcurrency = tonumber(ARGV[i + 1])
|
||||
|
||||
local slotKey = 'concurrency:account:' .. accountID
|
||||
|
||||
-- Clean up expired slots before counting
|
||||
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
|
||||
local currentConcurrency = redis.call('ZCARD', slotKey)
|
||||
|
||||
local waitKey = 'wait:account:' .. accountID
|
||||
local waitingCount = redis.call('GET', waitKey)
|
||||
if waitingCount == false then
|
||||
waitingCount = 0
|
||||
else
|
||||
waitingCount = tonumber(waitingCount)
|
||||
end
|
||||
|
||||
local loadRate = 0
|
||||
if maxConcurrency > 0 then
|
||||
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
|
||||
end
|
||||
|
||||
table.insert(result, accountID)
|
||||
table.insert(result, currentConcurrency)
|
||||
table.insert(result, waitingCount)
|
||||
table.insert(result, loadRate)
|
||||
|
||||
i = i + 2
|
||||
end
|
||||
|
||||
return result
|
||||
`)
|
||||
|
||||
// cleanupExpiredSlotsScript - remove expired slots
|
||||
// KEYS[1] = concurrency:account:{accountID}
|
||||
// ARGV[1] = TTL (seconds)
|
||||
cleanupExpiredSlotsScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
`)
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
rdb *redis.Client
|
||||
slotTTLSeconds int // 槽位过期时间(秒)
|
||||
waitQueueTTLSeconds int // 等待队列过期时间(秒)
|
||||
}
|
||||
|
||||
// NewConcurrencyCache 创建并发控制缓存
|
||||
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
|
||||
// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL
|
||||
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
|
||||
if slotTTLMinutes <= 0 {
|
||||
slotTTLMinutes = defaultSlotTTLMinutes
|
||||
}
|
||||
if waitQueueTTLSeconds <= 0 {
|
||||
waitQueueTTLSeconds = slotTTLMinutes * 60
|
||||
}
|
||||
return &concurrencyCache{
|
||||
rdb: rdb,
|
||||
slotTTLSeconds: slotTTLMinutes * 60,
|
||||
waitQueueTTLSeconds: waitQueueTTLSeconds,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for key generation
|
||||
func accountSlotKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
func userSlotKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func waitQueueKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func accountWaitKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
// Account slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
key := accountSlotKey(accountID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
key := accountSlotKey(accountID)
|
||||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
key := accountSlotKey(accountID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// User slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
key := userSlotKey(userID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
key := userSlotKey(userID)
|
||||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
key := userSlotKey(userID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
key := waitQueueKey(userID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// Account wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
key := accountWaitKey(accountID)
|
||||
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
key := accountWaitKey(accountID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
key := accountWaitKey(accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Int()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return 0, err
|
||||
}
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return 0, nil
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||
if len(accounts) == 0 {
|
||||
return map[int64]*service.AccountLoadInfo{}, nil
|
||||
}
|
||||
|
||||
args := []any{c.slotTTLSeconds}
|
||||
for _, acc := range accounts {
|
||||
args = append(args, acc.ID, acc.MaxConcurrency)
|
||||
}
|
||||
|
||||
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loadMap := make(map[int64]*service.AccountLoadInfo)
|
||||
for i := 0; i < len(result); i += 4 {
|
||||
if i+3 >= len(result) {
|
||||
break
|
||||
}
|
||||
|
||||
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
|
||||
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
|
||||
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
|
||||
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
|
||||
|
||||
loadMap[accountID] = &service.AccountLoadInfo{
|
||||
AccountID: accountID,
|
||||
CurrentConcurrency: currentConcurrency,
|
||||
WaitingCount: waitingCount,
|
||||
LoadRate: loadRate,
|
||||
}
|
||||
}
|
||||
|
||||
return loadMap, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
key := accountSlotKey(accountID)
|
||||
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,135 +1,135 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 基准测试用 TTL 配置
|
||||
const benchSlotTTLMinutes = 15
|
||||
|
||||
var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute
|
||||
|
||||
// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。
|
||||
func BenchmarkAccountConcurrency(b *testing.B) {
|
||||
rdb := newBenchmarkRedisClient(b)
|
||||
defer func() {
|
||||
_ = rdb.Close()
|
||||
}()
|
||||
|
||||
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache)
|
||||
ctx := context.Background()
|
||||
|
||||
for _, size := range []int{10, 100, 1000} {
|
||||
size := size
|
||||
b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) {
|
||||
accountID := time.Now().UnixNano()
|
||||
key := accountSlotKey(accountID)
|
||||
|
||||
b.StopTimer()
|
||||
members := make([]redis.Z, 0, size)
|
||||
now := float64(time.Now().Unix())
|
||||
for i := 0; i < size; i++ {
|
||||
members = append(members, redis.Z{
|
||||
Score: now,
|
||||
Member: fmt.Sprintf("req_%d", i),
|
||||
})
|
||||
}
|
||||
if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil {
|
||||
b.Fatalf("初始化有序集合失败: %v", err)
|
||||
}
|
||||
if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil {
|
||||
b.Fatalf("设置有序集合 TTL 失败: %v", err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil {
|
||||
b.Fatalf("获取并发数量失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
if err := rdb.Del(ctx, key).Err(); err != nil {
|
||||
b.Fatalf("清理有序集合失败: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) {
|
||||
accountID := time.Now().UnixNano()
|
||||
pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
|
||||
keys := make([]string, 0, size)
|
||||
|
||||
b.StopTimer()
|
||||
pipe := rdb.Pipeline()
|
||||
for i := 0; i < size; i++ {
|
||||
key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i)
|
||||
keys = append(keys, key)
|
||||
pipe.Set(ctx, key, "1", benchSlotTTL)
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
b.Fatalf("初始化扫描键失败: %v", err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := scanSlotCount(ctx, rdb, pattern); err != nil {
|
||||
b.Fatalf("SCAN 计数失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
if err := rdb.Del(ctx, keys...).Err(); err != nil {
|
||||
b.Fatalf("清理扫描键失败: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) {
|
||||
var cursor uint64
|
||||
count := 0
|
||||
for {
|
||||
keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
count += len(keys)
|
||||
if nextCursor == 0 {
|
||||
break
|
||||
}
|
||||
cursor = nextCursor
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func newBenchmarkRedisClient(b *testing.B) *redis.Client {
|
||||
b.Helper()
|
||||
|
||||
redisURL := os.Getenv("TEST_REDIS_URL")
|
||||
if redisURL == "" {
|
||||
b.Skip("未设置 TEST_REDIS_URL,跳过 Redis 基准测试")
|
||||
}
|
||||
|
||||
opt, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err)
|
||||
}
|
||||
|
||||
client := redis.NewClient(opt)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
b.Fatalf("Redis 连接失败: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 基准测试用 TTL 配置
|
||||
const benchSlotTTLMinutes = 15
|
||||
|
||||
var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute
|
||||
|
||||
// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。
|
||||
func BenchmarkAccountConcurrency(b *testing.B) {
|
||||
rdb := newBenchmarkRedisClient(b)
|
||||
defer func() {
|
||||
_ = rdb.Close()
|
||||
}()
|
||||
|
||||
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache)
|
||||
ctx := context.Background()
|
||||
|
||||
for _, size := range []int{10, 100, 1000} {
|
||||
size := size
|
||||
b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) {
|
||||
accountID := time.Now().UnixNano()
|
||||
key := accountSlotKey(accountID)
|
||||
|
||||
b.StopTimer()
|
||||
members := make([]redis.Z, 0, size)
|
||||
now := float64(time.Now().Unix())
|
||||
for i := 0; i < size; i++ {
|
||||
members = append(members, redis.Z{
|
||||
Score: now,
|
||||
Member: fmt.Sprintf("req_%d", i),
|
||||
})
|
||||
}
|
||||
if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil {
|
||||
b.Fatalf("初始化有序集合失败: %v", err)
|
||||
}
|
||||
if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil {
|
||||
b.Fatalf("设置有序集合 TTL 失败: %v", err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil {
|
||||
b.Fatalf("获取并发数量失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
if err := rdb.Del(ctx, key).Err(); err != nil {
|
||||
b.Fatalf("清理有序集合失败: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) {
|
||||
accountID := time.Now().UnixNano()
|
||||
pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
|
||||
keys := make([]string, 0, size)
|
||||
|
||||
b.StopTimer()
|
||||
pipe := rdb.Pipeline()
|
||||
for i := 0; i < size; i++ {
|
||||
key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i)
|
||||
keys = append(keys, key)
|
||||
pipe.Set(ctx, key, "1", benchSlotTTL)
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
b.Fatalf("初始化扫描键失败: %v", err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := scanSlotCount(ctx, rdb, pattern); err != nil {
|
||||
b.Fatalf("SCAN 计数失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
if err := rdb.Del(ctx, keys...).Err(); err != nil {
|
||||
b.Fatalf("清理扫描键失败: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) {
|
||||
var cursor uint64
|
||||
count := 0
|
||||
for {
|
||||
keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
count += len(keys)
|
||||
if nextCursor == 0 {
|
||||
break
|
||||
}
|
||||
cursor = nextCursor
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func newBenchmarkRedisClient(b *testing.B) *redis.Client {
|
||||
b.Helper()
|
||||
|
||||
redisURL := os.Getenv("TEST_REDIS_URL")
|
||||
if redisURL == "" {
|
||||
b.Skip("未设置 TEST_REDIS_URL,跳过 Redis 基准测试")
|
||||
}
|
||||
|
||||
opt, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err)
|
||||
}
|
||||
|
||||
client := redis.NewClient(opt)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
b.Fatalf("Redis 连接失败: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
@@ -1,412 +1,412 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// 测试用 TTL 配置(15 分钟,与默认值一致)
|
||||
const testSlotTTLMinutes = 15
|
||||
|
||||
// 测试用 TTL Duration,用于 TTL 断言
|
||||
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
|
||||
|
||||
type ConcurrencyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.ConcurrencyCache
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
||||
accountID := int64(10)
|
||||
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 3")
|
||||
require.False(s.T(), ok, "expected third acquire to fail")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency")
|
||||
require.Equal(s.T(), 2, cur, "concurrency mismatch")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
|
||||
|
||||
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency after release")
|
||||
require.Equal(s.T(), 1, cur, "expected 1 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
|
||||
accountID := int64(11)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
|
||||
accountID := int64(12)
|
||||
reqID := "dup-req"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Acquiring with same reqID should be idempotent
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
|
||||
accountID := int64(13)
|
||||
reqID := "release-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
|
||||
// Releasing again should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
|
||||
// Releasing non-existent should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
|
||||
accountID := int64(14)
|
||||
reqID := "max-zero-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.False(s.T(), ok, "expected acquire to fail with max=0")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
|
||||
userID := int64(42)
|
||||
reqID1, reqID2 := "req1", "req2"
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot 2")
|
||||
require.False(s.T(), ok, "expected second acquire to fail at max=1")
|
||||
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency")
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
|
||||
// Releasing a non-existent slot should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
|
||||
|
||||
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency after release")
|
||||
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
userID := int64(200)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
userID := int64(20)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 3")
|
||||
require.False(s.T(), ok, "expected wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||
userID := int64(300)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
// Test decrement on non-existent key - should not error and should not create negative value
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
|
||||
|
||||
// Verify no key was created or it's not negative
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
|
||||
|
||||
// Set count to 1, then decrement twice
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Decrement once (1 -> 0)
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
// Decrement again on 0 - should not go negative
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
|
||||
|
||||
// Verify count is 0, not negative
|
||||
val, err = s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey after double decrement")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||
accountID := int64(30)
|
||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
|
||||
require.False(s.T(), ok, "expected account wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL account waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected account wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
|
||||
accountID := int64(301)
|
||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||
// When no slots exist, GetAccountConcurrency should return 0
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
|
||||
// When no slots exist, GetUserConcurrency should return 0
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
|
||||
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
|
||||
// Setup: Create accounts with different load states
|
||||
account1 := int64(100)
|
||||
account2 := int64(101)
|
||||
account3 := int64(102)
|
||||
|
||||
// Account 1: 2/3 slots used, 1 waiting
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Account 2: 1/2 slots used, 0 waiting
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Account 3: 0/1 slots used, 0 waiting (idle)
|
||||
|
||||
// Query batch load
|
||||
accounts := []service.AccountWithConcurrency{
|
||||
{ID: account1, MaxConcurrency: 3},
|
||||
{ID: account2, MaxConcurrency: 2},
|
||||
{ID: account3, MaxConcurrency: 1},
|
||||
}
|
||||
|
||||
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
|
||||
require.NoError(s.T(), err)
|
||||
require.Len(s.T(), loadMap, 3)
|
||||
|
||||
// Verify account1: (2 + 1) / 3 = 100%
|
||||
load1 := loadMap[account1]
|
||||
require.NotNil(s.T(), load1)
|
||||
require.Equal(s.T(), account1, load1.AccountID)
|
||||
require.Equal(s.T(), 2, load1.CurrentConcurrency)
|
||||
require.Equal(s.T(), 1, load1.WaitingCount)
|
||||
require.Equal(s.T(), 100, load1.LoadRate)
|
||||
|
||||
// Verify account2: (1 + 0) / 2 = 50%
|
||||
load2 := loadMap[account2]
|
||||
require.NotNil(s.T(), load2)
|
||||
require.Equal(s.T(), account2, load2.AccountID)
|
||||
require.Equal(s.T(), 1, load2.CurrentConcurrency)
|
||||
require.Equal(s.T(), 0, load2.WaitingCount)
|
||||
require.Equal(s.T(), 50, load2.LoadRate)
|
||||
|
||||
// Verify account3: (0 + 0) / 1 = 0%
|
||||
load3 := loadMap[account3]
|
||||
require.NotNil(s.T(), load3)
|
||||
require.Equal(s.T(), account3, load3.AccountID)
|
||||
require.Equal(s.T(), 0, load3.CurrentConcurrency)
|
||||
require.Equal(s.T(), 0, load3.WaitingCount)
|
||||
require.Equal(s.T(), 0, load3.LoadRate)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
|
||||
// Test with empty account list
|
||||
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
|
||||
require.NoError(s.T(), err)
|
||||
require.Empty(s.T(), loadMap)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
|
||||
accountID := int64(200)
|
||||
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
|
||||
// Acquire 3 slots
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Verify 3 slots exist
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 3, cur)
|
||||
|
||||
// Manually set old timestamps for req1 and req2 (simulate expired slots)
|
||||
now := time.Now().Unix()
|
||||
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
|
||||
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
|
||||
require.NoError(s.T(), err)
|
||||
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Run cleanup
|
||||
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Verify only 1 slot remains (req3)
|
||||
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 1, cur)
|
||||
|
||||
// Verify req3 still exists
|
||||
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Len(s.T(), members, 1)
|
||||
require.Equal(s.T(), "req3", members[0])
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
|
||||
accountID := int64(201)
|
||||
|
||||
// Acquire 2 fresh slots
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Run cleanup (should not remove anything)
|
||||
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Verify both slots still exist
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 2, cur)
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
}
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// 测试用 TTL 配置(15 分钟,与默认值一致)
|
||||
const testSlotTTLMinutes = 15
|
||||
|
||||
// 测试用 TTL Duration,用于 TTL 断言
|
||||
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
|
||||
|
||||
type ConcurrencyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.ConcurrencyCache
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
||||
accountID := int64(10)
|
||||
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 3")
|
||||
require.False(s.T(), ok, "expected third acquire to fail")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency")
|
||||
require.Equal(s.T(), 2, cur, "concurrency mismatch")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
|
||||
|
||||
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency after release")
|
||||
require.Equal(s.T(), 1, cur, "expected 1 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
|
||||
accountID := int64(11)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
|
||||
accountID := int64(12)
|
||||
reqID := "dup-req"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Acquiring with same reqID should be idempotent
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
|
||||
accountID := int64(13)
|
||||
reqID := "release-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
|
||||
// Releasing again should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
|
||||
// Releasing non-existent should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
|
||||
accountID := int64(14)
|
||||
reqID := "max-zero-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.False(s.T(), ok, "expected acquire to fail with max=0")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
|
||||
userID := int64(42)
|
||||
reqID1, reqID2 := "req1", "req2"
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot 2")
|
||||
require.False(s.T(), ok, "expected second acquire to fail at max=1")
|
||||
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency")
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
|
||||
// Releasing a non-existent slot should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
|
||||
|
||||
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency after release")
|
||||
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
userID := int64(200)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
userID := int64(20)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 3")
|
||||
require.False(s.T(), ok, "expected wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||
userID := int64(300)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
// Test decrement on non-existent key - should not error and should not create negative value
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
|
||||
|
||||
// Verify no key was created or it's not negative
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
|
||||
|
||||
// Set count to 1, then decrement twice
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Decrement once (1 -> 0)
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
// Decrement again on 0 - should not go negative
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
|
||||
|
||||
// Verify count is 0, not negative
|
||||
val, err = s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey after double decrement")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||
accountID := int64(30)
|
||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
|
||||
require.False(s.T(), ok, "expected account wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL account waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected account wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
|
||||
accountID := int64(301)
|
||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||
// When no slots exist, GetAccountConcurrency should return 0
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
|
||||
// When no slots exist, GetUserConcurrency should return 0
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
|
||||
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
|
||||
// Setup: Create accounts with different load states
|
||||
account1 := int64(100)
|
||||
account2 := int64(101)
|
||||
account3 := int64(102)
|
||||
|
||||
// Account 1: 2/3 slots used, 1 waiting
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Account 2: 1/2 slots used, 0 waiting
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Account 3: 0/1 slots used, 0 waiting (idle)
|
||||
|
||||
// Query batch load
|
||||
accounts := []service.AccountWithConcurrency{
|
||||
{ID: account1, MaxConcurrency: 3},
|
||||
{ID: account2, MaxConcurrency: 2},
|
||||
{ID: account3, MaxConcurrency: 1},
|
||||
}
|
||||
|
||||
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
|
||||
require.NoError(s.T(), err)
|
||||
require.Len(s.T(), loadMap, 3)
|
||||
|
||||
// Verify account1: (2 + 1) / 3 = 100%
|
||||
load1 := loadMap[account1]
|
||||
require.NotNil(s.T(), load1)
|
||||
require.Equal(s.T(), account1, load1.AccountID)
|
||||
require.Equal(s.T(), 2, load1.CurrentConcurrency)
|
||||
require.Equal(s.T(), 1, load1.WaitingCount)
|
||||
require.Equal(s.T(), 100, load1.LoadRate)
|
||||
|
||||
// Verify account2: (1 + 0) / 2 = 50%
|
||||
load2 := loadMap[account2]
|
||||
require.NotNil(s.T(), load2)
|
||||
require.Equal(s.T(), account2, load2.AccountID)
|
||||
require.Equal(s.T(), 1, load2.CurrentConcurrency)
|
||||
require.Equal(s.T(), 0, load2.WaitingCount)
|
||||
require.Equal(s.T(), 50, load2.LoadRate)
|
||||
|
||||
// Verify account3: (0 + 0) / 1 = 0%
|
||||
load3 := loadMap[account3]
|
||||
require.NotNil(s.T(), load3)
|
||||
require.Equal(s.T(), account3, load3.AccountID)
|
||||
require.Equal(s.T(), 0, load3.CurrentConcurrency)
|
||||
require.Equal(s.T(), 0, load3.WaitingCount)
|
||||
require.Equal(s.T(), 0, load3.LoadRate)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
|
||||
// Test with empty account list
|
||||
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
|
||||
require.NoError(s.T(), err)
|
||||
require.Empty(s.T(), loadMap)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
|
||||
accountID := int64(200)
|
||||
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
|
||||
// Acquire 3 slots
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Verify 3 slots exist
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 3, cur)
|
||||
|
||||
// Manually set old timestamps for req1 and req2 (simulate expired slots)
|
||||
now := time.Now().Unix()
|
||||
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
|
||||
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
|
||||
require.NoError(s.T(), err)
|
||||
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Run cleanup
|
||||
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Verify only 1 slot remains (req3)
|
||||
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 1, cur)
|
||||
|
||||
// Verify req3 still exists
|
||||
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Len(s.T(), members, 1)
|
||||
require.Equal(s.T(), "req3", members[0])
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
|
||||
accountID := int64(201)
|
||||
|
||||
// Acquire 2 fresh slots
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Run cleanup (should not remove anything)
|
||||
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Verify both slots still exist
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 2, cur)
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
}
|
||||
|
||||
@@ -1,32 +1,32 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
type dbPoolSettings struct {
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
ConnMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
|
||||
return dbPoolSettings{
|
||||
MaxOpenConns: cfg.Database.MaxOpenConns,
|
||||
MaxIdleConns: cfg.Database.MaxIdleConns,
|
||||
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
|
||||
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
db.SetMaxOpenConns(settings.MaxOpenConns)
|
||||
db.SetMaxIdleConns(settings.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
|
||||
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
|
||||
}
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
type dbPoolSettings struct {
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
ConnMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
|
||||
return dbPoolSettings{
|
||||
MaxOpenConns: cfg.Database.MaxOpenConns,
|
||||
MaxIdleConns: cfg.Database.MaxIdleConns,
|
||||
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
|
||||
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
db.SetMaxOpenConns(settings.MaxOpenConns)
|
||||
db.SetMaxIdleConns(settings.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
|
||||
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
|
||||
}
|
||||
|
||||
@@ -1,50 +1,50 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func TestBuildDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 50,
|
||||
MaxIdleConns: 10,
|
||||
ConnMaxLifetimeMinutes: 30,
|
||||
ConnMaxIdleTimeMinutes: 5,
|
||||
},
|
||||
}
|
||||
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
require.Equal(t, 50, settings.MaxOpenConns)
|
||||
require.Equal(t, 10, settings.MaxIdleConns)
|
||||
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
|
||||
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
|
||||
}
|
||||
|
||||
func TestApplyDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 40,
|
||||
MaxIdleConns: 8,
|
||||
ConnMaxLifetimeMinutes: 15,
|
||||
ConnMaxIdleTimeMinutes: 3,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = db.Close()
|
||||
})
|
||||
|
||||
applyDBPoolSettings(db, cfg)
|
||||
stats := db.Stats()
|
||||
require.Equal(t, 40, stats.MaxOpenConnections)
|
||||
}
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func TestBuildDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 50,
|
||||
MaxIdleConns: 10,
|
||||
ConnMaxLifetimeMinutes: 30,
|
||||
ConnMaxIdleTimeMinutes: 5,
|
||||
},
|
||||
}
|
||||
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
require.Equal(t, 50, settings.MaxOpenConns)
|
||||
require.Equal(t, 10, settings.MaxIdleConns)
|
||||
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
|
||||
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
|
||||
}
|
||||
|
||||
func TestApplyDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 40,
|
||||
MaxIdleConns: 8,
|
||||
ConnMaxLifetimeMinutes: 15,
|
||||
ConnMaxIdleTimeMinutes: 3,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = db.Close()
|
||||
})
|
||||
|
||||
applyDBPoolSettings(db, cfg)
|
||||
stats := db.Stats()
|
||||
require.Equal(t, 40, stats.MaxOpenConnections)
|
||||
}
|
||||
|
||||
@@ -1,52 +1,52 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const verifyCodeKeyPrefix = "verify_code:"
|
||||
|
||||
// verifyCodeKey generates the Redis key for email verification code.
|
||||
func verifyCodeKey(email string) string {
|
||||
return verifyCodeKeyPrefix + email
|
||||
}
|
||||
|
||||
type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewEmailCache(rdb *redis.Client) service.EmailCache {
|
||||
return &emailCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
|
||||
key := verifyCodeKey(email)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data service.VerificationCodeData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
|
||||
key := verifyCodeKey(email)
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
|
||||
key := verifyCodeKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const verifyCodeKeyPrefix = "verify_code:"
|
||||
|
||||
// verifyCodeKey generates the Redis key for email verification code.
|
||||
func verifyCodeKey(email string) string {
|
||||
return verifyCodeKeyPrefix + email
|
||||
}
|
||||
|
||||
type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewEmailCache(rdb *redis.Client) service.EmailCache {
|
||||
return &emailCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
|
||||
key := verifyCodeKey(email)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data service.VerificationCodeData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
|
||||
key := verifyCodeKey(email)
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
|
||||
key := verifyCodeKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
@@ -1,92 +1,92 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type EmailCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.EmailCache
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewEmailCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
|
||||
email := "a@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &service.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
got, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode")
|
||||
require.Equal(s.T(), "123456", got.Code)
|
||||
require.Equal(s.T(), 1, got.Attempts)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
|
||||
email := "ttl@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &service.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
emailKey := verifyCodeKeyPrefix + email
|
||||
ttl, err := s.rdb.TTL(s.ctx, emailKey).Result()
|
||||
require.NoError(s.T(), err, "TTL emailKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, emailTTL)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
|
||||
email := "delete@example.com"
|
||||
data := &service.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
|
||||
|
||||
// Verify it exists
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode before delete")
|
||||
|
||||
// Delete
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode")
|
||||
|
||||
// Verify it's gone
|
||||
_, err = s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() {
|
||||
// Deleting a non-existent key should not error
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() {
|
||||
emailKey := verifyCodeKeyPrefix + "corrupted@example.com"
|
||||
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON")
|
||||
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com")
|
||||
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||
}
|
||||
|
||||
func TestEmailCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(EmailCacheSuite))
|
||||
}
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type EmailCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.EmailCache
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewEmailCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
|
||||
email := "a@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &service.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
got, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode")
|
||||
require.Equal(s.T(), "123456", got.Code)
|
||||
require.Equal(s.T(), 1, got.Attempts)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
|
||||
email := "ttl@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &service.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
emailKey := verifyCodeKeyPrefix + email
|
||||
ttl, err := s.rdb.TTL(s.ctx, emailKey).Result()
|
||||
require.NoError(s.T(), err, "TTL emailKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, emailTTL)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
|
||||
email := "delete@example.com"
|
||||
data := &service.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
|
||||
|
||||
// Verify it exists
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode before delete")
|
||||
|
||||
// Delete
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode")
|
||||
|
||||
// Verify it's gone
|
||||
_, err = s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() {
|
||||
// Deleting a non-existent key should not error
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() {
|
||||
emailKey := verifyCodeKeyPrefix + "corrupted@example.com"
|
||||
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON")
|
||||
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com")
|
||||
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||
}
|
||||
|
||||
func TestEmailCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(EmailCacheSuite))
|
||||
}
|
||||
|
||||
@@ -1,45 +1,45 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVerifyCodeKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_email",
|
||||
email: "user@example.com",
|
||||
expected: "verify_code:user@example.com",
|
||||
},
|
||||
{
|
||||
name: "empty_email",
|
||||
email: "",
|
||||
expected: "verify_code:",
|
||||
},
|
||||
{
|
||||
name: "email_with_plus",
|
||||
email: "user+tag@example.com",
|
||||
expected: "verify_code:user+tag@example.com",
|
||||
},
|
||||
{
|
||||
name: "email_with_special_chars",
|
||||
email: "user.name+tag@sub.domain.com",
|
||||
expected: "verify_code:user.name+tag@sub.domain.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := verifyCodeKey(tc.email)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVerifyCodeKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_email",
|
||||
email: "user@example.com",
|
||||
expected: "verify_code:user@example.com",
|
||||
},
|
||||
{
|
||||
name: "empty_email",
|
||||
email: "",
|
||||
expected: "verify_code:",
|
||||
},
|
||||
{
|
||||
name: "email_with_plus",
|
||||
email: "user+tag@example.com",
|
||||
expected: "verify_code:user+tag@example.com",
|
||||
},
|
||||
{
|
||||
name: "email_with_special_chars",
|
||||
email: "user.name+tag@sub.domain.com",
|
||||
expected: "verify_code:user.name+tag@sub.domain.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := verifyCodeKey(tc.email)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user