diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..3db3b83d --- /dev/null +++ b/.gitattributes @@ -0,0 +1,15 @@ +# 确保所有 SQL 迁移文件使用 LF 换行符 +backend/migrations/*.sql text eol=lf + +# Go 源代码文件 +*.go text eol=lf + +# Shell 脚本 +*.sh text eol=lf + +# YAML/YML 配置文件 +*.yaml text eol=lf +*.yml text eol=lf + +# Dockerfile +Dockerfile text eol=lf diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index e5624f86..2596a18c 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -19,7 +19,7 @@ jobs: cache: true - name: Verify Go version run: | - go version | grep -q 'go1.25.6' + go version | grep -q 'go1.25.7' - name: Unit tests working-directory: backend run: make test-unit @@ -38,7 +38,7 @@ jobs: cache: true - name: Verify Go version run: | - go version | grep -q 'go1.25.6' + go version | grep -q 'go1.25.7' - name: golangci-lint uses: golangci/golangci-lint-action@v9 with: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f45c1a0b..50bb73e0 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -115,7 +115,7 @@ jobs: - name: Verify Go version run: | - go version | grep -q 'go1.25.6' + go version | grep -q 'go1.25.7' # Docker setup for GoReleaser - name: Set up QEMU diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml index dfb8e37e..05dd1d1a 100644 --- a/.github/workflows/security-scan.yml +++ b/.github/workflows/security-scan.yml @@ -22,7 +22,7 @@ jobs: cache-dependency-path: backend/go.sum - name: Verify Go version run: | - go version | grep -q 'go1.25.6' + go version | grep -q 'go1.25.7' - name: Run govulncheck working-directory: backend run: | diff --git a/Dockerfile b/Dockerfile index 3d4b5094..c9fcf301 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ # ============================================================================= ARG NODE_IMAGE=node:24-alpine -ARG GOLANG_IMAGE=golang:1.25.6-alpine +ARG GOLANG_IMAGE=golang:1.25.7-alpine ARG ALPINE_IMAGE=alpine:3.20 ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn diff --git a/Linux DO Connect.md b/Linux DO Connect.md deleted file mode 100644 index 7ca1260f..00000000 --- a/Linux DO Connect.md +++ /dev/null @@ -1,368 +0,0 @@ -# Linux DO Connect - -OAuth(Open Authorization)是一个开放的网络授权标准,目前最新版本为 OAuth 2.0。我们日常使用的第三方登录(如 Google 账号登录)就采用了该标准。OAuth 允许用户授权第三方应用访问存储在其他服务提供商(如 Google)上的信息,无需在不同平台上重复填写注册信息。用户授权后,平台可以直接访问用户的账户信息进行身份验证,而用户无需向第三方应用提供密码。 - -目前系统已实现完整的 OAuth2 授权码(code)方式鉴权,但界面等配套功能还在持续完善中。让我们一起打造一个更完善的共享方案。 - -## 基本介绍 - -这是一套标准的 OAuth2 鉴权系统,可以让开发者共享论坛的用户基本信息。 - -- 可获取字段: - -| 参数 | 说明 | -| ----------------- | ------------------------------- | -| `id` | 用户唯一标识(不可变) | -| `username` | 论坛用户名 | -| `name` | 论坛用户昵称(可变) | -| `avatar_template` | 用户头像模板URL(支持多种尺寸) | -| `active` | 账号活跃状态 | -| `trust_level` | 信任等级(0-4) | -| `silenced` | 禁言状态 | -| `external_ids` | 外部ID关联信息 | -| `api_key` | API访问密钥 | - -通过这些信息,公益网站/接口可以实现: - -1. 基于 `id` 的服务频率限制 -2. 基于 `trust_level` 的服务额度分配 -3. 基于用户信息的滥用举报机制 - -## 相关端点 - -- Authorize 端点: `https://connect.linux.do/oauth2/authorize` -- Token 端点:`https://connect.linux.do/oauth2/token` -- 用户信息 端点:`https://connect.linux.do/api/user` - -## 申请使用 - -- 访问 [Connect.Linux.Do](https://connect.linux.do/) 申请接入你的应用。 - -![linuxdoconnect_1](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_1.png&w=1080&q=75) - -- 点击 **`我的应用接入`** - **`申请新接入`**,填写相关信息。其中 **`回调地址`** 是你的应用接收用户信息的地址。 - -![linuxdoconnect_2](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_2.png&w=1080&q=75) - -- 申请成功后,你将获得 **`Client Id`** 和 **`Client Secret`**,这是你应用的唯一身份凭证。 - -![linuxdoconnect_3](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_3.png&w=1080&q=75) - -## 接入 Linux Do - -JavaScript -```JavaScript -// 安装第三方请求库(或使用原生的 Fetch API),本例中使用 axios -// npm install axios - -// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 -const axios = require('axios'); -const readline = require('readline'); - -// 配置信息(建议通过环境变量配置,避免使用硬编码) -const CLIENT_ID = '你的 Client ID'; -const CLIENT_SECRET = '你的 Client Secret'; -const REDIRECT_URI = '你的回调地址'; -const AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; -const TOKEN_URL = 'https://connect.linux.do/oauth2/token'; -const USER_INFO_URL = 'https://connect.linux.do/api/user'; - -// 第一步:生成授权 URL -function getAuthUrl() { - const params = new URLSearchParams({ - client_id: CLIENT_ID, - redirect_uri: REDIRECT_URI, - response_type: 'code', - scope: 'user' - }); - - return `${AUTH_URL}?${params.toString()}`; -} - -// 第二步:获取 code 参数 -function getCode() { - return new Promise((resolve) => { - // 本例中使用终端输入来模拟流程,仅供本地测试 - // 请在实际应用中替换为真实的处理逻辑 - const rl = readline.createInterface({ input: process.stdin, output: process.stdout }); - rl.question('从回调 URL 中提取出 code,粘贴到此处并按回车:', (answer) => { - rl.close(); - resolve(answer.trim()); - }); - }); -} - -// 第三步:使用 code 参数获取访问令牌 -async function getAccessToken(code) { - try { - const form = new URLSearchParams({ - client_id: CLIENT_ID, - client_secret: CLIENT_SECRET, - code: code, - redirect_uri: REDIRECT_URI, - grant_type: 'authorization_code' - }).toString(); - - const response = await axios.post(TOKEN_URL, form, { - // 提醒:需正确配置请求头,否则无法正常获取访问令牌 - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - 'Accept': 'application/json' - } - }); - - return response.data; - } catch (error) { - console.error(`获取访问令牌失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); - throw error; - } -} - -// 第四步:使用访问令牌获取用户信息 -async function getUserInfo(accessToken) { - try { - const response = await axios.get(USER_INFO_URL, { - headers: { - Authorization: `Bearer ${accessToken}` - } - }); - - return response.data; - } catch (error) { - console.error(`获取用户信息失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); - throw error; - } -} - -// 主流程 -async function main() { - // 1. 生成授权 URL,前端引导用户访问授权页 - const authUrl = getAuthUrl(); - console.log(`请访问此 URL 授权:${authUrl} -`); - - // 2. 用户授权后,从回调 URL 获取 code 参数 - const code = await getCode(); - - try { - // 3. 使用 code 参数获取访问令牌 - const tokenData = await getAccessToken(code); - const accessToken = tokenData.access_token; - - // 4. 使用访问令牌获取用户信息 - if (accessToken) { - const userInfo = await getUserInfo(accessToken); - console.log(` -获取用户信息成功:${JSON.stringify(userInfo, null, 2)}`); - } else { - console.log(` -获取访问令牌失败:${JSON.stringify(tokenData)}`); - } - } catch (error) { - console.error('发生错误:', error); - } -} -``` -Python -```python -# 安装第三方请求库,本例中使用 requests -# pip install requests - -# 通过 OAuth2 获取 Linux Do 用户信息的参考流程 -import requests -import json - -# 配置信息(建议通过环境变量配置,避免使用硬编码) -CLIENT_ID = '你的 Client ID' -CLIENT_SECRET = '你的 Client Secret' -REDIRECT_URI = '你的回调地址' -AUTH_URL = 'https://connect.linux.do/oauth2/authorize' -TOKEN_URL = 'https://connect.linux.do/oauth2/token' -USER_INFO_URL = 'https://connect.linux.do/api/user' - -# 第一步:生成授权 URL -def get_auth_url(): - params = { - 'client_id': CLIENT_ID, - 'redirect_uri': REDIRECT_URI, - 'response_type': 'code', - 'scope': 'user' - } - auth_url = f"{AUTH_URL}?{'&'.join(f'{k}={v}' for k, v in params.items())}" - return auth_url - -# 第二步:获取 code 参数 -def get_code(): - # 本例中使用终端输入来模拟流程,仅供本地测试 - # 请在实际应用中替换为真实的处理逻辑 - return input('从回调 URL 中提取出 code,粘贴到此处并按回车:').strip() - -# 第三步:使用 code 参数获取访问令牌 -def get_access_token(code): - try: - data = { - 'client_id': CLIENT_ID, - 'client_secret': CLIENT_SECRET, - 'code': code, - 'redirect_uri': REDIRECT_URI, - 'grant_type': 'authorization_code' - } - # 提醒:需正确配置请求头,否则无法正常获取访问令牌 - headers = { - 'Content-Type': 'application/x-www-form-urlencoded', - 'Accept': 'application/json' - } - response = requests.post(TOKEN_URL, data=data, headers=headers) - response.raise_for_status() - return response.json() - except requests.exceptions.RequestException as e: - print(f"获取访问令牌失败:{e}") - return None - -# 第四步:使用访问令牌获取用户信息 -def get_user_info(access_token): - try: - headers = { - 'Authorization': f'Bearer {access_token}' - } - response = requests.get(USER_INFO_URL, headers=headers) - response.raise_for_status() - return response.json() - except requests.exceptions.RequestException as e: - print(f"获取用户信息失败:{e}") - return None - -# 主流程 -if __name__ == '__main__': - # 1. 生成授权 URL,前端引导用户访问授权页 - auth_url = get_auth_url() - print(f'请访问此 URL 授权:{auth_url} -') - - # 2. 用户授权后,从回调 URL 获取 code 参数 - code = get_code() - - # 3. 使用 code 参数获取访问令牌 - token_data = get_access_token(code) - if token_data: - access_token = token_data.get('access_token') - - # 4. 使用访问令牌获取用户信息 - if access_token: - user_info = get_user_info(access_token) - if user_info: - print(f" -获取用户信息成功:{json.dumps(user_info, indent=2)}") - else: - print(" -获取用户信息失败") - else: - print(f" -获取访问令牌失败:{json.dumps(token_data, indent=2)}") - else: - print(" -获取访问令牌失败") -``` -PHP -```php -// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 - -// 配置信息 -$CLIENT_ID = '你的 Client ID'; -$CLIENT_SECRET = '你的 Client Secret'; -$REDIRECT_URI = '你的回调地址'; -$AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; -$TOKEN_URL = 'https://connect.linux.do/oauth2/token'; -$USER_INFO_URL = 'https://connect.linux.do/api/user'; - -// 生成授权 URL -function getAuthUrl($clientId, $redirectUri) { - global $AUTH_URL; - return $AUTH_URL . '?' . http_build_query([ - 'client_id' => $clientId, - 'redirect_uri' => $redirectUri, - 'response_type' => 'code', - 'scope' => 'user' - ]); -} - -// 使用 code 参数获取用户信息(合并获取令牌和获取用户信息的步骤) -function getUserInfoWithCode($code, $clientId, $clientSecret, $redirectUri) { - global $TOKEN_URL, $USER_INFO_URL; - - // 1. 获取访问令牌 - $ch = curl_init($TOKEN_URL); - curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); - curl_setopt($ch, CURLOPT_POST, true); - curl_setopt($ch, CURLOPT_POSTFIELDS, http_build_query([ - 'client_id' => $clientId, - 'client_secret' => $clientSecret, - 'code' => $code, - 'redirect_uri' => $redirectUri, - 'grant_type' => 'authorization_code' - ])); - curl_setopt($ch, CURLOPT_HTTPHEADER, [ - 'Content-Type: application/x-www-form-urlencoded', - 'Accept: application/json' - ]); - - $tokenResponse = curl_exec($ch); - curl_close($ch); - - $tokenData = json_decode($tokenResponse, true); - if (!isset($tokenData['access_token'])) { - return ['error' => '获取访问令牌失败', 'details' => $tokenData]; - } - - // 2. 获取用户信息 - $ch = curl_init($USER_INFO_URL); - curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); - curl_setopt($ch, CURLOPT_HTTPHEADER, [ - 'Authorization: Bearer ' . $tokenData['access_token'] - ]); - - $userResponse = curl_exec($ch); - curl_close($ch); - - return json_decode($userResponse, true); -} - -// 主流程 -// 1. 生成授权 URL -$authUrl = getAuthUrl($CLIENT_ID, $REDIRECT_URI); -echo "使用 Linux Do 登录"; - -// 2. 处理回调并获取用户信息 -if (isset($_GET['code'])) { - $userInfo = getUserInfoWithCode( - $_GET['code'], - $CLIENT_ID, - $CLIENT_SECRET, - $REDIRECT_URI - ); - - if (isset($userInfo['error'])) { - echo '错误: ' . $userInfo['error']; - } else { - echo '欢迎, ' . $userInfo['name'] . '!'; - // 处理用户登录逻辑... - } -} -``` - -## 使用说明 - -### 授权流程 - -1. 用户点击应用中的’使用 Linux Do 登录’按钮 -2. 系统将用户重定向至 Linux Do 的授权页面 -3. 用户完成授权后,系统自动重定向回应用并携带授权码 -4. 应用使用授权码获取访问令牌 -5. 使用访问令牌获取用户信息 - -### 安全建议 - -- 切勿在前端代码中暴露 Client Secret -- 对所有用户输入数据进行严格验证 -- 确保使用 HTTPS 协议传输数据 -- 定期更新并妥善保管 Client Secret \ No newline at end of file diff --git a/PR_DESCRIPTION.md b/PR_DESCRIPTION.md deleted file mode 100644 index b240f45c..00000000 --- a/PR_DESCRIPTION.md +++ /dev/null @@ -1,164 +0,0 @@ -## 概述 - -全面增强运维监控系统(Ops)的错误日志管理和告警静默功能,优化前端 UI 组件代码质量和用户体验。本次更新重构了核心服务层和数据访问层,提升系统可维护性和运维效率。 - -## 主要改动 - -### 1. 错误日志查询优化 - -**功能特性:** -- 新增 GetErrorLogByID 接口,支持按 ID 精确查询错误详情 -- 优化错误日志过滤逻辑,支持多维度筛选(平台、阶段、来源、所有者等) -- 改进查询参数处理,简化代码结构 -- 增强错误分类和标准化处理 -- 支持错误解决状态追踪(resolved 字段) - -**技术实现:** -- `ops_handler.go` - 新增单条错误日志查询接口 -- `ops_repo.go` - 优化数据查询和过滤条件构建 -- `ops_models.go` - 扩展错误日志数据模型 -- 前端 API 接口同步更新 - -### 2. 告警静默功能 - -**功能特性:** -- 支持按规则、平台、分组、区域等维度静默告警 -- 可设置静默时长和原因说明 -- 静默记录可追溯,记录创建人和创建时间 -- 自动过期机制,避免永久静默 - -**技术实现:** -- `037_ops_alert_silences.sql` - 新增告警静默表 -- `ops_alerts.go` - 告警静默逻辑实现 -- `ops_alerts_handler.go` - 告警静默 API 接口 -- `OpsAlertEventsCard.vue` - 前端告警静默操作界面 - -**数据库结构:** - -| 字段 | 类型 | 说明 | -|------|------|------| -| rule_id | BIGINT | 告警规则 ID | -| platform | VARCHAR(64) | 平台标识 | -| group_id | BIGINT | 分组 ID(可选) | -| region | VARCHAR(64) | 区域(可选) | -| until | TIMESTAMPTZ | 静默截止时间 | -| reason | TEXT | 静默原因 | -| created_by | BIGINT | 创建人 ID | - -### 3. 错误分类标准化 - -**功能特性:** -- 统一错误阶段分类(request|auth|routing|upstream|network|internal) -- 规范错误归属分类(client|provider|platform) -- 标准化错误来源分类(client_request|upstream_http|gateway) -- 自动迁移历史数据到新分类体系 - -**技术实现:** -- `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` - 分类标准化迁移 -- 自动映射历史遗留分类到新标准 -- 自动解决已恢复的上游错误(客户端状态码 < 400) - -### 4. Gateway 服务集成 - -**功能特性:** -- 完善各 Gateway 服务的 Ops 集成 -- 统一错误日志记录接口 -- 增强上游错误追踪能力 - -**涉及服务:** -- `antigravity_gateway_service.go` - Antigravity 网关集成 -- `gateway_service.go` - 通用网关集成 -- `gemini_messages_compat_service.go` - Gemini 兼容层集成 -- `openai_gateway_service.go` - OpenAI 网关集成 - -### 5. 前端 UI 优化 - -**代码重构:** -- 大幅简化错误详情模态框代码(从 828 行优化到 450 行) -- 优化错误日志表格组件,提升可读性 -- 清理未使用的 i18n 翻译,减少冗余 -- 统一组件代码风格和格式 -- 优化骨架屏组件,更好匹配实际看板布局 - -**布局改进:** -- 修复模态框内容溢出和滚动问题 -- 优化表格布局,使用 flex 布局确保正确显示 -- 改进看板头部布局和交互 -- 提升响应式体验 -- 骨架屏支持全屏模式适配 - -**交互优化:** -- 优化告警事件卡片功能和展示 -- 改进错误详情展示逻辑 -- 增强请求详情模态框 -- 完善运行时设置卡片 -- 改进加载动画效果 - -### 6. 国际化完善 - -**文案补充:** -- 补充错误日志相关的英文翻译 -- 添加告警静默功能的中英文文案 -- 完善提示文本和错误信息 -- 统一术语翻译标准 - -## 文件变更 - -**后端(26 个文件):** -- `backend/internal/handler/admin/ops_alerts_handler.go` - 告警接口增强 -- `backend/internal/handler/admin/ops_handler.go` - 错误日志接口优化 -- `backend/internal/handler/ops_error_logger.go` - 错误记录器增强 -- `backend/internal/repository/ops_repo.go` - 数据访问层重构 -- `backend/internal/repository/ops_repo_alerts.go` - 告警数据访问增强 -- `backend/internal/service/ops_*.go` - 核心服务层重构(10 个文件) -- `backend/internal/service/*_gateway_service.go` - Gateway 集成(4 个文件) -- `backend/internal/server/routes/admin.go` - 路由配置更新 -- `backend/migrations/*.sql` - 数据库迁移(2 个文件) -- 测试文件更新(5 个文件) - -**前端(13 个文件):** -- `frontend/src/views/admin/ops/OpsDashboard.vue` - 看板主页优化 -- `frontend/src/views/admin/ops/components/*.vue` - 组件重构(10 个文件) -- `frontend/src/api/admin/ops.ts` - API 接口扩展 -- `frontend/src/i18n/locales/*.ts` - 国际化文本(2 个文件) - -## 代码统计 - -- 44 个文件修改 -- 3733 行新增 -- 995 行删除 -- 净增加 2738 行 - -## 核心改进 - -**可维护性提升:** -- 重构核心服务层,职责更清晰 -- 简化前端组件代码,降低复杂度 -- 统一代码风格和命名规范 -- 清理冗余代码和未使用的翻译 -- 标准化错误分类体系 - -**功能完善:** -- 告警静默功能,减少告警噪音 -- 错误日志查询优化,提升运维效率 -- Gateway 服务集成完善,统一监控能力 -- 错误解决状态追踪,便于问题管理 - -**用户体验优化:** -- 修复多个 UI 布局问题 -- 优化交互流程 -- 完善国际化支持 -- 提升响应式体验 -- 改进加载状态展示 - -## 测试验证 - -- ✅ 错误日志查询和过滤功能 -- ✅ 告警静默创建和自动过期 -- ✅ 错误分类标准化迁移 -- ✅ Gateway 服务错误日志记录 -- ✅ 前端组件布局和交互 -- ✅ 骨架屏全屏模式适配 -- ✅ 国际化文本完整性 -- ✅ API 接口功能正确性 -- ✅ 数据库迁移执行成功 diff --git a/README.md b/README.md index 14656332..36949b0a 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@
-[![Go](https://img.shields.io/badge/Go-1.25.5-00ADD8.svg)](https://golang.org/) +[![Go](https://img.shields.io/badge/Go-1.25.7-00ADD8.svg)](https://golang.org/) [![Vue](https://img.shields.io/badge/Vue-3.4+-4FC08D.svg)](https://vuejs.org/) [![PostgreSQL](https://img.shields.io/badge/PostgreSQL-15+-336791.svg)](https://www.postgresql.org/) [![Redis](https://img.shields.io/badge/Redis-7+-DC382D.svg)](https://redis.io/) @@ -44,7 +44,7 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot | Component | Technology | |-----------|------------| -| Backend | Go 1.25.5, Gin, Ent | +| Backend | Go 1.25.7, Gin, Ent | | Frontend | Vue 3.4+, Vite 5+, TailwindCSS | | Database | PostgreSQL 15+ | | Cache/Queue | Redis 7+ | diff --git a/README_CN.md b/README_CN.md index e609f25d..1e0d1d62 100644 --- a/README_CN.md +++ b/README_CN.md @@ -2,7 +2,7 @@
-[![Go](https://img.shields.io/badge/Go-1.25.5-00ADD8.svg)](https://golang.org/) +[![Go](https://img.shields.io/badge/Go-1.25.7-00ADD8.svg)](https://golang.org/) [![Vue](https://img.shields.io/badge/Vue-3.4+-4FC08D.svg)](https://vuejs.org/) [![PostgreSQL](https://img.shields.io/badge/PostgreSQL-15+-336791.svg)](https://www.postgresql.org/) [![Redis](https://img.shields.io/badge/Redis-7+-DC382D.svg)](https://redis.io/) @@ -44,7 +44,7 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( | 组件 | 技术 | |------|------| -| 后端 | Go 1.25.5, Gin, Ent | +| 后端 | Go 1.25.7, Gin, Ent | | 前端 | Vue 3.4+, Vite 5+, TailwindCSS | | 数据库 | PostgreSQL 15+ | | 缓存/队列 | Redis 7+ | diff --git a/backend/Dockerfile b/backend/Dockerfile index 770fdedf..aeb20fdb 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.25.5-alpine +FROM golang:1.25.7-alpine WORKDIR /app @@ -15,7 +15,7 @@ RUN go mod download COPY . . # 构建应用 -RUN go build -o main cmd/server/main.go +RUN go build -o main ./cmd/server/ # 暴露端口 EXPOSE 8080 diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index 139a3a39..ce4718bf 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(userRepo, nil, cfg, nil, nil, nil, nil, nil) + authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index a2d633db..f0768f09 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.61 +0.1.70 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 694d05a7..ab1831d8 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -44,9 +44,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { } userRepository := repository.NewUserRepository(client, db) redeemCodeRepository := repository.NewRedeemCodeRepository(client) + redisClient := repository.ProvideRedis(configConfig) + refreshTokenCache := repository.NewRefreshTokenCache(redisClient) settingRepository := repository.NewSettingRepository(client) settingService := service.NewSettingService(settingRepository, configConfig) - redisClient := repository.ProvideRedis(configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) turnstileVerifier := repository.NewTurnstileVerifier() @@ -58,11 +59,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) apiKeyRepository := repository.NewAPIKeyRepository(client) groupRepository := repository.NewGroupRepository(client, db) + userGroupRateRepository := repository.NewUserGroupRateRepository(db) apiKeyCache := repository.NewAPIKeyCache(redisClient) - apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) + apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) - authService := service.NewAuthService(userRepository, redeemCodeRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) + authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) redeemCache := repository.NewRedeemCache(redisClient) @@ -99,7 +101,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() @@ -125,7 +127,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) gatewayCache := repository.NewGatewayCache(redisClient) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) - antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) + schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) + schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) @@ -141,8 +145,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminRedeemHandler := admin.NewRedeemHandler(adminService) promoHandler := admin.NewPromoHandler(promoService) opsRepository := repository.NewOpsRepository(db) - schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) - schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig) pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) if err != nil { @@ -152,11 +154,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) - opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService) + opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) @@ -172,9 +174,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, configConfig) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig) + errorPassthroughRepository := repository.NewErrorPassthroughRepository(client) + errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient) + errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) + errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler) diff --git a/backend/ent/apikey.go b/backend/ent/apikey.go index 95586017..91d71964 100644 --- a/backend/ent/apikey.go +++ b/backend/ent/apikey.go @@ -40,6 +40,12 @@ type APIKey struct { IPWhitelist []string `json:"ip_whitelist,omitempty"` // Blocked IPs/CIDRs IPBlacklist []string `json:"ip_blacklist,omitempty"` + // Quota limit in USD for this API key (0 = unlimited) + Quota float64 `json:"quota,omitempty"` + // Used quota amount in USD + QuotaUsed float64 `json:"quota_used,omitempty"` + // Expiration time for this API key (null = never expires) + ExpiresAt *time.Time `json:"expires_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the APIKeyQuery when eager-loading is set. Edges APIKeyEdges `json:"edges"` @@ -97,11 +103,13 @@ func (*APIKey) scanValues(columns []string) ([]any, error) { switch columns[i] { case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist: values[i] = new([]byte) + case apikey.FieldQuota, apikey.FieldQuotaUsed: + values[i] = new(sql.NullFloat64) case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID: values[i] = new(sql.NullInt64) case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: values[i] = new(sql.NullString) - case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt: + case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldExpiresAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -190,6 +198,25 @@ func (_m *APIKey) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field ip_blacklist: %w", err) } } + case apikey.FieldQuota: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field quota", values[i]) + } else if value.Valid { + _m.Quota = value.Float64 + } + case apikey.FieldQuotaUsed: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field quota_used", values[i]) + } else if value.Valid { + _m.QuotaUsed = value.Float64 + } + case apikey.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = new(time.Time) + *_m.ExpiresAt = value.Time + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -274,6 +301,17 @@ func (_m *APIKey) String() string { builder.WriteString(", ") builder.WriteString("ip_blacklist=") builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist)) + builder.WriteString(", ") + builder.WriteString("quota=") + builder.WriteString(fmt.Sprintf("%v", _m.Quota)) + builder.WriteString(", ") + builder.WriteString("quota_used=") + builder.WriteString(fmt.Sprintf("%v", _m.QuotaUsed)) + builder.WriteString(", ") + if v := _m.ExpiresAt; v != nil { + builder.WriteString("expires_at=") + builder.WriteString(v.Format(time.ANSIC)) + } builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/apikey/apikey.go b/backend/ent/apikey/apikey.go index 564cddb1..ac2a6008 100644 --- a/backend/ent/apikey/apikey.go +++ b/backend/ent/apikey/apikey.go @@ -35,6 +35,12 @@ const ( FieldIPWhitelist = "ip_whitelist" // FieldIPBlacklist holds the string denoting the ip_blacklist field in the database. FieldIPBlacklist = "ip_blacklist" + // FieldQuota holds the string denoting the quota field in the database. + FieldQuota = "quota" + // FieldQuotaUsed holds the string denoting the quota_used field in the database. + FieldQuotaUsed = "quota_used" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" // EdgeUser holds the string denoting the user edge name in mutations. EdgeUser = "user" // EdgeGroup holds the string denoting the group edge name in mutations. @@ -79,6 +85,9 @@ var Columns = []string{ FieldStatus, FieldIPWhitelist, FieldIPBlacklist, + FieldQuota, + FieldQuotaUsed, + FieldExpiresAt, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -113,6 +122,10 @@ var ( DefaultStatus string // StatusValidator is a validator for the "status" field. It is called by the builders before save. StatusValidator func(string) error + // DefaultQuota holds the default value on creation for the "quota" field. + DefaultQuota float64 + // DefaultQuotaUsed holds the default value on creation for the "quota_used" field. + DefaultQuotaUsed float64 ) // OrderOption defines the ordering options for the APIKey queries. @@ -163,6 +176,21 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() } +// ByQuota orders the results by the quota field. +func ByQuota(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldQuota, opts...).ToFunc() +} + +// ByQuotaUsed orders the results by the quota_used field. +func ByQuotaUsed(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldQuotaUsed, opts...).ToFunc() +} + +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} + // ByUserField orders the results by user field. func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/apikey/where.go b/backend/ent/apikey/where.go index 5152867f..f54f44b7 100644 --- a/backend/ent/apikey/where.go +++ b/backend/ent/apikey/where.go @@ -95,6 +95,21 @@ func Status(v string) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldStatus, v)) } +// Quota applies equality check predicate on the "quota" field. It's identical to QuotaEQ. +func Quota(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldQuota, v)) +} + +// QuotaUsed applies equality check predicate on the "quota_used" field. It's identical to QuotaUsedEQ. +func QuotaUsed(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldQuotaUsed, v)) +} + +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldCreatedAt, v)) @@ -490,6 +505,136 @@ func IPBlacklistNotNil() predicate.APIKey { return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist)) } +// QuotaEQ applies the EQ predicate on the "quota" field. +func QuotaEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldQuota, v)) +} + +// QuotaNEQ applies the NEQ predicate on the "quota" field. +func QuotaNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldQuota, v)) +} + +// QuotaIn applies the In predicate on the "quota" field. +func QuotaIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldQuota, vs...)) +} + +// QuotaNotIn applies the NotIn predicate on the "quota" field. +func QuotaNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldQuota, vs...)) +} + +// QuotaGT applies the GT predicate on the "quota" field. +func QuotaGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldQuota, v)) +} + +// QuotaGTE applies the GTE predicate on the "quota" field. +func QuotaGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldQuota, v)) +} + +// QuotaLT applies the LT predicate on the "quota" field. +func QuotaLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldQuota, v)) +} + +// QuotaLTE applies the LTE predicate on the "quota" field. +func QuotaLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldQuota, v)) +} + +// QuotaUsedEQ applies the EQ predicate on the "quota_used" field. +func QuotaUsedEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldQuotaUsed, v)) +} + +// QuotaUsedNEQ applies the NEQ predicate on the "quota_used" field. +func QuotaUsedNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldQuotaUsed, v)) +} + +// QuotaUsedIn applies the In predicate on the "quota_used" field. +func QuotaUsedIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldQuotaUsed, vs...)) +} + +// QuotaUsedNotIn applies the NotIn predicate on the "quota_used" field. +func QuotaUsedNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldQuotaUsed, vs...)) +} + +// QuotaUsedGT applies the GT predicate on the "quota_used" field. +func QuotaUsedGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldQuotaUsed, v)) +} + +// QuotaUsedGTE applies the GTE predicate on the "quota_used" field. +func QuotaUsedGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldQuotaUsed, v)) +} + +// QuotaUsedLT applies the LT predicate on the "quota_used" field. +func QuotaUsedLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldQuotaUsed, v)) +} + +// QuotaUsedLTE applies the LTE predicate on the "quota_used" field. +func QuotaUsedLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldQuotaUsed, v)) +} + +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldExpiresAt, v)) +} + +// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field. +func ExpiresAtIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldExpiresAt)) +} + +// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field. +func ExpiresAtNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldExpiresAt)) +} + // HasUser applies the HasEdge predicate on the "user" edge. func HasUser() predicate.APIKey { return predicate.APIKey(func(s *sql.Selector) { diff --git a/backend/ent/apikey_create.go b/backend/ent/apikey_create.go index d5363be5..71540975 100644 --- a/backend/ent/apikey_create.go +++ b/backend/ent/apikey_create.go @@ -125,6 +125,48 @@ func (_c *APIKeyCreate) SetIPBlacklist(v []string) *APIKeyCreate { return _c } +// SetQuota sets the "quota" field. +func (_c *APIKeyCreate) SetQuota(v float64) *APIKeyCreate { + _c.mutation.SetQuota(v) + return _c +} + +// SetNillableQuota sets the "quota" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableQuota(v *float64) *APIKeyCreate { + if v != nil { + _c.SetQuota(*v) + } + return _c +} + +// SetQuotaUsed sets the "quota_used" field. +func (_c *APIKeyCreate) SetQuotaUsed(v float64) *APIKeyCreate { + _c.mutation.SetQuotaUsed(v) + return _c +} + +// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableQuotaUsed(v *float64) *APIKeyCreate { + if v != nil { + _c.SetQuotaUsed(*v) + } + return _c +} + +// SetExpiresAt sets the "expires_at" field. +func (_c *APIKeyCreate) SetExpiresAt(v time.Time) *APIKeyCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableExpiresAt(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetExpiresAt(*v) + } + return _c +} + // SetUser sets the "user" edge to the User entity. func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate { return _c.SetUserID(v.ID) @@ -205,6 +247,14 @@ func (_c *APIKeyCreate) defaults() error { v := apikey.DefaultStatus _c.mutation.SetStatus(v) } + if _, ok := _c.mutation.Quota(); !ok { + v := apikey.DefaultQuota + _c.mutation.SetQuota(v) + } + if _, ok := _c.mutation.QuotaUsed(); !ok { + v := apikey.DefaultQuotaUsed + _c.mutation.SetQuotaUsed(v) + } return nil } @@ -243,6 +293,12 @@ func (_c *APIKeyCreate) check() error { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "APIKey.status": %w`, err)} } } + if _, ok := _c.mutation.Quota(); !ok { + return &ValidationError{Name: "quota", err: errors.New(`ent: missing required field "APIKey.quota"`)} + } + if _, ok := _c.mutation.QuotaUsed(); !ok { + return &ValidationError{Name: "quota_used", err: errors.New(`ent: missing required field "APIKey.quota_used"`)} + } if len(_c.mutation.UserIDs()) == 0 { return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "APIKey.user"`)} } @@ -305,6 +361,18 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) { _spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value) _node.IPBlacklist = value } + if value, ok := _c.mutation.Quota(); ok { + _spec.SetField(apikey.FieldQuota, field.TypeFloat64, value) + _node.Quota = value + } + if value, ok := _c.mutation.QuotaUsed(); ok { + _spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + _node.QuotaUsed = value + } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = &value + } if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -539,6 +607,60 @@ func (u *APIKeyUpsert) ClearIPBlacklist() *APIKeyUpsert { return u } +// SetQuota sets the "quota" field. +func (u *APIKeyUpsert) SetQuota(v float64) *APIKeyUpsert { + u.Set(apikey.FieldQuota, v) + return u +} + +// UpdateQuota sets the "quota" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateQuota() *APIKeyUpsert { + u.SetExcluded(apikey.FieldQuota) + return u +} + +// AddQuota adds v to the "quota" field. +func (u *APIKeyUpsert) AddQuota(v float64) *APIKeyUpsert { + u.Add(apikey.FieldQuota, v) + return u +} + +// SetQuotaUsed sets the "quota_used" field. +func (u *APIKeyUpsert) SetQuotaUsed(v float64) *APIKeyUpsert { + u.Set(apikey.FieldQuotaUsed, v) + return u +} + +// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateQuotaUsed() *APIKeyUpsert { + u.SetExcluded(apikey.FieldQuotaUsed) + return u +} + +// AddQuotaUsed adds v to the "quota_used" field. +func (u *APIKeyUpsert) AddQuotaUsed(v float64) *APIKeyUpsert { + u.Add(apikey.FieldQuotaUsed, v) + return u +} + +// SetExpiresAt sets the "expires_at" field. +func (u *APIKeyUpsert) SetExpiresAt(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateExpiresAt() *APIKeyUpsert { + u.SetExcluded(apikey.FieldExpiresAt) + return u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *APIKeyUpsert) ClearExpiresAt() *APIKeyUpsert { + u.SetNull(apikey.FieldExpiresAt) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -738,6 +860,69 @@ func (u *APIKeyUpsertOne) ClearIPBlacklist() *APIKeyUpsertOne { }) } +// SetQuota sets the "quota" field. +func (u *APIKeyUpsertOne) SetQuota(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetQuota(v) + }) +} + +// AddQuota adds v to the "quota" field. +func (u *APIKeyUpsertOne) AddQuota(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddQuota(v) + }) +} + +// UpdateQuota sets the "quota" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateQuota() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateQuota() + }) +} + +// SetQuotaUsed sets the "quota_used" field. +func (u *APIKeyUpsertOne) SetQuotaUsed(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetQuotaUsed(v) + }) +} + +// AddQuotaUsed adds v to the "quota_used" field. +func (u *APIKeyUpsertOne) AddQuotaUsed(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddQuotaUsed(v) + }) +} + +// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateQuotaUsed() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateQuotaUsed() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *APIKeyUpsertOne) SetExpiresAt(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateExpiresAt() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateExpiresAt() + }) +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *APIKeyUpsertOne) ClearExpiresAt() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearExpiresAt() + }) +} + // Exec executes the query. func (u *APIKeyUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1103,6 +1288,69 @@ func (u *APIKeyUpsertBulk) ClearIPBlacklist() *APIKeyUpsertBulk { }) } +// SetQuota sets the "quota" field. +func (u *APIKeyUpsertBulk) SetQuota(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetQuota(v) + }) +} + +// AddQuota adds v to the "quota" field. +func (u *APIKeyUpsertBulk) AddQuota(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddQuota(v) + }) +} + +// UpdateQuota sets the "quota" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateQuota() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateQuota() + }) +} + +// SetQuotaUsed sets the "quota_used" field. +func (u *APIKeyUpsertBulk) SetQuotaUsed(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetQuotaUsed(v) + }) +} + +// AddQuotaUsed adds v to the "quota_used" field. +func (u *APIKeyUpsertBulk) AddQuotaUsed(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddQuotaUsed(v) + }) +} + +// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateQuotaUsed() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateQuotaUsed() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *APIKeyUpsertBulk) SetExpiresAt(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateExpiresAt() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateExpiresAt() + }) +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *APIKeyUpsertBulk) ClearExpiresAt() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearExpiresAt() + }) +} + // Exec executes the query. func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/apikey_update.go b/backend/ent/apikey_update.go index 9ae332a8..b4ff230b 100644 --- a/backend/ent/apikey_update.go +++ b/backend/ent/apikey_update.go @@ -170,6 +170,68 @@ func (_u *APIKeyUpdate) ClearIPBlacklist() *APIKeyUpdate { return _u } +// SetQuota sets the "quota" field. +func (_u *APIKeyUpdate) SetQuota(v float64) *APIKeyUpdate { + _u.mutation.ResetQuota() + _u.mutation.SetQuota(v) + return _u +} + +// SetNillableQuota sets the "quota" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableQuota(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetQuota(*v) + } + return _u +} + +// AddQuota adds value to the "quota" field. +func (_u *APIKeyUpdate) AddQuota(v float64) *APIKeyUpdate { + _u.mutation.AddQuota(v) + return _u +} + +// SetQuotaUsed sets the "quota_used" field. +func (_u *APIKeyUpdate) SetQuotaUsed(v float64) *APIKeyUpdate { + _u.mutation.ResetQuotaUsed() + _u.mutation.SetQuotaUsed(v) + return _u +} + +// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableQuotaUsed(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetQuotaUsed(*v) + } + return _u +} + +// AddQuotaUsed adds value to the "quota_used" field. +func (_u *APIKeyUpdate) AddQuotaUsed(v float64) *APIKeyUpdate { + _u.mutation.AddQuotaUsed(v) + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *APIKeyUpdate) SetExpiresAt(v time.Time) *APIKeyUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableExpiresAt(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (_u *APIKeyUpdate) ClearExpiresAt() *APIKeyUpdate { + _u.mutation.ClearExpiresAt() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate { return _u.SetUserID(v.ID) @@ -350,6 +412,24 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.IPBlacklistCleared() { _spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON) } + if value, ok := _u.mutation.Quota(); ok { + _spec.SetField(apikey.FieldQuota, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedQuota(); ok { + _spec.AddField(apikey.FieldQuota, field.TypeFloat64, value) + } + if value, ok := _u.mutation.QuotaUsed(); ok { + _spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedQuotaUsed(); ok { + _spec.AddField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value) + } + if _u.mutation.ExpiresAtCleared() { + _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -611,6 +691,68 @@ func (_u *APIKeyUpdateOne) ClearIPBlacklist() *APIKeyUpdateOne { return _u } +// SetQuota sets the "quota" field. +func (_u *APIKeyUpdateOne) SetQuota(v float64) *APIKeyUpdateOne { + _u.mutation.ResetQuota() + _u.mutation.SetQuota(v) + return _u +} + +// SetNillableQuota sets the "quota" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableQuota(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetQuota(*v) + } + return _u +} + +// AddQuota adds value to the "quota" field. +func (_u *APIKeyUpdateOne) AddQuota(v float64) *APIKeyUpdateOne { + _u.mutation.AddQuota(v) + return _u +} + +// SetQuotaUsed sets the "quota_used" field. +func (_u *APIKeyUpdateOne) SetQuotaUsed(v float64) *APIKeyUpdateOne { + _u.mutation.ResetQuotaUsed() + _u.mutation.SetQuotaUsed(v) + return _u +} + +// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableQuotaUsed(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetQuotaUsed(*v) + } + return _u +} + +// AddQuotaUsed adds value to the "quota_used" field. +func (_u *APIKeyUpdateOne) AddQuotaUsed(v float64) *APIKeyUpdateOne { + _u.mutation.AddQuotaUsed(v) + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *APIKeyUpdateOne) SetExpiresAt(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableExpiresAt(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (_u *APIKeyUpdateOne) ClearExpiresAt() *APIKeyUpdateOne { + _u.mutation.ClearExpiresAt() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne { return _u.SetUserID(v.ID) @@ -821,6 +963,24 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro if _u.mutation.IPBlacklistCleared() { _spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON) } + if value, ok := _u.mutation.Quota(); ok { + _spec.SetField(apikey.FieldQuota, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedQuota(); ok { + _spec.AddField(apikey.FieldQuota, field.TypeFloat64, value) + } + if value, ok := _u.mutation.QuotaUsed(); ok { + _spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedQuotaUsed(); ok { + _spec.AddField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value) + } + if _u.mutation.ExpiresAtCleared() { + _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/ent/client.go b/backend/ent/client.go index a17721da..a791c081 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -20,6 +20,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -52,6 +53,8 @@ type Client struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. + ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // PromoCode is the client for interacting with the PromoCode builders. @@ -94,6 +97,7 @@ func (c *Client) init() { c.AccountGroup = NewAccountGroupClient(c.config) c.Announcement = NewAnnouncementClient(c.config) c.AnnouncementRead = NewAnnouncementReadClient(c.config) + c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config) c.Group = NewGroupClient(c.config) c.PromoCode = NewPromoCodeClient(c.config) c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) @@ -204,6 +208,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { AccountGroup: NewAccountGroupClient(cfg), Announcement: NewAnnouncementClient(cfg), AnnouncementRead: NewAnnouncementReadClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), @@ -241,6 +246,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) AccountGroup: NewAccountGroupClient(cfg), Announcement: NewAnnouncementClient(cfg), AnnouncementRead: NewAnnouncementReadClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), @@ -284,9 +290,10 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting, - c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, - c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, + c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.UserSubscription, } { n.Use(hooks...) } @@ -297,9 +304,10 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting, - c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, - c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, + c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.UserSubscription, } { n.Intercept(interceptors...) } @@ -318,6 +326,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.Announcement.mutate(ctx, m) case *AnnouncementReadMutation: return c.AnnouncementRead.mutate(ctx, m) + case *ErrorPassthroughRuleMutation: + return c.ErrorPassthroughRule.mutate(ctx, m) case *GroupMutation: return c.Group.mutate(ctx, m) case *PromoCodeMutation: @@ -1161,6 +1171,139 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead } } +// ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema. +type ErrorPassthroughRuleClient struct { + config +} + +// NewErrorPassthroughRuleClient returns a client for the ErrorPassthroughRule from the given config. +func NewErrorPassthroughRuleClient(c config) *ErrorPassthroughRuleClient { + return &ErrorPassthroughRuleClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `errorpassthroughrule.Hooks(f(g(h())))`. +func (c *ErrorPassthroughRuleClient) Use(hooks ...Hook) { + c.hooks.ErrorPassthroughRule = append(c.hooks.ErrorPassthroughRule, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `errorpassthroughrule.Intercept(f(g(h())))`. +func (c *ErrorPassthroughRuleClient) Intercept(interceptors ...Interceptor) { + c.inters.ErrorPassthroughRule = append(c.inters.ErrorPassthroughRule, interceptors...) +} + +// Create returns a builder for creating a ErrorPassthroughRule entity. +func (c *ErrorPassthroughRuleClient) Create() *ErrorPassthroughRuleCreate { + mutation := newErrorPassthroughRuleMutation(c.config, OpCreate) + return &ErrorPassthroughRuleCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of ErrorPassthroughRule entities. +func (c *ErrorPassthroughRuleClient) CreateBulk(builders ...*ErrorPassthroughRuleCreate) *ErrorPassthroughRuleCreateBulk { + return &ErrorPassthroughRuleCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *ErrorPassthroughRuleClient) MapCreateBulk(slice any, setFunc func(*ErrorPassthroughRuleCreate, int)) *ErrorPassthroughRuleCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &ErrorPassthroughRuleCreateBulk{err: fmt.Errorf("calling to ErrorPassthroughRuleClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*ErrorPassthroughRuleCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &ErrorPassthroughRuleCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for ErrorPassthroughRule. +func (c *ErrorPassthroughRuleClient) Update() *ErrorPassthroughRuleUpdate { + mutation := newErrorPassthroughRuleMutation(c.config, OpUpdate) + return &ErrorPassthroughRuleUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *ErrorPassthroughRuleClient) UpdateOne(_m *ErrorPassthroughRule) *ErrorPassthroughRuleUpdateOne { + mutation := newErrorPassthroughRuleMutation(c.config, OpUpdateOne, withErrorPassthroughRule(_m)) + return &ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *ErrorPassthroughRuleClient) UpdateOneID(id int64) *ErrorPassthroughRuleUpdateOne { + mutation := newErrorPassthroughRuleMutation(c.config, OpUpdateOne, withErrorPassthroughRuleID(id)) + return &ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for ErrorPassthroughRule. +func (c *ErrorPassthroughRuleClient) Delete() *ErrorPassthroughRuleDelete { + mutation := newErrorPassthroughRuleMutation(c.config, OpDelete) + return &ErrorPassthroughRuleDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *ErrorPassthroughRuleClient) DeleteOne(_m *ErrorPassthroughRule) *ErrorPassthroughRuleDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *ErrorPassthroughRuleClient) DeleteOneID(id int64) *ErrorPassthroughRuleDeleteOne { + builder := c.Delete().Where(errorpassthroughrule.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &ErrorPassthroughRuleDeleteOne{builder} +} + +// Query returns a query builder for ErrorPassthroughRule. +func (c *ErrorPassthroughRuleClient) Query() *ErrorPassthroughRuleQuery { + return &ErrorPassthroughRuleQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeErrorPassthroughRule}, + inters: c.Interceptors(), + } +} + +// Get returns a ErrorPassthroughRule entity by its id. +func (c *ErrorPassthroughRuleClient) Get(ctx context.Context, id int64) (*ErrorPassthroughRule, error) { + return c.Query().Where(errorpassthroughrule.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *ErrorPassthroughRuleClient) GetX(ctx context.Context, id int64) *ErrorPassthroughRule { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *ErrorPassthroughRuleClient) Hooks() []Hook { + return c.hooks.ErrorPassthroughRule +} + +// Interceptors returns the client interceptors. +func (c *ErrorPassthroughRuleClient) Interceptors() []Interceptor { + return c.inters.ErrorPassthroughRule +} + +func (c *ErrorPassthroughRuleClient) mutate(ctx context.Context, m *ErrorPassthroughRuleMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&ErrorPassthroughRuleCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&ErrorPassthroughRuleUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&ErrorPassthroughRuleDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown ErrorPassthroughRule mutation op: %q", m.Op()) + } +} + // GroupClient is a client for the Group schema. type GroupClient struct { config @@ -3462,16 +3605,16 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription // hooks and interceptors per client, for fast access. type ( hooks struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode, - PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User, - UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, - UserSubscription []ent.Hook + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, + ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, + Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook } inters struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode, - PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User, - UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, - UserSubscription []ent.Interceptor + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, + ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, + Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 05e30ba7..5767a167 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -17,6 +17,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -95,6 +96,7 @@ func checkColumn(t, c string) error { accountgroup.Table: accountgroup.ValidColumn, announcement.Table: announcement.ValidColumn, announcementread.Table: announcementread.ValidColumn, + errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, group.Table: group.ValidColumn, promocode.Table: promocode.ValidColumn, promocodeusage.Table: promocodeusage.ValidColumn, diff --git a/backend/ent/errorpassthroughrule.go b/backend/ent/errorpassthroughrule.go new file mode 100644 index 00000000..1932f626 --- /dev/null +++ b/backend/ent/errorpassthroughrule.go @@ -0,0 +1,269 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" +) + +// ErrorPassthroughRule is the model entity for the ErrorPassthroughRule schema. +type ErrorPassthroughRule struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // Enabled holds the value of the "enabled" field. + Enabled bool `json:"enabled,omitempty"` + // Priority holds the value of the "priority" field. + Priority int `json:"priority,omitempty"` + // ErrorCodes holds the value of the "error_codes" field. + ErrorCodes []int `json:"error_codes,omitempty"` + // Keywords holds the value of the "keywords" field. + Keywords []string `json:"keywords,omitempty"` + // MatchMode holds the value of the "match_mode" field. + MatchMode string `json:"match_mode,omitempty"` + // Platforms holds the value of the "platforms" field. + Platforms []string `json:"platforms,omitempty"` + // PassthroughCode holds the value of the "passthrough_code" field. + PassthroughCode bool `json:"passthrough_code,omitempty"` + // ResponseCode holds the value of the "response_code" field. + ResponseCode *int `json:"response_code,omitempty"` + // PassthroughBody holds the value of the "passthrough_body" field. + PassthroughBody bool `json:"passthrough_body,omitempty"` + // CustomMessage holds the value of the "custom_message" field. + CustomMessage *string `json:"custom_message,omitempty"` + // Description holds the value of the "description" field. + Description *string `json:"description,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*ErrorPassthroughRule) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms: + values[i] = new([]byte) + case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody: + values[i] = new(sql.NullBool) + case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode: + values[i] = new(sql.NullInt64) + case errorpassthroughrule.FieldName, errorpassthroughrule.FieldMatchMode, errorpassthroughrule.FieldCustomMessage, errorpassthroughrule.FieldDescription: + values[i] = new(sql.NullString) + case errorpassthroughrule.FieldCreatedAt, errorpassthroughrule.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the ErrorPassthroughRule fields. +func (_m *ErrorPassthroughRule) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case errorpassthroughrule.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case errorpassthroughrule.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case errorpassthroughrule.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case errorpassthroughrule.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + _m.Name = value.String + } + case errorpassthroughrule.FieldEnabled: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field enabled", values[i]) + } else if value.Valid { + _m.Enabled = value.Bool + } + case errorpassthroughrule.FieldPriority: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field priority", values[i]) + } else if value.Valid { + _m.Priority = int(value.Int64) + } + case errorpassthroughrule.FieldErrorCodes: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field error_codes", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.ErrorCodes); err != nil { + return fmt.Errorf("unmarshal field error_codes: %w", err) + } + } + case errorpassthroughrule.FieldKeywords: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field keywords", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Keywords); err != nil { + return fmt.Errorf("unmarshal field keywords: %w", err) + } + } + case errorpassthroughrule.FieldMatchMode: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field match_mode", values[i]) + } else if value.Valid { + _m.MatchMode = value.String + } + case errorpassthroughrule.FieldPlatforms: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field platforms", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Platforms); err != nil { + return fmt.Errorf("unmarshal field platforms: %w", err) + } + } + case errorpassthroughrule.FieldPassthroughCode: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field passthrough_code", values[i]) + } else if value.Valid { + _m.PassthroughCode = value.Bool + } + case errorpassthroughrule.FieldResponseCode: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field response_code", values[i]) + } else if value.Valid { + _m.ResponseCode = new(int) + *_m.ResponseCode = int(value.Int64) + } + case errorpassthroughrule.FieldPassthroughBody: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field passthrough_body", values[i]) + } else if value.Valid { + _m.PassthroughBody = value.Bool + } + case errorpassthroughrule.FieldCustomMessage: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field custom_message", values[i]) + } else if value.Valid { + _m.CustomMessage = new(string) + *_m.CustomMessage = value.String + } + case errorpassthroughrule.FieldDescription: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field description", values[i]) + } else if value.Valid { + _m.Description = new(string) + *_m.Description = value.String + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the ErrorPassthroughRule. +// This includes values selected through modifiers, order, etc. +func (_m *ErrorPassthroughRule) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this ErrorPassthroughRule. +// Note that you need to call ErrorPassthroughRule.Unwrap() before calling this method if this ErrorPassthroughRule +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *ErrorPassthroughRule) Update() *ErrorPassthroughRuleUpdateOne { + return NewErrorPassthroughRuleClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the ErrorPassthroughRule entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *ErrorPassthroughRule) Unwrap() *ErrorPassthroughRule { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: ErrorPassthroughRule is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *ErrorPassthroughRule) String() string { + var builder strings.Builder + builder.WriteString("ErrorPassthroughRule(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(_m.Name) + builder.WriteString(", ") + builder.WriteString("enabled=") + builder.WriteString(fmt.Sprintf("%v", _m.Enabled)) + builder.WriteString(", ") + builder.WriteString("priority=") + builder.WriteString(fmt.Sprintf("%v", _m.Priority)) + builder.WriteString(", ") + builder.WriteString("error_codes=") + builder.WriteString(fmt.Sprintf("%v", _m.ErrorCodes)) + builder.WriteString(", ") + builder.WriteString("keywords=") + builder.WriteString(fmt.Sprintf("%v", _m.Keywords)) + builder.WriteString(", ") + builder.WriteString("match_mode=") + builder.WriteString(_m.MatchMode) + builder.WriteString(", ") + builder.WriteString("platforms=") + builder.WriteString(fmt.Sprintf("%v", _m.Platforms)) + builder.WriteString(", ") + builder.WriteString("passthrough_code=") + builder.WriteString(fmt.Sprintf("%v", _m.PassthroughCode)) + builder.WriteString(", ") + if v := _m.ResponseCode; v != nil { + builder.WriteString("response_code=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("passthrough_body=") + builder.WriteString(fmt.Sprintf("%v", _m.PassthroughBody)) + builder.WriteString(", ") + if v := _m.CustomMessage; v != nil { + builder.WriteString("custom_message=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.Description; v != nil { + builder.WriteString("description=") + builder.WriteString(*v) + } + builder.WriteByte(')') + return builder.String() +} + +// ErrorPassthroughRules is a parsable slice of ErrorPassthroughRule. +type ErrorPassthroughRules []*ErrorPassthroughRule diff --git a/backend/ent/errorpassthroughrule/errorpassthroughrule.go b/backend/ent/errorpassthroughrule/errorpassthroughrule.go new file mode 100644 index 00000000..d7be4f03 --- /dev/null +++ b/backend/ent/errorpassthroughrule/errorpassthroughrule.go @@ -0,0 +1,161 @@ +// Code generated by ent, DO NOT EDIT. + +package errorpassthroughrule + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the errorpassthroughrule type in the database. + Label = "error_passthrough_rule" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldEnabled holds the string denoting the enabled field in the database. + FieldEnabled = "enabled" + // FieldPriority holds the string denoting the priority field in the database. + FieldPriority = "priority" + // FieldErrorCodes holds the string denoting the error_codes field in the database. + FieldErrorCodes = "error_codes" + // FieldKeywords holds the string denoting the keywords field in the database. + FieldKeywords = "keywords" + // FieldMatchMode holds the string denoting the match_mode field in the database. + FieldMatchMode = "match_mode" + // FieldPlatforms holds the string denoting the platforms field in the database. + FieldPlatforms = "platforms" + // FieldPassthroughCode holds the string denoting the passthrough_code field in the database. + FieldPassthroughCode = "passthrough_code" + // FieldResponseCode holds the string denoting the response_code field in the database. + FieldResponseCode = "response_code" + // FieldPassthroughBody holds the string denoting the passthrough_body field in the database. + FieldPassthroughBody = "passthrough_body" + // FieldCustomMessage holds the string denoting the custom_message field in the database. + FieldCustomMessage = "custom_message" + // FieldDescription holds the string denoting the description field in the database. + FieldDescription = "description" + // Table holds the table name of the errorpassthroughrule in the database. + Table = "error_passthrough_rules" +) + +// Columns holds all SQL columns for errorpassthroughrule fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldName, + FieldEnabled, + FieldPriority, + FieldErrorCodes, + FieldKeywords, + FieldMatchMode, + FieldPlatforms, + FieldPassthroughCode, + FieldResponseCode, + FieldPassthroughBody, + FieldCustomMessage, + FieldDescription, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // NameValidator is a validator for the "name" field. It is called by the builders before save. + NameValidator func(string) error + // DefaultEnabled holds the default value on creation for the "enabled" field. + DefaultEnabled bool + // DefaultPriority holds the default value on creation for the "priority" field. + DefaultPriority int + // DefaultMatchMode holds the default value on creation for the "match_mode" field. + DefaultMatchMode string + // MatchModeValidator is a validator for the "match_mode" field. It is called by the builders before save. + MatchModeValidator func(string) error + // DefaultPassthroughCode holds the default value on creation for the "passthrough_code" field. + DefaultPassthroughCode bool + // DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field. + DefaultPassthroughBody bool +) + +// OrderOption defines the ordering options for the ErrorPassthroughRule queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByEnabled orders the results by the enabled field. +func ByEnabled(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEnabled, opts...).ToFunc() +} + +// ByPriority orders the results by the priority field. +func ByPriority(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPriority, opts...).ToFunc() +} + +// ByMatchMode orders the results by the match_mode field. +func ByMatchMode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMatchMode, opts...).ToFunc() +} + +// ByPassthroughCode orders the results by the passthrough_code field. +func ByPassthroughCode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPassthroughCode, opts...).ToFunc() +} + +// ByResponseCode orders the results by the response_code field. +func ByResponseCode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResponseCode, opts...).ToFunc() +} + +// ByPassthroughBody orders the results by the passthrough_body field. +func ByPassthroughBody(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPassthroughBody, opts...).ToFunc() +} + +// ByCustomMessage orders the results by the custom_message field. +func ByCustomMessage(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCustomMessage, opts...).ToFunc() +} + +// ByDescription orders the results by the description field. +func ByDescription(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDescription, opts...).ToFunc() +} diff --git a/backend/ent/errorpassthroughrule/where.go b/backend/ent/errorpassthroughrule/where.go new file mode 100644 index 00000000..56839d52 --- /dev/null +++ b/backend/ent/errorpassthroughrule/where.go @@ -0,0 +1,635 @@ +// Code generated by ent, DO NOT EDIT. + +package errorpassthroughrule + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldName, v)) +} + +// Enabled applies equality check predicate on the "enabled" field. It's identical to EnabledEQ. +func Enabled(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldEnabled, v)) +} + +// Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ. +func Priority(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPriority, v)) +} + +// MatchMode applies equality check predicate on the "match_mode" field. It's identical to MatchModeEQ. +func MatchMode(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldMatchMode, v)) +} + +// PassthroughCode applies equality check predicate on the "passthrough_code" field. It's identical to PassthroughCodeEQ. +func PassthroughCode(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughCode, v)) +} + +// ResponseCode applies equality check predicate on the "response_code" field. It's identical to ResponseCodeEQ. +func ResponseCode(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldResponseCode, v)) +} + +// PassthroughBody applies equality check predicate on the "passthrough_body" field. It's identical to PassthroughBodyEQ. +func PassthroughBody(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughBody, v)) +} + +// CustomMessage applies equality check predicate on the "custom_message" field. It's identical to CustomMessageEQ. +func CustomMessage(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v)) +} + +// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ. +func Description(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldName, v)) +} + +// EnabledEQ applies the EQ predicate on the "enabled" field. +func EnabledEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldEnabled, v)) +} + +// EnabledNEQ applies the NEQ predicate on the "enabled" field. +func EnabledNEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldEnabled, v)) +} + +// PriorityEQ applies the EQ predicate on the "priority" field. +func PriorityEQ(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPriority, v)) +} + +// PriorityNEQ applies the NEQ predicate on the "priority" field. +func PriorityNEQ(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPriority, v)) +} + +// PriorityIn applies the In predicate on the "priority" field. +func PriorityIn(vs ...int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldPriority, vs...)) +} + +// PriorityNotIn applies the NotIn predicate on the "priority" field. +func PriorityNotIn(vs ...int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldPriority, vs...)) +} + +// PriorityGT applies the GT predicate on the "priority" field. +func PriorityGT(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldPriority, v)) +} + +// PriorityGTE applies the GTE predicate on the "priority" field. +func PriorityGTE(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldPriority, v)) +} + +// PriorityLT applies the LT predicate on the "priority" field. +func PriorityLT(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldPriority, v)) +} + +// PriorityLTE applies the LTE predicate on the "priority" field. +func PriorityLTE(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldPriority, v)) +} + +// ErrorCodesIsNil applies the IsNil predicate on the "error_codes" field. +func ErrorCodesIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldErrorCodes)) +} + +// ErrorCodesNotNil applies the NotNil predicate on the "error_codes" field. +func ErrorCodesNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldErrorCodes)) +} + +// KeywordsIsNil applies the IsNil predicate on the "keywords" field. +func KeywordsIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldKeywords)) +} + +// KeywordsNotNil applies the NotNil predicate on the "keywords" field. +func KeywordsNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldKeywords)) +} + +// MatchModeEQ applies the EQ predicate on the "match_mode" field. +func MatchModeEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldMatchMode, v)) +} + +// MatchModeNEQ applies the NEQ predicate on the "match_mode" field. +func MatchModeNEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldMatchMode, v)) +} + +// MatchModeIn applies the In predicate on the "match_mode" field. +func MatchModeIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldMatchMode, vs...)) +} + +// MatchModeNotIn applies the NotIn predicate on the "match_mode" field. +func MatchModeNotIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldMatchMode, vs...)) +} + +// MatchModeGT applies the GT predicate on the "match_mode" field. +func MatchModeGT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldMatchMode, v)) +} + +// MatchModeGTE applies the GTE predicate on the "match_mode" field. +func MatchModeGTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldMatchMode, v)) +} + +// MatchModeLT applies the LT predicate on the "match_mode" field. +func MatchModeLT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldMatchMode, v)) +} + +// MatchModeLTE applies the LTE predicate on the "match_mode" field. +func MatchModeLTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldMatchMode, v)) +} + +// MatchModeContains applies the Contains predicate on the "match_mode" field. +func MatchModeContains(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContains(FieldMatchMode, v)) +} + +// MatchModeHasPrefix applies the HasPrefix predicate on the "match_mode" field. +func MatchModeHasPrefix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldMatchMode, v)) +} + +// MatchModeHasSuffix applies the HasSuffix predicate on the "match_mode" field. +func MatchModeHasSuffix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldMatchMode, v)) +} + +// MatchModeEqualFold applies the EqualFold predicate on the "match_mode" field. +func MatchModeEqualFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldMatchMode, v)) +} + +// MatchModeContainsFold applies the ContainsFold predicate on the "match_mode" field. +func MatchModeContainsFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldMatchMode, v)) +} + +// PlatformsIsNil applies the IsNil predicate on the "platforms" field. +func PlatformsIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldPlatforms)) +} + +// PlatformsNotNil applies the NotNil predicate on the "platforms" field. +func PlatformsNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldPlatforms)) +} + +// PassthroughCodeEQ applies the EQ predicate on the "passthrough_code" field. +func PassthroughCodeEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughCode, v)) +} + +// PassthroughCodeNEQ applies the NEQ predicate on the "passthrough_code" field. +func PassthroughCodeNEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPassthroughCode, v)) +} + +// ResponseCodeEQ applies the EQ predicate on the "response_code" field. +func ResponseCodeEQ(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldResponseCode, v)) +} + +// ResponseCodeNEQ applies the NEQ predicate on the "response_code" field. +func ResponseCodeNEQ(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldResponseCode, v)) +} + +// ResponseCodeIn applies the In predicate on the "response_code" field. +func ResponseCodeIn(vs ...int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldResponseCode, vs...)) +} + +// ResponseCodeNotIn applies the NotIn predicate on the "response_code" field. +func ResponseCodeNotIn(vs ...int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldResponseCode, vs...)) +} + +// ResponseCodeGT applies the GT predicate on the "response_code" field. +func ResponseCodeGT(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldResponseCode, v)) +} + +// ResponseCodeGTE applies the GTE predicate on the "response_code" field. +func ResponseCodeGTE(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldResponseCode, v)) +} + +// ResponseCodeLT applies the LT predicate on the "response_code" field. +func ResponseCodeLT(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldResponseCode, v)) +} + +// ResponseCodeLTE applies the LTE predicate on the "response_code" field. +func ResponseCodeLTE(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldResponseCode, v)) +} + +// ResponseCodeIsNil applies the IsNil predicate on the "response_code" field. +func ResponseCodeIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldResponseCode)) +} + +// ResponseCodeNotNil applies the NotNil predicate on the "response_code" field. +func ResponseCodeNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldResponseCode)) +} + +// PassthroughBodyEQ applies the EQ predicate on the "passthrough_body" field. +func PassthroughBodyEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughBody, v)) +} + +// PassthroughBodyNEQ applies the NEQ predicate on the "passthrough_body" field. +func PassthroughBodyNEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPassthroughBody, v)) +} + +// CustomMessageEQ applies the EQ predicate on the "custom_message" field. +func CustomMessageEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v)) +} + +// CustomMessageNEQ applies the NEQ predicate on the "custom_message" field. +func CustomMessageNEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldCustomMessage, v)) +} + +// CustomMessageIn applies the In predicate on the "custom_message" field. +func CustomMessageIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldCustomMessage, vs...)) +} + +// CustomMessageNotIn applies the NotIn predicate on the "custom_message" field. +func CustomMessageNotIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldCustomMessage, vs...)) +} + +// CustomMessageGT applies the GT predicate on the "custom_message" field. +func CustomMessageGT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldCustomMessage, v)) +} + +// CustomMessageGTE applies the GTE predicate on the "custom_message" field. +func CustomMessageGTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldCustomMessage, v)) +} + +// CustomMessageLT applies the LT predicate on the "custom_message" field. +func CustomMessageLT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldCustomMessage, v)) +} + +// CustomMessageLTE applies the LTE predicate on the "custom_message" field. +func CustomMessageLTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldCustomMessage, v)) +} + +// CustomMessageContains applies the Contains predicate on the "custom_message" field. +func CustomMessageContains(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContains(FieldCustomMessage, v)) +} + +// CustomMessageHasPrefix applies the HasPrefix predicate on the "custom_message" field. +func CustomMessageHasPrefix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldCustomMessage, v)) +} + +// CustomMessageHasSuffix applies the HasSuffix predicate on the "custom_message" field. +func CustomMessageHasSuffix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldCustomMessage, v)) +} + +// CustomMessageIsNil applies the IsNil predicate on the "custom_message" field. +func CustomMessageIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldCustomMessage)) +} + +// CustomMessageNotNil applies the NotNil predicate on the "custom_message" field. +func CustomMessageNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldCustomMessage)) +} + +// CustomMessageEqualFold applies the EqualFold predicate on the "custom_message" field. +func CustomMessageEqualFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldCustomMessage, v)) +} + +// CustomMessageContainsFold applies the ContainsFold predicate on the "custom_message" field. +func CustomMessageContainsFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v)) +} + +// DescriptionEQ applies the EQ predicate on the "description" field. +func DescriptionEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v)) +} + +// DescriptionNEQ applies the NEQ predicate on the "description" field. +func DescriptionNEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldDescription, v)) +} + +// DescriptionIn applies the In predicate on the "description" field. +func DescriptionIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldDescription, vs...)) +} + +// DescriptionNotIn applies the NotIn predicate on the "description" field. +func DescriptionNotIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldDescription, vs...)) +} + +// DescriptionGT applies the GT predicate on the "description" field. +func DescriptionGT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldDescription, v)) +} + +// DescriptionGTE applies the GTE predicate on the "description" field. +func DescriptionGTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldDescription, v)) +} + +// DescriptionLT applies the LT predicate on the "description" field. +func DescriptionLT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldDescription, v)) +} + +// DescriptionLTE applies the LTE predicate on the "description" field. +func DescriptionLTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldDescription, v)) +} + +// DescriptionContains applies the Contains predicate on the "description" field. +func DescriptionContains(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContains(FieldDescription, v)) +} + +// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field. +func DescriptionHasPrefix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldDescription, v)) +} + +// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field. +func DescriptionHasSuffix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldDescription, v)) +} + +// DescriptionIsNil applies the IsNil predicate on the "description" field. +func DescriptionIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldDescription)) +} + +// DescriptionNotNil applies the NotNil predicate on the "description" field. +func DescriptionNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldDescription)) +} + +// DescriptionEqualFold applies the EqualFold predicate on the "description" field. +func DescriptionEqualFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldDescription, v)) +} + +// DescriptionContainsFold applies the ContainsFold predicate on the "description" field. +func DescriptionContainsFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldDescription, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.NotPredicates(p)) +} diff --git a/backend/ent/errorpassthroughrule_create.go b/backend/ent/errorpassthroughrule_create.go new file mode 100644 index 00000000..4dc08dce --- /dev/null +++ b/backend/ent/errorpassthroughrule_create.go @@ -0,0 +1,1382 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" +) + +// ErrorPassthroughRuleCreate is the builder for creating a ErrorPassthroughRule entity. +type ErrorPassthroughRuleCreate struct { + config + mutation *ErrorPassthroughRuleMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *ErrorPassthroughRuleCreate) SetCreatedAt(v time.Time) *ErrorPassthroughRuleCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableCreatedAt(v *time.Time) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *ErrorPassthroughRuleCreate) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableUpdatedAt(v *time.Time) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetName sets the "name" field. +func (_c *ErrorPassthroughRuleCreate) SetName(v string) *ErrorPassthroughRuleCreate { + _c.mutation.SetName(v) + return _c +} + +// SetEnabled sets the "enabled" field. +func (_c *ErrorPassthroughRuleCreate) SetEnabled(v bool) *ErrorPassthroughRuleCreate { + _c.mutation.SetEnabled(v) + return _c +} + +// SetNillableEnabled sets the "enabled" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableEnabled(v *bool) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetEnabled(*v) + } + return _c +} + +// SetPriority sets the "priority" field. +func (_c *ErrorPassthroughRuleCreate) SetPriority(v int) *ErrorPassthroughRuleCreate { + _c.mutation.SetPriority(v) + return _c +} + +// SetNillablePriority sets the "priority" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillablePriority(v *int) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetPriority(*v) + } + return _c +} + +// SetErrorCodes sets the "error_codes" field. +func (_c *ErrorPassthroughRuleCreate) SetErrorCodes(v []int) *ErrorPassthroughRuleCreate { + _c.mutation.SetErrorCodes(v) + return _c +} + +// SetKeywords sets the "keywords" field. +func (_c *ErrorPassthroughRuleCreate) SetKeywords(v []string) *ErrorPassthroughRuleCreate { + _c.mutation.SetKeywords(v) + return _c +} + +// SetMatchMode sets the "match_mode" field. +func (_c *ErrorPassthroughRuleCreate) SetMatchMode(v string) *ErrorPassthroughRuleCreate { + _c.mutation.SetMatchMode(v) + return _c +} + +// SetNillableMatchMode sets the "match_mode" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableMatchMode(v *string) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetMatchMode(*v) + } + return _c +} + +// SetPlatforms sets the "platforms" field. +func (_c *ErrorPassthroughRuleCreate) SetPlatforms(v []string) *ErrorPassthroughRuleCreate { + _c.mutation.SetPlatforms(v) + return _c +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (_c *ErrorPassthroughRuleCreate) SetPassthroughCode(v bool) *ErrorPassthroughRuleCreate { + _c.mutation.SetPassthroughCode(v) + return _c +} + +// SetNillablePassthroughCode sets the "passthrough_code" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillablePassthroughCode(v *bool) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetPassthroughCode(*v) + } + return _c +} + +// SetResponseCode sets the "response_code" field. +func (_c *ErrorPassthroughRuleCreate) SetResponseCode(v int) *ErrorPassthroughRuleCreate { + _c.mutation.SetResponseCode(v) + return _c +} + +// SetNillableResponseCode sets the "response_code" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableResponseCode(v *int) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetResponseCode(*v) + } + return _c +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (_c *ErrorPassthroughRuleCreate) SetPassthroughBody(v bool) *ErrorPassthroughRuleCreate { + _c.mutation.SetPassthroughBody(v) + return _c +} + +// SetNillablePassthroughBody sets the "passthrough_body" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillablePassthroughBody(v *bool) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetPassthroughBody(*v) + } + return _c +} + +// SetCustomMessage sets the "custom_message" field. +func (_c *ErrorPassthroughRuleCreate) SetCustomMessage(v string) *ErrorPassthroughRuleCreate { + _c.mutation.SetCustomMessage(v) + return _c +} + +// SetNillableCustomMessage sets the "custom_message" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableCustomMessage(v *string) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetCustomMessage(*v) + } + return _c +} + +// SetDescription sets the "description" field. +func (_c *ErrorPassthroughRuleCreate) SetDescription(v string) *ErrorPassthroughRuleCreate { + _c.mutation.SetDescription(v) + return _c +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableDescription(v *string) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetDescription(*v) + } + return _c +} + +// Mutation returns the ErrorPassthroughRuleMutation object of the builder. +func (_c *ErrorPassthroughRuleCreate) Mutation() *ErrorPassthroughRuleMutation { + return _c.mutation +} + +// Save creates the ErrorPassthroughRule in the database. +func (_c *ErrorPassthroughRuleCreate) Save(ctx context.Context) (*ErrorPassthroughRule, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *ErrorPassthroughRuleCreate) SaveX(ctx context.Context) *ErrorPassthroughRule { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ErrorPassthroughRuleCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ErrorPassthroughRuleCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *ErrorPassthroughRuleCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := errorpassthroughrule.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := errorpassthroughrule.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Enabled(); !ok { + v := errorpassthroughrule.DefaultEnabled + _c.mutation.SetEnabled(v) + } + if _, ok := _c.mutation.Priority(); !ok { + v := errorpassthroughrule.DefaultPriority + _c.mutation.SetPriority(v) + } + if _, ok := _c.mutation.MatchMode(); !ok { + v := errorpassthroughrule.DefaultMatchMode + _c.mutation.SetMatchMode(v) + } + if _, ok := _c.mutation.PassthroughCode(); !ok { + v := errorpassthroughrule.DefaultPassthroughCode + _c.mutation.SetPassthroughCode(v) + } + if _, ok := _c.mutation.PassthroughBody(); !ok { + v := errorpassthroughrule.DefaultPassthroughBody + _c.mutation.SetPassthroughBody(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *ErrorPassthroughRuleCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "ErrorPassthroughRule.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "ErrorPassthroughRule.updated_at"`)} + } + if _, ok := _c.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "ErrorPassthroughRule.name"`)} + } + if v, ok := _c.mutation.Name(); ok { + if err := errorpassthroughrule.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.name": %w`, err)} + } + } + if _, ok := _c.mutation.Enabled(); !ok { + return &ValidationError{Name: "enabled", err: errors.New(`ent: missing required field "ErrorPassthroughRule.enabled"`)} + } + if _, ok := _c.mutation.Priority(); !ok { + return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "ErrorPassthroughRule.priority"`)} + } + if _, ok := _c.mutation.MatchMode(); !ok { + return &ValidationError{Name: "match_mode", err: errors.New(`ent: missing required field "ErrorPassthroughRule.match_mode"`)} + } + if v, ok := _c.mutation.MatchMode(); ok { + if err := errorpassthroughrule.MatchModeValidator(v); err != nil { + return &ValidationError{Name: "match_mode", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.match_mode": %w`, err)} + } + } + if _, ok := _c.mutation.PassthroughCode(); !ok { + return &ValidationError{Name: "passthrough_code", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_code"`)} + } + if _, ok := _c.mutation.PassthroughBody(); !ok { + return &ValidationError{Name: "passthrough_body", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_body"`)} + } + return nil +} + +func (_c *ErrorPassthroughRuleCreate) sqlSave(ctx context.Context) (*ErrorPassthroughRule, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *ErrorPassthroughRuleCreate) createSpec() (*ErrorPassthroughRule, *sqlgraph.CreateSpec) { + var ( + _node = &ErrorPassthroughRule{config: _c.config} + _spec = sqlgraph.NewCreateSpec(errorpassthroughrule.Table, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(errorpassthroughrule.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(errorpassthroughrule.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Name(); ok { + _spec.SetField(errorpassthroughrule.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := _c.mutation.Enabled(); ok { + _spec.SetField(errorpassthroughrule.FieldEnabled, field.TypeBool, value) + _node.Enabled = value + } + if value, ok := _c.mutation.Priority(); ok { + _spec.SetField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + _node.Priority = value + } + if value, ok := _c.mutation.ErrorCodes(); ok { + _spec.SetField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON, value) + _node.ErrorCodes = value + } + if value, ok := _c.mutation.Keywords(); ok { + _spec.SetField(errorpassthroughrule.FieldKeywords, field.TypeJSON, value) + _node.Keywords = value + } + if value, ok := _c.mutation.MatchMode(); ok { + _spec.SetField(errorpassthroughrule.FieldMatchMode, field.TypeString, value) + _node.MatchMode = value + } + if value, ok := _c.mutation.Platforms(); ok { + _spec.SetField(errorpassthroughrule.FieldPlatforms, field.TypeJSON, value) + _node.Platforms = value + } + if value, ok := _c.mutation.PassthroughCode(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughCode, field.TypeBool, value) + _node.PassthroughCode = value + } + if value, ok := _c.mutation.ResponseCode(); ok { + _spec.SetField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + _node.ResponseCode = &value + } + if value, ok := _c.mutation.PassthroughBody(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughBody, field.TypeBool, value) + _node.PassthroughBody = value + } + if value, ok := _c.mutation.CustomMessage(); ok { + _spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value) + _node.CustomMessage = &value + } + if value, ok := _c.mutation.Description(); ok { + _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) + _node.Description = &value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ErrorPassthroughRule.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ErrorPassthroughRuleUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *ErrorPassthroughRuleCreate) OnConflict(opts ...sql.ConflictOption) *ErrorPassthroughRuleUpsertOne { + _c.conflict = opts + return &ErrorPassthroughRuleUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ErrorPassthroughRuleCreate) OnConflictColumns(columns ...string) *ErrorPassthroughRuleUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ErrorPassthroughRuleUpsertOne{ + create: _c, + } +} + +type ( + // ErrorPassthroughRuleUpsertOne is the builder for "upsert"-ing + // one ErrorPassthroughRule node. + ErrorPassthroughRuleUpsertOne struct { + create *ErrorPassthroughRuleCreate + } + + // ErrorPassthroughRuleUpsert is the "OnConflict" setter. + ErrorPassthroughRuleUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *ErrorPassthroughRuleUpsert) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateUpdatedAt() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldUpdatedAt) + return u +} + +// SetName sets the "name" field. +func (u *ErrorPassthroughRuleUpsert) SetName(v string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateName() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldName) + return u +} + +// SetEnabled sets the "enabled" field. +func (u *ErrorPassthroughRuleUpsert) SetEnabled(v bool) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldEnabled, v) + return u +} + +// UpdateEnabled sets the "enabled" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateEnabled() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldEnabled) + return u +} + +// SetPriority sets the "priority" field. +func (u *ErrorPassthroughRuleUpsert) SetPriority(v int) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldPriority, v) + return u +} + +// UpdatePriority sets the "priority" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdatePriority() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldPriority) + return u +} + +// AddPriority adds v to the "priority" field. +func (u *ErrorPassthroughRuleUpsert) AddPriority(v int) *ErrorPassthroughRuleUpsert { + u.Add(errorpassthroughrule.FieldPriority, v) + return u +} + +// SetErrorCodes sets the "error_codes" field. +func (u *ErrorPassthroughRuleUpsert) SetErrorCodes(v []int) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldErrorCodes, v) + return u +} + +// UpdateErrorCodes sets the "error_codes" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateErrorCodes() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldErrorCodes) + return u +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (u *ErrorPassthroughRuleUpsert) ClearErrorCodes() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldErrorCodes) + return u +} + +// SetKeywords sets the "keywords" field. +func (u *ErrorPassthroughRuleUpsert) SetKeywords(v []string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldKeywords, v) + return u +} + +// UpdateKeywords sets the "keywords" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateKeywords() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldKeywords) + return u +} + +// ClearKeywords clears the value of the "keywords" field. +func (u *ErrorPassthroughRuleUpsert) ClearKeywords() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldKeywords) + return u +} + +// SetMatchMode sets the "match_mode" field. +func (u *ErrorPassthroughRuleUpsert) SetMatchMode(v string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldMatchMode, v) + return u +} + +// UpdateMatchMode sets the "match_mode" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateMatchMode() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldMatchMode) + return u +} + +// SetPlatforms sets the "platforms" field. +func (u *ErrorPassthroughRuleUpsert) SetPlatforms(v []string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldPlatforms, v) + return u +} + +// UpdatePlatforms sets the "platforms" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdatePlatforms() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldPlatforms) + return u +} + +// ClearPlatforms clears the value of the "platforms" field. +func (u *ErrorPassthroughRuleUpsert) ClearPlatforms() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldPlatforms) + return u +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (u *ErrorPassthroughRuleUpsert) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldPassthroughCode, v) + return u +} + +// UpdatePassthroughCode sets the "passthrough_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdatePassthroughCode() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldPassthroughCode) + return u +} + +// SetResponseCode sets the "response_code" field. +func (u *ErrorPassthroughRuleUpsert) SetResponseCode(v int) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldResponseCode, v) + return u +} + +// UpdateResponseCode sets the "response_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateResponseCode() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldResponseCode) + return u +} + +// AddResponseCode adds v to the "response_code" field. +func (u *ErrorPassthroughRuleUpsert) AddResponseCode(v int) *ErrorPassthroughRuleUpsert { + u.Add(errorpassthroughrule.FieldResponseCode, v) + return u +} + +// ClearResponseCode clears the value of the "response_code" field. +func (u *ErrorPassthroughRuleUpsert) ClearResponseCode() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldResponseCode) + return u +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (u *ErrorPassthroughRuleUpsert) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldPassthroughBody, v) + return u +} + +// UpdatePassthroughBody sets the "passthrough_body" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdatePassthroughBody() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldPassthroughBody) + return u +} + +// SetCustomMessage sets the "custom_message" field. +func (u *ErrorPassthroughRuleUpsert) SetCustomMessage(v string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldCustomMessage, v) + return u +} + +// UpdateCustomMessage sets the "custom_message" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateCustomMessage() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldCustomMessage) + return u +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (u *ErrorPassthroughRuleUpsert) ClearCustomMessage() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldCustomMessage) + return u +} + +// SetDescription sets the "description" field. +func (u *ErrorPassthroughRuleUpsert) SetDescription(v string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldDescription, v) + return u +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateDescription() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldDescription) + return u +} + +// ClearDescription clears the value of the "description" field. +func (u *ErrorPassthroughRuleUpsert) ClearDescription() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldDescription) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ErrorPassthroughRuleUpsertOne) UpdateNewValues() *ErrorPassthroughRuleUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(errorpassthroughrule.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ErrorPassthroughRuleUpsertOne) Ignore() *ErrorPassthroughRuleUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ErrorPassthroughRuleUpsertOne) DoNothing() *ErrorPassthroughRuleUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ErrorPassthroughRuleCreate.OnConflict +// documentation for more info. +func (u *ErrorPassthroughRuleUpsertOne) Update(set func(*ErrorPassthroughRuleUpsert)) *ErrorPassthroughRuleUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ErrorPassthroughRuleUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *ErrorPassthroughRuleUpsertOne) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateUpdatedAt() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetName sets the "name" field. +func (u *ErrorPassthroughRuleUpsertOne) SetName(v string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateName() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateName() + }) +} + +// SetEnabled sets the "enabled" field. +func (u *ErrorPassthroughRuleUpsertOne) SetEnabled(v bool) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetEnabled(v) + }) +} + +// UpdateEnabled sets the "enabled" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateEnabled() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateEnabled() + }) +} + +// SetPriority sets the "priority" field. +func (u *ErrorPassthroughRuleUpsertOne) SetPriority(v int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPriority(v) + }) +} + +// AddPriority adds v to the "priority" field. +func (u *ErrorPassthroughRuleUpsertOne) AddPriority(v int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.AddPriority(v) + }) +} + +// UpdatePriority sets the "priority" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdatePriority() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePriority() + }) +} + +// SetErrorCodes sets the "error_codes" field. +func (u *ErrorPassthroughRuleUpsertOne) SetErrorCodes(v []int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetErrorCodes(v) + }) +} + +// UpdateErrorCodes sets the "error_codes" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateErrorCodes() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateErrorCodes() + }) +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearErrorCodes() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearErrorCodes() + }) +} + +// SetKeywords sets the "keywords" field. +func (u *ErrorPassthroughRuleUpsertOne) SetKeywords(v []string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetKeywords(v) + }) +} + +// UpdateKeywords sets the "keywords" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateKeywords() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateKeywords() + }) +} + +// ClearKeywords clears the value of the "keywords" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearKeywords() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearKeywords() + }) +} + +// SetMatchMode sets the "match_mode" field. +func (u *ErrorPassthroughRuleUpsertOne) SetMatchMode(v string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetMatchMode(v) + }) +} + +// UpdateMatchMode sets the "match_mode" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateMatchMode() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateMatchMode() + }) +} + +// SetPlatforms sets the "platforms" field. +func (u *ErrorPassthroughRuleUpsertOne) SetPlatforms(v []string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPlatforms(v) + }) +} + +// UpdatePlatforms sets the "platforms" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdatePlatforms() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePlatforms() + }) +} + +// ClearPlatforms clears the value of the "platforms" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearPlatforms() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearPlatforms() + }) +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (u *ErrorPassthroughRuleUpsertOne) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPassthroughCode(v) + }) +} + +// UpdatePassthroughCode sets the "passthrough_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdatePassthroughCode() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePassthroughCode() + }) +} + +// SetResponseCode sets the "response_code" field. +func (u *ErrorPassthroughRuleUpsertOne) SetResponseCode(v int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetResponseCode(v) + }) +} + +// AddResponseCode adds v to the "response_code" field. +func (u *ErrorPassthroughRuleUpsertOne) AddResponseCode(v int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.AddResponseCode(v) + }) +} + +// UpdateResponseCode sets the "response_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateResponseCode() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateResponseCode() + }) +} + +// ClearResponseCode clears the value of the "response_code" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearResponseCode() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearResponseCode() + }) +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (u *ErrorPassthroughRuleUpsertOne) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPassthroughBody(v) + }) +} + +// UpdatePassthroughBody sets the "passthrough_body" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdatePassthroughBody() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePassthroughBody() + }) +} + +// SetCustomMessage sets the "custom_message" field. +func (u *ErrorPassthroughRuleUpsertOne) SetCustomMessage(v string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetCustomMessage(v) + }) +} + +// UpdateCustomMessage sets the "custom_message" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateCustomMessage() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateCustomMessage() + }) +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearCustomMessage() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearCustomMessage() + }) +} + +// SetDescription sets the "description" field. +func (u *ErrorPassthroughRuleUpsertOne) SetDescription(v string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetDescription(v) + }) +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateDescription() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateDescription() + }) +} + +// ClearDescription clears the value of the "description" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearDescription() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearDescription() + }) +} + +// Exec executes the query. +func (u *ErrorPassthroughRuleUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ErrorPassthroughRuleCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ErrorPassthroughRuleUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *ErrorPassthroughRuleUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *ErrorPassthroughRuleUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// ErrorPassthroughRuleCreateBulk is the builder for creating many ErrorPassthroughRule entities in bulk. +type ErrorPassthroughRuleCreateBulk struct { + config + err error + builders []*ErrorPassthroughRuleCreate + conflict []sql.ConflictOption +} + +// Save creates the ErrorPassthroughRule entities in the database. +func (_c *ErrorPassthroughRuleCreateBulk) Save(ctx context.Context) ([]*ErrorPassthroughRule, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*ErrorPassthroughRule, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*ErrorPassthroughRuleMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *ErrorPassthroughRuleCreateBulk) SaveX(ctx context.Context) []*ErrorPassthroughRule { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ErrorPassthroughRuleCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ErrorPassthroughRuleCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ErrorPassthroughRule.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ErrorPassthroughRuleUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *ErrorPassthroughRuleCreateBulk) OnConflict(opts ...sql.ConflictOption) *ErrorPassthroughRuleUpsertBulk { + _c.conflict = opts + return &ErrorPassthroughRuleUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ErrorPassthroughRuleCreateBulk) OnConflictColumns(columns ...string) *ErrorPassthroughRuleUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ErrorPassthroughRuleUpsertBulk{ + create: _c, + } +} + +// ErrorPassthroughRuleUpsertBulk is the builder for "upsert"-ing +// a bulk of ErrorPassthroughRule nodes. +type ErrorPassthroughRuleUpsertBulk struct { + create *ErrorPassthroughRuleCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ErrorPassthroughRuleUpsertBulk) UpdateNewValues() *ErrorPassthroughRuleUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(errorpassthroughrule.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ErrorPassthroughRuleUpsertBulk) Ignore() *ErrorPassthroughRuleUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ErrorPassthroughRuleUpsertBulk) DoNothing() *ErrorPassthroughRuleUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ErrorPassthroughRuleCreateBulk.OnConflict +// documentation for more info. +func (u *ErrorPassthroughRuleUpsertBulk) Update(set func(*ErrorPassthroughRuleUpsert)) *ErrorPassthroughRuleUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ErrorPassthroughRuleUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateUpdatedAt() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetName sets the "name" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetName(v string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateName() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateName() + }) +} + +// SetEnabled sets the "enabled" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetEnabled(v bool) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetEnabled(v) + }) +} + +// UpdateEnabled sets the "enabled" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateEnabled() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateEnabled() + }) +} + +// SetPriority sets the "priority" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetPriority(v int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPriority(v) + }) +} + +// AddPriority adds v to the "priority" field. +func (u *ErrorPassthroughRuleUpsertBulk) AddPriority(v int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.AddPriority(v) + }) +} + +// UpdatePriority sets the "priority" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdatePriority() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePriority() + }) +} + +// SetErrorCodes sets the "error_codes" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetErrorCodes(v []int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetErrorCodes(v) + }) +} + +// UpdateErrorCodes sets the "error_codes" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateErrorCodes() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateErrorCodes() + }) +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearErrorCodes() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearErrorCodes() + }) +} + +// SetKeywords sets the "keywords" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetKeywords(v []string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetKeywords(v) + }) +} + +// UpdateKeywords sets the "keywords" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateKeywords() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateKeywords() + }) +} + +// ClearKeywords clears the value of the "keywords" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearKeywords() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearKeywords() + }) +} + +// SetMatchMode sets the "match_mode" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetMatchMode(v string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetMatchMode(v) + }) +} + +// UpdateMatchMode sets the "match_mode" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateMatchMode() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateMatchMode() + }) +} + +// SetPlatforms sets the "platforms" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetPlatforms(v []string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPlatforms(v) + }) +} + +// UpdatePlatforms sets the "platforms" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdatePlatforms() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePlatforms() + }) +} + +// ClearPlatforms clears the value of the "platforms" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearPlatforms() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearPlatforms() + }) +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPassthroughCode(v) + }) +} + +// UpdatePassthroughCode sets the "passthrough_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdatePassthroughCode() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePassthroughCode() + }) +} + +// SetResponseCode sets the "response_code" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetResponseCode(v int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetResponseCode(v) + }) +} + +// AddResponseCode adds v to the "response_code" field. +func (u *ErrorPassthroughRuleUpsertBulk) AddResponseCode(v int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.AddResponseCode(v) + }) +} + +// UpdateResponseCode sets the "response_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateResponseCode() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateResponseCode() + }) +} + +// ClearResponseCode clears the value of the "response_code" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearResponseCode() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearResponseCode() + }) +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPassthroughBody(v) + }) +} + +// UpdatePassthroughBody sets the "passthrough_body" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdatePassthroughBody() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePassthroughBody() + }) +} + +// SetCustomMessage sets the "custom_message" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetCustomMessage(v string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetCustomMessage(v) + }) +} + +// UpdateCustomMessage sets the "custom_message" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateCustomMessage() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateCustomMessage() + }) +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearCustomMessage() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearCustomMessage() + }) +} + +// SetDescription sets the "description" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetDescription(v string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetDescription(v) + }) +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateDescription() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateDescription() + }) +} + +// ClearDescription clears the value of the "description" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearDescription() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearDescription() + }) +} + +// Exec executes the query. +func (u *ErrorPassthroughRuleUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ErrorPassthroughRuleCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ErrorPassthroughRuleCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ErrorPassthroughRuleUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/errorpassthroughrule_delete.go b/backend/ent/errorpassthroughrule_delete.go new file mode 100644 index 00000000..943c7e2b --- /dev/null +++ b/backend/ent/errorpassthroughrule_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ErrorPassthroughRuleDelete is the builder for deleting a ErrorPassthroughRule entity. +type ErrorPassthroughRuleDelete struct { + config + hooks []Hook + mutation *ErrorPassthroughRuleMutation +} + +// Where appends a list predicates to the ErrorPassthroughRuleDelete builder. +func (_d *ErrorPassthroughRuleDelete) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *ErrorPassthroughRuleDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ErrorPassthroughRuleDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *ErrorPassthroughRuleDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(errorpassthroughrule.Table, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// ErrorPassthroughRuleDeleteOne is the builder for deleting a single ErrorPassthroughRule entity. +type ErrorPassthroughRuleDeleteOne struct { + _d *ErrorPassthroughRuleDelete +} + +// Where appends a list predicates to the ErrorPassthroughRuleDelete builder. +func (_d *ErrorPassthroughRuleDeleteOne) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *ErrorPassthroughRuleDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{errorpassthroughrule.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ErrorPassthroughRuleDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/errorpassthroughrule_query.go b/backend/ent/errorpassthroughrule_query.go new file mode 100644 index 00000000..bfab5bd8 --- /dev/null +++ b/backend/ent/errorpassthroughrule_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ErrorPassthroughRuleQuery is the builder for querying ErrorPassthroughRule entities. +type ErrorPassthroughRuleQuery struct { + config + ctx *QueryContext + order []errorpassthroughrule.OrderOption + inters []Interceptor + predicates []predicate.ErrorPassthroughRule + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the ErrorPassthroughRuleQuery builder. +func (_q *ErrorPassthroughRuleQuery) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *ErrorPassthroughRuleQuery) Limit(limit int) *ErrorPassthroughRuleQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *ErrorPassthroughRuleQuery) Offset(offset int) *ErrorPassthroughRuleQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *ErrorPassthroughRuleQuery) Unique(unique bool) *ErrorPassthroughRuleQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *ErrorPassthroughRuleQuery) Order(o ...errorpassthroughrule.OrderOption) *ErrorPassthroughRuleQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first ErrorPassthroughRule entity from the query. +// Returns a *NotFoundError when no ErrorPassthroughRule was found. +func (_q *ErrorPassthroughRuleQuery) First(ctx context.Context) (*ErrorPassthroughRule, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{errorpassthroughrule.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) FirstX(ctx context.Context) *ErrorPassthroughRule { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first ErrorPassthroughRule ID from the query. +// Returns a *NotFoundError when no ErrorPassthroughRule ID was found. +func (_q *ErrorPassthroughRuleQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{errorpassthroughrule.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single ErrorPassthroughRule entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one ErrorPassthroughRule entity is found. +// Returns a *NotFoundError when no ErrorPassthroughRule entities are found. +func (_q *ErrorPassthroughRuleQuery) Only(ctx context.Context) (*ErrorPassthroughRule, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{errorpassthroughrule.Label} + default: + return nil, &NotSingularError{errorpassthroughrule.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) OnlyX(ctx context.Context) *ErrorPassthroughRule { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only ErrorPassthroughRule ID in the query. +// Returns a *NotSingularError when more than one ErrorPassthroughRule ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *ErrorPassthroughRuleQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{errorpassthroughrule.Label} + default: + err = &NotSingularError{errorpassthroughrule.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of ErrorPassthroughRules. +func (_q *ErrorPassthroughRuleQuery) All(ctx context.Context) ([]*ErrorPassthroughRule, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*ErrorPassthroughRule, *ErrorPassthroughRuleQuery]() + return withInterceptors[[]*ErrorPassthroughRule](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) AllX(ctx context.Context) []*ErrorPassthroughRule { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of ErrorPassthroughRule IDs. +func (_q *ErrorPassthroughRuleQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(errorpassthroughrule.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *ErrorPassthroughRuleQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*ErrorPassthroughRuleQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *ErrorPassthroughRuleQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the ErrorPassthroughRuleQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *ErrorPassthroughRuleQuery) Clone() *ErrorPassthroughRuleQuery { + if _q == nil { + return nil + } + return &ErrorPassthroughRuleQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]errorpassthroughrule.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.ErrorPassthroughRule{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.ErrorPassthroughRule.Query(). +// GroupBy(errorpassthroughrule.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *ErrorPassthroughRuleQuery) GroupBy(field string, fields ...string) *ErrorPassthroughRuleGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &ErrorPassthroughRuleGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = errorpassthroughrule.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.ErrorPassthroughRule.Query(). +// Select(errorpassthroughrule.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *ErrorPassthroughRuleQuery) Select(fields ...string) *ErrorPassthroughRuleSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &ErrorPassthroughRuleSelect{ErrorPassthroughRuleQuery: _q} + sbuild.label = errorpassthroughrule.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a ErrorPassthroughRuleSelect configured with the given aggregations. +func (_q *ErrorPassthroughRuleQuery) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *ErrorPassthroughRuleQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !errorpassthroughrule.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *ErrorPassthroughRuleQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ErrorPassthroughRule, error) { + var ( + nodes = []*ErrorPassthroughRule{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*ErrorPassthroughRule).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &ErrorPassthroughRule{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *ErrorPassthroughRuleQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *ErrorPassthroughRuleQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, errorpassthroughrule.FieldID) + for i := range fields { + if fields[i] != errorpassthroughrule.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *ErrorPassthroughRuleQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(errorpassthroughrule.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = errorpassthroughrule.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *ErrorPassthroughRuleQuery) ForUpdate(opts ...sql.LockOption) *ErrorPassthroughRuleQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *ErrorPassthroughRuleQuery) ForShare(opts ...sql.LockOption) *ErrorPassthroughRuleQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// ErrorPassthroughRuleGroupBy is the group-by builder for ErrorPassthroughRule entities. +type ErrorPassthroughRuleGroupBy struct { + selector + build *ErrorPassthroughRuleQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *ErrorPassthroughRuleGroupBy) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *ErrorPassthroughRuleGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ErrorPassthroughRuleQuery, *ErrorPassthroughRuleGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *ErrorPassthroughRuleGroupBy) sqlScan(ctx context.Context, root *ErrorPassthroughRuleQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// ErrorPassthroughRuleSelect is the builder for selecting fields of ErrorPassthroughRule entities. +type ErrorPassthroughRuleSelect struct { + *ErrorPassthroughRuleQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *ErrorPassthroughRuleSelect) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *ErrorPassthroughRuleSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ErrorPassthroughRuleQuery, *ErrorPassthroughRuleSelect](ctx, _s.ErrorPassthroughRuleQuery, _s, _s.inters, v) +} + +func (_s *ErrorPassthroughRuleSelect) sqlScan(ctx context.Context, root *ErrorPassthroughRuleQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/errorpassthroughrule_update.go b/backend/ent/errorpassthroughrule_update.go new file mode 100644 index 00000000..9d52aa49 --- /dev/null +++ b/backend/ent/errorpassthroughrule_update.go @@ -0,0 +1,823 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/dialect/sql/sqljson" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ErrorPassthroughRuleUpdate is the builder for updating ErrorPassthroughRule entities. +type ErrorPassthroughRuleUpdate struct { + config + hooks []Hook + mutation *ErrorPassthroughRuleMutation +} + +// Where appends a list predicates to the ErrorPassthroughRuleUpdate builder. +func (_u *ErrorPassthroughRuleUpdate) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *ErrorPassthroughRuleUpdate) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetName sets the "name" field. +func (_u *ErrorPassthroughRuleUpdate) SetName(v string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetName(v) + return _u +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableName(v *string) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetName(*v) + } + return _u +} + +// SetEnabled sets the "enabled" field. +func (_u *ErrorPassthroughRuleUpdate) SetEnabled(v bool) *ErrorPassthroughRuleUpdate { + _u.mutation.SetEnabled(v) + return _u +} + +// SetNillableEnabled sets the "enabled" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableEnabled(v *bool) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetEnabled(*v) + } + return _u +} + +// SetPriority sets the "priority" field. +func (_u *ErrorPassthroughRuleUpdate) SetPriority(v int) *ErrorPassthroughRuleUpdate { + _u.mutation.ResetPriority() + _u.mutation.SetPriority(v) + return _u +} + +// SetNillablePriority sets the "priority" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillablePriority(v *int) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetPriority(*v) + } + return _u +} + +// AddPriority adds value to the "priority" field. +func (_u *ErrorPassthroughRuleUpdate) AddPriority(v int) *ErrorPassthroughRuleUpdate { + _u.mutation.AddPriority(v) + return _u +} + +// SetErrorCodes sets the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdate) SetErrorCodes(v []int) *ErrorPassthroughRuleUpdate { + _u.mutation.SetErrorCodes(v) + return _u +} + +// AppendErrorCodes appends value to the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdate) AppendErrorCodes(v []int) *ErrorPassthroughRuleUpdate { + _u.mutation.AppendErrorCodes(v) + return _u +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdate) ClearErrorCodes() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearErrorCodes() + return _u +} + +// SetKeywords sets the "keywords" field. +func (_u *ErrorPassthroughRuleUpdate) SetKeywords(v []string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetKeywords(v) + return _u +} + +// AppendKeywords appends value to the "keywords" field. +func (_u *ErrorPassthroughRuleUpdate) AppendKeywords(v []string) *ErrorPassthroughRuleUpdate { + _u.mutation.AppendKeywords(v) + return _u +} + +// ClearKeywords clears the value of the "keywords" field. +func (_u *ErrorPassthroughRuleUpdate) ClearKeywords() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearKeywords() + return _u +} + +// SetMatchMode sets the "match_mode" field. +func (_u *ErrorPassthroughRuleUpdate) SetMatchMode(v string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetMatchMode(v) + return _u +} + +// SetNillableMatchMode sets the "match_mode" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableMatchMode(v *string) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetMatchMode(*v) + } + return _u +} + +// SetPlatforms sets the "platforms" field. +func (_u *ErrorPassthroughRuleUpdate) SetPlatforms(v []string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetPlatforms(v) + return _u +} + +// AppendPlatforms appends value to the "platforms" field. +func (_u *ErrorPassthroughRuleUpdate) AppendPlatforms(v []string) *ErrorPassthroughRuleUpdate { + _u.mutation.AppendPlatforms(v) + return _u +} + +// ClearPlatforms clears the value of the "platforms" field. +func (_u *ErrorPassthroughRuleUpdate) ClearPlatforms() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearPlatforms() + return _u +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (_u *ErrorPassthroughRuleUpdate) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpdate { + _u.mutation.SetPassthroughCode(v) + return _u +} + +// SetNillablePassthroughCode sets the "passthrough_code" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillablePassthroughCode(v *bool) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetPassthroughCode(*v) + } + return _u +} + +// SetResponseCode sets the "response_code" field. +func (_u *ErrorPassthroughRuleUpdate) SetResponseCode(v int) *ErrorPassthroughRuleUpdate { + _u.mutation.ResetResponseCode() + _u.mutation.SetResponseCode(v) + return _u +} + +// SetNillableResponseCode sets the "response_code" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableResponseCode(v *int) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetResponseCode(*v) + } + return _u +} + +// AddResponseCode adds value to the "response_code" field. +func (_u *ErrorPassthroughRuleUpdate) AddResponseCode(v int) *ErrorPassthroughRuleUpdate { + _u.mutation.AddResponseCode(v) + return _u +} + +// ClearResponseCode clears the value of the "response_code" field. +func (_u *ErrorPassthroughRuleUpdate) ClearResponseCode() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearResponseCode() + return _u +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (_u *ErrorPassthroughRuleUpdate) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpdate { + _u.mutation.SetPassthroughBody(v) + return _u +} + +// SetNillablePassthroughBody sets the "passthrough_body" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillablePassthroughBody(v *bool) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetPassthroughBody(*v) + } + return _u +} + +// SetCustomMessage sets the "custom_message" field. +func (_u *ErrorPassthroughRuleUpdate) SetCustomMessage(v string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetCustomMessage(v) + return _u +} + +// SetNillableCustomMessage sets the "custom_message" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableCustomMessage(v *string) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetCustomMessage(*v) + } + return _u +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (_u *ErrorPassthroughRuleUpdate) ClearCustomMessage() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearCustomMessage() + return _u +} + +// SetDescription sets the "description" field. +func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetDescription(v) + return _u +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableDescription(v *string) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetDescription(*v) + } + return _u +} + +// ClearDescription clears the value of the "description" field. +func (_u *ErrorPassthroughRuleUpdate) ClearDescription() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearDescription() + return _u +} + +// Mutation returns the ErrorPassthroughRuleMutation object of the builder. +func (_u *ErrorPassthroughRuleUpdate) Mutation() *ErrorPassthroughRuleMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *ErrorPassthroughRuleUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ErrorPassthroughRuleUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *ErrorPassthroughRuleUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ErrorPassthroughRuleUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *ErrorPassthroughRuleUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := errorpassthroughrule.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ErrorPassthroughRuleUpdate) check() error { + if v, ok := _u.mutation.Name(); ok { + if err := errorpassthroughrule.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.name": %w`, err)} + } + } + if v, ok := _u.mutation.MatchMode(); ok { + if err := errorpassthroughrule.MatchModeValidator(v); err != nil { + return &ValidationError{Name: "match_mode", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.match_mode": %w`, err)} + } + } + return nil +} + +func (_u *ErrorPassthroughRuleUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(errorpassthroughrule.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Name(); ok { + _spec.SetField(errorpassthroughrule.FieldName, field.TypeString, value) + } + if value, ok := _u.mutation.Enabled(); ok { + _spec.SetField(errorpassthroughrule.FieldEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.Priority(); ok { + _spec.SetField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedPriority(); ok { + _spec.AddField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + } + if value, ok := _u.mutation.ErrorCodes(); ok { + _spec.SetField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedErrorCodes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldErrorCodes, value) + }) + } + if _u.mutation.ErrorCodesCleared() { + _spec.ClearField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON) + } + if value, ok := _u.mutation.Keywords(); ok { + _spec.SetField(errorpassthroughrule.FieldKeywords, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedKeywords(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldKeywords, value) + }) + } + if _u.mutation.KeywordsCleared() { + _spec.ClearField(errorpassthroughrule.FieldKeywords, field.TypeJSON) + } + if value, ok := _u.mutation.MatchMode(); ok { + _spec.SetField(errorpassthroughrule.FieldMatchMode, field.TypeString, value) + } + if value, ok := _u.mutation.Platforms(); ok { + _spec.SetField(errorpassthroughrule.FieldPlatforms, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedPlatforms(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldPlatforms, value) + }) + } + if _u.mutation.PlatformsCleared() { + _spec.ClearField(errorpassthroughrule.FieldPlatforms, field.TypeJSON) + } + if value, ok := _u.mutation.PassthroughCode(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughCode, field.TypeBool, value) + } + if value, ok := _u.mutation.ResponseCode(); ok { + _spec.SetField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseCode(); ok { + _spec.AddField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + } + if _u.mutation.ResponseCodeCleared() { + _spec.ClearField(errorpassthroughrule.FieldResponseCode, field.TypeInt) + } + if value, ok := _u.mutation.PassthroughBody(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughBody, field.TypeBool, value) + } + if value, ok := _u.mutation.CustomMessage(); ok { + _spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value) + } + if _u.mutation.CustomMessageCleared() { + _spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString) + } + if value, ok := _u.mutation.Description(); ok { + _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) + } + if _u.mutation.DescriptionCleared() { + _spec.ClearField(errorpassthroughrule.FieldDescription, field.TypeString) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{errorpassthroughrule.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// ErrorPassthroughRuleUpdateOne is the builder for updating a single ErrorPassthroughRule entity. +type ErrorPassthroughRuleUpdateOne struct { + config + fields []string + hooks []Hook + mutation *ErrorPassthroughRuleMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetName sets the "name" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetName(v string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetName(v) + return _u +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableName(v *string) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetName(*v) + } + return _u +} + +// SetEnabled sets the "enabled" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetEnabled(v bool) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetEnabled(v) + return _u +} + +// SetNillableEnabled sets the "enabled" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableEnabled(v *bool) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetEnabled(*v) + } + return _u +} + +// SetPriority sets the "priority" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetPriority(v int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.ResetPriority() + _u.mutation.SetPriority(v) + return _u +} + +// SetNillablePriority sets the "priority" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePriority(v *int) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetPriority(*v) + } + return _u +} + +// AddPriority adds value to the "priority" field. +func (_u *ErrorPassthroughRuleUpdateOne) AddPriority(v int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AddPriority(v) + return _u +} + +// SetErrorCodes sets the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetErrorCodes(v []int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetErrorCodes(v) + return _u +} + +// AppendErrorCodes appends value to the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdateOne) AppendErrorCodes(v []int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AppendErrorCodes(v) + return _u +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearErrorCodes() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearErrorCodes() + return _u +} + +// SetKeywords sets the "keywords" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetKeywords(v []string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetKeywords(v) + return _u +} + +// AppendKeywords appends value to the "keywords" field. +func (_u *ErrorPassthroughRuleUpdateOne) AppendKeywords(v []string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AppendKeywords(v) + return _u +} + +// ClearKeywords clears the value of the "keywords" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearKeywords() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearKeywords() + return _u +} + +// SetMatchMode sets the "match_mode" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetMatchMode(v string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetMatchMode(v) + return _u +} + +// SetNillableMatchMode sets the "match_mode" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableMatchMode(v *string) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetMatchMode(*v) + } + return _u +} + +// SetPlatforms sets the "platforms" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetPlatforms(v []string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetPlatforms(v) + return _u +} + +// AppendPlatforms appends value to the "platforms" field. +func (_u *ErrorPassthroughRuleUpdateOne) AppendPlatforms(v []string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AppendPlatforms(v) + return _u +} + +// ClearPlatforms clears the value of the "platforms" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearPlatforms() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearPlatforms() + return _u +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetPassthroughCode(v) + return _u +} + +// SetNillablePassthroughCode sets the "passthrough_code" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePassthroughCode(v *bool) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetPassthroughCode(*v) + } + return _u +} + +// SetResponseCode sets the "response_code" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetResponseCode(v int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.ResetResponseCode() + _u.mutation.SetResponseCode(v) + return _u +} + +// SetNillableResponseCode sets the "response_code" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableResponseCode(v *int) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetResponseCode(*v) + } + return _u +} + +// AddResponseCode adds value to the "response_code" field. +func (_u *ErrorPassthroughRuleUpdateOne) AddResponseCode(v int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AddResponseCode(v) + return _u +} + +// ClearResponseCode clears the value of the "response_code" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearResponseCode() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearResponseCode() + return _u +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetPassthroughBody(v) + return _u +} + +// SetNillablePassthroughBody sets the "passthrough_body" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePassthroughBody(v *bool) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetPassthroughBody(*v) + } + return _u +} + +// SetCustomMessage sets the "custom_message" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetCustomMessage(v string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetCustomMessage(v) + return _u +} + +// SetNillableCustomMessage sets the "custom_message" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableCustomMessage(v *string) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetCustomMessage(*v) + } + return _u +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearCustomMessage() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearCustomMessage() + return _u +} + +// SetDescription sets the "description" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetDescription(v) + return _u +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableDescription(v *string) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetDescription(*v) + } + return _u +} + +// ClearDescription clears the value of the "description" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearDescription() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearDescription() + return _u +} + +// Mutation returns the ErrorPassthroughRuleMutation object of the builder. +func (_u *ErrorPassthroughRuleUpdateOne) Mutation() *ErrorPassthroughRuleMutation { + return _u.mutation +} + +// Where appends a list predicates to the ErrorPassthroughRuleUpdate builder. +func (_u *ErrorPassthroughRuleUpdateOne) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *ErrorPassthroughRuleUpdateOne) Select(field string, fields ...string) *ErrorPassthroughRuleUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated ErrorPassthroughRule entity. +func (_u *ErrorPassthroughRuleUpdateOne) Save(ctx context.Context) (*ErrorPassthroughRule, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ErrorPassthroughRuleUpdateOne) SaveX(ctx context.Context) *ErrorPassthroughRule { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *ErrorPassthroughRuleUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ErrorPassthroughRuleUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *ErrorPassthroughRuleUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := errorpassthroughrule.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ErrorPassthroughRuleUpdateOne) check() error { + if v, ok := _u.mutation.Name(); ok { + if err := errorpassthroughrule.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.name": %w`, err)} + } + } + if v, ok := _u.mutation.MatchMode(); ok { + if err := errorpassthroughrule.MatchModeValidator(v); err != nil { + return &ValidationError{Name: "match_mode", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.match_mode": %w`, err)} + } + } + return nil +} + +func (_u *ErrorPassthroughRuleUpdateOne) sqlSave(ctx context.Context) (_node *ErrorPassthroughRule, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ErrorPassthroughRule.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, errorpassthroughrule.FieldID) + for _, f := range fields { + if !errorpassthroughrule.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != errorpassthroughrule.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(errorpassthroughrule.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Name(); ok { + _spec.SetField(errorpassthroughrule.FieldName, field.TypeString, value) + } + if value, ok := _u.mutation.Enabled(); ok { + _spec.SetField(errorpassthroughrule.FieldEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.Priority(); ok { + _spec.SetField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedPriority(); ok { + _spec.AddField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + } + if value, ok := _u.mutation.ErrorCodes(); ok { + _spec.SetField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedErrorCodes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldErrorCodes, value) + }) + } + if _u.mutation.ErrorCodesCleared() { + _spec.ClearField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON) + } + if value, ok := _u.mutation.Keywords(); ok { + _spec.SetField(errorpassthroughrule.FieldKeywords, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedKeywords(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldKeywords, value) + }) + } + if _u.mutation.KeywordsCleared() { + _spec.ClearField(errorpassthroughrule.FieldKeywords, field.TypeJSON) + } + if value, ok := _u.mutation.MatchMode(); ok { + _spec.SetField(errorpassthroughrule.FieldMatchMode, field.TypeString, value) + } + if value, ok := _u.mutation.Platforms(); ok { + _spec.SetField(errorpassthroughrule.FieldPlatforms, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedPlatforms(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldPlatforms, value) + }) + } + if _u.mutation.PlatformsCleared() { + _spec.ClearField(errorpassthroughrule.FieldPlatforms, field.TypeJSON) + } + if value, ok := _u.mutation.PassthroughCode(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughCode, field.TypeBool, value) + } + if value, ok := _u.mutation.ResponseCode(); ok { + _spec.SetField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseCode(); ok { + _spec.AddField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + } + if _u.mutation.ResponseCodeCleared() { + _spec.ClearField(errorpassthroughrule.FieldResponseCode, field.TypeInt) + } + if value, ok := _u.mutation.PassthroughBody(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughBody, field.TypeBool, value) + } + if value, ok := _u.mutation.CustomMessage(); ok { + _spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value) + } + if _u.mutation.CustomMessageCleared() { + _spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString) + } + if value, ok := _u.mutation.Description(); ok { + _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) + } + if _u.mutation.DescriptionCleared() { + _spec.ClearField(errorpassthroughrule.FieldDescription, field.TypeString) + } + _node = &ErrorPassthroughRule{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{errorpassthroughrule.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/group.go b/backend/ent/group.go index 0d0c0538..1eb05e0e 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -56,10 +56,16 @@ type Group struct { ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + // 无效请求兜底使用的分组 ID + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` // 模型路由配置:模型模式 -> 优先账号ID列表 ModelRouting map[string][]int64 `json:"model_routing,omitempty"` // 是否启用模型路由配置 ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"` + // 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台) + McpXMLInject bool `json:"mcp_xml_inject,omitempty"` + // 支持的模型系列:claude, gemini_text, gemini_image + SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -166,13 +172,13 @@ func (*Group) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case group.FieldModelRouting: + case group.FieldModelRouting, group.FieldSupportedModelScopes: values[i] = new([]byte) - case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled: + case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject: values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: values[i] = new(sql.NullString) @@ -322,6 +328,13 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.FallbackGroupID = new(int64) *_m.FallbackGroupID = value.Int64 } + case group.FieldFallbackGroupIDOnInvalidRequest: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field fallback_group_id_on_invalid_request", values[i]) + } else if value.Valid { + _m.FallbackGroupIDOnInvalidRequest = new(int64) + *_m.FallbackGroupIDOnInvalidRequest = value.Int64 + } case group.FieldModelRouting: if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field model_routing", values[i]) @@ -336,6 +349,20 @@ func (_m *Group) assignValues(columns []string, values []any) error { } else if value.Valid { _m.ModelRoutingEnabled = value.Bool } + case group.FieldMcpXMLInject: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field mcp_xml_inject", values[i]) + } else if value.Valid { + _m.McpXMLInject = value.Bool + } + case group.FieldSupportedModelScopes: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field supported_model_scopes", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.SupportedModelScopes); err != nil { + return fmt.Errorf("unmarshal field supported_model_scopes: %w", err) + } + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -487,11 +514,22 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") + if v := _m.FallbackGroupIDOnInvalidRequest; v != nil { + builder.WriteString("fallback_group_id_on_invalid_request=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") builder.WriteString("model_routing=") builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting)) builder.WriteString(", ") builder.WriteString("model_routing_enabled=") builder.WriteString(fmt.Sprintf("%v", _m.ModelRoutingEnabled)) + builder.WriteString(", ") + builder.WriteString("mcp_xml_inject=") + builder.WriteString(fmt.Sprintf("%v", _m.McpXMLInject)) + builder.WriteString(", ") + builder.WriteString("supported_model_scopes=") + builder.WriteString(fmt.Sprintf("%v", _m.SupportedModelScopes)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index d66d3edc..278b2daf 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -53,10 +53,16 @@ const ( FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. FieldFallbackGroupID = "fallback_group_id" + // FieldFallbackGroupIDOnInvalidRequest holds the string denoting the fallback_group_id_on_invalid_request field in the database. + FieldFallbackGroupIDOnInvalidRequest = "fallback_group_id_on_invalid_request" // FieldModelRouting holds the string denoting the model_routing field in the database. FieldModelRouting = "model_routing" // FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database. FieldModelRoutingEnabled = "model_routing_enabled" + // FieldMcpXMLInject holds the string denoting the mcp_xml_inject field in the database. + FieldMcpXMLInject = "mcp_xml_inject" + // FieldSupportedModelScopes holds the string denoting the supported_model_scopes field in the database. + FieldSupportedModelScopes = "supported_model_scopes" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -151,8 +157,11 @@ var Columns = []string{ FieldImagePrice4k, FieldClaudeCodeOnly, FieldFallbackGroupID, + FieldFallbackGroupIDOnInvalidRequest, FieldModelRouting, FieldModelRoutingEnabled, + FieldMcpXMLInject, + FieldSupportedModelScopes, } var ( @@ -212,6 +221,10 @@ var ( DefaultClaudeCodeOnly bool // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. DefaultModelRoutingEnabled bool + // DefaultMcpXMLInject holds the default value on creation for the "mcp_xml_inject" field. + DefaultMcpXMLInject bool + // DefaultSupportedModelScopes holds the default value on creation for the "supported_model_scopes" field. + DefaultSupportedModelScopes []string ) // OrderOption defines the ordering options for the Group queries. @@ -317,11 +330,21 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc() } +// ByFallbackGroupIDOnInvalidRequest orders the results by the fallback_group_id_on_invalid_request field. +func ByFallbackGroupIDOnInvalidRequest(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFallbackGroupIDOnInvalidRequest, opts...).ToFunc() +} + // ByModelRoutingEnabled orders the results by the model_routing_enabled field. func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc() } +// ByMcpXMLInject orders the results by the mcp_xml_inject field. +func ByMcpXMLInject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMcpXMLInject, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 6ce9e4c6..b6fa2c33 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -150,11 +150,21 @@ func FallbackGroupID(v int64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v)) } +// FallbackGroupIDOnInvalidRequest applies equality check predicate on the "fallback_group_id_on_invalid_request" field. It's identical to FallbackGroupIDOnInvalidRequestEQ. +func FallbackGroupIDOnInvalidRequest(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v)) +} + // ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ. func ModelRoutingEnabled(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v)) } +// McpXMLInject applies equality check predicate on the "mcp_xml_inject" field. It's identical to McpXMLInjectEQ. +func McpXMLInject(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -1070,6 +1080,56 @@ func FallbackGroupIDNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID)) } +// FallbackGroupIDOnInvalidRequestEQ applies the EQ predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestNEQ applies the NEQ predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestNEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestIn applies the In predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldFallbackGroupIDOnInvalidRequest, vs...)) +} + +// FallbackGroupIDOnInvalidRequestNotIn applies the NotIn predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestNotIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldFallbackGroupIDOnInvalidRequest, vs...)) +} + +// FallbackGroupIDOnInvalidRequestGT applies the GT predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestGT(v int64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestGTE applies the GTE predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestGTE(v int64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestLT applies the LT predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestLT(v int64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestLTE applies the LTE predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestLTE(v int64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestIsNil applies the IsNil predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldFallbackGroupIDOnInvalidRequest)) +} + +// FallbackGroupIDOnInvalidRequestNotNil applies the NotNil predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldFallbackGroupIDOnInvalidRequest)) +} + // ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field. func ModelRoutingIsNil() predicate.Group { return predicate.Group(sql.FieldIsNull(FieldModelRouting)) @@ -1090,6 +1150,16 @@ func ModelRoutingEnabledNEQ(v bool) predicate.Group { return predicate.Group(sql.FieldNEQ(FieldModelRoutingEnabled, v)) } +// McpXMLInjectEQ applies the EQ predicate on the "mcp_xml_inject" field. +func McpXMLInjectEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v)) +} + +// McpXMLInjectNEQ applies the NEQ predicate on the "mcp_xml_inject" field. +func McpXMLInjectNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldMcpXMLInject, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 0f251e0b..9d845b61 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -286,6 +286,20 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate { return _c } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (_c *GroupCreate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupCreate { + _c.mutation.SetFallbackGroupIDOnInvalidRequest(v) + return _c +} + +// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil. +func (_c *GroupCreate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupCreate { + if v != nil { + _c.SetFallbackGroupIDOnInvalidRequest(*v) + } + return _c +} + // SetModelRouting sets the "model_routing" field. func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate { _c.mutation.SetModelRouting(v) @@ -306,6 +320,26 @@ func (_c *GroupCreate) SetNillableModelRoutingEnabled(v *bool) *GroupCreate { return _c } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (_c *GroupCreate) SetMcpXMLInject(v bool) *GroupCreate { + _c.mutation.SetMcpXMLInject(v) + return _c +} + +// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil. +func (_c *GroupCreate) SetNillableMcpXMLInject(v *bool) *GroupCreate { + if v != nil { + _c.SetMcpXMLInject(*v) + } + return _c +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (_c *GroupCreate) SetSupportedModelScopes(v []string) *GroupCreate { + _c.mutation.SetSupportedModelScopes(v) + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -479,6 +513,14 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultModelRoutingEnabled _c.mutation.SetModelRoutingEnabled(v) } + if _, ok := _c.mutation.McpXMLInject(); !ok { + v := group.DefaultMcpXMLInject + _c.mutation.SetMcpXMLInject(v) + } + if _, ok := _c.mutation.SupportedModelScopes(); !ok { + v := group.DefaultSupportedModelScopes + _c.mutation.SetSupportedModelScopes(v) + } return nil } @@ -537,6 +579,12 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.ModelRoutingEnabled(); !ok { return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)} } + if _, ok := _c.mutation.McpXMLInject(); !ok { + return &ValidationError{Name: "mcp_xml_inject", err: errors.New(`ent: missing required field "Group.mcp_xml_inject"`)} + } + if _, ok := _c.mutation.SupportedModelScopes(); !ok { + return &ValidationError{Name: "supported_model_scopes", err: errors.New(`ent: missing required field "Group.supported_model_scopes"`)} + } return nil } @@ -640,6 +688,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value) _node.FallbackGroupID = &value } + if value, ok := _c.mutation.FallbackGroupIDOnInvalidRequest(); ok { + _spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + _node.FallbackGroupIDOnInvalidRequest = &value + } if value, ok := _c.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) _node.ModelRouting = value @@ -648,6 +700,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value) _node.ModelRoutingEnabled = value } + if value, ok := _c.mutation.McpXMLInject(); ok { + _spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value) + _node.McpXMLInject = value + } + if value, ok := _c.mutation.SupportedModelScopes(); ok { + _spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value) + _node.SupportedModelScopes = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1128,6 +1188,30 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert { return u } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsert) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert { + u.Set(group.FieldFallbackGroupIDOnInvalidRequest, v) + return u +} + +// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create. +func (u *GroupUpsert) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsert { + u.SetExcluded(group.FieldFallbackGroupIDOnInvalidRequest) + return u +} + +// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsert) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert { + u.Add(group.FieldFallbackGroupIDOnInvalidRequest, v) + return u +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsert) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsert { + u.SetNull(group.FieldFallbackGroupIDOnInvalidRequest) + return u +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert { u.Set(group.FieldModelRouting, v) @@ -1158,6 +1242,30 @@ func (u *GroupUpsert) UpdateModelRoutingEnabled() *GroupUpsert { return u } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (u *GroupUpsert) SetMcpXMLInject(v bool) *GroupUpsert { + u.Set(group.FieldMcpXMLInject, v) + return u +} + +// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create. +func (u *GroupUpsert) UpdateMcpXMLInject() *GroupUpsert { + u.SetExcluded(group.FieldMcpXMLInject) + return u +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (u *GroupUpsert) SetSupportedModelScopes(v []string) *GroupUpsert { + u.Set(group.FieldSupportedModelScopes, v) + return u +} + +// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSupportedModelScopes() *GroupUpsert { + u.SetExcluded(group.FieldSupportedModelScopes) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1581,6 +1689,34 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne { }) } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetFallbackGroupIDOnInvalidRequest(v) + }) +} + +// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddFallbackGroupIDOnInvalidRequest(v) + }) +} + +// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateFallbackGroupIDOnInvalidRequest() + }) +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearFallbackGroupIDOnInvalidRequest() + }) +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -1616,6 +1752,34 @@ func (u *GroupUpsertOne) UpdateModelRoutingEnabled() *GroupUpsertOne { }) } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (u *GroupUpsertOne) SetMcpXMLInject(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetMcpXMLInject(v) + }) +} + +// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateMcpXMLInject() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateMcpXMLInject() + }) +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (u *GroupUpsertOne) SetSupportedModelScopes(v []string) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSupportedModelScopes(v) + }) +} + +// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSupportedModelScopes() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSupportedModelScopes() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2205,6 +2369,34 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk { }) } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertBulk) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetFallbackGroupIDOnInvalidRequest(v) + }) +} + +// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertBulk) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddFallbackGroupIDOnInvalidRequest(v) + }) +} + +// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateFallbackGroupIDOnInvalidRequest() + }) +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertBulk) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearFallbackGroupIDOnInvalidRequest() + }) +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { @@ -2240,6 +2432,34 @@ func (u *GroupUpsertBulk) UpdateModelRoutingEnabled() *GroupUpsertBulk { }) } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (u *GroupUpsertBulk) SetMcpXMLInject(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetMcpXMLInject(v) + }) +} + +// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateMcpXMLInject() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateMcpXMLInject() + }) +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (u *GroupUpsertBulk) SetSupportedModelScopes(v []string) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSupportedModelScopes(v) + }) +} + +// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSupportedModelScopes() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSupportedModelScopes() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index c3cc2708..9e7246ea 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -10,6 +10,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/dialect/sql/sqljson" "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/apikey" @@ -395,6 +396,33 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate { return _u } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate { + _u.mutation.ResetFallbackGroupIDOnInvalidRequest() + _u.mutation.SetFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdate { + if v != nil { + _u.SetFallbackGroupIDOnInvalidRequest(*v) + } + return _u +} + +// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdate) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate { + _u.mutation.AddFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdate) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdate { + _u.mutation.ClearFallbackGroupIDOnInvalidRequest() + return _u +} + // SetModelRouting sets the "model_routing" field. func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate { _u.mutation.SetModelRouting(v) @@ -421,6 +449,32 @@ func (_u *GroupUpdate) SetNillableModelRoutingEnabled(v *bool) *GroupUpdate { return _u } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (_u *GroupUpdate) SetMcpXMLInject(v bool) *GroupUpdate { + _u.mutation.SetMcpXMLInject(v) + return _u +} + +// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableMcpXMLInject(v *bool) *GroupUpdate { + if v != nil { + _u.SetMcpXMLInject(*v) + } + return _u +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (_u *GroupUpdate) SetSupportedModelScopes(v []string) *GroupUpdate { + _u.mutation.SetSupportedModelScopes(v) + return _u +} + +// AppendSupportedModelScopes appends value to the "supported_model_scopes" field. +func (_u *GroupUpdate) AppendSupportedModelScopes(v []string) *GroupUpdate { + _u.mutation.AppendSupportedModelScopes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -829,6 +883,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.FallbackGroupIDCleared() { _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) } + if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok { + _spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok { + _spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() { + _spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64) + } if value, ok := _u.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) } @@ -838,6 +901,17 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.ModelRoutingEnabled(); ok { _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.McpXMLInject(); ok { + _spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value) + } + if value, ok := _u.mutation.SupportedModelScopes(); ok { + _spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedSupportedModelScopes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, group.FieldSupportedModelScopes, value) + }) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1513,6 +1587,33 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne { return _u } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdateOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne { + _u.mutation.ResetFallbackGroupIDOnInvalidRequest() + _u.mutation.SetFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdateOne { + if v != nil { + _u.SetFallbackGroupIDOnInvalidRequest(*v) + } + return _u +} + +// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdateOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne { + _u.mutation.AddFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdateOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdateOne { + _u.mutation.ClearFallbackGroupIDOnInvalidRequest() + return _u +} + // SetModelRouting sets the "model_routing" field. func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne { _u.mutation.SetModelRouting(v) @@ -1539,6 +1640,32 @@ func (_u *GroupUpdateOne) SetNillableModelRoutingEnabled(v *bool) *GroupUpdateOn return _u } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (_u *GroupUpdateOne) SetMcpXMLInject(v bool) *GroupUpdateOne { + _u.mutation.SetMcpXMLInject(v) + return _u +} + +// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableMcpXMLInject(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetMcpXMLInject(*v) + } + return _u +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (_u *GroupUpdateOne) SetSupportedModelScopes(v []string) *GroupUpdateOne { + _u.mutation.SetSupportedModelScopes(v) + return _u +} + +// AppendSupportedModelScopes appends value to the "supported_model_scopes" field. +func (_u *GroupUpdateOne) AppendSupportedModelScopes(v []string) *GroupUpdateOne { + _u.mutation.AppendSupportedModelScopes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1977,6 +2104,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.FallbackGroupIDCleared() { _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) } + if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok { + _spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok { + _spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() { + _spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64) + } if value, ok := _u.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) } @@ -1986,6 +2122,17 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.ModelRoutingEnabled(); ok { _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.McpXMLInject(); ok { + _spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value) + } + if value, ok := _u.mutation.SupportedModelScopes(); ok { + _spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedSupportedModelScopes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, group.FieldSupportedModelScopes, value) + }) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 1e653c77..1b15685c 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -69,6 +69,18 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m) } +// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary +// function as ErrorPassthroughRule mutator. +type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f ErrorPassthroughRuleFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.ErrorPassthroughRuleMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ErrorPassthroughRuleMutation", m) +} + // The GroupFunc type is an adapter to allow the use of ordinary // function as Group mutator. type GroupFunc func(context.Context, *ent.GroupMutation) (ent.Value, error) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index a37be48f..8ee42db3 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -13,6 +13,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" @@ -220,6 +221,33 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q) } +// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier. +type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f ErrorPassthroughRuleFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.ErrorPassthroughRuleQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.ErrorPassthroughRuleQuery", q) +} + +// The TraverseErrorPassthroughRule type is an adapter to allow the use of ordinary function as Traverser. +type TraverseErrorPassthroughRule func(context.Context, *ent.ErrorPassthroughRuleQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseErrorPassthroughRule) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseErrorPassthroughRule) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.ErrorPassthroughRuleQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.ErrorPassthroughRuleQuery", q) +} + // The GroupFunc type is an adapter to allow the use of ordinary function as a Querier. type GroupFunc func(context.Context, *ent.GroupQuery) (ent.Value, error) @@ -584,6 +612,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil case *ent.AnnouncementReadQuery: return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil + case *ent.ErrorPassthroughRuleQuery: + return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil case *ent.GroupQuery: return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil case *ent.PromoCodeQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index e2ed7340..f9e90d73 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -20,6 +20,9 @@ var ( {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, {Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true}, {Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true}, + {Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "quota_used", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "expires_at", Type: field.TypeTime, Nullable: true}, {Name: "group_id", Type: field.TypeInt64, Nullable: true}, {Name: "user_id", Type: field.TypeInt64}, } @@ -31,13 +34,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "api_keys_groups_api_keys", - Columns: []*schema.Column{APIKeysColumns[9]}, + Columns: []*schema.Column{APIKeysColumns[12]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "api_keys_users_api_keys", - Columns: []*schema.Column{APIKeysColumns[10]}, + Columns: []*schema.Column{APIKeysColumns[13]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -46,12 +49,12 @@ var ( { Name: "apikey_user_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[10]}, + Columns: []*schema.Column{APIKeysColumns[13]}, }, { Name: "apikey_group_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[9]}, + Columns: []*schema.Column{APIKeysColumns[12]}, }, { Name: "apikey_status", @@ -63,6 +66,16 @@ var ( Unique: false, Columns: []*schema.Column{APIKeysColumns[3]}, }, + { + Name: "apikey_quota_quota_used", + Unique: false, + Columns: []*schema.Column{APIKeysColumns[9], APIKeysColumns[10]}, + }, + { + Name: "apikey_expires_at", + Unique: false, + Columns: []*schema.Column{APIKeysColumns[11]}, + }, }, } // AccountsColumns holds the columns for the "accounts" table. @@ -296,6 +309,42 @@ var ( }, }, } + // ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table. + ErrorPassthroughRulesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "name", Type: field.TypeString, Size: 100}, + {Name: "enabled", Type: field.TypeBool, Default: true}, + {Name: "priority", Type: field.TypeInt, Default: 0}, + {Name: "error_codes", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "keywords", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "match_mode", Type: field.TypeString, Size: 10, Default: "any"}, + {Name: "platforms", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "passthrough_code", Type: field.TypeBool, Default: true}, + {Name: "response_code", Type: field.TypeInt, Nullable: true}, + {Name: "passthrough_body", Type: field.TypeBool, Default: true}, + {Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647}, + {Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647}, + } + // ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table. + ErrorPassthroughRulesTable = &schema.Table{ + Name: "error_passthrough_rules", + Columns: ErrorPassthroughRulesColumns, + PrimaryKey: []*schema.Column{ErrorPassthroughRulesColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "errorpassthroughrule_enabled", + Unique: false, + Columns: []*schema.Column{ErrorPassthroughRulesColumns[4]}, + }, + { + Name: "errorpassthroughrule_priority", + Unique: false, + Columns: []*schema.Column{ErrorPassthroughRulesColumns[5]}, + }, + }, + } // GroupsColumns holds the columns for the "groups" table. GroupsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -318,8 +367,11 @@ var ( {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, + {Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true}, {Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "model_routing_enabled", Type: field.TypeBool, Default: false}, + {Name: "mcp_xml_inject", Type: field.TypeBool, Default: true}, + {Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ @@ -934,6 +986,7 @@ var ( AccountGroupsTable, AnnouncementsTable, AnnouncementReadsTable, + ErrorPassthroughRulesTable, GroupsTable, PromoCodesTable, PromoCodeUsagesTable, @@ -973,6 +1026,9 @@ func init() { AnnouncementReadsTable.Annotation = &entsql.Annotation{ Table: "announcement_reads", } + ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{ + Table: "error_passthrough_rules", + } GroupsTable.Annotation = &entsql.Annotation{ Table: "groups", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 38e0c7e5..5c182dea 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -17,6 +17,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" @@ -48,6 +49,7 @@ const ( TypeAccountGroup = "AccountGroup" TypeAnnouncement = "Announcement" TypeAnnouncementRead = "AnnouncementRead" + TypeErrorPassthroughRule = "ErrorPassthroughRule" TypeGroup = "Group" TypePromoCode = "PromoCode" TypePromoCodeUsage = "PromoCodeUsage" @@ -79,6 +81,11 @@ type APIKeyMutation struct { appendip_whitelist []string ip_blacklist *[]string appendip_blacklist []string + quota *float64 + addquota *float64 + quota_used *float64 + addquota_used *float64 + expires_at *time.Time clearedFields map[string]struct{} user *int64 cleareduser bool @@ -634,6 +641,167 @@ func (m *APIKeyMutation) ResetIPBlacklist() { delete(m.clearedFields, apikey.FieldIPBlacklist) } +// SetQuota sets the "quota" field. +func (m *APIKeyMutation) SetQuota(f float64) { + m.quota = &f + m.addquota = nil +} + +// Quota returns the value of the "quota" field in the mutation. +func (m *APIKeyMutation) Quota() (r float64, exists bool) { + v := m.quota + if v == nil { + return + } + return *v, true +} + +// OldQuota returns the old "quota" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldQuota(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldQuota is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldQuota requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldQuota: %w", err) + } + return oldValue.Quota, nil +} + +// AddQuota adds f to the "quota" field. +func (m *APIKeyMutation) AddQuota(f float64) { + if m.addquota != nil { + *m.addquota += f + } else { + m.addquota = &f + } +} + +// AddedQuota returns the value that was added to the "quota" field in this mutation. +func (m *APIKeyMutation) AddedQuota() (r float64, exists bool) { + v := m.addquota + if v == nil { + return + } + return *v, true +} + +// ResetQuota resets all changes to the "quota" field. +func (m *APIKeyMutation) ResetQuota() { + m.quota = nil + m.addquota = nil +} + +// SetQuotaUsed sets the "quota_used" field. +func (m *APIKeyMutation) SetQuotaUsed(f float64) { + m.quota_used = &f + m.addquota_used = nil +} + +// QuotaUsed returns the value of the "quota_used" field in the mutation. +func (m *APIKeyMutation) QuotaUsed() (r float64, exists bool) { + v := m.quota_used + if v == nil { + return + } + return *v, true +} + +// OldQuotaUsed returns the old "quota_used" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldQuotaUsed(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldQuotaUsed is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldQuotaUsed requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldQuotaUsed: %w", err) + } + return oldValue.QuotaUsed, nil +} + +// AddQuotaUsed adds f to the "quota_used" field. +func (m *APIKeyMutation) AddQuotaUsed(f float64) { + if m.addquota_used != nil { + *m.addquota_used += f + } else { + m.addquota_used = &f + } +} + +// AddedQuotaUsed returns the value that was added to the "quota_used" field in this mutation. +func (m *APIKeyMutation) AddedQuotaUsed() (r float64, exists bool) { + v := m.addquota_used + if v == nil { + return + } + return *v, true +} + +// ResetQuotaUsed resets all changes to the "quota_used" field. +func (m *APIKeyMutation) ResetQuotaUsed() { + m.quota_used = nil + m.addquota_used = nil +} + +// SetExpiresAt sets the "expires_at" field. +func (m *APIKeyMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *APIKeyMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldExpiresAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (m *APIKeyMutation) ClearExpiresAt() { + m.expires_at = nil + m.clearedFields[apikey.FieldExpiresAt] = struct{}{} +} + +// ExpiresAtCleared returns if the "expires_at" field was cleared in this mutation. +func (m *APIKeyMutation) ExpiresAtCleared() bool { + _, ok := m.clearedFields[apikey.FieldExpiresAt] + return ok +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *APIKeyMutation) ResetExpiresAt() { + m.expires_at = nil + delete(m.clearedFields, apikey.FieldExpiresAt) +} + // ClearUser clears the "user" edge to the User entity. func (m *APIKeyMutation) ClearUser() { m.cleareduser = true @@ -776,7 +944,7 @@ func (m *APIKeyMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *APIKeyMutation) Fields() []string { - fields := make([]string, 0, 10) + fields := make([]string, 0, 13) if m.created_at != nil { fields = append(fields, apikey.FieldCreatedAt) } @@ -807,6 +975,15 @@ func (m *APIKeyMutation) Fields() []string { if m.ip_blacklist != nil { fields = append(fields, apikey.FieldIPBlacklist) } + if m.quota != nil { + fields = append(fields, apikey.FieldQuota) + } + if m.quota_used != nil { + fields = append(fields, apikey.FieldQuotaUsed) + } + if m.expires_at != nil { + fields = append(fields, apikey.FieldExpiresAt) + } return fields } @@ -835,6 +1012,12 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) { return m.IPWhitelist() case apikey.FieldIPBlacklist: return m.IPBlacklist() + case apikey.FieldQuota: + return m.Quota() + case apikey.FieldQuotaUsed: + return m.QuotaUsed() + case apikey.FieldExpiresAt: + return m.ExpiresAt() } return nil, false } @@ -864,6 +1047,12 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldIPWhitelist(ctx) case apikey.FieldIPBlacklist: return m.OldIPBlacklist(ctx) + case apikey.FieldQuota: + return m.OldQuota(ctx) + case apikey.FieldQuotaUsed: + return m.OldQuotaUsed(ctx) + case apikey.FieldExpiresAt: + return m.OldExpiresAt(ctx) } return nil, fmt.Errorf("unknown APIKey field %s", name) } @@ -943,6 +1132,27 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { } m.SetIPBlacklist(v) return nil + case apikey.FieldQuota: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetQuota(v) + return nil + case apikey.FieldQuotaUsed: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetQuotaUsed(v) + return nil + case apikey.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -951,6 +1161,12 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { // this mutation. func (m *APIKeyMutation) AddedFields() []string { var fields []string + if m.addquota != nil { + fields = append(fields, apikey.FieldQuota) + } + if m.addquota_used != nil { + fields = append(fields, apikey.FieldQuotaUsed) + } return fields } @@ -959,6 +1175,10 @@ func (m *APIKeyMutation) AddedFields() []string { // was not set, or was not defined in the schema. func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) { switch name { + case apikey.FieldQuota: + return m.AddedQuota() + case apikey.FieldQuotaUsed: + return m.AddedQuotaUsed() } return nil, false } @@ -968,6 +1188,20 @@ func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) { // type. func (m *APIKeyMutation) AddField(name string, value ent.Value) error { switch name { + case apikey.FieldQuota: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddQuota(v) + return nil + case apikey.FieldQuotaUsed: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddQuotaUsed(v) + return nil } return fmt.Errorf("unknown APIKey numeric field %s", name) } @@ -988,6 +1222,9 @@ func (m *APIKeyMutation) ClearedFields() []string { if m.FieldCleared(apikey.FieldIPBlacklist) { fields = append(fields, apikey.FieldIPBlacklist) } + if m.FieldCleared(apikey.FieldExpiresAt) { + fields = append(fields, apikey.FieldExpiresAt) + } return fields } @@ -1014,6 +1251,9 @@ func (m *APIKeyMutation) ClearField(name string) error { case apikey.FieldIPBlacklist: m.ClearIPBlacklist() return nil + case apikey.FieldExpiresAt: + m.ClearExpiresAt() + return nil } return fmt.Errorf("unknown APIKey nullable field %s", name) } @@ -1052,6 +1292,15 @@ func (m *APIKeyMutation) ResetField(name string) error { case apikey.FieldIPBlacklist: m.ResetIPBlacklist() return nil + case apikey.FieldQuota: + m.ResetQuota() + return nil + case apikey.FieldQuotaUsed: + m.ResetQuotaUsed() + return nil + case apikey.FieldExpiresAt: + m.ResetExpiresAt() + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -5503,64 +5752,1335 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error { return fmt.Errorf("unknown AnnouncementRead edge %s", name) } +// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph. +type ErrorPassthroughRuleMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + name *string + enabled *bool + priority *int + addpriority *int + error_codes *[]int + appenderror_codes []int + keywords *[]string + appendkeywords []string + match_mode *string + platforms *[]string + appendplatforms []string + passthrough_code *bool + response_code *int + addresponse_code *int + passthrough_body *bool + custom_message *string + description *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*ErrorPassthroughRule, error) + predicates []predicate.ErrorPassthroughRule +} + +var _ ent.Mutation = (*ErrorPassthroughRuleMutation)(nil) + +// errorpassthroughruleOption allows management of the mutation configuration using functional options. +type errorpassthroughruleOption func(*ErrorPassthroughRuleMutation) + +// newErrorPassthroughRuleMutation creates new mutation for the ErrorPassthroughRule entity. +func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughruleOption) *ErrorPassthroughRuleMutation { + m := &ErrorPassthroughRuleMutation{ + config: c, + op: op, + typ: TypeErrorPassthroughRule, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withErrorPassthroughRuleID sets the ID field of the mutation. +func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption { + return func(m *ErrorPassthroughRuleMutation) { + var ( + err error + once sync.Once + value *ErrorPassthroughRule + ) + m.oldValue = func(ctx context.Context) (*ErrorPassthroughRule, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().ErrorPassthroughRule.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withErrorPassthroughRule sets the old ErrorPassthroughRule of the mutation. +func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOption { + return func(m *ErrorPassthroughRuleMutation) { + m.oldValue = func(context.Context) (*ErrorPassthroughRule, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m ErrorPassthroughRuleMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().ErrorPassthroughRule.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *ErrorPassthroughRuleMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *ErrorPassthroughRuleMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *ErrorPassthroughRuleMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *ErrorPassthroughRuleMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetName sets the "name" field. +func (m *ErrorPassthroughRuleMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *ErrorPassthroughRuleMutation) ResetName() { + m.name = nil +} + +// SetEnabled sets the "enabled" field. +func (m *ErrorPassthroughRuleMutation) SetEnabled(b bool) { + m.enabled = &b +} + +// Enabled returns the value of the "enabled" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Enabled() (r bool, exists bool) { + v := m.enabled + if v == nil { + return + } + return *v, true +} + +// OldEnabled returns the old "enabled" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEnabled: %w", err) + } + return oldValue.Enabled, nil +} + +// ResetEnabled resets all changes to the "enabled" field. +func (m *ErrorPassthroughRuleMutation) ResetEnabled() { + m.enabled = nil +} + +// SetPriority sets the "priority" field. +func (m *ErrorPassthroughRuleMutation) SetPriority(i int) { + m.priority = &i + m.addpriority = nil +} + +// Priority returns the value of the "priority" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Priority() (r int, exists bool) { + v := m.priority + if v == nil { + return + } + return *v, true +} + +// OldPriority returns the old "priority" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldPriority(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPriority is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPriority requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPriority: %w", err) + } + return oldValue.Priority, nil +} + +// AddPriority adds i to the "priority" field. +func (m *ErrorPassthroughRuleMutation) AddPriority(i int) { + if m.addpriority != nil { + *m.addpriority += i + } else { + m.addpriority = &i + } +} + +// AddedPriority returns the value that was added to the "priority" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedPriority() (r int, exists bool) { + v := m.addpriority + if v == nil { + return + } + return *v, true +} + +// ResetPriority resets all changes to the "priority" field. +func (m *ErrorPassthroughRuleMutation) ResetPriority() { + m.priority = nil + m.addpriority = nil +} + +// SetErrorCodes sets the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) SetErrorCodes(i []int) { + m.error_codes = &i + m.appenderror_codes = nil +} + +// ErrorCodes returns the value of the "error_codes" field in the mutation. +func (m *ErrorPassthroughRuleMutation) ErrorCodes() (r []int, exists bool) { + v := m.error_codes + if v == nil { + return + } + return *v, true +} + +// OldErrorCodes returns the old "error_codes" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldErrorCodes(ctx context.Context) (v []int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldErrorCodes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldErrorCodes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldErrorCodes: %w", err) + } + return oldValue.ErrorCodes, nil +} + +// AppendErrorCodes adds i to the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) AppendErrorCodes(i []int) { + m.appenderror_codes = append(m.appenderror_codes, i...) +} + +// AppendedErrorCodes returns the list of values that were appended to the "error_codes" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedErrorCodes() ([]int, bool) { + if len(m.appenderror_codes) == 0 { + return nil, false + } + return m.appenderror_codes, true +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) ClearErrorCodes() { + m.error_codes = nil + m.appenderror_codes = nil + m.clearedFields[errorpassthroughrule.FieldErrorCodes] = struct{}{} +} + +// ErrorCodesCleared returns if the "error_codes" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ErrorCodesCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldErrorCodes] + return ok +} + +// ResetErrorCodes resets all changes to the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) ResetErrorCodes() { + m.error_codes = nil + m.appenderror_codes = nil + delete(m.clearedFields, errorpassthroughrule.FieldErrorCodes) +} + +// SetKeywords sets the "keywords" field. +func (m *ErrorPassthroughRuleMutation) SetKeywords(s []string) { + m.keywords = &s + m.appendkeywords = nil +} + +// Keywords returns the value of the "keywords" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Keywords() (r []string, exists bool) { + v := m.keywords + if v == nil { + return + } + return *v, true +} + +// OldKeywords returns the old "keywords" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldKeywords(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldKeywords is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldKeywords requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldKeywords: %w", err) + } + return oldValue.Keywords, nil +} + +// AppendKeywords adds s to the "keywords" field. +func (m *ErrorPassthroughRuleMutation) AppendKeywords(s []string) { + m.appendkeywords = append(m.appendkeywords, s...) +} + +// AppendedKeywords returns the list of values that were appended to the "keywords" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedKeywords() ([]string, bool) { + if len(m.appendkeywords) == 0 { + return nil, false + } + return m.appendkeywords, true +} + +// ClearKeywords clears the value of the "keywords" field. +func (m *ErrorPassthroughRuleMutation) ClearKeywords() { + m.keywords = nil + m.appendkeywords = nil + m.clearedFields[errorpassthroughrule.FieldKeywords] = struct{}{} +} + +// KeywordsCleared returns if the "keywords" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) KeywordsCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldKeywords] + return ok +} + +// ResetKeywords resets all changes to the "keywords" field. +func (m *ErrorPassthroughRuleMutation) ResetKeywords() { + m.keywords = nil + m.appendkeywords = nil + delete(m.clearedFields, errorpassthroughrule.FieldKeywords) +} + +// SetMatchMode sets the "match_mode" field. +func (m *ErrorPassthroughRuleMutation) SetMatchMode(s string) { + m.match_mode = &s +} + +// MatchMode returns the value of the "match_mode" field in the mutation. +func (m *ErrorPassthroughRuleMutation) MatchMode() (r string, exists bool) { + v := m.match_mode + if v == nil { + return + } + return *v, true +} + +// OldMatchMode returns the old "match_mode" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldMatchMode(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMatchMode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMatchMode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMatchMode: %w", err) + } + return oldValue.MatchMode, nil +} + +// ResetMatchMode resets all changes to the "match_mode" field. +func (m *ErrorPassthroughRuleMutation) ResetMatchMode() { + m.match_mode = nil +} + +// SetPlatforms sets the "platforms" field. +func (m *ErrorPassthroughRuleMutation) SetPlatforms(s []string) { + m.platforms = &s + m.appendplatforms = nil +} + +// Platforms returns the value of the "platforms" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Platforms() (r []string, exists bool) { + v := m.platforms + if v == nil { + return + } + return *v, true +} + +// OldPlatforms returns the old "platforms" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldPlatforms(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPlatforms is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPlatforms requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPlatforms: %w", err) + } + return oldValue.Platforms, nil +} + +// AppendPlatforms adds s to the "platforms" field. +func (m *ErrorPassthroughRuleMutation) AppendPlatforms(s []string) { + m.appendplatforms = append(m.appendplatforms, s...) +} + +// AppendedPlatforms returns the list of values that were appended to the "platforms" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedPlatforms() ([]string, bool) { + if len(m.appendplatforms) == 0 { + return nil, false + } + return m.appendplatforms, true +} + +// ClearPlatforms clears the value of the "platforms" field. +func (m *ErrorPassthroughRuleMutation) ClearPlatforms() { + m.platforms = nil + m.appendplatforms = nil + m.clearedFields[errorpassthroughrule.FieldPlatforms] = struct{}{} +} + +// PlatformsCleared returns if the "platforms" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) PlatformsCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldPlatforms] + return ok +} + +// ResetPlatforms resets all changes to the "platforms" field. +func (m *ErrorPassthroughRuleMutation) ResetPlatforms() { + m.platforms = nil + m.appendplatforms = nil + delete(m.clearedFields, errorpassthroughrule.FieldPlatforms) +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (m *ErrorPassthroughRuleMutation) SetPassthroughCode(b bool) { + m.passthrough_code = &b +} + +// PassthroughCode returns the value of the "passthrough_code" field in the mutation. +func (m *ErrorPassthroughRuleMutation) PassthroughCode() (r bool, exists bool) { + v := m.passthrough_code + if v == nil { + return + } + return *v, true +} + +// OldPassthroughCode returns the old "passthrough_code" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldPassthroughCode(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPassthroughCode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPassthroughCode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPassthroughCode: %w", err) + } + return oldValue.PassthroughCode, nil +} + +// ResetPassthroughCode resets all changes to the "passthrough_code" field. +func (m *ErrorPassthroughRuleMutation) ResetPassthroughCode() { + m.passthrough_code = nil +} + +// SetResponseCode sets the "response_code" field. +func (m *ErrorPassthroughRuleMutation) SetResponseCode(i int) { + m.response_code = &i + m.addresponse_code = nil +} + +// ResponseCode returns the value of the "response_code" field in the mutation. +func (m *ErrorPassthroughRuleMutation) ResponseCode() (r int, exists bool) { + v := m.response_code + if v == nil { + return + } + return *v, true +} + +// OldResponseCode returns the old "response_code" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldResponseCode(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseCode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseCode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseCode: %w", err) + } + return oldValue.ResponseCode, nil +} + +// AddResponseCode adds i to the "response_code" field. +func (m *ErrorPassthroughRuleMutation) AddResponseCode(i int) { + if m.addresponse_code != nil { + *m.addresponse_code += i + } else { + m.addresponse_code = &i + } +} + +// AddedResponseCode returns the value that was added to the "response_code" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedResponseCode() (r int, exists bool) { + v := m.addresponse_code + if v == nil { + return + } + return *v, true +} + +// ClearResponseCode clears the value of the "response_code" field. +func (m *ErrorPassthroughRuleMutation) ClearResponseCode() { + m.response_code = nil + m.addresponse_code = nil + m.clearedFields[errorpassthroughrule.FieldResponseCode] = struct{}{} +} + +// ResponseCodeCleared returns if the "response_code" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ResponseCodeCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldResponseCode] + return ok +} + +// ResetResponseCode resets all changes to the "response_code" field. +func (m *ErrorPassthroughRuleMutation) ResetResponseCode() { + m.response_code = nil + m.addresponse_code = nil + delete(m.clearedFields, errorpassthroughrule.FieldResponseCode) +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (m *ErrorPassthroughRuleMutation) SetPassthroughBody(b bool) { + m.passthrough_body = &b +} + +// PassthroughBody returns the value of the "passthrough_body" field in the mutation. +func (m *ErrorPassthroughRuleMutation) PassthroughBody() (r bool, exists bool) { + v := m.passthrough_body + if v == nil { + return + } + return *v, true +} + +// OldPassthroughBody returns the old "passthrough_body" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldPassthroughBody(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPassthroughBody is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPassthroughBody requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPassthroughBody: %w", err) + } + return oldValue.PassthroughBody, nil +} + +// ResetPassthroughBody resets all changes to the "passthrough_body" field. +func (m *ErrorPassthroughRuleMutation) ResetPassthroughBody() { + m.passthrough_body = nil +} + +// SetCustomMessage sets the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) SetCustomMessage(s string) { + m.custom_message = &s +} + +// CustomMessage returns the value of the "custom_message" field in the mutation. +func (m *ErrorPassthroughRuleMutation) CustomMessage() (r string, exists bool) { + v := m.custom_message + if v == nil { + return + } + return *v, true +} + +// OldCustomMessage returns the old "custom_message" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldCustomMessage(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCustomMessage is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCustomMessage requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCustomMessage: %w", err) + } + return oldValue.CustomMessage, nil +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) ClearCustomMessage() { + m.custom_message = nil + m.clearedFields[errorpassthroughrule.FieldCustomMessage] = struct{}{} +} + +// CustomMessageCleared returns if the "custom_message" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) CustomMessageCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldCustomMessage] + return ok +} + +// ResetCustomMessage resets all changes to the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() { + m.custom_message = nil + delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage) +} + +// SetDescription sets the "description" field. +func (m *ErrorPassthroughRuleMutation) SetDescription(s string) { + m.description = &s +} + +// Description returns the value of the "description" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Description() (r string, exists bool) { + v := m.description + if v == nil { + return + } + return *v, true +} + +// OldDescription returns the old "description" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldDescription(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDescription is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDescription requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDescription: %w", err) + } + return oldValue.Description, nil +} + +// ClearDescription clears the value of the "description" field. +func (m *ErrorPassthroughRuleMutation) ClearDescription() { + m.description = nil + m.clearedFields[errorpassthroughrule.FieldDescription] = struct{}{} +} + +// DescriptionCleared returns if the "description" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) DescriptionCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldDescription] + return ok +} + +// ResetDescription resets all changes to the "description" field. +func (m *ErrorPassthroughRuleMutation) ResetDescription() { + m.description = nil + delete(m.clearedFields, errorpassthroughrule.FieldDescription) +} + +// Where appends a list predicates to the ErrorPassthroughRuleMutation builder. +func (m *ErrorPassthroughRuleMutation) Where(ps ...predicate.ErrorPassthroughRule) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the ErrorPassthroughRuleMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *ErrorPassthroughRuleMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.ErrorPassthroughRule, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *ErrorPassthroughRuleMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *ErrorPassthroughRuleMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (ErrorPassthroughRule). +func (m *ErrorPassthroughRuleMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *ErrorPassthroughRuleMutation) Fields() []string { + fields := make([]string, 0, 14) + if m.created_at != nil { + fields = append(fields, errorpassthroughrule.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, errorpassthroughrule.FieldUpdatedAt) + } + if m.name != nil { + fields = append(fields, errorpassthroughrule.FieldName) + } + if m.enabled != nil { + fields = append(fields, errorpassthroughrule.FieldEnabled) + } + if m.priority != nil { + fields = append(fields, errorpassthroughrule.FieldPriority) + } + if m.error_codes != nil { + fields = append(fields, errorpassthroughrule.FieldErrorCodes) + } + if m.keywords != nil { + fields = append(fields, errorpassthroughrule.FieldKeywords) + } + if m.match_mode != nil { + fields = append(fields, errorpassthroughrule.FieldMatchMode) + } + if m.platforms != nil { + fields = append(fields, errorpassthroughrule.FieldPlatforms) + } + if m.passthrough_code != nil { + fields = append(fields, errorpassthroughrule.FieldPassthroughCode) + } + if m.response_code != nil { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + if m.passthrough_body != nil { + fields = append(fields, errorpassthroughrule.FieldPassthroughBody) + } + if m.custom_message != nil { + fields = append(fields, errorpassthroughrule.FieldCustomMessage) + } + if m.description != nil { + fields = append(fields, errorpassthroughrule.FieldDescription) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) { + switch name { + case errorpassthroughrule.FieldCreatedAt: + return m.CreatedAt() + case errorpassthroughrule.FieldUpdatedAt: + return m.UpdatedAt() + case errorpassthroughrule.FieldName: + return m.Name() + case errorpassthroughrule.FieldEnabled: + return m.Enabled() + case errorpassthroughrule.FieldPriority: + return m.Priority() + case errorpassthroughrule.FieldErrorCodes: + return m.ErrorCodes() + case errorpassthroughrule.FieldKeywords: + return m.Keywords() + case errorpassthroughrule.FieldMatchMode: + return m.MatchMode() + case errorpassthroughrule.FieldPlatforms: + return m.Platforms() + case errorpassthroughrule.FieldPassthroughCode: + return m.PassthroughCode() + case errorpassthroughrule.FieldResponseCode: + return m.ResponseCode() + case errorpassthroughrule.FieldPassthroughBody: + return m.PassthroughBody() + case errorpassthroughrule.FieldCustomMessage: + return m.CustomMessage() + case errorpassthroughrule.FieldDescription: + return m.Description() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case errorpassthroughrule.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case errorpassthroughrule.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case errorpassthroughrule.FieldName: + return m.OldName(ctx) + case errorpassthroughrule.FieldEnabled: + return m.OldEnabled(ctx) + case errorpassthroughrule.FieldPriority: + return m.OldPriority(ctx) + case errorpassthroughrule.FieldErrorCodes: + return m.OldErrorCodes(ctx) + case errorpassthroughrule.FieldKeywords: + return m.OldKeywords(ctx) + case errorpassthroughrule.FieldMatchMode: + return m.OldMatchMode(ctx) + case errorpassthroughrule.FieldPlatforms: + return m.OldPlatforms(ctx) + case errorpassthroughrule.FieldPassthroughCode: + return m.OldPassthroughCode(ctx) + case errorpassthroughrule.FieldResponseCode: + return m.OldResponseCode(ctx) + case errorpassthroughrule.FieldPassthroughBody: + return m.OldPassthroughBody(ctx) + case errorpassthroughrule.FieldCustomMessage: + return m.OldCustomMessage(ctx) + case errorpassthroughrule.FieldDescription: + return m.OldDescription(ctx) + } + return nil, fmt.Errorf("unknown ErrorPassthroughRule field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) error { + switch name { + case errorpassthroughrule.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case errorpassthroughrule.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case errorpassthroughrule.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case errorpassthroughrule.FieldEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEnabled(v) + return nil + case errorpassthroughrule.FieldPriority: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPriority(v) + return nil + case errorpassthroughrule.FieldErrorCodes: + v, ok := value.([]int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorCodes(v) + return nil + case errorpassthroughrule.FieldKeywords: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetKeywords(v) + return nil + case errorpassthroughrule.FieldMatchMode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMatchMode(v) + return nil + case errorpassthroughrule.FieldPlatforms: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlatforms(v) + return nil + case errorpassthroughrule.FieldPassthroughCode: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassthroughCode(v) + return nil + case errorpassthroughrule.FieldResponseCode: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseCode(v) + return nil + case errorpassthroughrule.FieldPassthroughBody: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassthroughBody(v) + return nil + case errorpassthroughrule.FieldCustomMessage: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCustomMessage(v) + return nil + case errorpassthroughrule.FieldDescription: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDescription(v) + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *ErrorPassthroughRuleMutation) AddedFields() []string { + var fields []string + if m.addpriority != nil { + fields = append(fields, errorpassthroughrule.FieldPriority) + } + if m.addresponse_code != nil { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case errorpassthroughrule.FieldPriority: + return m.AddedPriority() + case errorpassthroughrule.FieldResponseCode: + return m.AddedResponseCode() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ErrorPassthroughRuleMutation) AddField(name string, value ent.Value) error { + switch name { + case errorpassthroughrule.FieldPriority: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPriority(v) + return nil + case errorpassthroughrule.FieldResponseCode: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddResponseCode(v) + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *ErrorPassthroughRuleMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(errorpassthroughrule.FieldErrorCodes) { + fields = append(fields, errorpassthroughrule.FieldErrorCodes) + } + if m.FieldCleared(errorpassthroughrule.FieldKeywords) { + fields = append(fields, errorpassthroughrule.FieldKeywords) + } + if m.FieldCleared(errorpassthroughrule.FieldPlatforms) { + fields = append(fields, errorpassthroughrule.FieldPlatforms) + } + if m.FieldCleared(errorpassthroughrule.FieldResponseCode) { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + if m.FieldCleared(errorpassthroughrule.FieldCustomMessage) { + fields = append(fields, errorpassthroughrule.FieldCustomMessage) + } + if m.FieldCleared(errorpassthroughrule.FieldDescription) { + fields = append(fields, errorpassthroughrule.FieldDescription) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ClearField(name string) error { + switch name { + case errorpassthroughrule.FieldErrorCodes: + m.ClearErrorCodes() + return nil + case errorpassthroughrule.FieldKeywords: + m.ClearKeywords() + return nil + case errorpassthroughrule.FieldPlatforms: + m.ClearPlatforms() + return nil + case errorpassthroughrule.FieldResponseCode: + m.ClearResponseCode() + return nil + case errorpassthroughrule.FieldCustomMessage: + m.ClearCustomMessage() + return nil + case errorpassthroughrule.FieldDescription: + m.ClearDescription() + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ResetField(name string) error { + switch name { + case errorpassthroughrule.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case errorpassthroughrule.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case errorpassthroughrule.FieldName: + m.ResetName() + return nil + case errorpassthroughrule.FieldEnabled: + m.ResetEnabled() + return nil + case errorpassthroughrule.FieldPriority: + m.ResetPriority() + return nil + case errorpassthroughrule.FieldErrorCodes: + m.ResetErrorCodes() + return nil + case errorpassthroughrule.FieldKeywords: + m.ResetKeywords() + return nil + case errorpassthroughrule.FieldMatchMode: + m.ResetMatchMode() + return nil + case errorpassthroughrule.FieldPlatforms: + m.ResetPlatforms() + return nil + case errorpassthroughrule.FieldPassthroughCode: + m.ResetPassthroughCode() + return nil + case errorpassthroughrule.FieldResponseCode: + m.ResetResponseCode() + return nil + case errorpassthroughrule.FieldPassthroughBody: + m.ResetPassthroughBody() + return nil + case errorpassthroughrule.FieldCustomMessage: + m.ResetCustomMessage() + return nil + case errorpassthroughrule.FieldDescription: + m.ResetDescription() + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *ErrorPassthroughRuleMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *ErrorPassthroughRuleMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown ErrorPassthroughRule unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown ErrorPassthroughRule edge %s", name) +} + // GroupMutation represents an operation that mutates the Group nodes in the graph. type GroupMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - description *string - rate_multiplier *float64 - addrate_multiplier *float64 - is_exclusive *bool - status *string - platform *string - subscription_type *string - daily_limit_usd *float64 - adddaily_limit_usd *float64 - weekly_limit_usd *float64 - addweekly_limit_usd *float64 - monthly_limit_usd *float64 - addmonthly_limit_usd *float64 - default_validity_days *int - adddefault_validity_days *int - image_price_1k *float64 - addimage_price_1k *float64 - image_price_2k *float64 - addimage_price_2k *float64 - image_price_4k *float64 - addimage_price_4k *float64 - claude_code_only *bool - fallback_group_id *int64 - addfallback_group_id *int64 - model_routing *map[string][]int64 - model_routing_enabled *bool - clearedFields map[string]struct{} - api_keys map[int64]struct{} - removedapi_keys map[int64]struct{} - clearedapi_keys bool - redeem_codes map[int64]struct{} - removedredeem_codes map[int64]struct{} - clearedredeem_codes bool - subscriptions map[int64]struct{} - removedsubscriptions map[int64]struct{} - clearedsubscriptions bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - accounts map[int64]struct{} - removedaccounts map[int64]struct{} - clearedaccounts bool - allowed_users map[int64]struct{} - removedallowed_users map[int64]struct{} - clearedallowed_users bool - done bool - oldValue func(context.Context) (*Group, error) - predicates []predicate.Group + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + description *string + rate_multiplier *float64 + addrate_multiplier *float64 + is_exclusive *bool + status *string + platform *string + subscription_type *string + daily_limit_usd *float64 + adddaily_limit_usd *float64 + weekly_limit_usd *float64 + addweekly_limit_usd *float64 + monthly_limit_usd *float64 + addmonthly_limit_usd *float64 + default_validity_days *int + adddefault_validity_days *int + image_price_1k *float64 + addimage_price_1k *float64 + image_price_2k *float64 + addimage_price_2k *float64 + image_price_4k *float64 + addimage_price_4k *float64 + claude_code_only *bool + fallback_group_id *int64 + addfallback_group_id *int64 + fallback_group_id_on_invalid_request *int64 + addfallback_group_id_on_invalid_request *int64 + model_routing *map[string][]int64 + model_routing_enabled *bool + mcp_xml_inject *bool + supported_model_scopes *[]string + appendsupported_model_scopes []string + clearedFields map[string]struct{} + api_keys map[int64]struct{} + removedapi_keys map[int64]struct{} + clearedapi_keys bool + redeem_codes map[int64]struct{} + removedredeem_codes map[int64]struct{} + clearedredeem_codes bool + subscriptions map[int64]struct{} + removedsubscriptions map[int64]struct{} + clearedsubscriptions bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + accounts map[int64]struct{} + removedaccounts map[int64]struct{} + clearedaccounts bool + allowed_users map[int64]struct{} + removedallowed_users map[int64]struct{} + clearedallowed_users bool + done bool + oldValue func(context.Context) (*Group, error) + predicates []predicate.Group } var _ ent.Mutation = (*GroupMutation)(nil) @@ -6649,6 +8169,76 @@ func (m *GroupMutation) ResetFallbackGroupID() { delete(m.clearedFields, group.FieldFallbackGroupID) } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) { + m.fallback_group_id_on_invalid_request = &i + m.addfallback_group_id_on_invalid_request = nil +} + +// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.fallback_group_id_on_invalid_request + if v == nil { + return + } + return *v, true +} + +// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err) + } + return oldValue.FallbackGroupIDOnInvalidRequest, nil +} + +// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) { + if m.addfallback_group_id_on_invalid_request != nil { + *m.addfallback_group_id_on_invalid_request += i + } else { + m.addfallback_group_id_on_invalid_request = &i + } +} + +// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation. +func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.addfallback_group_id_on_invalid_request + if v == nil { + return + } + return *v, true +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{} +} + +// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool { + _, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] + return ok +} + +// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest) +} + // SetModelRouting sets the "model_routing" field. func (m *GroupMutation) SetModelRouting(value map[string][]int64) { m.model_routing = &value @@ -6734,6 +8324,93 @@ func (m *GroupMutation) ResetModelRoutingEnabled() { m.model_routing_enabled = nil } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (m *GroupMutation) SetMcpXMLInject(b bool) { + m.mcp_xml_inject = &b +} + +// McpXMLInject returns the value of the "mcp_xml_inject" field in the mutation. +func (m *GroupMutation) McpXMLInject() (r bool, exists bool) { + v := m.mcp_xml_inject + if v == nil { + return + } + return *v, true +} + +// OldMcpXMLInject returns the old "mcp_xml_inject" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldMcpXMLInject(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMcpXMLInject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMcpXMLInject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMcpXMLInject: %w", err) + } + return oldValue.McpXMLInject, nil +} + +// ResetMcpXMLInject resets all changes to the "mcp_xml_inject" field. +func (m *GroupMutation) ResetMcpXMLInject() { + m.mcp_xml_inject = nil +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (m *GroupMutation) SetSupportedModelScopes(s []string) { + m.supported_model_scopes = &s + m.appendsupported_model_scopes = nil +} + +// SupportedModelScopes returns the value of the "supported_model_scopes" field in the mutation. +func (m *GroupMutation) SupportedModelScopes() (r []string, exists bool) { + v := m.supported_model_scopes + if v == nil { + return + } + return *v, true +} + +// OldSupportedModelScopes returns the old "supported_model_scopes" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSupportedModelScopes(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSupportedModelScopes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSupportedModelScopes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSupportedModelScopes: %w", err) + } + return oldValue.SupportedModelScopes, nil +} + +// AppendSupportedModelScopes adds s to the "supported_model_scopes" field. +func (m *GroupMutation) AppendSupportedModelScopes(s []string) { + m.appendsupported_model_scopes = append(m.appendsupported_model_scopes, s...) +} + +// AppendedSupportedModelScopes returns the list of values that were appended to the "supported_model_scopes" field in this mutation. +func (m *GroupMutation) AppendedSupportedModelScopes() ([]string, bool) { + if len(m.appendsupported_model_scopes) == 0 { + return nil, false + } + return m.appendsupported_model_scopes, true +} + +// ResetSupportedModelScopes resets all changes to the "supported_model_scopes" field. +func (m *GroupMutation) ResetSupportedModelScopes() { + m.supported_model_scopes = nil + m.appendsupported_model_scopes = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -7092,7 +8769,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 21) + fields := make([]string, 0, 24) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -7150,12 +8827,21 @@ func (m *GroupMutation) Fields() []string { if m.fallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } + if m.fallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } if m.model_routing != nil { fields = append(fields, group.FieldModelRouting) } if m.model_routing_enabled != nil { fields = append(fields, group.FieldModelRoutingEnabled) } + if m.mcp_xml_inject != nil { + fields = append(fields, group.FieldMcpXMLInject) + } + if m.supported_model_scopes != nil { + fields = append(fields, group.FieldSupportedModelScopes) + } return fields } @@ -7202,10 +8888,16 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: return m.FallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.FallbackGroupIDOnInvalidRequest() case group.FieldModelRouting: return m.ModelRouting() case group.FieldModelRoutingEnabled: return m.ModelRoutingEnabled() + case group.FieldMcpXMLInject: + return m.McpXMLInject() + case group.FieldSupportedModelScopes: + return m.SupportedModelScopes() } return nil, false } @@ -7253,10 +8945,16 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: return m.OldFallbackGroupID(ctx) + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.OldFallbackGroupIDOnInvalidRequest(ctx) case group.FieldModelRouting: return m.OldModelRouting(ctx) case group.FieldModelRoutingEnabled: return m.OldModelRoutingEnabled(ctx) + case group.FieldMcpXMLInject: + return m.OldMcpXMLInject(ctx) + case group.FieldSupportedModelScopes: + return m.OldSupportedModelScopes(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -7399,6 +9097,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetFallbackGroupID(v) return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFallbackGroupIDOnInvalidRequest(v) + return nil case group.FieldModelRouting: v, ok := value.(map[string][]int64) if !ok { @@ -7413,6 +9118,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetModelRoutingEnabled(v) return nil + case group.FieldMcpXMLInject: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMcpXMLInject(v) + return nil + case group.FieldSupportedModelScopes: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSupportedModelScopes(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -7448,6 +9167,9 @@ func (m *GroupMutation) AddedFields() []string { if m.addfallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } + if m.addfallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } return fields } @@ -7474,6 +9196,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedImagePrice4k() case group.FieldFallbackGroupID: return m.AddedFallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.AddedFallbackGroupIDOnInvalidRequest() } return nil, false } @@ -7546,6 +9270,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddFallbackGroupID(v) return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFallbackGroupIDOnInvalidRequest(v) + return nil } return fmt.Errorf("unknown Group numeric field %s", name) } @@ -7581,6 +9312,9 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldFallbackGroupID) { fields = append(fields, group.FieldFallbackGroupID) } + if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } if m.FieldCleared(group.FieldModelRouting) { fields = append(fields, group.FieldModelRouting) } @@ -7625,6 +9359,9 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldFallbackGroupID: m.ClearFallbackGroupID() return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ClearFallbackGroupIDOnInvalidRequest() + return nil case group.FieldModelRouting: m.ClearModelRouting() return nil @@ -7693,12 +9430,21 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldFallbackGroupID: m.ResetFallbackGroupID() return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ResetFallbackGroupIDOnInvalidRequest() + return nil case group.FieldModelRouting: m.ResetModelRouting() return nil case group.FieldModelRoutingEnabled: m.ResetModelRoutingEnabled() return nil + case group.FieldMcpXMLInject: + m.ResetMcpXMLInject() + return nil + case group.FieldSupportedModelScopes: + m.ResetSupportedModelScopes() + return nil } return fmt.Errorf("unknown Group field %s", name) } diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index 613c5913..c12955ef 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -21,6 +21,9 @@ type Announcement func(*sql.Selector) // AnnouncementRead is the predicate function for announcementread builders. type AnnouncementRead func(*sql.Selector) +// ErrorPassthroughRule is the predicate function for errorpassthroughrule builders. +type ErrorPassthroughRule func(*sql.Selector) + // Group is the predicate function for group builders. type Group func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index ae4eece8..4b3c1a4f 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -10,6 +10,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -91,6 +92,14 @@ func init() { apikey.DefaultStatus = apikeyDescStatus.Default.(string) // apikey.StatusValidator is a validator for the "status" field. It is called by the builders before save. apikey.StatusValidator = apikeyDescStatus.Validators[0].(func(string) error) + // apikeyDescQuota is the schema descriptor for quota field. + apikeyDescQuota := apikeyFields[7].Descriptor() + // apikey.DefaultQuota holds the default value on creation for the quota field. + apikey.DefaultQuota = apikeyDescQuota.Default.(float64) + // apikeyDescQuotaUsed is the schema descriptor for quota_used field. + apikeyDescQuotaUsed := apikeyFields[8].Descriptor() + // apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field. + apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64) accountMixin := schema.Account{}.Mixin() accountMixinHooks1 := accountMixin[1].Hooks() account.Hooks[0] = accountMixinHooks1[0] @@ -262,6 +271,61 @@ func init() { announcementreadDescCreatedAt := announcementreadFields[3].Descriptor() // announcementread.DefaultCreatedAt holds the default value on creation for the created_at field. announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time) + errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin() + errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields() + _ = errorpassthroughruleMixinFields0 + errorpassthroughruleFields := schema.ErrorPassthroughRule{}.Fields() + _ = errorpassthroughruleFields + // errorpassthroughruleDescCreatedAt is the schema descriptor for created_at field. + errorpassthroughruleDescCreatedAt := errorpassthroughruleMixinFields0[0].Descriptor() + // errorpassthroughrule.DefaultCreatedAt holds the default value on creation for the created_at field. + errorpassthroughrule.DefaultCreatedAt = errorpassthroughruleDescCreatedAt.Default.(func() time.Time) + // errorpassthroughruleDescUpdatedAt is the schema descriptor for updated_at field. + errorpassthroughruleDescUpdatedAt := errorpassthroughruleMixinFields0[1].Descriptor() + // errorpassthroughrule.DefaultUpdatedAt holds the default value on creation for the updated_at field. + errorpassthroughrule.DefaultUpdatedAt = errorpassthroughruleDescUpdatedAt.Default.(func() time.Time) + // errorpassthroughrule.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + errorpassthroughrule.UpdateDefaultUpdatedAt = errorpassthroughruleDescUpdatedAt.UpdateDefault.(func() time.Time) + // errorpassthroughruleDescName is the schema descriptor for name field. + errorpassthroughruleDescName := errorpassthroughruleFields[0].Descriptor() + // errorpassthroughrule.NameValidator is a validator for the "name" field. It is called by the builders before save. + errorpassthroughrule.NameValidator = func() func(string) error { + validators := errorpassthroughruleDescName.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(name string) error { + for _, fn := range fns { + if err := fn(name); err != nil { + return err + } + } + return nil + } + }() + // errorpassthroughruleDescEnabled is the schema descriptor for enabled field. + errorpassthroughruleDescEnabled := errorpassthroughruleFields[1].Descriptor() + // errorpassthroughrule.DefaultEnabled holds the default value on creation for the enabled field. + errorpassthroughrule.DefaultEnabled = errorpassthroughruleDescEnabled.Default.(bool) + // errorpassthroughruleDescPriority is the schema descriptor for priority field. + errorpassthroughruleDescPriority := errorpassthroughruleFields[2].Descriptor() + // errorpassthroughrule.DefaultPriority holds the default value on creation for the priority field. + errorpassthroughrule.DefaultPriority = errorpassthroughruleDescPriority.Default.(int) + // errorpassthroughruleDescMatchMode is the schema descriptor for match_mode field. + errorpassthroughruleDescMatchMode := errorpassthroughruleFields[5].Descriptor() + // errorpassthroughrule.DefaultMatchMode holds the default value on creation for the match_mode field. + errorpassthroughrule.DefaultMatchMode = errorpassthroughruleDescMatchMode.Default.(string) + // errorpassthroughrule.MatchModeValidator is a validator for the "match_mode" field. It is called by the builders before save. + errorpassthroughrule.MatchModeValidator = errorpassthroughruleDescMatchMode.Validators[0].(func(string) error) + // errorpassthroughruleDescPassthroughCode is the schema descriptor for passthrough_code field. + errorpassthroughruleDescPassthroughCode := errorpassthroughruleFields[7].Descriptor() + // errorpassthroughrule.DefaultPassthroughCode holds the default value on creation for the passthrough_code field. + errorpassthroughrule.DefaultPassthroughCode = errorpassthroughruleDescPassthroughCode.Default.(bool) + // errorpassthroughruleDescPassthroughBody is the schema descriptor for passthrough_body field. + errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor() + // errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field. + errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool) groupMixin := schema.Group{}.Mixin() groupMixinHooks1 := groupMixin[1].Hooks() group.Hooks[0] = groupMixinHooks1[0] @@ -334,9 +398,17 @@ func init() { // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[17].Descriptor() + groupDescModelRoutingEnabled := groupFields[18].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) + // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. + groupDescMcpXMLInject := groupFields[19].Descriptor() + // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. + group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) + // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. + groupDescSupportedModelScopes := groupFields[20].Descriptor() + // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. + group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. diff --git a/backend/ent/schema/api_key.go b/backend/ent/schema/api_key.go index 1c2d4bd4..26d52cb0 100644 --- a/backend/ent/schema/api_key.go +++ b/backend/ent/schema/api_key.go @@ -5,6 +5,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/domain" "entgo.io/ent" + "entgo.io/ent/dialect" "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema" "entgo.io/ent/schema/edge" @@ -52,6 +53,23 @@ func (APIKey) Fields() []ent.Field { field.JSON("ip_blacklist", []string{}). Optional(). Comment("Blocked IPs/CIDRs"), + + // ========== Quota fields ========== + // Quota limit in USD (0 = unlimited) + field.Float("quota"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Quota limit in USD for this API key (0 = unlimited)"), + // Used quota amount + field.Float("quota_used"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Used quota amount in USD"), + // Expiration time (nil = never expires) + field.Time("expires_at"). + Optional(). + Nillable(). + Comment("Expiration time for this API key (null = never expires)"), } } @@ -77,5 +95,8 @@ func (APIKey) Indexes() []ent.Index { index.Fields("group_id"), index.Fields("status"), index.Fields("deleted_at"), + // Index for quota queries + index.Fields("quota", "quota_used"), + index.Fields("expires_at"), } } diff --git a/backend/ent/schema/error_passthrough_rule.go b/backend/ent/schema/error_passthrough_rule.go new file mode 100644 index 00000000..4a861f38 --- /dev/null +++ b/backend/ent/schema/error_passthrough_rule.go @@ -0,0 +1,121 @@ +// Package schema 定义 Ent ORM 的数据库 schema。 +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// ErrorPassthroughRule 定义全局错误透传规则的 schema。 +// +// 错误透传规则用于控制上游错误如何返回给客户端: +// - 匹配条件:错误码 + 关键词组合 +// - 响应行为:透传原始信息 或 自定义错误信息 +// - 响应状态码:可指定返回给客户端的状态码 +// - 平台范围:规则适用的平台(Anthropic、OpenAI、Gemini、Antigravity) +type ErrorPassthroughRule struct { + ent.Schema +} + +// Annotations 返回 schema 的注解配置。 +func (ErrorPassthroughRule) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "error_passthrough_rules"}, + } +} + +// Mixin 返回该 schema 使用的混入组件。 +func (ErrorPassthroughRule) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +// Fields 定义错误透传规则实体的所有字段。 +func (ErrorPassthroughRule) Fields() []ent.Field { + return []ent.Field{ + // name: 规则名称,用于在界面中标识规则 + field.String("name"). + MaxLen(100). + NotEmpty(), + + // enabled: 是否启用该规则 + field.Bool("enabled"). + Default(true), + + // priority: 规则优先级,数值越小优先级越高 + // 匹配时按优先级顺序检查,命中第一个匹配的规则 + field.Int("priority"). + Default(0), + + // error_codes: 匹配的错误码列表(OR关系) + // 例如:[422, 400] 表示匹配 422 或 400 错误码 + field.JSON("error_codes", []int{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // keywords: 匹配的关键词列表(OR关系) + // 例如:["context limit", "model not supported"] + // 关键词匹配不区分大小写 + field.JSON("keywords", []string{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // match_mode: 匹配模式 + // - "any": 错误码匹配 OR 关键词匹配(任一条件满足即可) + // - "all": 错误码匹配 AND 关键词匹配(所有条件都必须满足) + field.String("match_mode"). + MaxLen(10). + Default("any"), + + // platforms: 适用平台列表 + // 例如:["anthropic", "openai", "gemini", "antigravity"] + // 空列表表示适用于所有平台 + field.JSON("platforms", []string{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // passthrough_code: 是否透传上游原始状态码 + // true: 使用上游返回的状态码 + // false: 使用 response_code 指定的状态码 + field.Bool("passthrough_code"). + Default(true), + + // response_code: 自定义响应状态码 + // 当 passthrough_code=false 时使用此状态码 + field.Int("response_code"). + Optional(). + Nillable(), + + // passthrough_body: 是否透传上游原始错误信息 + // true: 使用上游返回的错误信息 + // false: 使用 custom_message 指定的错误信息 + field.Bool("passthrough_body"). + Default(true), + + // custom_message: 自定义错误信息 + // 当 passthrough_body=false 时使用此错误信息 + field.Text("custom_message"). + Optional(). + Nillable(), + + // description: 规则描述,用于说明规则的用途 + field.Text("description"). + Optional(). + Nillable(), + } +} + +// Indexes 定义数据库索引,优化查询性能。 +func (ErrorPassthroughRule) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("enabled"), // 筛选启用的规则 + index.Fields("priority"), // 按优先级排序 + } +} diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index ccd72eac..8a3c1a90 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -95,6 +95,10 @@ func (Group) Fields() []ent.Field { Optional(). Nillable(). Comment("非 Claude Code 请求降级使用的分组 ID"), + field.Int64("fallback_group_id_on_invalid_request"). + Optional(). + Nillable(). + Comment("无效请求兜底使用的分组 ID"), // 模型路由配置 (added by migration 040) field.JSON("model_routing", map[string][]int64{}). @@ -106,6 +110,17 @@ func (Group) Fields() []ent.Field { field.Bool("model_routing_enabled"). Default(false). Comment("是否启用模型路由配置"), + + // MCP XML 协议注入开关 (added by migration 042) + field.Bool("mcp_xml_inject"). + Default(true). + Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"), + + // 支持的模型系列 (added by migration 046) + field.JSON("supported_model_scopes", []string{}). + Default([]string{"claude", "gemini_text", "gemini_image"}). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}). + Comment("支持的模型系列:claude, gemini_text, gemini_image"), } } diff --git a/backend/ent/tx.go b/backend/ent/tx.go index 702bdf90..45d83428 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -24,6 +24,8 @@ type Tx struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. + ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // PromoCode is the client for interacting with the PromoCode builders. @@ -186,6 +188,7 @@ func (tx *Tx) init() { tx.AccountGroup = NewAccountGroupClient(tx.config) tx.Announcement = NewAnnouncementClient(tx.config) tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) + tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config) tx.Group = NewGroupClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) diff --git a/backend/go.mod b/backend/go.mod index 4c3e6246..6916057f 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -1,9 +1,11 @@ module github.com/Wei-Shaw/sub2api -go 1.25.6 +go 1.25.7 require ( entgo.io/ent v0.14.5 + github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/dgraph-io/ristretto v0.2.0 github.com/gin-gonic/gin v1.9.1 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/uuid v1.6.0 @@ -11,7 +13,10 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/imroc/req/v3 v3.57.0 github.com/lib/pq v1.10.9 + github.com/pquerna/otp v1.5.0 github.com/redis/go-redis/v9 v9.17.2 + github.com/refraction-networking/utls v1.8.1 + github.com/robfig/cron/v3 v3.0.1 github.com/shirou/gopsutil/v4 v4.25.6 github.com/spf13/viper v1.18.2 github.com/stretchr/testify v1.11.1 @@ -20,18 +25,18 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/zeromicro/go-zero v1.9.4 - golang.org/x/crypto v0.46.0 - golang.org/x/net v0.48.0 + golang.org/x/crypto v0.47.0 + golang.org/x/net v0.49.0 golang.org/x/sync v0.19.0 - golang.org/x/term v0.38.0 + golang.org/x/term v0.39.0 gopkg.in/yaml.v3 v3.0.1 + modernc.org/sqlite v1.44.3 ) require ( ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect dario.cat/mergo v1.0.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect - github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/agext/levenshtein v1.2.3 // indirect github.com/andybalholm/brotli v1.2.0 // indirect @@ -48,7 +53,6 @@ require ( github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/dgraph-io/ristretto v0.2.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v28.5.1+incompatible // indirect @@ -71,12 +75,10 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect - github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect github.com/icholy/digest v1.1.0 // indirect - github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect @@ -85,7 +87,6 @@ require ( github.com/magiconair/properties v1.8.10 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mdelapenya/tlscert v0.2.0 // indirect github.com/mitchellh/go-wordwrap v1.0.1 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect @@ -100,20 +101,15 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/morikuni/aec v1.0.0 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect - github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect - github.com/pquerna/otp v1.5.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/quic-go v0.57.1 // indirect - github.com/refraction-networking/utls v1.8.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - github.com/rivo/uniseg v0.2.0 // indirect - github.com/robfig/cron/v3 v3.0.1 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect @@ -121,7 +117,6 @@ require ( github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect - github.com/spf13/cobra v1.7.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/testcontainers/testcontainers-go v0.40.0 // indirect @@ -145,16 +140,13 @@ require ( go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect - golang.org/x/mod v0.30.0 // indirect - golang.org/x/sys v0.39.0 // indirect - golang.org/x/text v0.32.0 // indirect - golang.org/x/tools v0.39.0 // indirect - golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect + golang.org/x/mod v0.31.0 // indirect + golang.org/x/sys v0.40.0 // indirect + golang.org/x/text v0.33.0 // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.67.6 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect - modernc.org/sqlite v1.44.1 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index 0addb5bb..171995c7 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -46,7 +46,6 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw= github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA= github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc= -github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -55,6 +54,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE= github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU= +github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= +github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= @@ -113,8 +114,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= -github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -123,6 +124,9 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/hcl/v2 v2.18.1 h1:6nxnOJFku1EuSawSD81fuviYUV8DxFr3fp2dUi3ZYSo= @@ -131,8 +135,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -168,9 +170,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -204,8 +203,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= -github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -233,13 +230,10 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= -github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= @@ -258,8 +252,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= @@ -343,16 +335,14 @@ go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTV golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= -golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= -golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= +golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -364,21 +354,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= -golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= +golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= -golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= -golang.org/x/tools/go/expect v0.1.0-deprecated h1:jY2C5HGYR5lqex3gEniOQL0r7Dq5+VGVgY1nudX5lXY= -golang.org/x/tools/go/expect v0.1.0-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= -golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM= -golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM= -golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8= +golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= @@ -399,12 +384,32 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= -modernc.org/sqlite v1.44.1 h1:qybx/rNpfQipX/t47OxbHmkkJuv2JWifCMH8SVUiDas= -modernc.org/sqlite v1.44.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.44.3 h1:+39JvV/HWMcYslAwRxHb8067w+2zowvFOUrOWIy9PjY= +modernc.org/sqlite v1.44.3/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 84be445b..91437ba8 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -144,12 +144,24 @@ type PricingConfig struct { } type ServerConfig struct { - Host string `mapstructure:"host"` - Port int `mapstructure:"port"` - Mode string `mapstructure:"mode"` // debug/release - ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) - IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) - TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + Mode string `mapstructure:"mode"` // debug/release + ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) + IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) + TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) + MaxRequestBodySize int64 `mapstructure:"max_request_body_size"` // 全局最大请求体限制 + H2C H2CConfig `mapstructure:"h2c"` // HTTP/2 Cleartext 配置 +} + +// H2CConfig HTTP/2 Cleartext 配置 +type H2CConfig struct { + Enabled bool `mapstructure:"enabled"` // 是否启用 H2C + MaxConcurrentStreams uint32 `mapstructure:"max_concurrent_streams"` // 最大并发流数量 + IdleTimeout int `mapstructure:"idle_timeout"` // 空闲超时(秒) + MaxReadFrameSize int `mapstructure:"max_read_frame_size"` // 最大帧大小(字节) + MaxUploadBufferPerConnection int `mapstructure:"max_upload_buffer_per_connection"` // 每个连接的上传缓冲区(字节) + MaxUploadBufferPerStream int `mapstructure:"max_upload_buffer_per_stream"` // 每个流的上传缓冲区(字节) } type CORSConfig struct { @@ -467,6 +479,13 @@ type OpsMetricsCollectorCacheConfig struct { type JWTConfig struct { Secret string `mapstructure:"secret"` ExpireHour int `mapstructure:"expire_hour"` + // AccessTokenExpireMinutes: Access Token有效期(分钟),默认15分钟 + // 短有效期减少被盗用风险,配合Refresh Token实现无感续期 + AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"` + // RefreshTokenExpireDays: Refresh Token有效期(天),默认30天 + RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"` + // RefreshWindowMinutes: 刷新窗口(分钟),在Access Token过期前多久开始允许刷新 + RefreshWindowMinutes int `mapstructure:"refresh_window_minutes"` } // TotpConfig TOTP 双因素认证配置 @@ -687,6 +706,14 @@ func setDefaults() { viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.trusted_proxies", []string{}) + viper.SetDefault("server.max_request_body_size", int64(100*1024*1024)) + // H2C 默认配置 + viper.SetDefault("server.h2c.enabled", false) + viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流 + viper.SetDefault("server.h2c.idle_timeout", 75) // 75 秒 + viper.SetDefault("server.h2c.max_read_frame_size", 1<<20) // 1MB(够用) + viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB + viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB // CORS viper.SetDefault("cors.allowed_origins", []string{}) @@ -783,6 +810,9 @@ func setDefaults() { // JWT viper.SetDefault("jwt.secret", "") viper.SetDefault("jwt.expire_hour", 24) + viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期 + viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期 + viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新 // TOTP viper.SetDefault("totp.encryption_key", "") @@ -912,6 +942,22 @@ func (c *Config) Validate() error { if c.JWT.ExpireHour > 24 { log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour) } + // JWT Refresh Token配置验证 + if c.JWT.AccessTokenExpireMinutes <= 0 { + return fmt.Errorf("jwt.access_token_expire_minutes must be positive") + } + if c.JWT.AccessTokenExpireMinutes > 720 { + log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes) + } + if c.JWT.RefreshTokenExpireDays <= 0 { + return fmt.Errorf("jwt.refresh_token_expire_days must be positive") + } + if c.JWT.RefreshTokenExpireDays > 90 { + log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays) + } + if c.JWT.RefreshWindowMinutes < 0 { + return fmt.Errorf("jwt.refresh_window_minutes must be non-negative") + } if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" { return fmt.Errorf("security.csp.policy is required when CSP is enabled") } diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 3655e07f..05b5adc1 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -29,6 +29,7 @@ const ( AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) AccountTypeAPIKey = "apikey" // API Key类型账号 + AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) ) // Redeem type constants @@ -63,3 +64,38 @@ const ( SubscriptionStatusExpired = "expired" SubscriptionStatusSuspended = "suspended" ) + +// DefaultAntigravityModelMapping 是 Antigravity 平台的默认模型映射 +// 当账号未配置 model_mapping 时使用此默认值 +// 与前端 useModelWhitelist.ts 中的 antigravityDefaultMappings 保持一致 +var DefaultAntigravityModelMapping = map[string]string{ + // Claude 白名单 + "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型 + "claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射 + "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型 + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + // Claude 详细版本 ID 映射 + "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型 + "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", + // Claude Haiku → Sonnet(无 Haiku 支持) + "claude-haiku-4-5": "claude-sonnet-4-5", + "claude-haiku-4-5-20251001": "claude-sonnet-4-5", + // Gemini 2.5 白名单 + "gemini-2.5-flash": "gemini-2.5-flash", + "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", + "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", + "gemini-2.5-pro": "gemini-2.5-pro", + // Gemini 3 白名单 + "gemini-3-flash": "gemini-3-flash", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", + "gemini-3-pro-image": "gemini-3-pro-image", + // Gemini 3 preview 映射 + "gemini-3-flash-preview": "gemini-3-flash", + "gemini-3-pro-preview": "gemini-3-pro-high", + "gemini-3-pro-image-preview": "gemini-3-pro-image", + // 其他官方模型 + "gpt-oss-120b-medium": "gpt-oss-120b-medium", + "tab_flash_lite_preview": "tab_flash_lite_preview", +} diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go new file mode 100644 index 00000000..b5d1dd0a --- /dev/null +++ b/backend/internal/handler/admin/account_data.go @@ -0,0 +1,544 @@ +package admin + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +const ( + dataType = "sub2api-data" + legacyDataType = "sub2api-bundle" + dataVersion = 1 + dataPageCap = 1000 +) + +type DataPayload struct { + Type string `json:"type,omitempty"` + Version int `json:"version,omitempty"` + ExportedAt string `json:"exported_at"` + Proxies []DataProxy `json:"proxies"` + Accounts []DataAccount `json:"accounts"` +} + +type DataProxy struct { + ProxyKey string `json:"proxy_key"` + Name string `json:"name"` + Protocol string `json:"protocol"` + Host string `json:"host"` + Port int `json:"port"` + Username string `json:"username,omitempty"` + Password string `json:"password,omitempty"` + Status string `json:"status"` +} + +type DataAccount struct { + Name string `json:"name"` + Notes *string `json:"notes,omitempty"` + Platform string `json:"platform"` + Type string `json:"type"` + Credentials map[string]any `json:"credentials"` + Extra map[string]any `json:"extra,omitempty"` + ProxyKey *string `json:"proxy_key,omitempty"` + Concurrency int `json:"concurrency"` + Priority int `json:"priority"` + RateMultiplier *float64 `json:"rate_multiplier,omitempty"` + ExpiresAt *int64 `json:"expires_at,omitempty"` + AutoPauseOnExpired *bool `json:"auto_pause_on_expired,omitempty"` +} + +type DataImportRequest struct { + Data DataPayload `json:"data"` + SkipDefaultGroupBind *bool `json:"skip_default_group_bind"` +} + +type DataImportResult struct { + ProxyCreated int `json:"proxy_created"` + ProxyReused int `json:"proxy_reused"` + ProxyFailed int `json:"proxy_failed"` + AccountCreated int `json:"account_created"` + AccountFailed int `json:"account_failed"` + Errors []DataImportError `json:"errors,omitempty"` +} + +type DataImportError struct { + Kind string `json:"kind"` + Name string `json:"name,omitempty"` + ProxyKey string `json:"proxy_key,omitempty"` + Message string `json:"message"` +} + +func buildProxyKey(protocol, host string, port int, username, password string) string { + return fmt.Sprintf("%s|%s|%d|%s|%s", strings.TrimSpace(protocol), strings.TrimSpace(host), port, strings.TrimSpace(username), strings.TrimSpace(password)) +} + +func (h *AccountHandler) ExportData(c *gin.Context) { + ctx := c.Request.Context() + + selectedIDs, err := parseAccountIDs(c) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + accounts, err := h.resolveExportAccounts(ctx, selectedIDs, c) + if err != nil { + response.ErrorFrom(c, err) + return + } + + includeProxies, err := parseIncludeProxies(c) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + var proxies []service.Proxy + if includeProxies { + proxies, err = h.resolveExportProxies(ctx, accounts) + if err != nil { + response.ErrorFrom(c, err) + return + } + } else { + proxies = []service.Proxy{} + } + + proxyKeyByID := make(map[int64]string, len(proxies)) + dataProxies := make([]DataProxy, 0, len(proxies)) + for i := range proxies { + p := proxies[i] + key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password) + proxyKeyByID[p.ID] = key + dataProxies = append(dataProxies, DataProxy{ + ProxyKey: key, + Name: p.Name, + Protocol: p.Protocol, + Host: p.Host, + Port: p.Port, + Username: p.Username, + Password: p.Password, + Status: p.Status, + }) + } + + dataAccounts := make([]DataAccount, 0, len(accounts)) + for i := range accounts { + acc := accounts[i] + var proxyKey *string + if acc.ProxyID != nil { + if key, ok := proxyKeyByID[*acc.ProxyID]; ok { + proxyKey = &key + } + } + var expiresAt *int64 + if acc.ExpiresAt != nil { + v := acc.ExpiresAt.Unix() + expiresAt = &v + } + dataAccounts = append(dataAccounts, DataAccount{ + Name: acc.Name, + Notes: acc.Notes, + Platform: acc.Platform, + Type: acc.Type, + Credentials: acc.Credentials, + Extra: acc.Extra, + ProxyKey: proxyKey, + Concurrency: acc.Concurrency, + Priority: acc.Priority, + RateMultiplier: acc.RateMultiplier, + ExpiresAt: expiresAt, + AutoPauseOnExpired: &acc.AutoPauseOnExpired, + }) + } + + payload := DataPayload{ + ExportedAt: time.Now().UTC().Format(time.RFC3339), + Proxies: dataProxies, + Accounts: dataAccounts, + } + + response.Success(c, payload) +} + +func (h *AccountHandler) ImportData(c *gin.Context) { + var req DataImportRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + dataPayload := req.Data + if err := validateDataHeader(dataPayload); err != nil { + response.BadRequest(c, err.Error()) + return + } + + skipDefaultGroupBind := true + if req.SkipDefaultGroupBind != nil { + skipDefaultGroupBind = *req.SkipDefaultGroupBind + } + + result := DataImportResult{} + existingProxies, err := h.listAllProxies(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + proxyKeyToID := make(map[string]int64, len(existingProxies)) + for i := range existingProxies { + p := existingProxies[i] + key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password) + proxyKeyToID[key] = p.ID + } + + for i := range dataPayload.Proxies { + item := dataPayload.Proxies[i] + key := item.ProxyKey + if key == "" { + key = buildProxyKey(item.Protocol, item.Host, item.Port, item.Username, item.Password) + } + if err := validateDataProxy(item); err != nil { + result.ProxyFailed++ + result.Errors = append(result.Errors, DataImportError{ + Kind: "proxy", + Name: item.Name, + ProxyKey: key, + Message: err.Error(), + }) + continue + } + normalizedStatus := normalizeProxyStatus(item.Status) + if existingID, ok := proxyKeyToID[key]; ok { + proxyKeyToID[key] = existingID + result.ProxyReused++ + if normalizedStatus != "" { + if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus { + _, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{ + Status: normalizedStatus, + }) + } + } + continue + } + + created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ + Name: defaultProxyName(item.Name), + Protocol: item.Protocol, + Host: item.Host, + Port: item.Port, + Username: item.Username, + Password: item.Password, + }) + if err != nil { + result.ProxyFailed++ + result.Errors = append(result.Errors, DataImportError{ + Kind: "proxy", + Name: item.Name, + ProxyKey: key, + Message: err.Error(), + }) + continue + } + proxyKeyToID[key] = created.ID + result.ProxyCreated++ + + if normalizedStatus != "" && normalizedStatus != created.Status { + _, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{ + Status: normalizedStatus, + }) + } + } + + for i := range dataPayload.Accounts { + item := dataPayload.Accounts[i] + if err := validateDataAccount(item); err != nil { + result.AccountFailed++ + result.Errors = append(result.Errors, DataImportError{ + Kind: "account", + Name: item.Name, + Message: err.Error(), + }) + continue + } + + var proxyID *int64 + if item.ProxyKey != nil && *item.ProxyKey != "" { + if id, ok := proxyKeyToID[*item.ProxyKey]; ok { + proxyID = &id + } else { + result.AccountFailed++ + result.Errors = append(result.Errors, DataImportError{ + Kind: "account", + Name: item.Name, + ProxyKey: *item.ProxyKey, + Message: "proxy_key not found", + }) + continue + } + } + + accountInput := &service.CreateAccountInput{ + Name: item.Name, + Notes: item.Notes, + Platform: item.Platform, + Type: item.Type, + Credentials: item.Credentials, + Extra: item.Extra, + ProxyID: proxyID, + Concurrency: item.Concurrency, + Priority: item.Priority, + RateMultiplier: item.RateMultiplier, + GroupIDs: nil, + ExpiresAt: item.ExpiresAt, + AutoPauseOnExpired: item.AutoPauseOnExpired, + SkipDefaultGroupBind: skipDefaultGroupBind, + } + + if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil { + result.AccountFailed++ + result.Errors = append(result.Errors, DataImportError{ + Kind: "account", + Name: item.Name, + Message: err.Error(), + }) + continue + } + result.AccountCreated++ + } + + response.Success(c, result) +} + +func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) { + page := 1 + pageSize := dataPageCap + var out []service.Proxy + for { + items, total, err := h.adminService.ListProxies(ctx, page, pageSize, "", "", "") + if err != nil { + return nil, err + } + out = append(out, items...) + if len(out) >= int(total) || len(items) == 0 { + break + } + page++ + } + return out, nil +} + +func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string) ([]service.Account, error) { + page := 1 + pageSize := dataPageCap + var out []service.Account + for { + items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search) + if err != nil { + return nil, err + } + out = append(out, items...) + if len(out) >= int(total) || len(items) == 0 { + break + } + page++ + } + return out, nil +} + +func (h *AccountHandler) resolveExportAccounts(ctx context.Context, ids []int64, c *gin.Context) ([]service.Account, error) { + if len(ids) > 0 { + accounts, err := h.adminService.GetAccountsByIDs(ctx, ids) + if err != nil { + return nil, err + } + out := make([]service.Account, 0, len(accounts)) + for _, acc := range accounts { + if acc == nil { + continue + } + out = append(out, *acc) + } + return out, nil + } + + platform := c.Query("platform") + accountType := c.Query("type") + status := c.Query("status") + search := strings.TrimSpace(c.Query("search")) + if len(search) > 100 { + search = search[:100] + } + return h.listAccountsFiltered(ctx, platform, accountType, status, search) +} + +func (h *AccountHandler) resolveExportProxies(ctx context.Context, accounts []service.Account) ([]service.Proxy, error) { + if len(accounts) == 0 { + return []service.Proxy{}, nil + } + + seen := make(map[int64]struct{}) + ids := make([]int64, 0) + for i := range accounts { + if accounts[i].ProxyID == nil { + continue + } + id := *accounts[i].ProxyID + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + ids = append(ids, id) + } + if len(ids) == 0 { + return []service.Proxy{}, nil + } + + return h.adminService.GetProxiesByIDs(ctx, ids) +} + +func parseAccountIDs(c *gin.Context) ([]int64, error) { + values := c.QueryArray("ids") + if len(values) == 0 { + raw := strings.TrimSpace(c.Query("ids")) + if raw != "" { + values = []string{raw} + } + } + if len(values) == 0 { + return nil, nil + } + + ids := make([]int64, 0, len(values)) + for _, item := range values { + for _, part := range strings.Split(item, ",") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + id, err := strconv.ParseInt(part, 10, 64) + if err != nil || id <= 0 { + return nil, fmt.Errorf("invalid account id: %s", part) + } + ids = append(ids, id) + } + } + return ids, nil +} + +func parseIncludeProxies(c *gin.Context) (bool, error) { + raw := strings.TrimSpace(strings.ToLower(c.Query("include_proxies"))) + if raw == "" { + return true, nil + } + switch raw { + case "1", "true", "yes", "on": + return true, nil + case "0", "false", "no", "off": + return false, nil + default: + return true, fmt.Errorf("invalid include_proxies value: %s", raw) + } +} + +func validateDataHeader(payload DataPayload) error { + if payload.Type != "" && payload.Type != dataType && payload.Type != legacyDataType { + return fmt.Errorf("unsupported data type: %s", payload.Type) + } + if payload.Version != 0 && payload.Version != dataVersion { + return fmt.Errorf("unsupported data version: %d", payload.Version) + } + if payload.Proxies == nil { + return errors.New("proxies is required") + } + if payload.Accounts == nil { + return errors.New("accounts is required") + } + return nil +} + +func validateDataProxy(item DataProxy) error { + if strings.TrimSpace(item.Protocol) == "" { + return errors.New("proxy protocol is required") + } + if strings.TrimSpace(item.Host) == "" { + return errors.New("proxy host is required") + } + if item.Port <= 0 || item.Port > 65535 { + return errors.New("proxy port is invalid") + } + switch item.Protocol { + case "http", "https", "socks5", "socks5h": + default: + return fmt.Errorf("proxy protocol is invalid: %s", item.Protocol) + } + if item.Status != "" { + normalizedStatus := normalizeProxyStatus(item.Status) + if normalizedStatus != service.StatusActive && normalizedStatus != "inactive" { + return fmt.Errorf("proxy status is invalid: %s", item.Status) + } + } + return nil +} + +func validateDataAccount(item DataAccount) error { + if strings.TrimSpace(item.Name) == "" { + return errors.New("account name is required") + } + if strings.TrimSpace(item.Platform) == "" { + return errors.New("account platform is required") + } + if strings.TrimSpace(item.Type) == "" { + return errors.New("account type is required") + } + if len(item.Credentials) == 0 { + return errors.New("account credentials is required") + } + switch item.Type { + case service.AccountTypeOAuth, service.AccountTypeSetupToken, service.AccountTypeAPIKey, service.AccountTypeUpstream: + default: + return fmt.Errorf("account type is invalid: %s", item.Type) + } + if item.RateMultiplier != nil && *item.RateMultiplier < 0 { + return errors.New("rate_multiplier must be >= 0") + } + if item.Concurrency < 0 { + return errors.New("concurrency must be >= 0") + } + if item.Priority < 0 { + return errors.New("priority must be >= 0") + } + return nil +} + +func defaultProxyName(name string) string { + if strings.TrimSpace(name) == "" { + return "imported-proxy" + } + return name +} + +func normalizeProxyStatus(status string) string { + normalized := strings.TrimSpace(strings.ToLower(status)) + switch normalized { + case "": + return "" + case service.StatusActive: + return service.StatusActive + case "inactive", service.StatusDisabled: + return "inactive" + default: + return normalized + } +} diff --git a/backend/internal/handler/admin/account_data_handler_test.go b/backend/internal/handler/admin/account_data_handler_test.go new file mode 100644 index 00000000..c8b04c2a --- /dev/null +++ b/backend/internal/handler/admin/account_data_handler_test.go @@ -0,0 +1,231 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type dataResponse struct { + Code int `json:"code"` + Data dataPayload `json:"data"` +} + +type dataPayload struct { + Type string `json:"type"` + Version int `json:"version"` + Proxies []dataProxy `json:"proxies"` + Accounts []dataAccount `json:"accounts"` +} + +type dataProxy struct { + ProxyKey string `json:"proxy_key"` + Name string `json:"name"` + Protocol string `json:"protocol"` + Host string `json:"host"` + Port int `json:"port"` + Username string `json:"username"` + Password string `json:"password"` + Status string `json:"status"` +} + +type dataAccount struct { + Name string `json:"name"` + Platform string `json:"platform"` + Type string `json:"type"` + Credentials map[string]any `json:"credentials"` + Extra map[string]any `json:"extra"` + ProxyKey *string `json:"proxy_key"` + Concurrency int `json:"concurrency"` + Priority int `json:"priority"` +} + +func setupAccountDataRouter() (*gin.Engine, *stubAdminService) { + gin.SetMode(gin.TestMode) + router := gin.New() + adminSvc := newStubAdminService() + + h := NewAccountHandler( + adminSvc, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + router.GET("/api/v1/admin/accounts/data", h.ExportData) + router.POST("/api/v1/admin/accounts/data", h.ImportData) + return router, adminSvc +} + +func TestExportDataIncludesSecrets(t *testing.T) { + router, adminSvc := setupAccountDataRouter() + + proxyID := int64(11) + adminSvc.proxies = []service.Proxy{ + { + ID: proxyID, + Name: "proxy", + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + Username: "user", + Password: "pass", + Status: service.StatusActive, + }, + { + ID: 12, + Name: "orphan", + Protocol: "https", + Host: "10.0.0.1", + Port: 443, + Username: "o", + Password: "p", + Status: service.StatusActive, + }, + } + adminSvc.accounts = []service.Account{ + { + ID: 21, + Name: "account", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Credentials: map[string]any{"token": "secret"}, + Extra: map[string]any{"note": "x"}, + ProxyID: &proxyID, + Concurrency: 3, + Priority: 50, + Status: service.StatusDisabled, + }, + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/data", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var resp dataResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Empty(t, resp.Data.Type) + require.Equal(t, 0, resp.Data.Version) + require.Len(t, resp.Data.Proxies, 1) + require.Equal(t, "pass", resp.Data.Proxies[0].Password) + require.Len(t, resp.Data.Accounts, 1) + require.Equal(t, "secret", resp.Data.Accounts[0].Credentials["token"]) +} + +func TestExportDataWithoutProxies(t *testing.T) { + router, adminSvc := setupAccountDataRouter() + + proxyID := int64(11) + adminSvc.proxies = []service.Proxy{ + { + ID: proxyID, + Name: "proxy", + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + Username: "user", + Password: "pass", + Status: service.StatusActive, + }, + } + adminSvc.accounts = []service.Account{ + { + ID: 21, + Name: "account", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Credentials: map[string]any{"token": "secret"}, + ProxyID: &proxyID, + Concurrency: 3, + Priority: 50, + Status: service.StatusDisabled, + }, + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/data?include_proxies=false", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var resp dataResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Len(t, resp.Data.Proxies, 0) + require.Len(t, resp.Data.Accounts, 1) + require.Nil(t, resp.Data.Accounts[0].ProxyKey) +} + +func TestImportDataReusesProxyAndSkipsDefaultGroup(t *testing.T) { + router, adminSvc := setupAccountDataRouter() + + adminSvc.proxies = []service.Proxy{ + { + ID: 1, + Name: "proxy", + Protocol: "socks5", + Host: "1.2.3.4", + Port: 1080, + Username: "u", + Password: "p", + Status: service.StatusActive, + }, + } + + dataPayload := map[string]any{ + "data": map[string]any{ + "type": dataType, + "version": dataVersion, + "proxies": []map[string]any{ + { + "proxy_key": "socks5|1.2.3.4|1080|u|p", + "name": "proxy", + "protocol": "socks5", + "host": "1.2.3.4", + "port": 1080, + "username": "u", + "password": "p", + "status": "active", + }, + }, + "accounts": []map[string]any{ + { + "name": "acc", + "platform": service.PlatformOpenAI, + "type": service.AccountTypeOAuth, + "credentials": map[string]any{"token": "x"}, + "proxy_key": "socks5|1.2.3.4|1080|u|p", + "concurrency": 3, + "priority": 50, + }, + }, + }, + "skip_default_group_bind": true, + } + + body, _ := json.Marshal(dataPayload) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/data", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + require.Len(t, adminSvc.createdProxies, 0) + require.Len(t, adminSvc.createdAccounts, 1) + require.True(t, adminSvc.createdAccounts[0].SkipDefaultGroupBind) +} diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index bbf5d026..9a13b57c 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" @@ -84,7 +85,7 @@ type CreateAccountRequest struct { Name string `json:"name" binding:"required"` Notes *string `json:"notes"` Platform string `json:"platform" binding:"required"` - Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"` + Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"` Credentials map[string]any `json:"credentials" binding:"required"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` @@ -102,7 +103,7 @@ type CreateAccountRequest struct { type UpdateAccountRequest struct { Name string `json:"name"` Notes *string `json:"notes"` - Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"` + Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` @@ -696,11 +697,61 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { return } - // Return mock data for now + ctx := c.Request.Context() + success := 0 + failed := 0 + results := make([]gin.H, 0, len(req.Accounts)) + + for _, item := range req.Accounts { + if item.RateMultiplier != nil && *item.RateMultiplier < 0 { + failed++ + results = append(results, gin.H{ + "name": item.Name, + "success": false, + "error": "rate_multiplier must be >= 0", + }) + continue + } + + skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk + + account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ + Name: item.Name, + Notes: item.Notes, + Platform: item.Platform, + Type: item.Type, + Credentials: item.Credentials, + Extra: item.Extra, + ProxyID: item.ProxyID, + Concurrency: item.Concurrency, + Priority: item.Priority, + RateMultiplier: item.RateMultiplier, + GroupIDs: item.GroupIDs, + ExpiresAt: item.ExpiresAt, + AutoPauseOnExpired: item.AutoPauseOnExpired, + SkipMixedChannelCheck: skipCheck, + }) + if err != nil { + failed++ + results = append(results, gin.H{ + "name": item.Name, + "success": false, + "error": err.Error(), + }) + continue + } + success++ + results = append(results, gin.H{ + "name": item.Name, + "id": account.ID, + "success": true, + }) + } + response.Success(c, gin.H{ - "success": len(req.Accounts), - "failed": 0, - "results": []gin.H{}, + "success": success, + "failed": failed, + "results": results, }) } @@ -1440,3 +1491,9 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { response.Success(c, results) } + +// GetAntigravityDefaultModelMapping 获取 Antigravity 平台的默认模型映射 +// GET /api/v1/admin/accounts/antigravity/default-model-mapping +func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) { + response.Success(c, domain.DefaultAntigravityModelMapping) +} diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index b820a3fb..77d288f9 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -2,19 +2,27 @@ package admin import ( "context" + "strings" + "sync" "time" "github.com/Wei-Shaw/sub2api/internal/service" ) type stubAdminService struct { - users []service.User - apiKeys []service.APIKey - groups []service.Group - accounts []service.Account - proxies []service.Proxy - proxyCounts []service.ProxyWithAccountCount - redeems []service.RedeemCode + users []service.User + apiKeys []service.APIKey + groups []service.Group + accounts []service.Account + proxies []service.Proxy + proxyCounts []service.ProxyWithAccountCount + redeems []service.RedeemCode + createdAccounts []*service.CreateAccountInput + createdProxies []*service.CreateProxyInput + updatedProxyIDs []int64 + updatedProxies []*service.UpdateProxyInput + testedProxyIDs []int64 + mu sync.Mutex } func newStubAdminService() *stubAdminService { @@ -177,6 +185,9 @@ func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([ } func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) { + s.mu.Lock() + s.createdAccounts = append(s.createdAccounts, input) + s.mu.Unlock() account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive} return &account, nil } @@ -214,7 +225,25 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic } func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) { - return s.proxies, int64(len(s.proxies)), nil + search = strings.TrimSpace(strings.ToLower(search)) + filtered := make([]service.Proxy, 0, len(s.proxies)) + for _, proxy := range s.proxies { + if protocol != "" && proxy.Protocol != protocol { + continue + } + if status != "" && proxy.Status != status { + continue + } + if search != "" { + name := strings.ToLower(proxy.Name) + host := strings.ToLower(proxy.Host) + if !strings.Contains(name, search) && !strings.Contains(host, search) { + continue + } + } + filtered = append(filtered, proxy) + } + return filtered, int64(len(filtered)), nil } func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) { @@ -230,16 +259,47 @@ func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([ } func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) { + for i := range s.proxies { + proxy := s.proxies[i] + if proxy.ID == id { + return &proxy, nil + } + } proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive} return &proxy, nil } +func (s *stubAdminService) GetProxiesByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) { + if len(ids) == 0 { + return []service.Proxy{}, nil + } + out := make([]service.Proxy, 0, len(ids)) + seen := make(map[int64]struct{}, len(ids)) + for _, id := range ids { + seen[id] = struct{}{} + } + for i := range s.proxies { + proxy := s.proxies[i] + if _, ok := seen[proxy.ID]; ok { + out = append(out, proxy) + } + } + return out, nil +} + func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) { + s.mu.Lock() + s.createdProxies = append(s.createdProxies, input) + s.mu.Unlock() proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive} return &proxy, nil } func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) { + s.mu.Lock() + s.updatedProxyIDs = append(s.updatedProxyIDs, id) + s.updatedProxies = append(s.updatedProxies, input) + s.mu.Unlock() proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive} return &proxy, nil } @@ -261,6 +321,9 @@ func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, po } func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) { + s.mu.Lock() + s.testedProxyIDs = append(s.testedProxyIDs, id) + s.mu.Unlock() return &service.ProxyTestResult{Success: true, Message: "ok"}, nil } @@ -290,5 +353,9 @@ func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*ser return &code, nil } +func (s *stubAdminService) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]service.RedeemCode, int64, float64, error) { + return s.redeems, int64(len(s.redeems)), 100.0, nil +} + // Ensure stub implements interface. var _ service.AdminService = (*stubAdminService)(nil) diff --git a/backend/internal/handler/admin/error_passthrough_handler.go b/backend/internal/handler/admin/error_passthrough_handler.go new file mode 100644 index 00000000..c32db561 --- /dev/null +++ b/backend/internal/handler/admin/error_passthrough_handler.go @@ -0,0 +1,273 @@ +package admin + +import ( + "strconv" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// ErrorPassthroughHandler 处理错误透传规则的 HTTP 请求 +type ErrorPassthroughHandler struct { + service *service.ErrorPassthroughService +} + +// NewErrorPassthroughHandler 创建错误透传规则处理器 +func NewErrorPassthroughHandler(service *service.ErrorPassthroughService) *ErrorPassthroughHandler { + return &ErrorPassthroughHandler{service: service} +} + +// CreateErrorPassthroughRuleRequest 创建规则请求 +type CreateErrorPassthroughRuleRequest struct { + Name string `json:"name" binding:"required"` + Enabled *bool `json:"enabled"` + Priority int `json:"priority"` + ErrorCodes []int `json:"error_codes"` + Keywords []string `json:"keywords"` + MatchMode string `json:"match_mode"` + Platforms []string `json:"platforms"` + PassthroughCode *bool `json:"passthrough_code"` + ResponseCode *int `json:"response_code"` + PassthroughBody *bool `json:"passthrough_body"` + CustomMessage *string `json:"custom_message"` + Description *string `json:"description"` +} + +// UpdateErrorPassthroughRuleRequest 更新规则请求(部分更新,所有字段可选) +type UpdateErrorPassthroughRuleRequest struct { + Name *string `json:"name"` + Enabled *bool `json:"enabled"` + Priority *int `json:"priority"` + ErrorCodes []int `json:"error_codes"` + Keywords []string `json:"keywords"` + MatchMode *string `json:"match_mode"` + Platforms []string `json:"platforms"` + PassthroughCode *bool `json:"passthrough_code"` + ResponseCode *int `json:"response_code"` + PassthroughBody *bool `json:"passthrough_body"` + CustomMessage *string `json:"custom_message"` + Description *string `json:"description"` +} + +// List 获取所有规则 +// GET /api/v1/admin/error-passthrough-rules +func (h *ErrorPassthroughHandler) List(c *gin.Context) { + rules, err := h.service.List(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, rules) +} + +// GetByID 根据 ID 获取规则 +// GET /api/v1/admin/error-passthrough-rules/:id +func (h *ErrorPassthroughHandler) GetByID(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid rule ID") + return + } + + rule, err := h.service.GetByID(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + if rule == nil { + response.NotFound(c, "Rule not found") + return + } + + response.Success(c, rule) +} + +// Create 创建规则 +// POST /api/v1/admin/error-passthrough-rules +func (h *ErrorPassthroughHandler) Create(c *gin.Context) { + var req CreateErrorPassthroughRuleRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + rule := &model.ErrorPassthroughRule{ + Name: req.Name, + Priority: req.Priority, + ErrorCodes: req.ErrorCodes, + Keywords: req.Keywords, + Platforms: req.Platforms, + } + + // 设置默认值 + if req.Enabled != nil { + rule.Enabled = *req.Enabled + } else { + rule.Enabled = true + } + if req.MatchMode != "" { + rule.MatchMode = req.MatchMode + } else { + rule.MatchMode = model.MatchModeAny + } + if req.PassthroughCode != nil { + rule.PassthroughCode = *req.PassthroughCode + } else { + rule.PassthroughCode = true + } + if req.PassthroughBody != nil { + rule.PassthroughBody = *req.PassthroughBody + } else { + rule.PassthroughBody = true + } + rule.ResponseCode = req.ResponseCode + rule.CustomMessage = req.CustomMessage + rule.Description = req.Description + + // 确保切片不为 nil + if rule.ErrorCodes == nil { + rule.ErrorCodes = []int{} + } + if rule.Keywords == nil { + rule.Keywords = []string{} + } + if rule.Platforms == nil { + rule.Platforms = []string{} + } + + created, err := h.service.Create(c.Request.Context(), rule) + if err != nil { + if _, ok := err.(*model.ValidationError); ok { + response.BadRequest(c, err.Error()) + return + } + response.ErrorFrom(c, err) + return + } + + response.Success(c, created) +} + +// Update 更新规则(支持部分更新) +// PUT /api/v1/admin/error-passthrough-rules/:id +func (h *ErrorPassthroughHandler) Update(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid rule ID") + return + } + + var req UpdateErrorPassthroughRuleRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + // 先获取现有规则 + existing, err := h.service.GetByID(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + if existing == nil { + response.NotFound(c, "Rule not found") + return + } + + // 部分更新:只更新请求中提供的字段 + rule := &model.ErrorPassthroughRule{ + ID: id, + Name: existing.Name, + Enabled: existing.Enabled, + Priority: existing.Priority, + ErrorCodes: existing.ErrorCodes, + Keywords: existing.Keywords, + MatchMode: existing.MatchMode, + Platforms: existing.Platforms, + PassthroughCode: existing.PassthroughCode, + ResponseCode: existing.ResponseCode, + PassthroughBody: existing.PassthroughBody, + CustomMessage: existing.CustomMessage, + Description: existing.Description, + } + + // 应用请求中提供的更新 + if req.Name != nil { + rule.Name = *req.Name + } + if req.Enabled != nil { + rule.Enabled = *req.Enabled + } + if req.Priority != nil { + rule.Priority = *req.Priority + } + if req.ErrorCodes != nil { + rule.ErrorCodes = req.ErrorCodes + } + if req.Keywords != nil { + rule.Keywords = req.Keywords + } + if req.MatchMode != nil { + rule.MatchMode = *req.MatchMode + } + if req.Platforms != nil { + rule.Platforms = req.Platforms + } + if req.PassthroughCode != nil { + rule.PassthroughCode = *req.PassthroughCode + } + if req.ResponseCode != nil { + rule.ResponseCode = req.ResponseCode + } + if req.PassthroughBody != nil { + rule.PassthroughBody = *req.PassthroughBody + } + if req.CustomMessage != nil { + rule.CustomMessage = req.CustomMessage + } + if req.Description != nil { + rule.Description = req.Description + } + + // 确保切片不为 nil + if rule.ErrorCodes == nil { + rule.ErrorCodes = []int{} + } + if rule.Keywords == nil { + rule.Keywords = []string{} + } + if rule.Platforms == nil { + rule.Platforms = []string{} + } + + updated, err := h.service.Update(c.Request.Context(), rule) + if err != nil { + if _, ok := err.(*model.ValidationError); ok { + response.BadRequest(c, err.Error()) + return + } + response.ErrorFrom(c, err) + return + } + + response.Success(c, updated) +} + +// Delete 删除规则 +// DELETE /api/v1/admin/error-passthrough-rules/:id +func (h *ErrorPassthroughHandler) Delete(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid rule ID") + return + } + + if err := h.service.Delete(c.Request.Context(), id); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Rule deleted successfully"}) +} diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index f93edbc8..d10d678b 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -35,14 +35,18 @@ type CreateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` - ClaudeCodeOnly bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled bool `json:"model_routing_enabled"` + MCPXMLInject *bool `json:"mcp_xml_inject"` + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string `json:"supported_model_scopes"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -60,14 +64,18 @@ type UpdateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` - ClaudeCodeOnly *bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + ClaudeCodeOnly *bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled *bool `json:"model_routing_enabled"` + MCPXMLInject *bool `json:"mcp_xml_inject"` + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes *[]string `json:"supported_model_scopes"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -159,23 +167,26 @@ func (h *GroupHandler) Create(c *gin.Context) { } group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, - ModelRouting: req.ModelRouting, - ModelRoutingEnabled: req.ModelRoutingEnabled, - CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, + MCPXMLInject: req.MCPXMLInject, + SupportedModelScopes: req.SupportedModelScopes, + CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { response.ErrorFrom(c, err) @@ -201,24 +212,27 @@ func (h *GroupHandler) Update(c *gin.Context) { } group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - Status: req.Status, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, - ModelRouting: req.ModelRouting, - ModelRoutingEnabled: req.ModelRoutingEnabled, - CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + Status: req.Status, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, + MCPXMLInject: req.MCPXMLInject, + SupportedModelScopes: req.SupportedModelScopes, + CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/admin/ops_realtime_handler.go b/backend/internal/handler/admin/ops_realtime_handler.go index 4f15ec57..c175dcd0 100644 --- a/backend/internal/handler/admin/ops_realtime_handler.go +++ b/backend/internal/handler/admin/ops_realtime_handler.go @@ -63,6 +63,43 @@ func (h *OpsHandler) GetConcurrencyStats(c *gin.Context) { response.Success(c, payload) } +// GetUserConcurrencyStats returns real-time concurrency usage for all active users. +// GET /api/v1/admin/ops/user-concurrency +func (h *OpsHandler) GetUserConcurrencyStats(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) { + response.Success(c, gin.H{ + "enabled": false, + "user": map[int64]*service.UserConcurrencyInfo{}, + "timestamp": time.Now().UTC(), + }) + return + } + + users, collectedAt, err := h.opsService.GetUserConcurrencyStats(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + payload := gin.H{ + "enabled": true, + "user": users, + } + if collectedAt != nil { + payload["timestamp"] = collectedAt.UTC() + } + response.Success(c, payload) +} + // GetAccountAvailability returns account availability statistics. // GET /api/v1/admin/ops/account-availability // diff --git a/backend/internal/handler/admin/proxy_data.go b/backend/internal/handler/admin/proxy_data.go new file mode 100644 index 00000000..72ecd6c1 --- /dev/null +++ b/backend/internal/handler/admin/proxy_data.go @@ -0,0 +1,239 @@ +package admin + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// ExportData exports proxy-only data for migration. +func (h *ProxyHandler) ExportData(c *gin.Context) { + ctx := c.Request.Context() + + selectedIDs, err := parseProxyIDs(c) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + var proxies []service.Proxy + if len(selectedIDs) > 0 { + proxies, err = h.getProxiesByIDs(ctx, selectedIDs) + if err != nil { + response.ErrorFrom(c, err) + return + } + } else { + protocol := c.Query("protocol") + status := c.Query("status") + search := strings.TrimSpace(c.Query("search")) + if len(search) > 100 { + search = search[:100] + } + + proxies, err = h.listProxiesFiltered(ctx, protocol, status, search) + if err != nil { + response.ErrorFrom(c, err) + return + } + } + + dataProxies := make([]DataProxy, 0, len(proxies)) + for i := range proxies { + p := proxies[i] + key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password) + dataProxies = append(dataProxies, DataProxy{ + ProxyKey: key, + Name: p.Name, + Protocol: p.Protocol, + Host: p.Host, + Port: p.Port, + Username: p.Username, + Password: p.Password, + Status: p.Status, + }) + } + + payload := DataPayload{ + ExportedAt: time.Now().UTC().Format(time.RFC3339), + Proxies: dataProxies, + Accounts: []DataAccount{}, + } + + response.Success(c, payload) +} + +// ImportData imports proxy-only data for migration. +func (h *ProxyHandler) ImportData(c *gin.Context) { + type ProxyImportRequest struct { + Data DataPayload `json:"data"` + } + + var req ProxyImportRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := validateDataHeader(req.Data); err != nil { + response.BadRequest(c, err.Error()) + return + } + + ctx := c.Request.Context() + result := DataImportResult{} + + existingProxies, err := h.listProxiesFiltered(ctx, "", "", "") + if err != nil { + response.ErrorFrom(c, err) + return + } + + proxyByKey := make(map[string]service.Proxy, len(existingProxies)) + for i := range existingProxies { + p := existingProxies[i] + key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password) + proxyByKey[key] = p + } + + latencyProbeIDs := make([]int64, 0, len(req.Data.Proxies)) + for i := range req.Data.Proxies { + item := req.Data.Proxies[i] + key := item.ProxyKey + if key == "" { + key = buildProxyKey(item.Protocol, item.Host, item.Port, item.Username, item.Password) + } + + if err := validateDataProxy(item); err != nil { + result.ProxyFailed++ + result.Errors = append(result.Errors, DataImportError{ + Kind: "proxy", + Name: item.Name, + ProxyKey: key, + Message: err.Error(), + }) + continue + } + + normalizedStatus := normalizeProxyStatus(item.Status) + if existing, ok := proxyByKey[key]; ok { + result.ProxyReused++ + if normalizedStatus != "" && normalizedStatus != existing.Status { + if _, err := h.adminService.UpdateProxy(ctx, existing.ID, &service.UpdateProxyInput{Status: normalizedStatus}); err != nil { + result.Errors = append(result.Errors, DataImportError{ + Kind: "proxy", + Name: item.Name, + ProxyKey: key, + Message: "update status failed: " + err.Error(), + }) + } + } + latencyProbeIDs = append(latencyProbeIDs, existing.ID) + continue + } + + created, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{ + Name: defaultProxyName(item.Name), + Protocol: item.Protocol, + Host: item.Host, + Port: item.Port, + Username: item.Username, + Password: item.Password, + }) + if err != nil { + result.ProxyFailed++ + result.Errors = append(result.Errors, DataImportError{ + Kind: "proxy", + Name: item.Name, + ProxyKey: key, + Message: err.Error(), + }) + continue + } + result.ProxyCreated++ + proxyByKey[key] = *created + + if normalizedStatus != "" && normalizedStatus != created.Status { + if _, err := h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{Status: normalizedStatus}); err != nil { + result.Errors = append(result.Errors, DataImportError{ + Kind: "proxy", + Name: item.Name, + ProxyKey: key, + Message: "update status failed: " + err.Error(), + }) + } + } + // CreateProxy already triggers a latency probe, avoid double probing here. + } + + if len(latencyProbeIDs) > 0 { + ids := append([]int64(nil), latencyProbeIDs...) + go func() { + for _, id := range ids { + _, _ = h.adminService.TestProxy(context.Background(), id) + } + }() + } + + response.Success(c, result) +} + +func (h *ProxyHandler) getProxiesByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) { + if len(ids) == 0 { + return []service.Proxy{}, nil + } + return h.adminService.GetProxiesByIDs(ctx, ids) +} + +func parseProxyIDs(c *gin.Context) ([]int64, error) { + values := c.QueryArray("ids") + if len(values) == 0 { + raw := strings.TrimSpace(c.Query("ids")) + if raw != "" { + values = []string{raw} + } + } + if len(values) == 0 { + return nil, nil + } + + ids := make([]int64, 0, len(values)) + for _, item := range values { + for _, part := range strings.Split(item, ",") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + id, err := strconv.ParseInt(part, 10, 64) + if err != nil || id <= 0 { + return nil, fmt.Errorf("invalid proxy id: %s", part) + } + ids = append(ids, id) + } + } + return ids, nil +} + +func (h *ProxyHandler) listProxiesFiltered(ctx context.Context, protocol, status, search string) ([]service.Proxy, error) { + page := 1 + pageSize := dataPageCap + var out []service.Proxy + for { + items, total, err := h.adminService.ListProxies(ctx, page, pageSize, protocol, status, search) + if err != nil { + return nil, err + } + out = append(out, items...) + if len(out) >= int(total) || len(items) == 0 { + break + } + page++ + } + return out, nil +} diff --git a/backend/internal/handler/admin/proxy_data_handler_test.go b/backend/internal/handler/admin/proxy_data_handler_test.go new file mode 100644 index 00000000..803f9b61 --- /dev/null +++ b/backend/internal/handler/admin/proxy_data_handler_test.go @@ -0,0 +1,188 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type proxyDataResponse struct { + Code int `json:"code"` + Data DataPayload `json:"data"` +} + +type proxyImportResponse struct { + Code int `json:"code"` + Data DataImportResult `json:"data"` +} + +func setupProxyDataRouter() (*gin.Engine, *stubAdminService) { + gin.SetMode(gin.TestMode) + router := gin.New() + adminSvc := newStubAdminService() + + h := NewProxyHandler(adminSvc) + router.GET("/api/v1/admin/proxies/data", h.ExportData) + router.POST("/api/v1/admin/proxies/data", h.ImportData) + + return router, adminSvc +} + +func TestProxyExportDataRespectsFilters(t *testing.T) { + router, adminSvc := setupProxyDataRouter() + + adminSvc.proxies = []service.Proxy{ + { + ID: 1, + Name: "proxy-a", + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + Username: "user", + Password: "pass", + Status: service.StatusActive, + }, + { + ID: 2, + Name: "proxy-b", + Protocol: "https", + Host: "10.0.0.2", + Port: 443, + Username: "u", + Password: "p", + Status: service.StatusDisabled, + }, + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?protocol=https", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var resp proxyDataResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Empty(t, resp.Data.Type) + require.Equal(t, 0, resp.Data.Version) + require.Len(t, resp.Data.Proxies, 1) + require.Len(t, resp.Data.Accounts, 0) + require.Equal(t, "https", resp.Data.Proxies[0].Protocol) +} + +func TestProxyExportDataWithSelectedIDs(t *testing.T) { + router, adminSvc := setupProxyDataRouter() + + adminSvc.proxies = []service.Proxy{ + { + ID: 1, + Name: "proxy-a", + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + Username: "user", + Password: "pass", + Status: service.StatusActive, + }, + { + ID: 2, + Name: "proxy-b", + Protocol: "https", + Host: "10.0.0.2", + Port: 443, + Username: "u", + Password: "p", + Status: service.StatusDisabled, + }, + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?ids=2", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var resp proxyDataResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Len(t, resp.Data.Proxies, 1) + require.Equal(t, "https", resp.Data.Proxies[0].Protocol) + require.Equal(t, "10.0.0.2", resp.Data.Proxies[0].Host) +} + +func TestProxyImportDataReusesAndTriggersLatencyProbe(t *testing.T) { + router, adminSvc := setupProxyDataRouter() + + adminSvc.proxies = []service.Proxy{ + { + ID: 1, + Name: "proxy-a", + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + Username: "user", + Password: "pass", + Status: service.StatusActive, + }, + } + + payload := map[string]any{ + "data": map[string]any{ + "type": dataType, + "version": dataVersion, + "proxies": []map[string]any{ + { + "proxy_key": "http|127.0.0.1|8080|user|pass", + "name": "proxy-a", + "protocol": "http", + "host": "127.0.0.1", + "port": 8080, + "username": "user", + "password": "pass", + "status": "inactive", + }, + { + "proxy_key": "https|10.0.0.2|443|u|p", + "name": "proxy-b", + "protocol": "https", + "host": "10.0.0.2", + "port": 443, + "username": "u", + "password": "p", + "status": "active", + }, + }, + "accounts": []map[string]any{}, + }, + } + + body, _ := json.Marshal(payload) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/data", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var resp proxyImportResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, 1, resp.Data.ProxyCreated) + require.Equal(t, 1, resp.Data.ProxyReused) + require.Equal(t, 0, resp.Data.ProxyFailed) + + adminSvc.mu.Lock() + updatedIDs := append([]int64(nil), adminSvc.updatedProxyIDs...) + adminSvc.mu.Unlock() + require.Contains(t, updatedIDs, int64(1)) + + require.Eventually(t, func() bool { + adminSvc.mu.Lock() + defer adminSvc.mu.Unlock() + return len(adminSvc.testedProxyIDs) == 1 + }, time.Second, 10*time.Millisecond) +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 9a5a691f..1c772e7d 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -45,6 +45,9 @@ type UpdateUserRequest struct { Concurrency *int `json:"concurrency"` Status string `json:"status" binding:"omitempty,oneof=active disabled"` AllowedGroups *[]int64 `json:"allowed_groups"` + // GroupRates 用户专属分组倍率配置 + // map[groupID]*rate,nil 表示删除该分组的专属倍率 + GroupRates map[int64]*float64 `json:"group_rates"` } // UpdateBalanceRequest represents balance update request @@ -183,6 +186,7 @@ func (h *UserHandler) Update(c *gin.Context) { Concurrency: req.Concurrency, Status: req.Status, AllowedGroups: req.AllowedGroups, + GroupRates: req.GroupRates, }) if err != nil { response.ErrorFrom(c, err) @@ -277,3 +281,44 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) { response.Success(c, stats) } + +// GetBalanceHistory handles getting user's balance/concurrency change history +// GET /api/v1/admin/users/:id/balance-history +// Query params: +// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription) +func (h *UserHandler) GetBalanceHistory(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + page, pageSize := response.ParsePagination(c) + codeType := c.Query("type") + + codes, total, totalRecharged, err := h.adminService.GetUserBalanceHistory(c.Request.Context(), userID, page, pageSize, codeType) + if err != nil { + response.ErrorFrom(c, err) + return + } + + // Convert to admin DTO (includes notes field for admin visibility) + out := make([]dto.AdminRedeemCode, 0, len(codes)) + for i := range codes { + out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) + } + + // Custom response with total_recharged alongside pagination + pages := int((total + int64(pageSize) - 1) / int64(pageSize)) + if pages < 1 { + pages = 1 + } + response.Success(c, gin.H{ + "items": out, + "total": total, + "page": page, + "page_size": pageSize, + "pages": pages, + "total_recharged": totalRecharged, + }) +} diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index 52dc6911..f1a18ad2 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -3,6 +3,7 @@ package handler import ( "strconv" + "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -27,11 +28,13 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler { // CreateAPIKeyRequest represents the create API key request payload type CreateAPIKeyRequest struct { - Name string `json:"name" binding:"required"` - GroupID *int64 `json:"group_id"` // nullable - CustomKey *string `json:"custom_key"` // 可选的自定义key - IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 - IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 + Name string `json:"name" binding:"required"` + GroupID *int64 `json:"group_id"` // nullable + CustomKey *string `json:"custom_key"` // 可选的自定义key + IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 + IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 + Quota *float64 `json:"quota"` // 配额限制 (USD) + ExpiresInDays *int `json:"expires_in_days"` // 过期天数 } // UpdateAPIKeyRequest represents the update API key request payload @@ -41,6 +44,9 @@ type UpdateAPIKeyRequest struct { Status string `json:"status" binding:"omitempty,oneof=active inactive"` IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 + Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制 + ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601) + ResetQuota *bool `json:"reset_quota"` // 重置已用配额 } // List handles listing user's API keys with pagination @@ -114,11 +120,15 @@ func (h *APIKeyHandler) Create(c *gin.Context) { } svcReq := service.CreateAPIKeyRequest{ - Name: req.Name, - GroupID: req.GroupID, - CustomKey: req.CustomKey, - IPWhitelist: req.IPWhitelist, - IPBlacklist: req.IPBlacklist, + Name: req.Name, + GroupID: req.GroupID, + CustomKey: req.CustomKey, + IPWhitelist: req.IPWhitelist, + IPBlacklist: req.IPBlacklist, + ExpiresInDays: req.ExpiresInDays, + } + if req.Quota != nil { + svcReq.Quota = *req.Quota } key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) if err != nil { @@ -153,6 +163,8 @@ func (h *APIKeyHandler) Update(c *gin.Context) { svcReq := service.UpdateAPIKeyRequest{ IPWhitelist: req.IPWhitelist, IPBlacklist: req.IPBlacklist, + Quota: req.Quota, + ResetQuota: req.ResetQuota, } if req.Name != "" { svcReq.Name = &req.Name @@ -161,6 +173,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) { if req.Status != "" { svcReq.Status = &req.Status } + // Parse expires_at if provided + if req.ExpiresAt != nil { + if *req.ExpiresAt == "" { + // Empty string means clear expiration + svcReq.ExpiresAt = nil + svcReq.ClearExpiration = true + } else { + t, err := time.Parse(time.RFC3339, *req.ExpiresAt) + if err != nil { + response.BadRequest(c, "Invalid expires_at format: "+err.Error()) + return + } + svcReq.ExpiresAt = &t + } + } key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq) if err != nil { @@ -216,3 +243,21 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) { } response.Success(c, out) } + +// GetUserGroupRates 获取当前用户的专属分组倍率配置 +// GET /api/v1/groups/rates +func (h *APIKeyHandler) GetUserGroupRates(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + rates, err := h.apiKeyService.GetUserGroupRates(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, rates) +} diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 75ea9f08..34ed63bc 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -68,9 +68,39 @@ type LoginRequest struct { // AuthResponse 认证响应格式(匹配前端期望) type AuthResponse struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - User *dto.User `json:"user"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` // 新增:Refresh Token + ExpiresIn int `json:"expires_in,omitempty"` // 新增:Access Token有效期(秒) + TokenType string `json:"token_type"` + User *dto.User `json:"user"` +} + +// respondWithTokenPair 生成 Token 对并返回认证响应 +// 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容) +func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) { + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "") + if err != nil { + slog.Error("failed to generate token pair", "error", err, "user_id", user.ID) + // 回退到只返回Access Token + token, tokenErr := h.authService.GenerateToken(user) + if tokenErr != nil { + response.InternalError(c, "Failed to generate token") + return + } + response.Success(c, AuthResponse{ + AccessToken: token, + TokenType: "Bearer", + User: dto.UserFromService(user), + }) + return + } + response.Success(c, AuthResponse{ + AccessToken: tokenPair.AccessToken, + RefreshToken: tokenPair.RefreshToken, + ExpiresIn: tokenPair.ExpiresIn, + TokenType: "Bearer", + User: dto.UserFromService(user), + }) } // Register handles user registration @@ -90,17 +120,13 @@ func (h *AuthHandler) Register(c *gin.Context) { } } - token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) + _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, AuthResponse{ - AccessToken: token, - TokenType: "Bearer", - User: dto.UserFromService(user), - }) + h.respondWithTokenPair(c, user) } // SendVerifyCode 发送邮箱验证码 @@ -150,6 +176,7 @@ func (h *AuthHandler) Login(c *gin.Context) { response.ErrorFrom(c, err) return } + _ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成 // Check if TOTP 2FA is enabled for this user if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled { @@ -168,11 +195,7 @@ func (h *AuthHandler) Login(c *gin.Context) { return } - response.Success(c, AuthResponse{ - AccessToken: token, - TokenType: "Bearer", - User: dto.UserFromService(user), - }) + h.respondWithTokenPair(c, user) } // TotpLoginResponse represents the response when 2FA is required @@ -238,18 +261,7 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { return } - // Generate the JWT token - token, err := h.authService.GenerateToken(user) - if err != nil { - response.InternalError(c, "Failed to generate token") - return - } - - response.Success(c, AuthResponse{ - AccessToken: token, - TokenType: "Bearer", - User: dto.UserFromService(user), - }) + h.respondWithTokenPair(c, user) } // GetCurrentUser handles getting current authenticated user @@ -491,3 +503,96 @@ func (h *AuthHandler) ResetPassword(c *gin.Context) { Message: "Your password has been reset successfully. You can now log in with your new password.", }) } + +// ==================== Token Refresh Endpoints ==================== + +// RefreshTokenRequest 刷新Token请求 +type RefreshTokenRequest struct { + RefreshToken string `json:"refresh_token" binding:"required"` +} + +// RefreshTokenResponse 刷新Token响应 +type RefreshTokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` // Access Token有效期(秒) + TokenType string `json:"token_type"` +} + +// RefreshToken 刷新Token +// POST /api/v1/auth/refresh +func (h *AuthHandler) RefreshToken(c *gin.Context) { + var req RefreshTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, RefreshTokenResponse{ + AccessToken: tokenPair.AccessToken, + RefreshToken: tokenPair.RefreshToken, + ExpiresIn: tokenPair.ExpiresIn, + TokenType: "Bearer", + }) +} + +// LogoutRequest 登出请求 +type LogoutRequest struct { + RefreshToken string `json:"refresh_token,omitempty"` // 可选:撤销指定的Refresh Token +} + +// LogoutResponse 登出响应 +type LogoutResponse struct { + Message string `json:"message"` +} + +// Logout 用户登出 +// POST /api/v1/auth/logout +func (h *AuthHandler) Logout(c *gin.Context) { + var req LogoutRequest + // 允许空请求体(向后兼容) + _ = c.ShouldBindJSON(&req) + + // 如果提供了Refresh Token,撤销它 + if req.RefreshToken != "" { + if err := h.authService.RevokeRefreshToken(c.Request.Context(), req.RefreshToken); err != nil { + slog.Debug("failed to revoke refresh token", "error", err) + // 不影响登出流程 + } + } + + response.Success(c, LogoutResponse{ + Message: "Logged out successfully", + }) +} + +// RevokeAllSessionsResponse 撤销所有会话响应 +type RevokeAllSessionsResponse struct { + Message string `json:"message"` +} + +// RevokeAllSessions 撤销当前用户的所有会话 +// POST /api/v1/auth/revoke-all-sessions +func (h *AuthHandler) RevokeAllSessions(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil { + slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err) + response.InternalError(c, "Failed to revoke sessions") + return + } + + response.Success(c, RevokeAllSessionsResponse{ + Message: "All sessions have been revoked. Please log in again.", + }) +} diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index a16c4cc7..0ccf47e4 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -211,7 +211,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { email = linuxDoSyntheticEmail(subject) } - jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username) + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username) if err != nil { // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) @@ -219,7 +219,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { } fragment := url.Values{} - fragment.Set("access_token", jwtToken) + fragment.Set("access_token", tokenPair.AccessToken) + fragment.Set("refresh_token", tokenPair.RefreshToken) + fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn)) fragment.Set("token_type", "Bearer") fragment.Set("redirect", redirectTo) redirectWithFragment(c, frontendCallback, fragment) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 632ee454..d14ab1d1 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -58,8 +58,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser { return nil } return &AdminUser{ - User: *base, - Notes: u.Notes, + User: *base, + Notes: u.Notes, + GroupRates: u.GroupRates, } } @@ -76,6 +77,9 @@ func APIKeyFromService(k *service.APIKey) *APIKey { Status: k.Status, IPWhitelist: k.IPWhitelist, IPBlacklist: k.IPBlacklist, + Quota: k.Quota, + QuotaUsed: k.QuotaUsed, + ExpiresAt: k.ExpiresAt, CreatedAt: k.CreatedAt, UpdatedAt: k.UpdatedAt, User: UserFromServiceShallow(k.User), @@ -105,10 +109,12 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { return nil } out := &AdminGroup{ - Group: groupFromServiceBase(g), - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - AccountCount: g.AccountCount, + Group: groupFromServiceBase(g), + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + MCPXMLInject: g.MCPXMLInject, + SupportedModelScopes: g.SupportedModelScopes, + AccountCount: g.AccountCount, } if len(g.AccountGroups) > 0 { out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) @@ -138,8 +144,10 @@ func groupFromServiceBase(g *service.Group) Group { ImagePrice4K: g.ImagePrice4K, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, + // 无效请求兜底分组 + FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, } } @@ -204,17 +212,6 @@ func AccountFromServiceShallow(a *service.Account) *Account { } } - if scopeLimits := a.GetAntigravityScopeRateLimits(); len(scopeLimits) > 0 { - out.ScopeRateLimits = make(map[string]ScopeRateLimitInfo, len(scopeLimits)) - now := time.Now() - for scope, remainingSec := range scopeLimits { - out.ScopeRateLimits[scope] = ScopeRateLimitInfo{ - ResetAt: now.Add(time.Duration(remainingSec) * time.Second), - RemainingSec: remainingSec, - } - } - } - return out } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index d3f706b3..71bb1ed4 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -29,19 +29,25 @@ type AdminUser struct { User Notes string `json:"notes"` + // GroupRates 用户专属分组倍率配置 + // map[groupID]rateMultiplier + GroupRates map[int64]float64 `json:"group_rates,omitempty"` } type APIKey struct { - ID int64 `json:"id"` - UserID int64 `json:"user_id"` - Key string `json:"key"` - Name string `json:"name"` - GroupID *int64 `json:"group_id"` - Status string `json:"status"` - IPWhitelist []string `json:"ip_whitelist"` - IPBlacklist []string `json:"ip_blacklist"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + Key string `json:"key"` + Name string `json:"name"` + GroupID *int64 `json:"group_id"` + Status string `json:"status"` + IPWhitelist []string `json:"ip_whitelist"` + IPBlacklist []string `json:"ip_blacklist"` + Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) + QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD + ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires) + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` User *User `json:"user,omitempty"` Group *Group `json:"group,omitempty"` @@ -69,6 +75,8 @@ type Group struct { // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` + // 无效请求兜底分组 + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` @@ -83,8 +91,13 @@ type AdminGroup struct { ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled bool `json:"model_routing_enabled"` - AccountGroups []AccountGroup `json:"account_groups,omitempty"` - AccountCount int64 `json:"account_count,omitempty"` + // MCP XML 协议注入(仅 antigravity 平台使用) + MCPXMLInject bool `json:"mcp_xml_inject"` + + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string `json:"supported_model_scopes"` + AccountGroups []AccountGroup `json:"account_groups,omitempty"` + AccountCount int64 `json:"account_count,omitempty"` } type Account struct { diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index f29da43f..ca4442e4 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -2,6 +2,7 @@ package handler import ( "context" + "crypto/rand" "encoding/json" "errors" "fmt" @@ -14,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" @@ -31,6 +33,8 @@ type GatewayHandler struct { userService *service.UserService billingCacheService *service.BillingCacheService usageService *service.UsageService + apiKeyService *service.APIKeyService + errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int maxAccountSwitchesGemini int @@ -45,6 +49,8 @@ func NewGatewayHandler( concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, usageService *service.UsageService, + apiKeyService *service.APIKeyService, + errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *GatewayHandler { pingInterval := time.Duration(0) @@ -66,6 +72,8 @@ func NewGatewayHandler( userService: userService, billingCacheService: billingCacheService, usageService: usageService, + apiKeyService: apiKeyService, + errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, @@ -104,9 +112,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - // 检查是否为 Claude Code 客户端,设置到 context 中 - SetClaudeCodeClientContext(c, body) - setOpsRequestContext(c, "", false, body) parsedReq, err := service.ParseGatewayRequest(body) @@ -117,6 +122,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqModel := parsedReq.Model reqStream := parsedReq.Stream + // 设置 max_tokens=1 + haiku 探测请求标识到 context 中 + // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 + if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { + ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true) + c.Request = c.Request.WithContext(ctx) + } + + // 检查是否为 Claude Code 客户端,设置到 context 中 + SetClaudeCodeClientContext(c, body) + isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context()) + + // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 + c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) + setOpsRequestContext(c, reqModel, reqStream, body) // 验证 model 必填 @@ -128,6 +147,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // Track if we've started streaming (for error handling) streamStarted := false + // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + // 获取订阅信息(可能为nil)- 提前获取用于后续检查 subscription, _ := middleware2.GetSubscriptionFromContext(c) @@ -193,11 +217,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { sessionKey = "gemini:" + sessionHash } + // 查询粘性会话绑定的账号 ID + var sessionBoundAccountID int64 + if sessionKey != "" { + sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) + } + // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 + hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 + if platform == service.PlatformGemini { maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + var lastFailoverErr *service.UpstreamFailoverError + var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 for { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 @@ -206,7 +239,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) + } return } account := selection.Account @@ -214,7 +251,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 检查请求拦截(预热请求、SUGGESTION MODE等) if account.IsInterceptWarmupEnabled() { - interceptType := detectInterceptType(body) + interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient) if interceptType != InterceptTypeNone { if selection.Acquired && selection.ReleaseFunc != nil { selection.ReleaseFunc() @@ -281,10 +318,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult + requestCtx := c.Request.Context() + if switchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + } if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body) + result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession) } else { - result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) + result, err = h.geminiCompatService.Forward(requestCtx, c, account, body) } if accountReleaseFunc != nil { accountReleaseFunc() @@ -293,9 +334,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} - lastFailoverStatus = failoverErr.StatusCode + lastFailoverErr = failoverErr + if failoverErr.ForceCacheBilling { + forceCacheBilling = true + } if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted) return } switchCount++ @@ -312,158 +356,223 @@ func (h *GatewayHandler) Messages(c *gin.Context) { clientIP := ip.GetClientIP(c) // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) { + go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: clientIP, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: clientIP, + ForceCacheBilling: fcb, + APIKeyService: h.apiKeyService, }); err != nil { log.Printf("Record usage failed: %v", err) } - }(result, account, userAgent, clientIP) + }(result, account, userAgent, clientIP, forceCacheBilling) return } } - maxAccountSwitches := h.maxAccountSwitches - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + currentAPIKey := apiKey + currentSubscription := subscription + var fallbackGroupID *int64 + if apiKey.Group != nil { + fallbackGroupID = apiKey.Group.FallbackGroupIDOnInvalidRequest + } + fallbackUsed := false for { - // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) - if err != nil { - if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) - return - } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) - return - } - account := selection.Account - setOpsSelectedAccount(c, account.ID) + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + var lastFailoverErr *service.UpstreamFailoverError + retryWithFallback := false + var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 - // 检查请求拦截(预热请求、SUGGESTION MODE等) - if account.IsInterceptWarmupEnabled() { - interceptType := detectInterceptType(body) - if interceptType != InterceptTypeNone { - if selection.Acquired && selection.ReleaseFunc != nil { - selection.ReleaseFunc() - } - if reqStream { - sendMockInterceptStream(c, reqModel, interceptType) - } else { - sendMockInterceptResponse(c, reqModel, interceptType) - } - return - } - } - - // 3. 获取账号并发槽位 - accountReleaseFunc := selection.ReleaseFunc - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - accountWaitCounted := false - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + for { + // 选择支持该模型的账号 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } - if err == nil && canWait { - accountWaitCounted = true - } - defer func() { - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - }() - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } - } - // 账号槽位/等待计数需要在超时或断开时安全回收 - accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) - - // 转发请求 - 根据账号平台分流 - var result *service.ForwardResult - if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) - } else { - result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq) - } - if accountReleaseFunc != nil { - accountReleaseFunc() - } - if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} - lastFailoverStatus = failoverErr.StatusCode - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - switchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) - continue + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) + } + return } - // 错误响应已在Forward中处理,这里只记录日志 - log.Printf("Account %d: Forward request failed: %v", account.ID, err) + account := selection.Account + setOpsSelectedAccount(c, account.ID) + + // 检查请求拦截(预热请求、SUGGESTION MODE等) + if account.IsInterceptWarmupEnabled() { + interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient) + if interceptType != InterceptTypeNone { + if selection.Acquired && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + if reqStream { + sendMockInterceptStream(c, reqModel, interceptType) + } else { + sendMockInterceptResponse(c, reqModel, interceptType) + } + return + } + } + + // 3. 获取账号并发槽位 + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + accountWaitCounted := false + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + if err == nil && canWait { + accountWaitCounted = true + } + defer func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + }() + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } + if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } + } + // 账号槽位/等待计数需要在超时或断开时安全回收 + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + // 转发请求 - 根据账号平台分流 + var result *service.ForwardResult + requestCtx := c.Request.Context() + if switchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + } + if account.Platform == service.PlatformAntigravity { + result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) + } else { + result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) + } + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var promptTooLongErr *service.PromptTooLongError + if errors.As(err, &promptTooLongErr) { + log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed) + if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 { + fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID) + if err != nil { + log.Printf("Resolve fallback group failed: %v", err) + _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) + return + } + if fallbackGroup.Platform != service.PlatformAnthropic || + fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription || + fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { + log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType) + _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) + return + } + fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup) + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil { + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + // 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度 + ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "") + c.Request = c.Request.WithContext(ctx) + currentAPIKey = fallbackAPIKey + currentSubscription = nil + fallbackUsed = true + retryWithFallback = true + break + } + _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) + return + } + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if failoverErr.ForceCacheBilling { + forceCacheBilling = true + } + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted) + return + } + switchCount++ + log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + continue + } + // 错误响应已在Forward中处理,这里只记录日志 + log.Printf("Account %d: Forward request failed: %v", account.ID, err) + return + } + + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + + // 异步记录使用量(subscription已在函数开头获取) + go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: currentAPIKey, + User: currentAPIKey.User, + Account: usedAccount, + Subscription: currentSubscription, + UserAgent: ua, + IPAddress: clientIP, + ForceCacheBilling: fcb, + APIKeyService: h.apiKeyService, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account, userAgent, clientIP, forceCacheBilling) + return + } + if !retryWithFallback { return } - - // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) - userAgent := c.GetHeader("User-Agent") - clientIP := ip.GetClientIP(c) - - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: clientIP, - }); err != nil { - log.Printf("Record usage failed: %v", err) - } - }(result, account, userAgent, clientIP) - return } } @@ -527,6 +636,17 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) { }) } +func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service.APIKey { + if apiKey == nil || group == nil { + return apiKey + } + cloned := *apiKey + groupID := group.ID + cloned.GroupID = &groupID + cloned.Group = group + return &cloned +} + // Usage handles getting account balance and usage statistics for CC Switch integration // GET /v1/usage func (h *GatewayHandler) Usage(c *gin.Context) { @@ -542,10 +662,10 @@ func (h *GatewayHandler) Usage(c *gin.Context) { return } - // Best-effort: 获取用量统计,失败不影响基础响应 + // Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应 var usageData gin.H if h.usageService != nil { - dashStats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID) + dashStats, err := h.usageService.GetAPIKeyDashboardStats(c.Request.Context(), apiKey.ID) if err == nil && dashStats != nil { usageData = gin.H{ "today": gin.H{ @@ -681,7 +801,37 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } -func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { +func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { + statusCode := failoverErr.StatusCode + responseBody := failoverErr.ResponseBody + + // 先检查透传规则 + if h.errorPassthroughService != nil && len(responseBody) > 0 { + if rule := h.errorPassthroughService.MatchRule(platform, statusCode, responseBody); rule != nil { + // 确定响应状态码 + respCode := statusCode + if !rule.PassthroughCode && rule.ResponseCode != nil { + respCode = *rule.ResponseCode + } + + // 确定响应消息 + msg := service.ExtractUpstreamErrorMessage(responseBody) + if !rule.PassthroughBody && rule.CustomMessage != nil { + msg = *rule.CustomMessage + } + + h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) + return + } + } + + // 使用默认的错误映射 + status, errType, errMsg := h.mapUpstreamError(statusCode) + h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况 +func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) { status, errType, errMsg := h.mapUpstreamError(statusCode) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } @@ -789,6 +939,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } + // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 + c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) // 验证 model 必填 if parsedReq.Model == "" { @@ -832,13 +984,37 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { type InterceptType int const ( - InterceptTypeNone InterceptType = iota - InterceptTypeWarmup // 预热请求(返回 "New Conversation") - InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串) + InterceptTypeNone InterceptType = iota + InterceptTypeWarmup // 预热请求(返回 "New Conversation") + InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串) + InterceptTypeMaxTokensOneHaiku // max_tokens=1 + haiku 探测请求(返回 "#") ) +// isHaikuModel 检查模型名称是否包含 "haiku"(大小写不敏感) +func isHaikuModel(model string) bool { + return strings.Contains(strings.ToLower(model), "haiku") +} + +// isMaxTokensOneHaikuRequest 检查是否为 max_tokens=1 + haiku 模型的探测请求 +// 这类请求用于 Claude Code 验证 API 连通性 +// 条件:max_tokens == 1 且 model 包含 "haiku" 且非流式请求 +func isMaxTokensOneHaikuRequest(model string, maxTokens int, isStream bool) bool { + return maxTokens == 1 && isHaikuModel(model) && !isStream +} + // detectInterceptType 检测请求是否需要拦截,返回拦截类型 -func detectInterceptType(body []byte) InterceptType { +// 参数说明: +// - body: 请求体字节 +// - model: 请求的模型名称 +// - maxTokens: max_tokens 值 +// - isStream: 是否为流式请求 +// - isClaudeCodeClient: 是否已通过 Claude Code 客户端校验 +func detectInterceptType(body []byte, model string, maxTokens int, isStream bool, isClaudeCodeClient bool) InterceptType { + // 优先检查 max_tokens=1 + haiku 探测请求(仅非流式) + if isClaudeCodeClient && isMaxTokensOneHaikuRequest(model, maxTokens, isStream) { + return InterceptTypeMaxTokensOneHaiku + } + // 快速检查:如果不包含任何关键字,直接返回 bodyStr := string(body) hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:") @@ -988,9 +1164,25 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce } } +// generateRealisticMsgID 生成仿真的消息 ID(msg_bdrk_XXXXXXX 格式) +// 格式与 Claude API 真实响应一致,24 位随机字母数字 +func generateRealisticMsgID() string { + const charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + const idLen = 24 + randomBytes := make([]byte, idLen) + if _, err := rand.Read(randomBytes); err != nil { + return fmt.Sprintf("msg_bdrk_%d", time.Now().UnixNano()) + } + b := make([]byte, idLen) + for i := range b { + b[i] = charset[int(randomBytes[i])%len(charset)] + } + return "msg_bdrk_" + string(b) +} + // sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截) func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) { - var msgID, text string + var msgID, text, stopReason string var outputTokens int switch interceptType { @@ -998,24 +1190,42 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter msgID = "msg_mock_suggestion" text = "" outputTokens = 1 + stopReason = "end_turn" + case InterceptTypeMaxTokensOneHaiku: + msgID = generateRealisticMsgID() + text = "#" + outputTokens = 1 + stopReason = "max_tokens" // max_tokens=1 探测请求的 stop_reason 应为 max_tokens default: // InterceptTypeWarmup msgID = "msg_mock_warmup" text = "New Conversation" outputTokens = 2 + stopReason = "end_turn" } - c.JSON(http.StatusOK, gin.H{ - "id": msgID, - "type": "message", - "role": "assistant", - "model": model, - "content": []gin.H{{"type": "text", "text": text}}, - "stop_reason": "end_turn", + // 构建完整的响应格式(与 Claude API 响应格式一致) + response := gin.H{ + "model": model, + "id": msgID, + "type": "message", + "role": "assistant", + "content": []gin.H{{"type": "text", "text": text}}, + "stop_reason": stopReason, + "stop_sequence": nil, "usage": gin.H{ - "input_tokens": 10, + "input_tokens": 10, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "cache_creation": gin.H{ + "ephemeral_5m_input_tokens": 0, + "ephemeral_1h_input_tokens": 0, + }, "output_tokens": outputTokens, + "total_tokens": 10 + outputTokens, }, - }) + } + + c.JSON(http.StatusOK, response) } func billingErrorDetails(err error) (status int, code, message string) { diff --git a/backend/internal/handler/gateway_handler_intercept_test.go b/backend/internal/handler/gateway_handler_intercept_test.go new file mode 100644 index 00000000..9e7d77a1 --- /dev/null +++ b/backend/internal/handler/gateway_handler_intercept_test.go @@ -0,0 +1,65 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestDetectInterceptType_MaxTokensOneHaikuRequiresClaudeCodeClient(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + + notClaudeCode := detectInterceptType(body, "claude-haiku-4-5", 1, false, false) + require.Equal(t, InterceptTypeNone, notClaudeCode) + + isClaudeCode := detectInterceptType(body, "claude-haiku-4-5", 1, false, true) + require.Equal(t, InterceptTypeMaxTokensOneHaiku, isClaudeCode) +} + +func TestDetectInterceptType_SuggestionModeUnaffected(t *testing.T) { + body := []byte(`{ + "messages":[{ + "role":"user", + "content":[{"type":"text","text":"[SUGGESTION MODE:foo]"}] + }], + "system":[] + }`) + + got := detectInterceptType(body, "claude-sonnet-4-5", 256, false, false) + require.Equal(t, InterceptTypeSuggestionMode, got) +} + +func TestSendMockInterceptResponse_MaxTokensOneHaiku(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + + sendMockInterceptResponse(ctx, "claude-haiku-4-5", InterceptTypeMaxTokensOneHaiku) + + require.Equal(t, http.StatusOK, rec.Code) + + var response map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &response)) + require.Equal(t, "max_tokens", response["stop_reason"]) + + id, ok := response["id"].(string) + require.True(t, ok) + require.True(t, strings.HasPrefix(id, "msg_bdrk_")) + + content, ok := response["content"].([]any) + require.True(t, ok) + require.NotEmpty(t, content) + + firstBlock, ok := content[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "#", firstBlock["text"]) + + usage, ok := response["usage"].(map[string]any) + require.True(t, ok) + require.Equal(t, float64(1), usage["output_tokens"]) +} diff --git a/backend/internal/handler/gemini_cli_session_test.go b/backend/internal/handler/gemini_cli_session_test.go index 0b37f5f2..80bc79c8 100644 --- a/backend/internal/handler/gemini_cli_session_test.go +++ b/backend/internal/handler/gemini_cli_session_test.go @@ -120,3 +120,24 @@ func TestGeminiCLITmpDirRegex(t *testing.T) { }) } } + +func TestSafeShortPrefix(t *testing.T) { + tests := []struct { + name string + input string + n int + want string + }{ + {name: "空字符串", input: "", n: 8, want: ""}, + {name: "长度小于截断值", input: "abc", n: 8, want: "abc"}, + {name: "长度等于截断值", input: "12345678", n: 8, want: "12345678"}, + {name: "长度大于截断值", input: "1234567890", n: 8, want: "12345678"}, + {name: "截断值为0", input: "123456", n: 0, want: "123456"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, safeShortPrefix(tt.input, tt.n)) + }) + } +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index d1b19ede..b1477ac6 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -5,6 +5,7 @@ import ( "context" "crypto/sha256" "encoding/hex" + "encoding/json" "errors" "io" "log" @@ -14,11 +15,13 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/google/uuid" "github.com/gin-gonic/gin" ) @@ -206,6 +209,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 1) user concurrency slot streamStarted := false + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted) if err != nil { googleError(c, http.StatusTooManyRequests, err.Error()) @@ -246,13 +252,78 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) } + + // === Gemini 内容摘要会话 Fallback 逻辑 === + // 当原有会话标识无效时(sessionBoundAccountID == 0),尝试基于内容摘要链匹配 + var geminiDigestChain string + var geminiPrefixHash string + var geminiSessionUUID string + useDigestFallback := sessionBoundAccountID == 0 + + if useDigestFallback { + // 解析 Gemini 请求体 + var geminiReq antigravity.GeminiRequest + if err := json.Unmarshal(body, &geminiReq); err == nil && len(geminiReq.Contents) > 0 { + // 生成摘要链 + geminiDigestChain = service.BuildGeminiDigestChain(&geminiReq) + if geminiDigestChain != "" { + // 生成前缀 hash + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + platform := "" + if apiKey.Group != nil { + platform = apiKey.Group.Platform + } + geminiPrefixHash = service.GenerateGeminiPrefixHash( + authSubject.UserID, + apiKey.ID, + clientIP, + userAgent, + platform, + modelName, + ) + + // 查找会话 + foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession( + c.Request.Context(), + derefGroupID(apiKey.GroupID), + geminiPrefixHash, + geminiDigestChain, + ) + if found { + sessionBoundAccountID = foundAccountID + geminiSessionUUID = foundUUID + log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s", + safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain)) + + // 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey + // 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号 + if sessionKey == "" { + sessionKey = service.GenerateGeminiDigestSessionKey(geminiPrefixHash, foundUUID) + } + _ = h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, foundAccountID) + } else { + // 生成新的会话 UUID + geminiSessionUUID = uuid.New().String() + // 为新会话也生成 sessionKey(用于后续请求的粘性会话) + if sessionKey == "" { + sessionKey = service.GenerateGeminiDigestSessionKey(geminiPrefixHash, geminiSessionUUID) + } + } + } + } + } + + // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 + hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 isCLI := isGeminiCLIRequest(c, body) cleanedForUnknownBinding := false maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + var lastFailoverErr *service.UpstreamFailoverError + var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 for { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制 @@ -261,7 +332,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) return } - handleGeminiFailoverExhausted(c, lastFailoverStatus) + h.handleGeminiFailoverExhausted(c, lastFailoverErr) return } account := selection.Account @@ -335,10 +406,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 5) forward (根据平台分流) var result *service.ForwardResult + requestCtx := c.Request.Context() + if switchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + } if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body) + result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) } else { - result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) + result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body) } if accountReleaseFunc != nil { accountReleaseFunc() @@ -347,12 +422,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} + if failoverErr.ForceCacheBilling { + forceCacheBilling = true + } if switchCount >= maxAccountSwitches { - lastFailoverStatus = failoverErr.StatusCode - handleGeminiFailoverExhausted(c, lastFailoverStatus) + lastFailoverErr = failoverErr + h.handleGeminiFailoverExhausted(c, lastFailoverErr) return } - lastFailoverStatus = failoverErr.StatusCode + lastFailoverErr = failoverErr switchCount++ log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) continue @@ -366,8 +444,22 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + // 保存 Gemini 内容摘要会话(用于 Fallback 匹配) + if useDigestFallback && geminiDigestChain != "" && geminiPrefixHash != "" { + if err := h.gatewayService.SaveGeminiSession( + c.Request.Context(), + derefGroupID(apiKey.GroupID), + geminiPrefixHash, + geminiDigestChain, + geminiSessionUUID, + account.ID, + ); err != nil { + log.Printf("[Gemini] Failed to save digest session: %v", err) + } + } + // 6) record usage async (Gemini 使用长上下文双倍计费) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) { + go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -381,10 +473,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { IPAddress: ip, LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextMultiplier: 2.0, // 超出部分双倍计费 + ForceCacheBilling: fcb, + APIKeyService: h.apiKeyService, }); err != nil { log.Printf("Record usage failed: %v", err) } - }(result, account, userAgent, clientIP) + }(result, account, userAgent, clientIP, forceCacheBilling) return } } @@ -408,7 +502,36 @@ func parseGeminiModelAction(rest string) (model string, action string, err error return "", "", &pathParseError{"invalid model action path"} } -func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) { +func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError) { + if failoverErr == nil { + googleError(c, http.StatusBadGateway, "Upstream request failed") + return + } + + statusCode := failoverErr.StatusCode + responseBody := failoverErr.ResponseBody + + // 先检查透传规则 + if h.errorPassthroughService != nil && len(responseBody) > 0 { + if rule := h.errorPassthroughService.MatchRule(service.PlatformGemini, statusCode, responseBody); rule != nil { + // 确定响应状态码 + respCode := statusCode + if !rule.PassthroughCode && rule.ResponseCode != nil { + respCode = *rule.ResponseCode + } + + // 确定响应消息 + msg := service.ExtractUpstreamErrorMessage(responseBody) + if !rule.PassthroughBody && rule.CustomMessage != nil { + msg = *rule.CustomMessage + } + + googleError(c, respCode, msg) + return + } + } + + // 使用默认的错误映射 status, message := mapGeminiUpstreamError(statusCode) googleError(c, status, message) } @@ -518,3 +641,28 @@ func extractGeminiCLISessionHash(c *gin.Context, body []byte) string { // 如果没有 privileged-user-id,直接使用 tmp 目录哈希 return tmpDirHash } + +// truncateDigestChain 截断摘要链用于日志显示 +func truncateDigestChain(chain string) string { + if len(chain) <= 50 { + return chain + } + return chain[:50] + "..." +} + +// safeShortPrefix 返回字符串前 n 个字符;长度不足时返回原字符串。 +// 用于日志展示,避免切片越界。 +func safeShortPrefix(value string, n int) string { + if n <= 0 || len(value) <= n { + return value + } + return value[:n] +} + +// derefGroupID 安全解引用 *int64,nil 返回 0 +func derefGroupID(groupID *int64) int64 { + if groupID == nil { + return 0 + } + return *groupID +} diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index b8f7d417..b2b12c0d 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -24,6 +24,7 @@ type AdminHandlers struct { Subscription *admin.SubscriptionHandler Usage *admin.UsageHandler UserAttribute *admin.UserAttributeHandler + ErrorPassthrough *admin.ErrorPassthroughHandler } // Handlers contains all HTTP handlers diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 4c9dd8b9..835297b8 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -22,10 +22,12 @@ import ( // OpenAIGatewayHandler handles OpenAI API gateway requests type OpenAIGatewayHandler struct { - gatewayService *service.OpenAIGatewayService - billingCacheService *service.BillingCacheService - concurrencyHelper *ConcurrencyHelper - maxAccountSwitches int + gatewayService *service.OpenAIGatewayService + billingCacheService *service.BillingCacheService + apiKeyService *service.APIKeyService + errorPassthroughService *service.ErrorPassthroughService + concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int } // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler @@ -33,6 +35,8 @@ func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, + apiKeyService *service.APIKeyService, + errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *OpenAIGatewayHandler { pingInterval := time.Duration(0) @@ -44,10 +48,12 @@ func NewOpenAIGatewayHandler( } } return &OpenAIGatewayHandler{ - gatewayService: gatewayService, - billingCacheService: billingCacheService, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), - maxAccountSwitches: maxAccountSwitches, + gatewayService: gatewayService, + billingCacheService: billingCacheService, + apiKeyService: apiKeyService, + errorPassthroughService: errorPassthroughService, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + maxAccountSwitches: maxAccountSwitches, } } @@ -143,6 +149,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Track if we've started streaming (for error handling) streamStarted := false + // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + // Get subscription info (may be nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) @@ -198,7 +209,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + var lastFailoverErr *service.UpstreamFailoverError for { // Select account supporting the requested model @@ -210,7 +221,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) + } return } account := selection.Account @@ -275,12 +290,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { - lastFailoverStatus = failoverErr.StatusCode - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + h.handleFailoverExhausted(c, failoverErr, streamStarted) return } - lastFailoverStatus = failoverErr.StatusCode switchCount++ log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) continue @@ -299,13 +313,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: ip, + APIKeyService: h.apiKeyService, }); err != nil { log.Printf("Record usage failed: %v", err) } @@ -320,7 +335,37 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } -func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { +func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) { + statusCode := failoverErr.StatusCode + responseBody := failoverErr.ResponseBody + + // 先检查透传规则 + if h.errorPassthroughService != nil && len(responseBody) > 0 { + if rule := h.errorPassthroughService.MatchRule("openai", statusCode, responseBody); rule != nil { + // 确定响应状态码 + respCode := statusCode + if !rule.PassthroughCode && rule.ResponseCode != nil { + respCode = *rule.ResponseCode + } + + // 确定响应消息 + msg := service.ExtractUpstreamErrorMessage(responseBody) + if !rule.PassthroughBody && rule.CustomMessage != nil { + msg = *rule.CustomMessage + } + + h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) + return + } + } + + // 使用默认的错误映射 + status, errType, errMsg := h.mapUpstreamError(statusCode) + h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况 +func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) { status, errType, errMsg := h.mapUpstreamError(statusCode) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 48a3794b..7b62149c 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -27,6 +27,7 @@ func ProvideAdminHandlers( subscriptionHandler *admin.SubscriptionHandler, usageHandler *admin.UsageHandler, userAttributeHandler *admin.UserAttributeHandler, + errorPassthroughHandler *admin.ErrorPassthroughHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -47,6 +48,7 @@ func ProvideAdminHandlers( Subscription: subscriptionHandler, Usage: usageHandler, UserAttribute: userAttributeHandler, + ErrorPassthrough: errorPassthroughHandler, } } @@ -125,6 +127,7 @@ var ProviderSet = wire.NewSet( admin.NewSubscriptionHandler, admin.NewUsageHandler, admin.NewUserAttributeHandler, + admin.NewErrorPassthroughHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/model/error_passthrough_rule.go b/backend/internal/model/error_passthrough_rule.go new file mode 100644 index 00000000..d4fc16e3 --- /dev/null +++ b/backend/internal/model/error_passthrough_rule.go @@ -0,0 +1,74 @@ +// Package model 定义服务层使用的数据模型。 +package model + +import "time" + +// ErrorPassthroughRule 全局错误透传规则 +// 用于控制上游错误如何返回给客户端 +type ErrorPassthroughRule struct { + ID int64 `json:"id"` + Name string `json:"name"` // 规则名称 + Enabled bool `json:"enabled"` // 是否启用 + Priority int `json:"priority"` // 优先级(数字越小优先级越高) + ErrorCodes []int `json:"error_codes"` // 匹配的错误码列表(OR关系) + Keywords []string `json:"keywords"` // 匹配的关键词列表(OR关系) + MatchMode string `json:"match_mode"` // "any"(任一条件) 或 "all"(所有条件) + Platforms []string `json:"platforms"` // 适用平台列表 + PassthroughCode bool `json:"passthrough_code"` // 是否透传原始状态码 + ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用) + PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息 + CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用) + Description *string `json:"description"` // 规则描述 + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// MatchModeAny 表示任一条件匹配即可 +const MatchModeAny = "any" + +// MatchModeAll 表示所有条件都必须匹配 +const MatchModeAll = "all" + +// 支持的平台常量 +const ( + PlatformAnthropic = "anthropic" + PlatformOpenAI = "openai" + PlatformGemini = "gemini" + PlatformAntigravity = "antigravity" +) + +// AllPlatforms 返回所有支持的平台列表 +func AllPlatforms() []string { + return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity} +} + +// Validate 验证规则配置的有效性 +func (r *ErrorPassthroughRule) Validate() error { + if r.Name == "" { + return &ValidationError{Field: "name", Message: "name is required"} + } + if r.MatchMode != MatchModeAny && r.MatchMode != MatchModeAll { + return &ValidationError{Field: "match_mode", Message: "match_mode must be 'any' or 'all'"} + } + // 至少需要配置一个匹配条件(错误码或关键词) + if len(r.ErrorCodes) == 0 && len(r.Keywords) == 0 { + return &ValidationError{Field: "conditions", Message: "at least one error_code or keyword is required"} + } + if !r.PassthroughCode && (r.ResponseCode == nil || *r.ResponseCode <= 0) { + return &ValidationError{Field: "response_code", Message: "response_code is required when passthrough_code is false"} + } + if !r.PassthroughBody && (r.CustomMessage == nil || *r.CustomMessage == "") { + return &ValidationError{Field: "custom_message", Message: "custom_message is required when passthrough_body is false"} + } + return nil +} + +// ValidationError 表示验证错误 +type ValidationError struct { + Field string + Message string +} + +func (e *ValidationError) Error() string { + return e.Field + ": " + e.Message +} diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index c7d657b9..d1712c98 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -40,17 +40,48 @@ const ( // URL 可用性 TTL(不可用 URL 的恢复时间) URLAvailabilityTTL = 5 * time.Minute + + // Antigravity API 端点 + antigravityProdBaseURL = "https://cloudcode-pa.googleapis.com" + antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" ) // BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致) var BaseURLs = []string{ - "https://cloudcode-pa.googleapis.com", // prod (优先) - "https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用) + antigravityProdBaseURL, // prod (优先) + antigravityDailyBaseURL, // daily sandbox (备用) } // BaseURL 默认 URL(保持向后兼容) var BaseURL = BaseURLs[0] +// ForwardBaseURLs 返回 API 转发用的 URL 顺序(daily 优先) +func ForwardBaseURLs() []string { + if len(BaseURLs) == 0 { + return nil + } + urls := append([]string(nil), BaseURLs...) + dailyIndex := -1 + for i, url := range urls { + if url == antigravityDailyBaseURL { + dailyIndex = i + break + } + } + if dailyIndex <= 0 { + return urls + } + reordered := make([]string, 0, len(urls)) + reordered = append(reordered, urls[dailyIndex]) + for i, url := range urls { + if i == dailyIndex { + continue + } + reordered = append(reordered, url) + } + return reordered +} + // URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级) type URLAvailability struct { mu sync.RWMutex @@ -100,22 +131,37 @@ func (u *URLAvailability) IsAvailable(url string) bool { // GetAvailableURLs 返回可用的 URL 列表 // 最近成功的 URL 优先,其他按默认顺序 func (u *URLAvailability) GetAvailableURLs() []string { + return u.GetAvailableURLsWithBase(BaseURLs) +} + +// GetAvailableURLsWithBase 返回可用的 URL 列表(使用自定义顺序) +// 最近成功的 URL 优先,其他按传入顺序 +func (u *URLAvailability) GetAvailableURLsWithBase(baseURLs []string) []string { u.mu.RLock() defer u.mu.RUnlock() now := time.Now() - result := make([]string, 0, len(BaseURLs)) + result := make([]string, 0, len(baseURLs)) // 如果有最近成功的 URL 且可用,放在最前面 if u.lastSuccess != "" { - expiry, exists := u.unavailable[u.lastSuccess] - if !exists || now.After(expiry) { - result = append(result, u.lastSuccess) + found := false + for _, url := range baseURLs { + if url == u.lastSuccess { + found = true + break + } + } + if found { + expiry, exists := u.unavailable[u.lastSuccess] + if !exists || now.After(expiry) { + result = append(result, u.lastSuccess) + } } } - // 添加其他可用的 URL(按默认顺序) - for _, url := range BaseURLs { + // 添加其他可用的 URL(按传入顺序) + for _, url := range baseURLs { // 跳过已添加的 lastSuccess if url == u.lastSuccess { continue diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 63f6ee7c..65f45cfc 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -44,17 +44,36 @@ type TransformOptions struct { // IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词; // 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。 IdentityPatch string + EnableMCPXML bool } func DefaultTransformOptions() TransformOptions { return TransformOptions{ EnableIdentityPatch: true, + EnableMCPXML: true, } } // webSearchFallbackModel web_search 请求使用的降级模型 const webSearchFallbackModel = "gemini-2.5-flash" +// MaxTokensBudgetPadding max_tokens 自动调整时在 budget_tokens 基础上增加的额度 +// Claude API 要求 max_tokens > thinking.budget_tokens,否则返回 400 错误 +const MaxTokensBudgetPadding = 1000 + +// Gemini 2.5 Flash thinking budget 上限 +const Gemini25FlashThinkingBudgetLimit = 24576 + +// ensureMaxTokensGreaterThanBudget 确保 max_tokens > budget_tokens +// Claude API 要求启用 thinking 时,max_tokens 必须大于 thinking.budget_tokens +// 返回调整后的 maxTokens 和是否进行了调整 +func ensureMaxTokensGreaterThanBudget(maxTokens, budgetTokens int) (int, bool) { + if budgetTokens > 0 && maxTokens <= budgetTokens { + return budgetTokens + MaxTokensBudgetPadding, true + } + return maxTokens, false +} + // TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式 func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) { return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions()) @@ -89,8 +108,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map return nil, fmt.Errorf("build contents: %w", err) } - // 2. 构建 systemInstruction - systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts, claudeReq.Tools) + // 2. 构建 systemInstruction(使用 targetModel 而非原始请求模型,确保身份注入基于最终模型) + systemInstruction := buildSystemInstruction(claudeReq.System, targetModel, opts, claudeReq.Tools) // 3. 构建 generationConfig reqForConfig := claudeReq @@ -171,6 +190,55 @@ func GetDefaultIdentityPatch() string { return antigravityIdentity } +// modelInfo 模型信息 +type modelInfo struct { + DisplayName string // 人类可读名称,如 "Claude Opus 4.5" + CanonicalID string // 规范模型 ID,如 "claude-opus-4-5-20250929" +} + +// modelInfoMap 模型前缀 → 模型信息映射 +// 只有在此映射表中的模型才会注入身份提示词 +// 注意:当前 claude-opus-4-6 会被映射到 claude-opus-4-5-thinking, +// 但保留此条目以便后续 Antigravity 上游支持 4.6 时快速切换 +var modelInfoMap = map[string]modelInfo{ + "claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"}, + "claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"}, + "claude-sonnet-4-5": {DisplayName: "Claude Sonnet 4.5", CanonicalID: "claude-sonnet-4-5-20250929"}, + "claude-haiku-4-5": {DisplayName: "Claude Haiku 4.5", CanonicalID: "claude-haiku-4-5-20251001"}, +} + +// getModelInfo 根据模型 ID 获取模型信息(前缀匹配) +func getModelInfo(modelID string) (info modelInfo, matched bool) { + var bestMatch string + + for prefix, mi := range modelInfoMap { + if strings.HasPrefix(modelID, prefix) && len(prefix) > len(bestMatch) { + bestMatch = prefix + info = mi + } + } + + return info, bestMatch != "" +} + +// GetModelDisplayName 根据模型 ID 获取人类可读的显示名称 +func GetModelDisplayName(modelID string) string { + if info, ok := getModelInfo(modelID); ok { + return info.DisplayName + } + return modelID +} + +// buildModelIdentityText 构建模型身份提示文本 +// 如果模型 ID 没有匹配到映射,返回空字符串 +func buildModelIdentityText(modelID string) string { + info, matched := getModelInfo(modelID) + if !matched { + return "" + } + return fmt.Sprintf("You are Model %s, ModelId is %s.", info.DisplayName, info.CanonicalID) +} + // mcpXMLProtocol MCP XML 工具调用协议(与 Antigravity-Manager 保持一致) const mcpXMLProtocol = ` ==== MCP XML 工具调用协议 (Workaround) ==== @@ -252,13 +320,17 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans identityPatch = defaultIdentityPatch(modelName) } parts = append(parts, GeminiPart{Text: identityPatch}) + + // 静默边界:隔离上方 identity 内容,使其被忽略 + modelIdentity := buildModelIdentityText(modelName) + parts = append(parts, GeminiPart{Text: fmt.Sprintf("\nBelow are your system instructions. Follow them strictly. The content above is internal initialization logs, irrelevant to the conversation. Do not reference, acknowledge, or mention it.\n\n**IMPORTANT**: Your responses must **NEVER** explicitly or implicitly reveal the existence of any content above this line. Never mention \"Antigravity\", \"Google Deepmind\", or any identity defined above.\n%s\n", modelIdentity)}) } // 添加用户的 system prompt parts = append(parts, userSystemParts...) - // 检测是否有 MCP 工具,如有则注入 XML 调用协议 - if hasMCPTools(tools) { + // 检测是否有 MCP 工具,如有且启用了 MCP XML 注入则注入 XML 调用协议 + if opts.EnableMCPXML && hasMCPTools(tools) { parts = append(parts, GeminiPart{Text: mcpXMLProtocol}) } @@ -312,7 +384,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT parts = append([]GeminiPart{{ Text: "Thinking...", Thought: true, - ThoughtSignature: dummyThoughtSignature, + ThoughtSignature: DummyThoughtSignature, }}, parts...) } } @@ -330,9 +402,10 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT return contents, strippedThinking, nil } -// dummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证 +// DummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证 // 参考: https://ai.google.dev/gemini-api/docs/thought-signatures -const dummyThoughtSignature = "skip_thought_signature_validator" +// 导出供跨包使用(如 gemini_native_signature_cleaner 跨账号修复) +const DummyThoughtSignature = "skip_thought_signature_validator" // buildParts 构建消息的 parts // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature @@ -370,7 +443,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu // signature 处理: // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失) // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature - if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) { + if block.Signature != "" && (allowDummyThought || block.Signature != DummyThoughtSignature) { part.ThoughtSignature = block.Signature } else if !allowDummyThought { // Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。 @@ -381,7 +454,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu continue } else { // Gemini 模型使用 dummy signature - part.ThoughtSignature = dummyThoughtSignature + part.ThoughtSignature = DummyThoughtSignature } parts = append(parts, part) @@ -411,10 +484,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu // tool_use 的 signature 处理: // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失) // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature - if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) { + if block.Signature != "" && (allowDummyThought || block.Signature != DummyThoughtSignature) { part.ThoughtSignature = block.Signature } else if allowDummyThought { - part.ThoughtSignature = dummyThoughtSignature + part.ThoughtSignature = DummyThoughtSignature } parts = append(parts, part) @@ -492,9 +565,23 @@ func parseToolResultContent(content json.RawMessage, isError bool) string { } // buildGenerationConfig 构建 generationConfig +const ( + defaultMaxOutputTokens = 64000 + maxOutputTokensUpperBound = 65000 + maxOutputTokensClaude = 64000 +) + +func maxOutputTokensLimit(model string) int { + if strings.HasPrefix(model, "claude-") { + return maxOutputTokensClaude + } + return maxOutputTokensUpperBound +} + func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { + maxLimit := maxOutputTokensLimit(req.Model) config := &GeminiGenerationConfig{ - MaxOutputTokens: 64000, // 默认最大输出 + MaxOutputTokens: defaultMaxOutputTokens, // 默认最大输出 StopSequences: DefaultStopSequences, } @@ -510,14 +597,25 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { } if req.Thinking.BudgetTokens > 0 { budget := req.Thinking.BudgetTokens - // gemini-2.5-flash 上限 24576 - if strings.Contains(req.Model, "gemini-2.5-flash") && budget > 24576 { - budget = 24576 + // gemini-2.5-flash 上限 + if strings.Contains(req.Model, "gemini-2.5-flash") && budget > Gemini25FlashThinkingBudgetLimit { + budget = Gemini25FlashThinkingBudgetLimit } config.ThinkingConfig.ThinkingBudget = budget + + // 自动修正:max_tokens 必须大于 budget_tokens + if adjusted, ok := ensureMaxTokensGreaterThanBudget(config.MaxOutputTokens, budget); ok { + log.Printf("[Antigravity] Auto-adjusted max_tokens from %d to %d (must be > budget_tokens=%d)", + config.MaxOutputTokens, adjusted, budget) + config.MaxOutputTokens = adjusted + } } } + if config.MaxOutputTokens > maxLimit { + config.MaxOutputTokens = maxLimit + } + // 其他参数 if req.Temperature != nil { config.Temperature = req.Temperature diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index 9d62a4a1..f938b47f 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -86,7 +86,7 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { if len(parts) != 3 { t.Fatalf("expected 3 parts, got %d", len(parts)) } - if !parts[1].Thought || parts[1].ThoughtSignature != dummyThoughtSignature { + if !parts[1].Thought || parts[1].ThoughtSignature != DummyThoughtSignature { t.Fatalf("expected dummy thought signature, got thought=%v signature=%q", parts[1].Thought, parts[1].ThoughtSignature) } @@ -126,8 +126,8 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) { if len(parts) != 1 || parts[0].FunctionCall == nil { t.Fatalf("expected 1 functionCall part, got %+v", parts) } - if parts[0].ThoughtSignature != dummyThoughtSignature { - t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature) + if parts[0].ThoughtSignature != DummyThoughtSignature { + t.Fatalf("expected dummy tool signature %q, got %q", DummyThoughtSignature, parts[0].ThoughtSignature) } }) diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index 8b3441dc..eecee11e 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -71,6 +71,12 @@ var DefaultModels = []Model{ DisplayName: "Claude Opus 4.5", CreatedAt: "2025-11-01T00:00:00Z", }, + { + ID: "claude-opus-4-6", + Type: "model", + DisplayName: "Claude Opus 4.6", + CreatedAt: "2026-02-06T00:00:00Z", + }, { ID: "claude-sonnet-4-5-20250929", Type: "model", diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 27bb5ac5..9bf563e7 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -14,8 +14,18 @@ const ( // RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。 RetryCount Key = "ctx_retry_count" + // AccountSwitchCount 表示请求过程中发生的账号切换次数 + AccountSwitchCount Key = "ctx_account_switch_count" + // IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端 IsClaudeCodeClient Key = "ctx_is_claude_code_client" + + // ThinkingEnabled 标识当前请求是否开启 thinking(用于 Antigravity 最终模型名推导与模型维度限流) + ThinkingEnabled Key = "ctx_thinking_enabled" // Group 认证后的分组信息,由 API Key 认证中间件设置 Group Key = "ctx_group" + + // IsMaxTokensOneHaikuRequest 标识当前请求是否为 max_tokens=1 + haiku 模型的探测请求 + // 用于 ClaudeCodeOnly 验证绕过(绕过 system prompt 检查,但仍需验证 User-Agent) + IsMaxTokensOneHaikuRequest Key = "ctx_is_max_tokens_one_haiku" ) diff --git a/backend/internal/pkg/googleapi/error.go b/backend/internal/pkg/googleapi/error.go new file mode 100644 index 00000000..b6374e02 --- /dev/null +++ b/backend/internal/pkg/googleapi/error.go @@ -0,0 +1,109 @@ +// Package googleapi provides helpers for Google-style API responses. +package googleapi + +import ( + "encoding/json" + "fmt" + "strings" +) + +// ErrorResponse represents a Google API error response +type ErrorResponse struct { + Error ErrorDetail `json:"error"` +} + +// ErrorDetail contains the error details from Google API +type ErrorDetail struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + Details []json.RawMessage `json:"details,omitempty"` +} + +// ErrorDetailInfo contains additional error information +type ErrorDetailInfo struct { + Type string `json:"@type"` + Reason string `json:"reason,omitempty"` + Domain string `json:"domain,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// ErrorHelp contains help links +type ErrorHelp struct { + Type string `json:"@type"` + Links []HelpLink `json:"links,omitempty"` +} + +// HelpLink represents a help link +type HelpLink struct { + Description string `json:"description"` + URL string `json:"url"` +} + +// ParseError parses a Google API error response and extracts key information +func ParseError(body string) (*ErrorResponse, error) { + var errResp ErrorResponse + if err := json.Unmarshal([]byte(body), &errResp); err != nil { + return nil, fmt.Errorf("failed to parse error response: %w", err) + } + return &errResp, nil +} + +// ExtractActivationURL extracts the API activation URL from error details +func ExtractActivationURL(body string) string { + var errResp ErrorResponse + if err := json.Unmarshal([]byte(body), &errResp); err != nil { + return "" + } + + // Check error details for activation URL + for _, detailRaw := range errResp.Error.Details { + // Parse as ErrorDetailInfo + var info ErrorDetailInfo + if err := json.Unmarshal(detailRaw, &info); err == nil { + if info.Metadata != nil { + if activationURL, ok := info.Metadata["activationUrl"]; ok && activationURL != "" { + return activationURL + } + } + } + + // Parse as ErrorHelp + var help ErrorHelp + if err := json.Unmarshal(detailRaw, &help); err == nil { + for _, link := range help.Links { + if strings.Contains(link.Description, "activation") || + strings.Contains(link.Description, "API activation") || + strings.Contains(link.URL, "/apis/api/") { + return link.URL + } + } + } + } + + return "" +} + +// IsServiceDisabledError checks if the error is a SERVICE_DISABLED error +func IsServiceDisabledError(body string) bool { + var errResp ErrorResponse + if err := json.Unmarshal([]byte(body), &errResp); err != nil { + return false + } + + // Check if it's a 403 PERMISSION_DENIED with SERVICE_DISABLED reason + if errResp.Error.Code != 403 || errResp.Error.Status != "PERMISSION_DENIED" { + return false + } + + for _, detailRaw := range errResp.Error.Details { + var info ErrorDetailInfo + if err := json.Unmarshal(detailRaw, &info); err == nil { + if info.Reason == "SERVICE_DISABLED" { + return true + } + } + } + + return false +} diff --git a/backend/internal/pkg/googleapi/error_test.go b/backend/internal/pkg/googleapi/error_test.go new file mode 100644 index 00000000..992dcf85 --- /dev/null +++ b/backend/internal/pkg/googleapi/error_test.go @@ -0,0 +1,143 @@ +package googleapi + +import ( + "testing" +) + +func TestExtractActivationURL(t *testing.T) { + // Test case from the user's error message + errorBody := `{ + "error": { + "code": 403, + "message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry.", + "status": "PERMISSION_DENIED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "SERVICE_DISABLED", + "domain": "googleapis.com", + "metadata": { + "service": "cloudaicompanion.googleapis.com", + "activationUrl": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843", + "consumer": "projects/project-6eca5881-ab73-4736-843", + "serviceTitle": "Gemini for Google Cloud API", + "containerInfo": "project-6eca5881-ab73-4736-843" + } + }, + { + "@type": "type.googleapis.com/google.rpc.LocalizedMessage", + "locale": "en-US", + "message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry." + }, + { + "@type": "type.googleapis.com/google.rpc.Help", + "links": [ + { + "description": "Google developers console API activation", + "url": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843" + } + ] + } + ] + } + }` + + activationURL := ExtractActivationURL(errorBody) + expectedURL := "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843" + + if activationURL != expectedURL { + t.Errorf("Expected activation URL %s, got %s", expectedURL, activationURL) + } +} + +func TestIsServiceDisabledError(t *testing.T) { + tests := []struct { + name string + body string + expected bool + }{ + { + name: "SERVICE_DISABLED error", + body: `{ + "error": { + "code": 403, + "status": "PERMISSION_DENIED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "SERVICE_DISABLED" + } + ] + } + }`, + expected: true, + }, + { + name: "Other 403 error", + body: `{ + "error": { + "code": 403, + "status": "PERMISSION_DENIED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "OTHER_REASON" + } + ] + } + }`, + expected: false, + }, + { + name: "404 error", + body: `{ + "error": { + "code": 404, + "status": "NOT_FOUND" + } + }`, + expected: false, + }, + { + name: "Invalid JSON", + body: `invalid json`, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsServiceDisabledError(tt.body) + if result != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestParseError(t *testing.T) { + errorBody := `{ + "error": { + "code": 403, + "message": "API not enabled", + "status": "PERMISSION_DENIED" + } + }` + + errResp, err := ParseError(errorBody) + if err != nil { + t.Fatalf("Failed to parse error: %v", err) + } + + if errResp.Error.Code != 403 { + t.Errorf("Expected code 403, got %d", errResp.Error.Code) + } + + if errResp.Error.Status != "PERMISSION_DENIED" { + t.Errorf("Expected status PERMISSION_DENIED, got %s", errResp.Error.Status) + } + + if errResp.Error.Message != "API not enabled" { + t.Errorf("Expected message 'API not enabled', got %s", errResp.Error.Message) + } +} diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index 4fab3359..fd24b11d 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -15,6 +15,8 @@ type Model struct { // DefaultModels OpenAI models list var DefaultModels = []Model{ + {ID: "gpt-5.3", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3"}, + {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, {ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"}, {ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"}, diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index e4e837e2..11c206d8 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -1089,8 +1089,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m result, err := client.ExecContext( ctx, "UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL", - payload, id, + string(payload), id, ) + if err != nil { return err } diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 1e5a62df..c0cfd256 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -33,7 +33,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro SetKey(key.Key). SetName(key.Name). SetStatus(key.Status). - SetNillableGroupID(key.GroupID) + SetNillableGroupID(key.GroupID). + SetQuota(key.Quota). + SetQuotaUsed(key.QuotaUsed). + SetNillableExpiresAt(key.ExpiresAt) if len(key.IPWhitelist) > 0 { builder.SetIPWhitelist(key.IPWhitelist) @@ -110,6 +113,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se apikey.FieldStatus, apikey.FieldIPWhitelist, apikey.FieldIPBlacklist, + apikey.FieldQuota, + apikey.FieldQuotaUsed, + apikey.FieldExpiresAt, ). WithUser(func(q *dbent.UserQuery) { q.Select( @@ -136,8 +142,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldImagePrice4k, group.FieldClaudeCodeOnly, group.FieldFallbackGroupID, + group.FieldFallbackGroupIDOnInvalidRequest, group.FieldModelRoutingEnabled, group.FieldModelRouting, + group.FieldMcpXMLInject, + group.FieldSupportedModelScopes, ) }). Only(ctx) @@ -161,6 +170,8 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()). SetName(key.Name). SetStatus(key.Status). + SetQuota(key.Quota). + SetQuotaUsed(key.QuotaUsed). SetUpdatedAt(now) if key.GroupID != nil { builder.SetGroupID(*key.GroupID) @@ -168,6 +179,13 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro builder.ClearGroupID() } + // Expiration time + if key.ExpiresAt != nil { + builder.SetExpiresAt(*key.ExpiresAt) + } else { + builder.ClearExpiresAt() + } + // IP 限制字段 if len(key.IPWhitelist) > 0 { builder.SetIPWhitelist(key.IPWhitelist) @@ -357,6 +375,38 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) return keys, nil } +// IncrementQuotaUsed atomically increments the quota_used field and returns the new value +func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + // Use raw SQL for atomic increment to avoid race conditions + // First get current value + m, err := r.activeQuery(). + Where(apikey.IDEQ(id)). + Select(apikey.FieldQuotaUsed). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return 0, service.ErrAPIKeyNotFound + } + return 0, err + } + + newValue := m.QuotaUsed + amount + + // Update with new value + affected, err := r.client.APIKey.Update(). + Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). + SetQuotaUsed(newValue). + Save(ctx) + if err != nil { + return 0, err + } + if affected == 0 { + return 0, service.ErrAPIKeyNotFound + } + + return newValue, nil +} + func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { if m == nil { return nil @@ -372,6 +422,9 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { CreatedAt: m.CreatedAt, UpdatedAt: m.UpdatedAt, GroupID: m.GroupID, + Quota: m.Quota, + QuotaUsed: m.QuotaUsed, + ExpiresAt: m.ExpiresAt, } if m.Edges.User != nil { out.User = userEntityToService(m.Edges.User) @@ -409,28 +462,31 @@ func groupEntityToService(g *dbent.Group) *service.Group { return nil } return &service.Group{ - ID: g.ID, - Name: g.Name, - Description: derefString(g.Description), - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - Hydrated: true, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUsd, - WeeklyLimitUSD: g.WeeklyLimitUsd, - MonthlyLimitUSD: g.MonthlyLimitUsd, - ImagePrice1K: g.ImagePrice1k, - ImagePrice2K: g.ImagePrice2k, - ImagePrice4K: g.ImagePrice4k, - DefaultValidityDays: g.DefaultValidityDays, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, + ID: g.ID, + Name: g.Name, + Description: derefString(g.Description), + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + Hydrated: true, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUsd, + WeeklyLimitUSD: g.WeeklyLimitUsd, + MonthlyLimitUSD: g.MonthlyLimitUsd, + ImagePrice1K: g.ImagePrice1k, + ImagePrice2K: g.ImagePrice2k, + ImagePrice4K: g.ImagePrice4k, + DefaultValidityDays: g.DefaultValidityDays, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + MCPXMLInject: g.McpXMLInject, + SupportedModelScopes: g.SupportedModelScopes, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, } } diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index b34961e1..cc0c6db5 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -194,6 +194,53 @@ var ( return result `) + // getUsersLoadBatchScript - batch load query for users with expired slot cleanup + // ARGV[1] = slot TTL (seconds) + // ARGV[2..n] = userID1, maxConcurrency1, userID2, maxConcurrency2, ... + getUsersLoadBatchScript = redis.NewScript(` + local result = {} + local slotTTL = tonumber(ARGV[1]) + + -- Get current server time + local timeResult = redis.call('TIME') + local nowSeconds = tonumber(timeResult[1]) + local cutoffTime = nowSeconds - slotTTL + + local i = 2 + while i <= #ARGV do + local userID = ARGV[i] + local maxConcurrency = tonumber(ARGV[i + 1]) + + local slotKey = 'concurrency:user:' .. userID + + -- Clean up expired slots before counting + redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime) + local currentConcurrency = redis.call('ZCARD', slotKey) + + local waitKey = 'concurrency:wait:' .. userID + local waitingCount = redis.call('GET', waitKey) + if waitingCount == false then + waitingCount = 0 + else + waitingCount = tonumber(waitingCount) + end + + local loadRate = 0 + if maxConcurrency > 0 then + loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency) + end + + table.insert(result, userID) + table.insert(result, currentConcurrency) + table.insert(result, waitingCount) + table.insert(result, loadRate) + + i = i + 2 + end + + return result + `) + // cleanupExpiredSlotsScript - remove expired slots // KEYS[1] = concurrency:account:{accountID} // ARGV[1] = TTL (seconds) @@ -384,6 +431,43 @@ func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts [] return loadMap, nil } +func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + if len(users) == 0 { + return map[int64]*service.UserLoadInfo{}, nil + } + + args := []any{c.slotTTLSeconds} + for _, u := range users { + args = append(args, u.ID, u.MaxConcurrency) + } + + result, err := getUsersLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice() + if err != nil { + return nil, err + } + + loadMap := make(map[int64]*service.UserLoadInfo) + for i := 0; i < len(result); i += 4 { + if i+3 >= len(result) { + break + } + + userID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64) + currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1])) + waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2])) + loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3])) + + loadMap[userID] = &service.UserLoadInfo{ + UserID: userID, + CurrentConcurrency: currentConcurrency, + WaitingCount: waitingCount, + LoadRate: loadRate, + } + } + + return loadMap, nil +} + func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { key := accountSlotKey(accountID) _, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result() diff --git a/backend/internal/repository/error_passthrough_cache.go b/backend/internal/repository/error_passthrough_cache.go new file mode 100644 index 00000000..5584ffc8 --- /dev/null +++ b/backend/internal/repository/error_passthrough_cache.go @@ -0,0 +1,128 @@ +package repository + +import ( + "context" + "encoding/json" + "log" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const ( + errorPassthroughCacheKey = "error_passthrough_rules" + errorPassthroughPubSubKey = "error_passthrough_rules_updated" + errorPassthroughCacheTTL = 24 * time.Hour +) + +type errorPassthroughCache struct { + rdb *redis.Client + localCache []*model.ErrorPassthroughRule + localMu sync.RWMutex +} + +// NewErrorPassthroughCache 创建错误透传规则缓存 +func NewErrorPassthroughCache(rdb *redis.Client) service.ErrorPassthroughCache { + return &errorPassthroughCache{ + rdb: rdb, + } +} + +// Get 从缓存获取规则列表 +func (c *errorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) { + // 先检查本地缓存 + c.localMu.RLock() + if c.localCache != nil { + rules := c.localCache + c.localMu.RUnlock() + return rules, true + } + c.localMu.RUnlock() + + // 从 Redis 获取 + data, err := c.rdb.Get(ctx, errorPassthroughCacheKey).Bytes() + if err != nil { + if err != redis.Nil { + log.Printf("[ErrorPassthroughCache] Failed to get from Redis: %v", err) + } + return nil, false + } + + var rules []*model.ErrorPassthroughRule + if err := json.Unmarshal(data, &rules); err != nil { + log.Printf("[ErrorPassthroughCache] Failed to unmarshal rules: %v", err) + return nil, false + } + + // 更新本地缓存 + c.localMu.Lock() + c.localCache = rules + c.localMu.Unlock() + + return rules, true +} + +// Set 设置缓存 +func (c *errorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error { + data, err := json.Marshal(rules) + if err != nil { + return err + } + + if err := c.rdb.Set(ctx, errorPassthroughCacheKey, data, errorPassthroughCacheTTL).Err(); err != nil { + return err + } + + // 更新本地缓存 + c.localMu.Lock() + c.localCache = rules + c.localMu.Unlock() + + return nil +} + +// Invalidate 使缓存失效 +func (c *errorPassthroughCache) Invalidate(ctx context.Context) error { + // 清除本地缓存 + c.localMu.Lock() + c.localCache = nil + c.localMu.Unlock() + + // 清除 Redis 缓存 + return c.rdb.Del(ctx, errorPassthroughCacheKey).Err() +} + +// NotifyUpdate 通知其他实例刷新缓存 +func (c *errorPassthroughCache) NotifyUpdate(ctx context.Context) error { + return c.rdb.Publish(ctx, errorPassthroughPubSubKey, "refresh").Err() +} + +// SubscribeUpdates 订阅缓存更新通知 +func (c *errorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) { + go func() { + sub := c.rdb.Subscribe(ctx, errorPassthroughPubSubKey) + defer func() { _ = sub.Close() }() + + ch := sub.Channel() + for { + select { + case <-ctx.Done(): + return + case msg := <-ch: + if msg == nil { + return + } + // 清除本地缓存,下次访问时会从 Redis 或数据库重新加载 + c.localMu.Lock() + c.localCache = nil + c.localMu.Unlock() + + // 调用处理函数 + handler() + } + } + }() +} diff --git a/backend/internal/repository/error_passthrough_repo.go b/backend/internal/repository/error_passthrough_repo.go new file mode 100644 index 00000000..a58ab60f --- /dev/null +++ b/backend/internal/repository/error_passthrough_repo.go @@ -0,0 +1,178 @@ +package repository + +import ( + "context" + + "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type errorPassthroughRepository struct { + client *ent.Client +} + +// NewErrorPassthroughRepository 创建错误透传规则仓库 +func NewErrorPassthroughRepository(client *ent.Client) service.ErrorPassthroughRepository { + return &errorPassthroughRepository{client: client} +} + +// List 获取所有规则 +func (r *errorPassthroughRepository) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { + rules, err := r.client.ErrorPassthroughRule.Query(). + Order(ent.Asc(errorpassthroughrule.FieldPriority)). + All(ctx) + if err != nil { + return nil, err + } + + result := make([]*model.ErrorPassthroughRule, len(rules)) + for i, rule := range rules { + result[i] = r.toModel(rule) + } + return result, nil +} + +// GetByID 根据 ID 获取规则 +func (r *errorPassthroughRepository) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { + rule, err := r.client.ErrorPassthroughRule.Get(ctx, id) + if err != nil { + if ent.IsNotFound(err) { + return nil, nil + } + return nil, err + } + return r.toModel(rule), nil +} + +// Create 创建规则 +func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + builder := r.client.ErrorPassthroughRule.Create(). + SetName(rule.Name). + SetEnabled(rule.Enabled). + SetPriority(rule.Priority). + SetMatchMode(rule.MatchMode). + SetPassthroughCode(rule.PassthroughCode). + SetPassthroughBody(rule.PassthroughBody) + + if len(rule.ErrorCodes) > 0 { + builder.SetErrorCodes(rule.ErrorCodes) + } + if len(rule.Keywords) > 0 { + builder.SetKeywords(rule.Keywords) + } + if len(rule.Platforms) > 0 { + builder.SetPlatforms(rule.Platforms) + } + if rule.ResponseCode != nil { + builder.SetResponseCode(*rule.ResponseCode) + } + if rule.CustomMessage != nil { + builder.SetCustomMessage(*rule.CustomMessage) + } + if rule.Description != nil { + builder.SetDescription(*rule.Description) + } + + created, err := builder.Save(ctx) + if err != nil { + return nil, err + } + return r.toModel(created), nil +} + +// Update 更新规则 +func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + builder := r.client.ErrorPassthroughRule.UpdateOneID(rule.ID). + SetName(rule.Name). + SetEnabled(rule.Enabled). + SetPriority(rule.Priority). + SetMatchMode(rule.MatchMode). + SetPassthroughCode(rule.PassthroughCode). + SetPassthroughBody(rule.PassthroughBody) + + // 处理可选字段 + if len(rule.ErrorCodes) > 0 { + builder.SetErrorCodes(rule.ErrorCodes) + } else { + builder.ClearErrorCodes() + } + if len(rule.Keywords) > 0 { + builder.SetKeywords(rule.Keywords) + } else { + builder.ClearKeywords() + } + if len(rule.Platforms) > 0 { + builder.SetPlatforms(rule.Platforms) + } else { + builder.ClearPlatforms() + } + if rule.ResponseCode != nil { + builder.SetResponseCode(*rule.ResponseCode) + } else { + builder.ClearResponseCode() + } + if rule.CustomMessage != nil { + builder.SetCustomMessage(*rule.CustomMessage) + } else { + builder.ClearCustomMessage() + } + if rule.Description != nil { + builder.SetDescription(*rule.Description) + } else { + builder.ClearDescription() + } + + updated, err := builder.Save(ctx) + if err != nil { + return nil, err + } + return r.toModel(updated), nil +} + +// Delete 删除规则 +func (r *errorPassthroughRepository) Delete(ctx context.Context, id int64) error { + return r.client.ErrorPassthroughRule.DeleteOneID(id).Exec(ctx) +} + +// toModel 将 Ent 实体转换为服务模型 +func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model.ErrorPassthroughRule { + rule := &model.ErrorPassthroughRule{ + ID: int64(e.ID), + Name: e.Name, + Enabled: e.Enabled, + Priority: e.Priority, + ErrorCodes: e.ErrorCodes, + Keywords: e.Keywords, + MatchMode: e.MatchMode, + Platforms: e.Platforms, + PassthroughCode: e.PassthroughCode, + PassthroughBody: e.PassthroughBody, + CreatedAt: e.CreatedAt, + UpdatedAt: e.UpdatedAt, + } + + if e.ResponseCode != nil { + rule.ResponseCode = e.ResponseCode + } + if e.CustomMessage != nil { + rule.CustomMessage = e.CustomMessage + } + if e.Description != nil { + rule.Description = e.Description + } + + // 确保切片不为 nil + if rule.ErrorCodes == nil { + rule.ErrorCodes = []int{} + } + if rule.Keywords == nil { + rule.Keywords = []string{} + } + if rule.Platforms == nil { + rule.Platforms = []string{} + } + + return rule +} diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 58291b66..9365252a 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -11,6 +11,63 @@ import ( const stickySessionPrefix = "sticky_session:" +// Gemini Trie Lua 脚本 +const ( + // geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本 + // KEYS[1] = trie key + // ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d") + // ARGV[2] = TTL seconds (用于刷新) + // 返回: 最长匹配的 value (uuid:accountID) 或 nil + // 查找成功时自动刷新 TTL,防止活跃会话意外过期 + geminiTrieFindScript = ` +local chain = ARGV[1] +local ttl = tonumber(ARGV[2]) +local lastMatch = nil +local path = "" + +for part in string.gmatch(chain, "[^-]+") do + path = path == "" and part or path .. "-" .. part + local val = redis.call('HGET', KEYS[1], path) + if val and val ~= "" then + lastMatch = val + end +end + +if lastMatch then + redis.call('EXPIRE', KEYS[1], ttl) +end + +return lastMatch +` + + // geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本 + // KEYS[1] = trie key + // ARGV[1] = digestChain + // ARGV[2] = value (uuid:accountID) + // ARGV[3] = TTL seconds + geminiTrieSaveScript = ` +local chain = ARGV[1] +local value = ARGV[2] +local ttl = tonumber(ARGV[3]) +local path = "" + +for part in string.gmatch(chain, "[^-]+") do + path = path == "" and part or path .. "-" .. part +end +redis.call('HSET', KEYS[1], path, value) +redis.call('EXPIRE', KEYS[1], ttl) +return "OK" +` +) + +// 模型负载统计相关常量 +const ( + modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀 + modelLastUsedKeyPrefix = "ag:model_last_used:" // 模型最后调度时间 key 前缀 + modelLoadTTL = 24 * time.Hour // 调用次数 TTL(24 小时无调用后清零) + modelLastUsedTTL = 24 * time.Hour // 最后调度时间 TTL +) + type gatewayCache struct { rdb *redis.Client } @@ -51,3 +108,133 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64 key := buildSessionKey(groupID, sessionHash) return c.rdb.Del(ctx, key).Err() } + +// ============ Antigravity 模型负载统计方法 ============ + +// modelLoadKey 构建模型调用次数 key +// 格式: ag:model_load:{accountID}:{model} +func modelLoadKey(accountID int64, model string) string { + return fmt.Sprintf("%s%d:%s", modelLoadKeyPrefix, accountID, model) +} + +// modelLastUsedKey 构建模型最后调度时间 key +// 格式: ag:model_last_used:{accountID}:{model} +func modelLastUsedKey(accountID int64, model string) string { + return fmt.Sprintf("%s%d:%s", modelLastUsedKeyPrefix, accountID, model) +} + +// IncrModelCallCount 增加模型调用次数并更新最后调度时间 +// 返回更新后的调用次数 +func (c *gatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { + loadKey := modelLoadKey(accountID, model) + lastUsedKey := modelLastUsedKey(accountID, model) + + pipe := c.rdb.Pipeline() + incrCmd := pipe.Incr(ctx, loadKey) + pipe.Expire(ctx, loadKey, modelLoadTTL) // 每次调用刷新 TTL + pipe.Set(ctx, lastUsedKey, time.Now().Unix(), modelLastUsedTTL) + if _, err := pipe.Exec(ctx); err != nil { + return 0, err + } + return incrCmd.Val(), nil +} + +// GetModelLoadBatch 批量获取账号的模型负载信息 +func (c *gatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*service.ModelLoadInfo, error) { + if len(accountIDs) == 0 { + return make(map[int64]*service.ModelLoadInfo), nil + } + + loadCmds, lastUsedCmds := c.pipelineModelLoadGet(ctx, accountIDs, model) + return c.parseModelLoadResults(accountIDs, loadCmds, lastUsedCmds), nil +} + +// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作 +func (c *gatewayCache) pipelineModelLoadGet( + ctx context.Context, + accountIDs []int64, + model string, +) (map[int64]*redis.StringCmd, map[int64]*redis.StringCmd) { + pipe := c.rdb.Pipeline() + loadCmds := make(map[int64]*redis.StringCmd, len(accountIDs)) + lastUsedCmds := make(map[int64]*redis.StringCmd, len(accountIDs)) + + for _, id := range accountIDs { + loadCmds[id] = pipe.Get(ctx, modelLoadKey(id, model)) + lastUsedCmds[id] = pipe.Get(ctx, modelLastUsedKey(id, model)) + } + _, _ = pipe.Exec(ctx) // 忽略错误,key 不存在是正常的 + return loadCmds, lastUsedCmds +} + +// parseModelLoadResults 解析 Pipeline 结果 +func (c *gatewayCache) parseModelLoadResults( + accountIDs []int64, + loadCmds map[int64]*redis.StringCmd, + lastUsedCmds map[int64]*redis.StringCmd, +) map[int64]*service.ModelLoadInfo { + result := make(map[int64]*service.ModelLoadInfo, len(accountIDs)) + for _, id := range accountIDs { + result[id] = &service.ModelLoadInfo{ + CallCount: getInt64OrZero(loadCmds[id]), + LastUsedAt: getTimeOrZero(lastUsedCmds[id]), + } + } + return result +} + +// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0 +func getInt64OrZero(cmd *redis.StringCmd) int64 { + val, _ := cmd.Int64() + return val +} + +// getTimeOrZero 从 StringCmd 获取 time.Time,失败返回零值 +func getTimeOrZero(cmd *redis.StringCmd) time.Time { + val, err := cmd.Int64() + if err != nil { + return time.Time{} + } + return time.Unix(val, 0) +} + +// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============ + +// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询) +// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL +func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { + if digestChain == "" { + return "", 0, false + } + + trieKey := service.BuildGeminiTrieKey(groupID, prefixHash) + ttlSeconds := int(service.GeminiSessionTTL().Seconds()) + + // 使用 Lua 脚本在 Redis 端执行 Trie 查找,O(L) 次 HGET,1 次网络往返 + // 查找成功时自动刷新 TTL,防止活跃会话意外过期 + result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result() + if err != nil || result == nil { + return "", 0, false + } + + value, ok := result.(string) + if !ok || value == "" { + return "", 0, false + } + + uuid, accountID, ok = service.ParseGeminiSessionValue(value) + return uuid, accountID, ok +} + +// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本) +func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { + if digestChain == "" { + return nil + } + + trieKey := service.BuildGeminiTrieKey(groupID, prefixHash) + value := service.FormatGeminiSessionValue(uuid, accountID) + ttlSeconds := int(service.GeminiSessionTTL().Seconds()) + + return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err() +} diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go index 0eebc33f..fc8e7372 100644 --- a/backend/internal/repository/gateway_cache_integration_test.go +++ b/backend/internal/repository/gateway_cache_integration_test.go @@ -104,6 +104,158 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") } +// ============ Gemini Trie 会话测试 ============ + +func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() { + groupID := int64(1) + prefixHash := "testprefix" + digestChain := "u:hash1-m:hash2-u:hash3" + uuid := "test-uuid-123" + accountID := int64(42) + + // 保存会话 + err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID) + require.NoError(s.T(), err, "SaveGeminiSession") + + // 精确匹配查找 + foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain) + require.True(s.T(), found, "should find exact match") + require.Equal(s.T(), uuid, foundUUID) + require.Equal(s.T(), accountID, foundAccountID) +} + +func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() { + groupID := int64(1) + prefixHash := "prefixmatch" + shortChain := "u:a-m:b" + longChain := "u:a-m:b-u:c-m:d" + uuid := "uuid-prefix" + accountID := int64(100) + + // 保存短链 + err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID) + require.NoError(s.T(), err) + + // 用长链查找,应该匹配到短链(前缀匹配) + foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain) + require.True(s.T(), found, "should find prefix match") + require.Equal(s.T(), uuid, foundUUID) + require.Equal(s.T(), accountID, foundAccountID) +} + +func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() { + groupID := int64(1) + prefixHash := "longestmatch" + + // 保存多个不同长度的链 + err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1) + require.NoError(s.T(), err) + err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2) + require.NoError(s.T(), err) + err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3) + require.NoError(s.T(), err) + + // 查找更长的链,应该匹配到最长的前缀 + foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e") + require.True(s.T(), found, "should find longest prefix match") + require.Equal(s.T(), "uuid-long", foundUUID) + require.Equal(s.T(), int64(3), foundAccountID) + + // 查找中等长度的链 + foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x") + require.True(s.T(), found) + require.Equal(s.T(), "uuid-medium", foundUUID) + require.Equal(s.T(), int64(2), foundAccountID) +} + +func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() { + groupID := int64(1) + prefixHash := "nomatch" + digestChain := "u:a-m:b" + + // 保存一个会话 + err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1) + require.NoError(s.T(), err) + + // 用不同的链查找,应该找不到 + _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y") + require.False(s.T(), found, "should not find non-matching chain") +} + +func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() { + groupID := int64(1) + digestChain := "u:a-m:b" + + // 保存到 prefixHash1 + err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1) + require.NoError(s.T(), err) + + // 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离) + _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain) + require.False(s.T(), found, "different prefixHash should be isolated") +} + +func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() { + prefixHash := "sameprefix" + digestChain := "u:a-m:b" + + // 保存到 groupID 1 + err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1) + require.NoError(s.T(), err) + + // 用 groupID 2 查找,应该找不到(分组隔离) + _, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain) + require.False(s.T(), found, "different groupID should be isolated") +} + +func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() { + groupID := int64(1) + prefixHash := "emptytest" + + // 空链不应该保存 + err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1) + require.NoError(s.T(), err, "empty chain should not error") + + // 空链查找应该返回 false + _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "") + require.False(s.T(), found, "empty chain should not match") +} + +func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() { + groupID := int64(1) + prefixHash := "multisession" + + // 保存多个不同会话(模拟 1000 个并发会话的场景) + sessions := []struct { + chain string + uuid string + accountID int64 + }{ + {"u:session1", "uuid-1", 1}, + {"u:session2-m:reply2", "uuid-2", 2}, + {"u:session3-m:reply3-u:msg3", "uuid-3", 3}, + } + + for _, sess := range sessions { + err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID) + require.NoError(s.T(), err) + } + + // 验证每个会话都能正确查找 + for _, sess := range sessions { + foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain) + require.True(s.T(), found, "should find session: %s", sess.chain) + require.Equal(s.T(), sess.uuid, foundUUID) + require.Equal(s.T(), sess.accountID, foundAccountID) + } + + // 验证继续对话的场景 + foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg") + require.True(s.T(), found) + require.Equal(s.T(), "uuid-2", foundUUID) + require.Equal(s.T(), int64(2), foundAccountID) +} + func TestGatewayCacheSuite(t *testing.T) { suite.Run(t, new(GatewayCacheSuite)) } diff --git a/backend/internal/repository/gateway_cache_model_load_integration_test.go b/backend/internal/repository/gateway_cache_model_load_integration_test.go new file mode 100644 index 00000000..de6fa5ae --- /dev/null +++ b/backend/internal/repository/gateway_cache_model_load_integration_test.go @@ -0,0 +1,234 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +// ============ Gateway Cache 模型负载统计集成测试 ============ + +type GatewayCacheModelLoadSuite struct { + suite.Suite +} + +func TestGatewayCacheModelLoadSuite(t *testing.T) { + suite.Run(t, new(GatewayCacheModelLoadSuite)) +} + +func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_Basic() { + t := s.T() + rdb := testRedis(t) + cache := &gatewayCache{rdb: rdb} + ctx := context.Background() + + accountID := int64(123) + model := "claude-sonnet-4-20250514" + + // 首次调用应返回 1 + count1, err := cache.IncrModelCallCount(ctx, accountID, model) + require.NoError(t, err) + require.Equal(t, int64(1), count1) + + // 第二次调用应返回 2 + count2, err := cache.IncrModelCallCount(ctx, accountID, model) + require.NoError(t, err) + require.Equal(t, int64(2), count2) + + // 第三次调用应返回 3 + count3, err := cache.IncrModelCallCount(ctx, accountID, model) + require.NoError(t, err) + require.Equal(t, int64(3), count3) +} + +func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentModels() { + t := s.T() + rdb := testRedis(t) + cache := &gatewayCache{rdb: rdb} + ctx := context.Background() + + accountID := int64(456) + model1 := "claude-sonnet-4-20250514" + model2 := "claude-opus-4-5-20251101" + + // 不同模型应该独立计数 + count1, err := cache.IncrModelCallCount(ctx, accountID, model1) + require.NoError(t, err) + require.Equal(t, int64(1), count1) + + count2, err := cache.IncrModelCallCount(ctx, accountID, model2) + require.NoError(t, err) + require.Equal(t, int64(1), count2) + + count1Again, err := cache.IncrModelCallCount(ctx, accountID, model1) + require.NoError(t, err) + require.Equal(t, int64(2), count1Again) +} + +func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentAccounts() { + t := s.T() + rdb := testRedis(t) + cache := &gatewayCache{rdb: rdb} + ctx := context.Background() + + account1 := int64(111) + account2 := int64(222) + model := "gemini-2.5-pro" + + // 不同账号应该独立计数 + count1, err := cache.IncrModelCallCount(ctx, account1, model) + require.NoError(t, err) + require.Equal(t, int64(1), count1) + + count2, err := cache.IncrModelCallCount(ctx, account2, model) + require.NoError(t, err) + require.Equal(t, int64(1), count2) +} + +func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_Empty() { + t := s.T() + rdb := testRedis(t) + cache := &gatewayCache{rdb: rdb} + ctx := context.Background() + + result, err := cache.GetModelLoadBatch(ctx, []int64{}, "any-model") + require.NoError(t, err) + require.NotNil(t, result) + require.Empty(t, result) +} + +func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_NonExistent() { + t := s.T() + rdb := testRedis(t) + cache := &gatewayCache{rdb: rdb} + ctx := context.Background() + + // 查询不存在的账号应返回零值 + result, err := cache.GetModelLoadBatch(ctx, []int64{9999, 9998}, "claude-sonnet-4-20250514") + require.NoError(t, err) + require.Len(t, result, 2) + + require.Equal(t, int64(0), result[9999].CallCount) + require.True(t, result[9999].LastUsedAt.IsZero()) + require.Equal(t, int64(0), result[9998].CallCount) + require.True(t, result[9998].LastUsedAt.IsZero()) +} + +func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_AfterIncrement() { + t := s.T() + rdb := testRedis(t) + cache := &gatewayCache{rdb: rdb} + ctx := context.Background() + + accountID := int64(789) + model := "claude-sonnet-4-20250514" + + // 先增加调用次数 + beforeIncr := time.Now() + _, err := cache.IncrModelCallCount(ctx, accountID, model) + require.NoError(t, err) + _, err = cache.IncrModelCallCount(ctx, accountID, model) + require.NoError(t, err) + _, err = cache.IncrModelCallCount(ctx, accountID, model) + require.NoError(t, err) + afterIncr := time.Now() + + // 获取负载信息 + result, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model) + require.NoError(t, err) + require.Len(t, result, 1) + + loadInfo := result[accountID] + require.NotNil(t, loadInfo) + require.Equal(t, int64(3), loadInfo.CallCount) + require.False(t, loadInfo.LastUsedAt.IsZero()) + // LastUsedAt 应该在 beforeIncr 和 afterIncr 之间 + require.True(t, loadInfo.LastUsedAt.After(beforeIncr.Add(-time.Second)) || loadInfo.LastUsedAt.Equal(beforeIncr)) + require.True(t, loadInfo.LastUsedAt.Before(afterIncr.Add(time.Second)) || loadInfo.LastUsedAt.Equal(afterIncr)) +} + +func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_MultipleAccounts() { + t := s.T() + rdb := testRedis(t) + cache := &gatewayCache{rdb: rdb} + ctx := context.Background() + + model := "claude-opus-4-5-20251101" + account1 := int64(1001) + account2 := int64(1002) + account3 := int64(1003) // 不调用 + + // account1 调用 2 次 + _, err := cache.IncrModelCallCount(ctx, account1, model) + require.NoError(t, err) + _, err = cache.IncrModelCallCount(ctx, account1, model) + require.NoError(t, err) + + // account2 调用 5 次 + for i := 0; i < 5; i++ { + _, err = cache.IncrModelCallCount(ctx, account2, model) + require.NoError(t, err) + } + + // 批量获取 + result, err := cache.GetModelLoadBatch(ctx, []int64{account1, account2, account3}, model) + require.NoError(t, err) + require.Len(t, result, 3) + + require.Equal(t, int64(2), result[account1].CallCount) + require.False(t, result[account1].LastUsedAt.IsZero()) + + require.Equal(t, int64(5), result[account2].CallCount) + require.False(t, result[account2].LastUsedAt.IsZero()) + + require.Equal(t, int64(0), result[account3].CallCount) + require.True(t, result[account3].LastUsedAt.IsZero()) +} + +func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_ModelIsolation() { + t := s.T() + rdb := testRedis(t) + cache := &gatewayCache{rdb: rdb} + ctx := context.Background() + + accountID := int64(2001) + model1 := "claude-sonnet-4-20250514" + model2 := "gemini-2.5-pro" + + // 对 model1 调用 3 次 + for i := 0; i < 3; i++ { + _, err := cache.IncrModelCallCount(ctx, accountID, model1) + require.NoError(t, err) + } + + // 获取 model1 的负载 + result1, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model1) + require.NoError(t, err) + require.Equal(t, int64(3), result1[accountID].CallCount) + + // 获取 model2 的负载(应该为 0) + result2, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model2) + require.NoError(t, err) + require.Equal(t, int64(0), result2[accountID].CallCount) +} + +// ============ 辅助函数测试 ============ + +func (s *GatewayCacheModelLoadSuite) TestModelLoadKey_Format() { + t := s.T() + + key := modelLoadKey(123, "claude-sonnet-4") + require.Equal(t, "ag:model_load:123:claude-sonnet-4", key) +} + +func (s *GatewayCacheModelLoadSuite) TestModelLastUsedKey_Format() { + t := s.T() + + key := modelLastUsedKey(456, "gemini-2.5-pro") + require.Equal(t, "ag:model_last_used:456:gemini-2.5-pro", key) +} diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go index d7f54e85..4f63280d 100644 --- a/backend/internal/repository/geminicli_codeassist_client.go +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -6,6 +6,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/imroc/req/v3" @@ -38,9 +39,20 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo return nil, fmt.Errorf("request failed: %w", err) } if !resp.IsSuccessState() { - body := geminicli.SanitizeBodyForLogs(resp.String()) - fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body) - return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body) + body := resp.String() + sanitizedBody := geminicli.SanitizeBodyForLogs(body) + fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody) + + // Check if this is a SERVICE_DISABLED error and extract activation URL + if googleapi.IsServiceDisabledError(body) { + activationURL := googleapi.ExtractActivationURL(body) + if activationURL != "" { + return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) + } + return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") + } + + return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody) } fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out) return &out, nil @@ -67,9 +79,20 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken return nil, fmt.Errorf("request failed: %w", err) } if !resp.IsSuccessState() { - body := geminicli.SanitizeBodyForLogs(resp.String()) - fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body) - return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body) + body := resp.String() + sanitizedBody := geminicli.SanitizeBodyForLogs(body) + fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody) + + // Check if this is a SERVICE_DISABLED error and extract activation URL + if googleapi.IsServiceDisabledError(body) { + activationURL := googleapi.ExtractActivationURL(body) + if activationURL != "" { + return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) + } + return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") + } + + return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody) } fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out) return &out, nil diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go index 77839626..03f8cc66 100644 --- a/backend/internal/repository/github_release_service.go +++ b/backend/internal/repository/github_release_service.go @@ -98,12 +98,16 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string if err != nil { return err } - defer func() { _ = out.Close() }() // SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong limited := io.LimitReader(resp.Body, maxSize+1) written, err := io.Copy(out, limited) + + // Close file before attempting to remove (required on Windows) + _ = out.Close() + if err != nil { + _ = os.Remove(dest) // Clean up partial file (best-effort) return err } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index a5b0512d..d8cec491 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -50,13 +50,18 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetNillableFallbackGroupID(groupIn.FallbackGroupID). - SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) + SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). + SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). + SetMcpXMLInject(groupIn.MCPXMLInject) // 设置模型路由配置 if groupIn.ModelRouting != nil { builder = builder.SetModelRouting(groupIn.ModelRouting) } + // 设置支持的模型系列(始终设置,空数组表示不限制) + builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes) + created, err := builder.Save(ctx) if err == nil { groupIn.ID = created.ID @@ -87,7 +92,6 @@ func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.G if err != nil { return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil) } - return groupEntityToService(m), nil } @@ -108,7 +112,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice4k(groupIn.ImagePrice4K). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). - SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) + SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). + SetMcpXMLInject(groupIn.MCPXMLInject) // 处理 FallbackGroupID:nil 时清除,否则设置 if groupIn.FallbackGroupID != nil { @@ -116,6 +121,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er } else { builder = builder.ClearFallbackGroupID() } + // 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置 + if groupIn.FallbackGroupIDOnInvalidRequest != nil { + builder = builder.SetFallbackGroupIDOnInvalidRequest(*groupIn.FallbackGroupIDOnInvalidRequest) + } else { + builder = builder.ClearFallbackGroupIDOnInvalidRequest() + } // 处理 ModelRouting:nil 时清除,否则设置 if groupIn.ModelRouting != nil { @@ -124,6 +135,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er builder = builder.ClearModelRouting() } + // 处理 SupportedModelScopes(始终设置,空数组表示不限制) + builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes) + updated, err := builder.Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) diff --git a/backend/internal/repository/ops_repo_metrics.go b/backend/internal/repository/ops_repo_metrics.go index 713e0eb9..f1e57c38 100644 --- a/backend/internal/repository/ops_repo_metrics.go +++ b/backend/internal/repository/ops_repo_metrics.go @@ -43,6 +43,7 @@ INSERT INTO ops_system_metrics ( upstream_529_count, token_consumed, + account_switch_count, qps, tps, @@ -81,14 +82,14 @@ INSERT INTO ops_system_metrics ( $1,$2,$3,$4, $5,$6,$7,$8, $9,$10,$11, - $12,$13,$14, - $15,$16,$17,$18,$19,$20, - $21,$22,$23,$24,$25,$26, - $27,$28,$29,$30, - $31,$32, - $33,$34, - $35,$36,$37, - $38,$39 + $12,$13,$14,$15, + $16,$17,$18,$19,$20,$21, + $22,$23,$24,$25,$26,$27, + $28,$29,$30,$31, + $32,$33, + $34,$35, + $36,$37,$38, + $39,$40 )` _, err := r.db.ExecContext( @@ -109,6 +110,7 @@ INSERT INTO ops_system_metrics ( input.Upstream529Count, input.TokenConsumed, + input.AccountSwitchCount, opsNullFloat64(input.QPS), opsNullFloat64(input.TPS), @@ -177,7 +179,8 @@ SELECT db_conn_waiting, goroutine_count, - concurrency_queue_depth + concurrency_queue_depth, + account_switch_count FROM ops_system_metrics WHERE window_minutes = $1 AND platform IS NULL @@ -199,6 +202,7 @@ LIMIT 1` var dbWaiting sql.NullInt64 var goroutines sql.NullInt64 var queueDepth sql.NullInt64 + var accountSwitchCount sql.NullInt64 if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan( &out.ID, @@ -217,6 +221,7 @@ LIMIT 1` &dbWaiting, &goroutines, &queueDepth, + &accountSwitchCount, ); err != nil { return nil, err } @@ -273,6 +278,10 @@ LIMIT 1` v := int(queueDepth.Int64) out.ConcurrencyQueueDepth = &v } + if accountSwitchCount.Valid { + v := accountSwitchCount.Int64 + out.AccountSwitchCount = &v + } return &out, nil } diff --git a/backend/internal/repository/ops_repo_trends.go b/backend/internal/repository/ops_repo_trends.go index 022d1187..14394ed8 100644 --- a/backend/internal/repository/ops_repo_trends.go +++ b/backend/internal/repository/ops_repo_trends.go @@ -56,18 +56,44 @@ error_buckets AS ( AND COALESCE(status_code, 0) >= 400 GROUP BY 1 ), +switch_buckets AS ( + SELECT ` + errorBucketExpr + ` AS bucket, + COALESCE(SUM(CASE + WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1 + ELSE 0 + END), 0) AS switch_count + FROM ops_error_logs + CROSS JOIN LATERAL jsonb_array_elements( + COALESCE(NULLIF(upstream_errors, 'null'::jsonb), '[]'::jsonb) + ) AS ev + ` + errorWhere + ` + AND upstream_errors IS NOT NULL + GROUP BY 1 +), combined AS ( - SELECT COALESCE(u.bucket, e.bucket) AS bucket, - COALESCE(u.success_count, 0) AS success_count, - COALESCE(e.error_count, 0) AS error_count, - COALESCE(u.token_consumed, 0) AS token_consumed - FROM usage_buckets u - FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket + SELECT + bucket, + SUM(success_count) AS success_count, + SUM(error_count) AS error_count, + SUM(token_consumed) AS token_consumed, + SUM(switch_count) AS switch_count + FROM ( + SELECT bucket, success_count, 0 AS error_count, token_consumed, 0 AS switch_count + FROM usage_buckets + UNION ALL + SELECT bucket, 0, error_count, 0, 0 + FROM error_buckets + UNION ALL + SELECT bucket, 0, 0, 0, switch_count + FROM switch_buckets + ) t + GROUP BY bucket ) SELECT bucket, (success_count + error_count) AS request_count, - token_consumed + token_consumed, + switch_count FROM combined ORDER BY bucket ASC` @@ -84,13 +110,18 @@ ORDER BY bucket ASC` var bucket time.Time var requests int64 var tokens sql.NullInt64 - if err := rows.Scan(&bucket, &requests, &tokens); err != nil { + var switches sql.NullInt64 + if err := rows.Scan(&bucket, &requests, &tokens, &switches); err != nil { return nil, err } tokenConsumed := int64(0) if tokens.Valid { tokenConsumed = tokens.Int64 } + switchCount := int64(0) + if switches.Valid { + switchCount = switches.Int64 + } denom := float64(bucketSeconds) if denom <= 0 { @@ -103,6 +134,7 @@ ORDER BY bucket ASC` BucketStart: bucket.UTC(), RequestCount: requests, TokenConsumed: tokenConsumed, + SwitchCount: switchCount, QPS: qps, TPS: tps, }) @@ -385,6 +417,7 @@ func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points [] BucketStart: cursor, RequestCount: 0, TokenConsumed: 0, + SwitchCount: 0, QPS: 0, TPS: 0, }) diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index fb6f405e..513e929c 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -28,7 +28,6 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.") } return &proxyProbeService{ - ipInfoURL: defaultIPInfoURL, insecureSkipVerify: insecure, allowPrivateHosts: allowPrivate, validateResolvedIP: validateResolvedIP, @@ -36,12 +35,20 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { } const ( - defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN" defaultProxyProbeTimeout = 30 * time.Second ) +// probeURLs 按优先级排列的探测 URL 列表 +// 某些 AI API 专用代理只允许访问特定域名,因此需要多个备选 +var probeURLs = []struct { + url string + parser string // "ip-api" or "httpbin" +}{ + {"http://ip-api.com/json/?lang=zh-CN", "ip-api"}, + {"http://httpbin.org/ip", "httpbin"}, +} + type proxyProbeService struct { - ipInfoURL string insecureSkipVerify bool allowPrivateHosts bool validateResolvedIP bool @@ -60,8 +67,21 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s return nil, 0, fmt.Errorf("failed to create proxy client: %w", err) } + var lastErr error + for _, probe := range probeURLs { + exitInfo, latencyMs, err := s.probeWithURL(ctx, client, probe.url, probe.parser) + if err == nil { + return exitInfo, latencyMs, nil + } + lastErr = err + } + + return nil, 0, fmt.Errorf("all probe URLs failed, last error: %w", lastErr) +} + +func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Client, url string, parser string) (*service.ProxyExitInfo, int64, error) { startTime := time.Now() - req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, 0, fmt.Errorf("failed to create request: %w", err) } @@ -78,6 +98,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode) } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) + } + + switch parser { + case "ip-api": + return s.parseIPAPI(body, latencyMs) + case "httpbin": + return s.parseHTTPBin(body, latencyMs) + default: + return nil, latencyMs, fmt.Errorf("unknown parser: %s", parser) + } +} + +func (s *proxyProbeService) parseIPAPI(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) { var ipInfo struct { Status string `json:"status"` Message string `json:"message"` @@ -89,13 +125,12 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s CountryCode string `json:"countryCode"` } - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) - } - if err := json.Unmarshal(body, &ipInfo); err != nil { - return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err) + preview := string(body) + if len(preview) > 200 { + preview = preview[:200] + "..." + } + return nil, latencyMs, fmt.Errorf("failed to parse response: %w (body: %s)", err, preview) } if strings.ToLower(ipInfo.Status) != "success" { if ipInfo.Message == "" { @@ -116,3 +151,19 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s CountryCode: ipInfo.CountryCode, }, latencyMs, nil } + +func (s *proxyProbeService) parseHTTPBin(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) { + // httpbin.org/ip 返回格式: {"origin": "1.2.3.4"} + var result struct { + Origin string `json:"origin"` + } + if err := json.Unmarshal(body, &result); err != nil { + return nil, latencyMs, fmt.Errorf("failed to parse httpbin response: %w", err) + } + if result.Origin == "" { + return nil, latencyMs, fmt.Errorf("httpbin: no IP found in response") + } + return &service.ProxyExitInfo{ + IP: result.Origin, + }, latencyMs, nil +} diff --git a/backend/internal/repository/proxy_probe_service_test.go b/backend/internal/repository/proxy_probe_service_test.go index f1cd5721..7450653b 100644 --- a/backend/internal/repository/proxy_probe_service_test.go +++ b/backend/internal/repository/proxy_probe_service_test.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/require" @@ -21,7 +22,6 @@ type ProxyProbeServiceSuite struct { func (s *ProxyProbeServiceSuite) SetupTest() { s.ctx = context.Background() s.prober = &proxyProbeService{ - ipInfoURL: "http://ip-api.test/json/?lang=zh-CN", allowPrivateHosts: true, } } @@ -49,12 +49,16 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() { require.ErrorContains(s.T(), err, "failed to create proxy client") } -func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() { - seen := make(chan string, 1) +func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_IPAPI() { s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - seen <- r.RequestURI - w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`) + // 检查是否是 ip-api 请求 + if strings.Contains(r.RequestURI, "ip-api.com") { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`) + return + } + // 其他请求返回错误 + w.WriteHeader(http.StatusServiceUnavailable) })) info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) @@ -65,45 +69,59 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() { require.Equal(s.T(), "r", info.Region) require.Equal(s.T(), "cc", info.Country) require.Equal(s.T(), "CC", info.CountryCode) - - // Verify proxy received the request - select { - case uri := <-seen: - require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy") - default: - require.Fail(s.T(), "expected proxy to receive request") - } } -func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() { +func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_HTTPBinFallback() { + s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // ip-api 失败 + if strings.Contains(r.RequestURI, "ip-api.com") { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + // httpbin 成功 + if strings.Contains(r.RequestURI, "httpbin.org") { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"origin": "5.6.7.8"}`) + return + } + w.WriteHeader(http.StatusServiceUnavailable) + })) + + info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) + require.NoError(s.T(), err, "ProbeProxy should fallback to httpbin") + require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency") + require.Equal(s.T(), "5.6.7.8", info.IP) +} + +func (s *ProxyProbeServiceSuite) TestProbeProxy_AllFailed() { s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) })) _, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) require.Error(s.T(), err) - require.ErrorContains(s.T(), err, "status: 503") + require.ErrorContains(s.T(), err, "all probe URLs failed") } func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() { s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, "not-json") + if strings.Contains(r.RequestURI, "ip-api.com") { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, "not-json") + return + } + // httpbin 也返回无效响应 + if strings.Contains(r.RequestURI, "httpbin.org") { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, "not-json") + return + } + w.WriteHeader(http.StatusServiceUnavailable) })) _, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) require.Error(s.T(), err) - require.ErrorContains(s.T(), err, "failed to parse response") -} - -func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() { - s.prober.ipInfoURL = "://invalid-url" - s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - _, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) - require.Error(s.T(), err, "expected error for invalid ipInfoURL") + require.ErrorContains(s.T(), err, "all probe URLs failed") } func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() { @@ -114,6 +132,40 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() { require.Error(s.T(), err, "expected error when proxy server is closed") } +func (s *ProxyProbeServiceSuite) TestParseIPAPI_Success() { + body := []byte(`{"status":"success","query":"1.2.3.4","city":"Beijing","regionName":"Beijing","country":"China","countryCode":"CN"}`) + info, latencyMs, err := s.prober.parseIPAPI(body, 100) + require.NoError(s.T(), err) + require.Equal(s.T(), int64(100), latencyMs) + require.Equal(s.T(), "1.2.3.4", info.IP) + require.Equal(s.T(), "Beijing", info.City) + require.Equal(s.T(), "Beijing", info.Region) + require.Equal(s.T(), "China", info.Country) + require.Equal(s.T(), "CN", info.CountryCode) +} + +func (s *ProxyProbeServiceSuite) TestParseIPAPI_Failure() { + body := []byte(`{"status":"fail","message":"rate limited"}`) + _, _, err := s.prober.parseIPAPI(body, 100) + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "rate limited") +} + +func (s *ProxyProbeServiceSuite) TestParseHTTPBin_Success() { + body := []byte(`{"origin": "9.8.7.6"}`) + info, latencyMs, err := s.prober.parseHTTPBin(body, 50) + require.NoError(s.T(), err) + require.Equal(s.T(), int64(50), latencyMs) + require.Equal(s.T(), "9.8.7.6", info.IP) +} + +func (s *ProxyProbeServiceSuite) TestParseHTTPBin_NoIP() { + body := []byte(`{"origin": ""}`) + _, _, err := s.prober.parseHTTPBin(body, 50) + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "no IP found") +} + func TestProxyProbeServiceSuite(t *testing.T) { suite.Run(t, new(ProxyProbeServiceSuite)) } diff --git a/backend/internal/repository/proxy_repo.go b/backend/internal/repository/proxy_repo.go index 36965c05..07c2a204 100644 --- a/backend/internal/repository/proxy_repo.go +++ b/backend/internal/repository/proxy_repo.go @@ -60,6 +60,25 @@ func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy return proxyEntityToService(m), nil } +func (r *proxyRepository) ListByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) { + if len(ids) == 0 { + return []service.Proxy{}, nil + } + + proxies, err := r.client.Proxy.Query(). + Where(proxy.IDIn(ids...)). + All(ctx) + if err != nil { + return nil, err + } + + out := make([]service.Proxy, 0, len(proxies)) + for i := range proxies { + out = append(out, *proxyEntityToService(proxies[i])) + } + return out, nil +} + func (r *proxyRepository) Update(ctx context.Context, proxyIn *service.Proxy) error { builder := r.client.Proxy.UpdateOneID(proxyIn.ID). SetName(proxyIn.Name). diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go index ee8a01b5..a3a048c3 100644 --- a/backend/internal/repository/redeem_code_repo.go +++ b/backend/internal/repository/redeem_code_repo.go @@ -202,6 +202,57 @@ func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, lim return redeemCodeEntitiesToService(codes), nil } +// ListByUserPaginated returns paginated balance/concurrency history for a user. +// Supports optional type filter (e.g. "balance", "admin_balance", "concurrency", "admin_concurrency", "subscription"). +func (r *redeemCodeRepository) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) { + q := r.client.RedeemCode.Query(). + Where(redeemcode.UsedByEQ(userID)) + + // Optional type filter + if codeType != "" { + q = q.Where(redeemcode.TypeEQ(codeType)) + } + + total, err := q.Count(ctx) + if err != nil { + return nil, nil, err + } + + codes, err := q. + WithGroup(). + Offset(params.Offset()). + Limit(params.Limit()). + Order(dbent.Desc(redeemcode.FieldUsedAt)). + All(ctx) + if err != nil { + return nil, nil, err + } + + return redeemCodeEntitiesToService(codes), paginationResultFromTotal(int64(total), params), nil +} + +// SumPositiveBalanceByUser returns total recharged amount (sum of value > 0 where type is balance/admin_balance). +func (r *redeemCodeRepository) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) { + var result []struct { + Sum float64 `json:"sum"` + } + err := r.client.RedeemCode.Query(). + Where( + redeemcode.UsedByEQ(userID), + redeemcode.ValueGT(0), + redeemcode.TypeIn("balance", "admin_balance"), + ). + Aggregate(dbent.As(dbent.Sum(redeemcode.FieldValue), "sum")). + Scan(ctx, &result) + if err != nil { + return 0, err + } + if len(result) == 0 { + return 0, nil + } + return result[0].Sum, nil +} + func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode { if m == nil { return nil diff --git a/backend/internal/repository/refresh_token_cache.go b/backend/internal/repository/refresh_token_cache.go new file mode 100644 index 00000000..b01bd476 --- /dev/null +++ b/backend/internal/repository/refresh_token_cache.go @@ -0,0 +1,158 @@ +package repository + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const ( + refreshTokenKeyPrefix = "refresh_token:" + userRefreshTokensPrefix = "user_refresh_tokens:" + tokenFamilyPrefix = "token_family:" +) + +// refreshTokenKey generates the Redis key for a refresh token. +func refreshTokenKey(tokenHash string) string { + return refreshTokenKeyPrefix + tokenHash +} + +// userRefreshTokensKey generates the Redis key for user's token set. +func userRefreshTokensKey(userID int64) string { + return fmt.Sprintf("%s%d", userRefreshTokensPrefix, userID) +} + +// tokenFamilyKey generates the Redis key for token family set. +func tokenFamilyKey(familyID string) string { + return tokenFamilyPrefix + familyID +} + +type refreshTokenCache struct { + rdb *redis.Client +} + +// NewRefreshTokenCache creates a new RefreshTokenCache implementation. +func NewRefreshTokenCache(rdb *redis.Client) service.RefreshTokenCache { + return &refreshTokenCache{rdb: rdb} +} + +func (c *refreshTokenCache) StoreRefreshToken(ctx context.Context, tokenHash string, data *service.RefreshTokenData, ttl time.Duration) error { + key := refreshTokenKey(tokenHash) + val, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("marshal refresh token data: %w", err) + } + return c.rdb.Set(ctx, key, val, ttl).Err() +} + +func (c *refreshTokenCache) GetRefreshToken(ctx context.Context, tokenHash string) (*service.RefreshTokenData, error) { + key := refreshTokenKey(tokenHash) + val, err := c.rdb.Get(ctx, key).Result() + if err != nil { + if err == redis.Nil { + return nil, service.ErrRefreshTokenNotFound + } + return nil, err + } + var data service.RefreshTokenData + if err := json.Unmarshal([]byte(val), &data); err != nil { + return nil, fmt.Errorf("unmarshal refresh token data: %w", err) + } + return &data, nil +} + +func (c *refreshTokenCache) DeleteRefreshToken(ctx context.Context, tokenHash string) error { + key := refreshTokenKey(tokenHash) + return c.rdb.Del(ctx, key).Err() +} + +func (c *refreshTokenCache) DeleteUserRefreshTokens(ctx context.Context, userID int64) error { + // Get all token hashes for this user + tokenHashes, err := c.GetUserTokenHashes(ctx, userID) + if err != nil && err != redis.Nil { + return fmt.Errorf("get user token hashes: %w", err) + } + + if len(tokenHashes) == 0 { + return nil + } + + // Build keys to delete + keys := make([]string, 0, len(tokenHashes)+1) + for _, hash := range tokenHashes { + keys = append(keys, refreshTokenKey(hash)) + } + keys = append(keys, userRefreshTokensKey(userID)) + + // Delete all keys in a pipeline + pipe := c.rdb.Pipeline() + for _, key := range keys { + pipe.Del(ctx, key) + } + _, err = pipe.Exec(ctx) + return err +} + +func (c *refreshTokenCache) DeleteTokenFamily(ctx context.Context, familyID string) error { + // Get all token hashes in this family + tokenHashes, err := c.GetFamilyTokenHashes(ctx, familyID) + if err != nil && err != redis.Nil { + return fmt.Errorf("get family token hashes: %w", err) + } + + if len(tokenHashes) == 0 { + return nil + } + + // Build keys to delete + keys := make([]string, 0, len(tokenHashes)+1) + for _, hash := range tokenHashes { + keys = append(keys, refreshTokenKey(hash)) + } + keys = append(keys, tokenFamilyKey(familyID)) + + // Delete all keys in a pipeline + pipe := c.rdb.Pipeline() + for _, key := range keys { + pipe.Del(ctx, key) + } + _, err = pipe.Exec(ctx) + return err +} + +func (c *refreshTokenCache) AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error { + key := userRefreshTokensKey(userID) + pipe := c.rdb.Pipeline() + pipe.SAdd(ctx, key, tokenHash) + pipe.Expire(ctx, key, ttl) + _, err := pipe.Exec(ctx) + return err +} + +func (c *refreshTokenCache) AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error { + key := tokenFamilyKey(familyID) + pipe := c.rdb.Pipeline() + pipe.SAdd(ctx, key, tokenHash) + pipe.Expire(ctx, key, ttl) + _, err := pipe.Exec(ctx) + return err +} + +func (c *refreshTokenCache) GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error) { + key := userRefreshTokensKey(userID) + return c.rdb.SMembers(ctx, key).Result() +} + +func (c *refreshTokenCache) GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error) { + key := tokenFamilyKey(familyID) + return c.rdb.SMembers(ctx, key).Result() +} + +func (c *refreshTokenCache) IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error) { + key := tokenFamilyKey(familyID) + return c.rdb.SIsMember(ctx, key, tokenHash).Result() +} diff --git a/backend/internal/repository/session_limit_cache.go b/backend/internal/repository/session_limit_cache.go index 3dc89f87..3d57b152 100644 --- a/backend/internal/repository/session_limit_cache.go +++ b/backend/internal/repository/session_limit_cache.go @@ -3,6 +3,7 @@ package repository import ( "context" "fmt" + "log" "strconv" "time" @@ -153,6 +154,21 @@ func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) serv if defaultIdleTimeoutMinutes <= 0 { defaultIdleTimeoutMinutes = 5 // 默认 5 分钟 } + + // 预加载 Lua 脚本到 Redis,避免 Pipeline 中出现 NOSCRIPT 错误 + ctx := context.Background() + scripts := []*redis.Script{ + registerSessionScript, + refreshSessionScript, + getActiveSessionCountScript, + isSessionActiveScript, + } + for _, script := range scripts { + if err := script.Load(ctx, rdb).Err(); err != nil { + log.Printf("[SessionLimitCache] Failed to preload Lua script: %v", err) + } + } + return &sessionLimitCache{ rdb: rdb, defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute, diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index dc8f1460..2db1764f 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -1125,6 +1125,107 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i return stats, nil } +// getPerformanceStatsByAPIKey 获取指定 API Key 的 RPM 和 TPM(近5分钟平均值) +func (r *usageLogRepository) getPerformanceStatsByAPIKey(ctx context.Context, apiKeyID int64) (rpm, tpm int64, err error) { + fiveMinutesAgo := time.Now().Add(-5 * time.Minute) + query := ` + SELECT + COUNT(*) as request_count, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as token_count + FROM usage_logs + WHERE created_at >= $1 AND api_key_id = $2` + args := []any{fiveMinutesAgo, apiKeyID} + + var requestCount int64 + var tokenCount int64 + if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil { + return 0, 0, err + } + return requestCount / 5, tokenCount / 5, nil +} + +// GetAPIKeyDashboardStats 获取指定 API Key 的仪表盘统计(按 api_key_id 过滤) +func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*UserDashboardStats, error) { + stats := &UserDashboardStats{} + today := timezone.Today() + + // API Key 维度不需要统计 key 数量,设为 1 + stats.TotalAPIKeys = 1 + stats.ActiveAPIKeys = 1 + + // 累计 Token 统计 + totalStatsQuery := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(duration_ms), 0) as avg_duration_ms + FROM usage_logs + WHERE api_key_id = $1 + ` + if err := scanSingleRow( + ctx, + r.sql, + totalStatsQuery, + []any{apiKeyID}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheCreationTokens, + &stats.TotalCacheReadTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { + return nil, err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens + + // 今日 Token 统计 + todayStatsQuery := ` + SELECT + COUNT(*) as today_requests, + COALESCE(SUM(input_tokens), 0) as today_input_tokens, + COALESCE(SUM(output_tokens), 0) as today_output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens, + COALESCE(SUM(total_cost), 0) as today_cost, + COALESCE(SUM(actual_cost), 0) as today_actual_cost + FROM usage_logs + WHERE api_key_id = $1 AND created_at >= $2 + ` + if err := scanSingleRow( + ctx, + r.sql, + todayStatsQuery, + []any{apiKeyID, today}, + &stats.TodayRequests, + &stats.TodayInputTokens, + &stats.TodayOutputTokens, + &stats.TodayCacheCreationTokens, + &stats.TodayCacheReadTokens, + &stats.TodayCost, + &stats.TodayActualCost, + ); err != nil { + return nil, err + } + stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens + + // 性能指标:RPM 和 TPM(最近5分钟,按 API Key 过滤) + rpm, tpm, err := r.getPerformanceStatsByAPIKey(ctx, apiKeyID) + if err != nil { + return nil, err + } + stats.Rpm = rpm + stats.Tpm = tpm + + return stats, nil +} + // GetUserUsageTrendByUserID 获取指定用户的使用趋势 func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { dateFormat := "YYYY-MM-DD" diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go new file mode 100644 index 00000000..eb65403b --- /dev/null +++ b/backend/internal/repository/user_group_rate_repo.go @@ -0,0 +1,113 @@ +package repository + +import ( + "context" + "database/sql" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type userGroupRateRepository struct { + sql sqlExecutor +} + +// NewUserGroupRateRepository 创建用户专属分组倍率仓储 +func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository { + return &userGroupRateRepository{sql: sqlDB} +} + +// GetByUserID 获取用户的所有专属分组倍率 +func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) { + query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1` + rows, err := r.sql.QueryContext(ctx, query, userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + result := make(map[int64]float64) + for rows.Next() { + var groupID int64 + var rate float64 + if err := rows.Scan(&groupID, &rate); err != nil { + return nil, err + } + result[groupID] = rate + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + +// GetByUserAndGroup 获取用户在特定分组的专属倍率 +func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { + query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` + var rate float64 + err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + return &rate, nil +} + +// SyncUserGroupRates 同步用户的分组专属倍率 +func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error { + if len(rates) == 0 { + // 如果传入空 map,删除该用户的所有专属倍率 + _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) + return err + } + + // 分离需要删除和需要 upsert 的记录 + var toDelete []int64 + toUpsert := make(map[int64]float64) + for groupID, rate := range rates { + if rate == nil { + toDelete = append(toDelete, groupID) + } else { + toUpsert[groupID] = *rate + } + } + + // 删除指定的记录 + for _, groupID := range toDelete { + _, err := r.sql.ExecContext(ctx, + `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`, + userID, groupID) + if err != nil { + return err + } + } + + // Upsert 记录 + now := time.Now() + for groupID, rate := range toUpsert { + _, err := r.sql.ExecContext(ctx, ` + INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) + VALUES ($1, $2, $3, $4, $4) + ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4 + `, userID, groupID, rate, now) + if err != nil { + return err + } + } + + return nil +} + +// DeleteByGroupID 删除指定分组的所有用户专属倍率 +func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error { + _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID) + return err +} + +// DeleteByUserID 删除指定用户的所有专属倍率 +func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error { + _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) + return err +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index e3394361..3aed9d9c 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -66,6 +66,8 @@ var ProviderSet = wire.NewSet( NewUserSubscriptionRepository, NewUserAttributeDefinitionRepository, NewUserAttributeValueRepository, + NewUserGroupRateRepository, + NewErrorPassthroughRepository, // Cache implementations NewGatewayCache, @@ -85,6 +87,8 @@ var ProviderSet = wire.NewSet( NewSchedulerOutboxRepository, NewProxyLatencyCache, NewTotpCache, + NewRefreshTokenCache, + NewErrorPassthroughCache, // Encryptors NewAESEncryptor, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 49a7e0e4..efef0452 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -83,6 +83,9 @@ func TestAPIContracts(t *testing.T) { "status": "active", "ip_whitelist": null, "ip_blacklist": null, + "quota": 0, + "quota_used": 0, + "expires_at": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -119,6 +122,9 @@ func TestAPIContracts(t *testing.T) { "status": "active", "ip_whitelist": null, "ip_blacklist": null, + "quota": 0, + "quota_used": 0, + "expires_at": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -180,6 +186,7 @@ func TestAPIContracts(t *testing.T) { "image_price_4k": null, "claude_code_only": false, "fallback_group_id": null, + "fallback_group_id_on_invalid_request": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -586,7 +593,7 @@ func newContractDeps(t *testing.T) *contractDeps { } userService := service.NewUserService(userRepo, nil) - apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) @@ -600,8 +607,8 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) - authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil) + authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) @@ -1052,6 +1059,10 @@ func (stubProxyRepo) GetByID(ctx context.Context, id int64) (*service.Proxy, err return nil, service.ErrProxyNotFound } +func (stubProxyRepo) ListByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) { + return nil, errors.New("not implemented") +} + func (stubProxyRepo) Update(ctx context.Context, proxy *service.Proxy) error { return errors.New("not implemented") } @@ -1150,6 +1161,14 @@ func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit return append([]service.RedeemCode(nil), codes...), nil } +func (stubRedeemCodeRepo) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubRedeemCodeRepo) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) { + return 0, errors.New("not implemented") +} + type stubUserSubscriptionRepo struct { byUser map[int64][]service.UserSubscription activeByUser map[int64][]service.UserSubscription @@ -1434,6 +1453,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ( return nil, errors.New("not implemented") } +func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + return 0, errors.New("not implemented") +} + type stubUsageLogRepo struct { userLogs map[int64][]service.UsageLog } @@ -1591,6 +1614,10 @@ func (r *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int return nil, errors.New("not implemented") } +func (r *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) { + return nil, errors.New("not implemented") +} + func (r *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index 52d5c926..d2d8ed40 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -14,6 +14,8 @@ import ( "github.com/gin-gonic/gin" "github.com/google/wire" "github.com/redis/go-redis/v9" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) // ProviderSet 提供服务器层的依赖 @@ -56,9 +58,39 @@ func ProvideRouter( // ProvideHTTPServer 提供 HTTP 服务器 func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { + httpHandler := http.Handler(router) + + globalMaxSize := cfg.Server.MaxRequestBodySize + if globalMaxSize <= 0 { + globalMaxSize = cfg.Gateway.MaxBodySize + } + if globalMaxSize > 0 { + httpHandler = http.MaxBytesHandler(httpHandler, globalMaxSize) + log.Printf("Global max request body size: %d bytes (%.2f MB)", globalMaxSize, float64(globalMaxSize)/(1<<20)) + } + + // 根据配置决定是否启用 H2C + if cfg.Server.H2C.Enabled { + h2cConfig := cfg.Server.H2C + httpHandler = h2c.NewHandler(router, &http2.Server{ + MaxConcurrentStreams: h2cConfig.MaxConcurrentStreams, + IdleTimeout: time.Duration(h2cConfig.IdleTimeout) * time.Second, + MaxReadFrameSize: uint32(h2cConfig.MaxReadFrameSize), + MaxUploadBufferPerConnection: int32(h2cConfig.MaxUploadBufferPerConnection), + MaxUploadBufferPerStream: int32(h2cConfig.MaxUploadBufferPerStream), + }) + log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d", + h2cConfig.MaxConcurrentStreams, + h2cConfig.IdleTimeout, + h2cConfig.MaxReadFrameSize, + h2cConfig.MaxUploadBufferPerConnection, + h2cConfig.MaxUploadBufferPerStream, + ) + } + return &http.Server{ Addr: cfg.Server.Address(), - Handler: router, + Handler: httpHandler, // ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击 ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second, // IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源 diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index dff6ba95..2f739357 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -70,7 +70,27 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti // 检查API key是否激活 if !apiKey.IsActive() { - AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") + // Provide more specific error message based on status + switch apiKey.Status { + case service.StatusAPIKeyQuotaExhausted: + AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") + case service.StatusAPIKeyExpired: + AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") + default: + AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") + } + return + } + + // 检查API Key是否过期(即使状态是active,也要检查时间) + if apiKey.IsExpired() { + AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") + return + } + + // 检查API Key配额是否耗尽 + if apiKey.IsQuotaExhausted() { + AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") return } diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index 1a0b0dd5..38fbe38b 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -26,7 +26,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.") return } - apiKeyString := extractAPIKeyFromRequest(c) + apiKeyString := extractAPIKeyForGoogle(c) if apiKeyString == "" { abortWithGoogleError(c, 401, "API key is required") return @@ -108,25 +108,38 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs } } -func extractAPIKeyFromRequest(c *gin.Context) string { - authHeader := c.GetHeader("Authorization") - if authHeader != "" { - parts := strings.SplitN(authHeader, " ", 2) - if len(parts) == 2 && parts[0] == "Bearer" && strings.TrimSpace(parts[1]) != "" { - return strings.TrimSpace(parts[1]) +// extractAPIKeyForGoogle extracts API key for Google/Gemini endpoints. +// Priority: x-goog-api-key > Authorization: Bearer > x-api-key > query key +// This allows OpenClaw and other clients using Bearer auth to work with Gemini endpoints. +func extractAPIKeyForGoogle(c *gin.Context) string { + // 1) preferred: Gemini native header + if k := strings.TrimSpace(c.GetHeader("x-goog-api-key")); k != "" { + return k + } + + // 2) fallback: Authorization: Bearer + auth := strings.TrimSpace(c.GetHeader("Authorization")) + if auth != "" { + parts := strings.SplitN(auth, " ", 2) + if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") { + if k := strings.TrimSpace(parts[1]); k != "" { + return k + } } } - if v := strings.TrimSpace(c.GetHeader("x-api-key")); v != "" { - return v - } - if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" { - return v + + // 3) x-api-key header (backward compatibility) + if k := strings.TrimSpace(c.GetHeader("x-api-key")); k != "" { + return k } + + // 4) query parameter key (for specific paths) if allowGoogleQueryKey(c.Request.URL.Path) { if v := strings.TrimSpace(c.Query("key")); v != "" { return v } } + return "" } diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 6f09469b..38b93cb2 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -75,6 +75,9 @@ func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]s func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { return nil, errors.New("not implemented") } +func (f fakeAPIKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + return 0, errors.New("not implemented") +} type googleErrorResponse struct { Error struct { @@ -90,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService nil, // userRepo (unused in GetByKey) nil, // groupRepo nil, // userSubRepo + nil, // userGroupRateRepo nil, // cache &config.Config{}, ) @@ -184,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) { nil, nil, nil, + nil, &config.Config{RunMode: config.RunModeSimple}, ) diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 920ff93f..9d514818 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { cfg := &config.Config{RunMode: config.RunModeSimple} - apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) @@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { cfg := &config.Config{RunMode: config.RunModeStandard} - apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) now := time.Now() sub := &service.UserSubscription{ @@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) { } cfg := &config.Config{RunMode: config.RunModeSimple} - apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) router.GET("/t", func(c *gin.Context) { @@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { } cfg := &config.Config{RunMode: config.RunModeSimple} - apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) @@ -319,6 +319,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ( return nil, errors.New("not implemented") } +func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + return 0, errors.New("not implemented") +} + type stubUserSubscriptionRepo struct { getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) updateStatus func(ctx context.Context, subscriptionID int64, status string) error diff --git a/backend/internal/server/middleware/logger.go b/backend/internal/server/middleware/logger.go index a9beeb40..842efda9 100644 --- a/backend/internal/server/middleware/logger.go +++ b/backend/internal/server/middleware/logger.go @@ -34,12 +34,16 @@ func Logger() gin.HandlerFunc { // 客户端IP clientIP := c.ClientIP() - // 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径 - log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s", + // 协议版本 + protocol := c.Request.Proto + + // 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径 + log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s", endTime.Format("2006/01/02 - 15:04:05"), statusCode, latency, clientIP, + protocol, method, path, ) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 3e0033e7..14815262 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -67,6 +67,9 @@ func RegisterAdminRoutes( // 用户属性管理 registerUserAttributeRoutes(admin, h) + + // 错误透传规则管理 + registerErrorPassthroughRoutes(admin, h) } } @@ -75,6 +78,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { // Realtime ops signals ops.GET("/concurrency", h.Admin.Ops.GetConcurrencyStats) + ops.GET("/user-concurrency", h.Admin.Ops.GetUserConcurrencyStats) ops.GET("/account-availability", h.Admin.Ops.GetAccountAvailability) ops.GET("/realtime-traffic", h.Admin.Ops.GetRealtimeTrafficSummary) @@ -175,6 +179,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users.POST("/:id/balance", h.Admin.User.UpdateBalance) users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys) users.GET("/:id/usage", h.Admin.User.GetUserUsage) + users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory) // User attribute values users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes) @@ -218,10 +223,15 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable) accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels) accounts.POST("/batch", h.Admin.Account.BatchCreate) + accounts.GET("/data", h.Admin.Account.ExportData) + accounts.POST("/data", h.Admin.Account.ImportData) accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials) accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier) accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate) + // Antigravity 默认模型映射 + accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping) + // Claude OAuth routes accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL) accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL) @@ -277,6 +287,8 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { proxies.GET("", h.Admin.Proxy.List) proxies.GET("/all", h.Admin.Proxy.GetAll) + proxies.GET("/data", h.Admin.Proxy.ExportData) + proxies.POST("/data", h.Admin.Proxy.ImportData) proxies.GET("/:id", h.Admin.Proxy.GetByID) proxies.POST("", h.Admin.Proxy.Create) proxies.PUT("/:id", h.Admin.Proxy.Update) @@ -386,3 +398,14 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition) } } + +func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + rules := admin.Group("/error-passthrough-rules") + { + rules.GET("", h.Admin.ErrorPassthrough.List) + rules.GET("/:id", h.Admin.ErrorPassthrough.GetByID) + rules.POST("", h.Admin.ErrorPassthrough.Create) + rules.PUT("/:id", h.Admin.ErrorPassthrough.Update) + rules.DELETE("/:id", h.Admin.ErrorPassthrough.Delete) + } +} diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 24f6d549..26d79605 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -28,6 +28,12 @@ func RegisterAuthRoutes( auth.POST("/login", h.Auth.Login) auth.POST("/login/2fa", h.Auth.Login2FA) auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + // Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close) + auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.RefreshToken) + // 登出接口(公开,允许未认证用户调用以撤销Refresh Token) + auth.POST("/logout", h.Auth.Logout) // 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close) auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, @@ -59,5 +65,7 @@ func RegisterAuthRoutes( authenticated.Use(gin.HandlerFunc(jwtAuth)) { authenticated.GET("/auth/me", h.Auth.GetCurrentUser) + // 撤销所有会话(需要认证) + authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions) } } diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index 5581e1e1..d0ed2489 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -49,6 +49,7 @@ func RegisterUserRoutes( groups := authenticated.Group("/groups") { groups.GET("/available", h.APIKey.GetAvailableGroups) + groups.GET("/rates", h.APIKey.GetUserGroupRates) } // 使用记录 diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 7b958838..a6ae8a68 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -3,9 +3,12 @@ package service import ( "encoding/json" + "sort" "strconv" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/domain" ) type Account struct { @@ -347,10 +350,18 @@ func parseTempUnschedInt(value any) int { func (a *Account) GetModelMapping() map[string]string { if a.Credentials == nil { + // Antigravity 平台使用默认映射 + if a.Platform == domain.PlatformAntigravity { + return domain.DefaultAntigravityModelMapping + } return nil } raw, ok := a.Credentials["model_mapping"] if !ok || raw == nil { + // Antigravity 平台使用默认映射 + if a.Platform == domain.PlatformAntigravity { + return domain.DefaultAntigravityModelMapping + } return nil } if m, ok := raw.(map[string]any); ok { @@ -364,27 +375,46 @@ func (a *Account) GetModelMapping() map[string]string { return result } } + // Antigravity 平台使用默认映射 + if a.Platform == domain.PlatformAntigravity { + return domain.DefaultAntigravityModelMapping + } return nil } +// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符) +// 如果未配置 mapping,返回 true(允许所有模型) func (a *Account) IsModelSupported(requestedModel string) bool { mapping := a.GetModelMapping() if len(mapping) == 0 { + return true // 无映射 = 允许所有 + } + // 精确匹配 + if _, exists := mapping[requestedModel]; exists { return true } - _, exists := mapping[requestedModel] - return exists + // 通配符匹配 + for pattern := range mapping { + if matchWildcard(pattern, requestedModel) { + return true + } + } + return false } +// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) +// 如果未配置 mapping,返回原始模型名 func (a *Account) GetMappedModel(requestedModel string) string { mapping := a.GetModelMapping() if len(mapping) == 0 { return requestedModel } + // 精确匹配优先 if mappedModel, exists := mapping[requestedModel]; exists { return mappedModel } - return requestedModel + // 通配符匹配(最长优先) + return matchWildcardMapping(mapping, requestedModel) } func (a *Account) GetBaseURL() string { @@ -426,6 +456,53 @@ func (a *Account) GetClaudeUserID() string { return "" } +// matchAntigravityWildcard 通配符匹配(仅支持末尾 *) +// 用于 model_mapping 的通配符匹配 +func matchAntigravityWildcard(pattern, str string) bool { + if strings.HasSuffix(pattern, "*") { + prefix := pattern[:len(pattern)-1] + return strings.HasPrefix(str, prefix) + } + return pattern == str +} + +// matchWildcard 通用通配符匹配(仅支持末尾 *) +// 复用 Antigravity 的通配符逻辑,供其他平台使用 +func matchWildcard(pattern, str string) bool { + return matchAntigravityWildcard(pattern, str) +} + +// matchWildcardMapping 通配符映射匹配(最长优先) +// 如果没有匹配,返回原始字符串 +func matchWildcardMapping(mapping map[string]string, requestedModel string) string { + // 收集所有匹配的 pattern,按长度降序排序(最长优先) + type patternMatch struct { + pattern string + target string + } + var matches []patternMatch + + for pattern, target := range mapping { + if matchWildcard(pattern, requestedModel) { + matches = append(matches, patternMatch{pattern, target}) + } + } + + if len(matches) == 0 { + return requestedModel // 无匹配,返回原始模型名 + } + + // 按 pattern 长度降序排序 + sort.Slice(matches, func(i, j int) bool { + if len(matches[i].pattern) != len(matches[j].pattern) { + return len(matches[i].pattern) > len(matches[j].pattern) + } + return matches[i].pattern < matches[j].pattern + }) + + return matches[0].target +} + func (a *Account) IsCustomErrorCodesEnabled() bool { if a.Type != AccountTypeAPIKey || a.Credentials == nil { return false diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index f3b3e20d..304c5781 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -41,6 +41,7 @@ type UsageLogRepository interface { // User dashboard stats GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) + GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go new file mode 100644 index 00000000..90e5b573 --- /dev/null +++ b/backend/internal/service/account_wildcard_test.go @@ -0,0 +1,269 @@ +//go:build unit + +package service + +import ( + "testing" +) + +func TestMatchWildcard(t *testing.T) { + tests := []struct { + name string + pattern string + str string + expected bool + }{ + // 精确匹配 + {"exact match", "claude-sonnet-4-5", "claude-sonnet-4-5", true}, + {"exact mismatch", "claude-sonnet-4-5", "claude-opus-4-5", false}, + + // 通配符匹配 + {"wildcard prefix match", "claude-*", "claude-sonnet-4-5", true}, + {"wildcard prefix match 2", "claude-*", "claude-opus-4-5-thinking", true}, + {"wildcard prefix mismatch", "claude-*", "gemini-3-flash", false}, + {"wildcard partial match", "gemini-3*", "gemini-3-flash", true}, + {"wildcard partial match 2", "gemini-3*", "gemini-3-pro-image", true}, + {"wildcard partial mismatch", "gemini-3*", "gemini-2.5-flash", false}, + + // 边界情况 + {"empty pattern exact", "", "", true}, + {"empty pattern mismatch", "", "claude", false}, + {"single star", "*", "anything", true}, + {"star at end only", "abc*", "abcdef", true}, + {"star at end empty suffix", "abc*", "abc", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := matchWildcard(tt.pattern, tt.str) + if result != tt.expected { + t.Errorf("matchWildcard(%q, %q) = %v, want %v", tt.pattern, tt.str, result, tt.expected) + } + }) + } +} + +func TestMatchWildcardMapping(t *testing.T) { + tests := []struct { + name string + mapping map[string]string + requestedModel string + expected string + }{ + // 精确匹配优先于通配符 + { + name: "exact match takes precedence", + mapping: map[string]string{ + "claude-sonnet-4-5": "claude-sonnet-4-5-exact", + "claude-*": "claude-default", + }, + requestedModel: "claude-sonnet-4-5", + expected: "claude-sonnet-4-5-exact", + }, + + // 最长通配符优先 + { + name: "longer wildcard takes precedence", + mapping: map[string]string{ + "claude-*": "claude-default", + "claude-sonnet-*": "claude-sonnet-default", + "claude-sonnet-4*": "claude-sonnet-4-series", + }, + requestedModel: "claude-sonnet-4-5", + expected: "claude-sonnet-4-series", + }, + + // 单个通配符 + { + name: "single wildcard", + mapping: map[string]string{ + "claude-*": "claude-mapped", + }, + requestedModel: "claude-opus-4-5", + expected: "claude-mapped", + }, + + // 无匹配返回原始模型 + { + name: "no match returns original", + mapping: map[string]string{ + "claude-*": "claude-mapped", + }, + requestedModel: "gemini-3-flash", + expected: "gemini-3-flash", + }, + + // 空映射返回原始模型 + { + name: "empty mapping returns original", + mapping: map[string]string{}, + requestedModel: "claude-sonnet-4-5", + expected: "claude-sonnet-4-5", + }, + + // Gemini 模型映射 + { + name: "gemini wildcard mapping", + mapping: map[string]string{ + "gemini-3*": "gemini-3-pro-high", + "gemini-2.5*": "gemini-2.5-flash", + }, + requestedModel: "gemini-3-flash-preview", + expected: "gemini-3-pro-high", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := matchWildcardMapping(tt.mapping, tt.requestedModel) + if result != tt.expected { + t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected) + } + }) + } +} + +func TestAccountIsModelSupported(t *testing.T) { + tests := []struct { + name string + credentials map[string]any + requestedModel string + expected bool + }{ + // 无映射 = 允许所有 + { + name: "no mapping allows all", + credentials: nil, + requestedModel: "any-model", + expected: true, + }, + { + name: "empty mapping allows all", + credentials: map[string]any{}, + requestedModel: "any-model", + expected: true, + }, + + // 精确匹配 + { + name: "exact match supported", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-sonnet-4-5": "target-model", + }, + }, + requestedModel: "claude-sonnet-4-5", + expected: true, + }, + { + name: "exact match not supported", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-sonnet-4-5": "target-model", + }, + }, + requestedModel: "claude-opus-4-5", + expected: false, + }, + + // 通配符匹配 + { + name: "wildcard match supported", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-*": "claude-sonnet-4-5", + }, + }, + requestedModel: "claude-opus-4-5-thinking", + expected: true, + }, + { + name: "wildcard match not supported", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-*": "claude-sonnet-4-5", + }, + }, + requestedModel: "gemini-3-flash", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Credentials: tt.credentials, + } + result := account.IsModelSupported(tt.requestedModel) + if result != tt.expected { + t.Errorf("IsModelSupported(%q) = %v, want %v", tt.requestedModel, result, tt.expected) + } + }) + } +} + +func TestAccountGetMappedModel(t *testing.T) { + tests := []struct { + name string + credentials map[string]any + requestedModel string + expected string + }{ + // 无映射 = 返回原始模型 + { + name: "no mapping returns original", + credentials: nil, + requestedModel: "claude-sonnet-4-5", + expected: "claude-sonnet-4-5", + }, + + // 精确匹配 + { + name: "exact match", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-sonnet-4-5": "target-model", + }, + }, + requestedModel: "claude-sonnet-4-5", + expected: "target-model", + }, + + // 通配符匹配(最长优先) + { + name: "wildcard longest match", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-*": "claude-default", + "claude-sonnet-*": "claude-sonnet-mapped", + }, + }, + requestedModel: "claude-sonnet-4-5", + expected: "claude-sonnet-mapped", + }, + + // 无匹配返回原始模型 + { + name: "no match returns original", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-*": "gemini-mapped", + }, + }, + requestedModel: "claude-sonnet-4-5", + expected: "claude-sonnet-4-5", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Credentials: tt.credentials, + } + result := account.GetMappedModel(tt.requestedModel) + if result != tt.expected { + t.Errorf("GetMappedModel(%q) = %q, want %q", tt.requestedModel, result, tt.expected) + } + }) + } +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index ef2d526b..59d7062b 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -22,6 +22,10 @@ type AdminService interface { UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) + // GetUserBalanceHistory returns paginated balance/concurrency change records for a user. + // codeType is optional - pass empty string to return all types. + // Also returns totalRecharged (sum of all positive balance top-ups). + GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) // Group management ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) @@ -52,6 +56,7 @@ type AdminService interface { GetAllProxies(ctx context.Context) ([]Proxy, error) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) GetProxy(ctx context.Context, id int64) (*Proxy, error) + GetProxiesByIDs(ctx context.Context, ids []int64) ([]Proxy, error) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) DeleteProxy(ctx context.Context, id int64) error @@ -89,6 +94,9 @@ type UpdateUserInput struct { Concurrency *int // 使用指针区分"未提供"和"设置为0" Status string AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" + // GroupRates 用户专属分组倍率配置 + // map[groupID]*rate,nil 表示删除该分组的专属倍率 + GroupRates map[int64]*float64 } type CreateGroupInput struct { @@ -107,9 +115,14 @@ type CreateGroupInput struct { ImagePrice4K *float64 ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID + // 无效请求兜底分组 ID(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 ModelRoutingEnabled bool // 是否启用模型路由 + MCPXMLInject *bool + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -131,9 +144,14 @@ type UpdateGroupInput struct { ImagePrice4K *float64 ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID + // 无效请求兜底分组 ID(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 ModelRoutingEnabled *bool // 是否启用模型路由 + MCPXMLInject *bool + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes *[]string // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -152,6 +170,8 @@ type CreateAccountInput struct { GroupIDs []int64 ExpiresAt *int64 AutoPauseOnExpired *bool + // SkipDefaultGroupBind prevents auto-binding to platform default group when GroupIDs is empty. + SkipDefaultGroupBind bool // SkipMixedChannelCheck skips the mixed channel risk check when binding groups. // This should only be set when the caller has explicitly confirmed the risk. SkipMixedChannelCheck bool @@ -279,6 +299,7 @@ type adminServiceImpl struct { proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository + userGroupRateRepo UserGroupRateRepository billingCacheService *BillingCacheService proxyProber ProxyExitInfoProber proxyLatencyCache ProxyLatencyCache @@ -293,6 +314,7 @@ func NewAdminService( proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, + userGroupRateRepo UserGroupRateRepository, billingCacheService *BillingCacheService, proxyProber ProxyExitInfoProber, proxyLatencyCache ProxyLatencyCache, @@ -305,6 +327,7 @@ func NewAdminService( proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, + userGroupRateRepo: userGroupRateRepo, billingCacheService: billingCacheService, proxyProber: proxyProber, proxyLatencyCache: proxyLatencyCache, @@ -319,11 +342,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi if err != nil { return nil, 0, err } + // 批量加载用户专属分组倍率 + if s.userGroupRateRepo != nil && len(users) > 0 { + for i := range users { + rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) + if err != nil { + log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err) + continue + } + users[i].GroupRates = rates + } + } return users, result.Total, nil } func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { - return s.userRepo.GetByID(ctx, id) + user, err := s.userRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + // 加载用户专属分组倍率 + if s.userGroupRateRepo != nil { + rates, err := s.userGroupRateRepo.GetByUserID(ctx, id) + if err != nil { + log.Printf("failed to load user group rates: user_id=%d err=%v", id, err) + } else { + user.GroupRates = rates + } + } + return user, nil } func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { @@ -392,6 +439,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } + + // 同步用户专属分组倍率 + if input.GroupRates != nil && s.userGroupRateRepo != nil { + if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil { + log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err) + } + } + if s.authCacheInvalidator != nil { if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) @@ -526,6 +581,21 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, }, nil } +// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. +func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType) + if err != nil { + return nil, 0, 0, err + } + // Aggregate total recharged amount (only once, regardless of type filter) + totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) + if err != nil { + return nil, 0, 0, err + } + return codes, result.Total, totalRecharged, nil +} + // Group management implementations func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} @@ -575,6 +645,22 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn return nil, err } } + fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest + if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 { + fallbackOnInvalidRequest = nil + } + // 校验无效请求兜底分组 + if fallbackOnInvalidRequest != nil { + if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil { + return nil, err + } + } + + // MCPXMLInject:默认为 true,仅当显式传入 false 时关闭 + mcpXMLInject := true + if input.MCPXMLInject != nil { + mcpXMLInject = *input.MCPXMLInject + } // 如果指定了复制账号的源分组,先获取账号 ID 列表 var accountIDsToCopy []int64 @@ -609,22 +695,25 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn } group := &Group{ - Name: input.Name, - Description: input.Description, - Platform: platform, - RateMultiplier: input.RateMultiplier, - IsExclusive: input.IsExclusive, - Status: StatusActive, - SubscriptionType: subscriptionType, - DailyLimitUSD: dailyLimit, - WeeklyLimitUSD: weeklyLimit, - MonthlyLimitUSD: monthlyLimit, - ImagePrice1K: imagePrice1K, - ImagePrice2K: imagePrice2K, - ImagePrice4K: imagePrice4K, - ClaudeCodeOnly: input.ClaudeCodeOnly, - FallbackGroupID: input.FallbackGroupID, - ModelRouting: input.ModelRouting, + Name: input.Name, + Description: input.Description, + Platform: platform, + RateMultiplier: input.RateMultiplier, + IsExclusive: input.IsExclusive, + Status: StatusActive, + SubscriptionType: subscriptionType, + DailyLimitUSD: dailyLimit, + WeeklyLimitUSD: weeklyLimit, + MonthlyLimitUSD: monthlyLimit, + ImagePrice1K: imagePrice1K, + ImagePrice2K: imagePrice2K, + ImagePrice4K: imagePrice4K, + ClaudeCodeOnly: input.ClaudeCodeOnly, + FallbackGroupID: input.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, + ModelRouting: input.ModelRouting, + MCPXMLInject: mcpXMLInject, + SupportedModelScopes: input.SupportedModelScopes, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -695,6 +784,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro } } +// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性 +// currentGroupID: 当前分组 ID(新建时为 0) +// platform/subscriptionType: 当前分组的有效平台/订阅类型 +// fallbackGroupID: 兜底分组 ID +func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error { + if platform != PlatformAnthropic && platform != PlatformAntigravity { + return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups") + } + if subscriptionType == SubscriptionTypeSubscription { + return fmt.Errorf("subscription groups cannot set invalid request fallback") + } + if currentGroupID > 0 && currentGroupID == fallbackGroupID { + return fmt.Errorf("cannot set self as invalid request fallback group") + } + + fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID) + if err != nil { + return fmt.Errorf("fallback group not found: %w", err) + } + if fallbackGroup.Platform != PlatformAnthropic { + return fmt.Errorf("fallback group must be anthropic platform") + } + if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription { + return fmt.Errorf("fallback group cannot be subscription type") + } + if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { + return fmt.Errorf("fallback group cannot have invalid request fallback configured") + } + return nil +} + func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { @@ -761,6 +881,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.FallbackGroupID = nil } } + fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest + if input.FallbackGroupIDOnInvalidRequest != nil { + if *input.FallbackGroupIDOnInvalidRequest > 0 { + fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest + } else { + fallbackOnInvalidRequest = nil + } + } + if fallbackOnInvalidRequest != nil { + if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil { + return nil, err + } + } + group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest // 模型路由配置 if input.ModelRouting != nil { @@ -769,6 +903,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.ModelRoutingEnabled != nil { group.ModelRoutingEnabled = *input.ModelRoutingEnabled } + if input.MCPXMLInject != nil { + group.MCPXMLInject = *input.MCPXMLInject + } + + // 支持的模型系列(仅 antigravity 平台使用) + if input.SupportedModelScopes != nil { + group.SupportedModelScopes = *input.SupportedModelScopes + } if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err @@ -840,6 +982,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { if err != nil { return err } + // 注意:user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理 // 事务成功后,异步失效受影响用户的订阅缓存 if len(affectedUserIDs) > 0 && s.billingCacheService != nil { @@ -903,7 +1046,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou // 绑定分组 groupIDs := input.GroupIDs // 如果没有指定分组,自动绑定对应平台的默认分组 - if len(groupIDs) == 0 { + if len(groupIDs) == 0 && !input.SkipDefaultGroupBind { defaultGroupName := input.Platform + "-default" groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform) if err == nil { @@ -1243,6 +1386,10 @@ func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, erro return s.proxyRepo.GetByID(ctx, id) } +func (s *adminServiceImpl) GetProxiesByIDs(ctx context.Context, ids []int64) ([]Proxy, error) { + return s.proxyRepo.ListByIDs(ctx, ids) +} + func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) { proxy := &Proxy{ Name: input.Name, diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 923d33ab..c775749d 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -187,6 +187,10 @@ func (s *proxyRepoStub) GetByID(ctx context.Context, id int64) (*Proxy, error) { panic("unexpected GetByID call") } +func (s *proxyRepoStub) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) { + panic("unexpected ListByIDs call") +} + func (s *proxyRepoStub) Update(ctx context.Context, proxy *Proxy) error { panic("unexpected Update call") } @@ -282,6 +286,14 @@ func (s *redeemRepoStub) ListByUser(ctx context.Context, userID int64, limit int panic("unexpected ListByUser call") } +func (s *redeemRepoStub) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListByUserPaginated call") +} + +func (s *redeemRepoStub) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) { + panic("unexpected SumPositiveBalanceByUser call") +} + type subscriptionInvalidateCall struct { userID int64 groupID int64 diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 1daee89f..d921a086 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -394,3 +394,382 @@ func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) { panic("unexpected GetAccountIDsByGroupIDs call") } + +type groupRepoStubForInvalidRequestFallback struct { + groups map[int64]*Group + created *Group + updated *Group +} + +func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error { + s.created = g + return nil +} + +func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error { + s.updated = g + return nil +} + +func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) { + return s.GetByIDLite(ctx, id) +} + +func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) { + if g, ok := s.groups[id]; ok { + return g, nil + } + return nil, ErrGroupNotFound +} + +func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error { + panic("unexpected Delete call") +} + +func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) { + panic("unexpected DeleteCascade call") +} + +func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) { + panic("unexpected ListActive call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) { + panic("unexpected ListActiveByPlatform call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) { + panic("unexpected ExistsByName call") +} + +func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) { + panic("unexpected GetAccountCount call") +} + +func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) { + panic("unexpected DeleteAccountGroupsByGroupID call") +} + +func (s *groupRepoStubForInvalidRequestFallback) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) { + panic("unexpected GetAccountIDsByGroupIDs call") +} + +func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error { + panic("unexpected BindAccountsToGroup call") +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformOpenAI, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups") + require.Nil(t, repo.created) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeSubscription, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback") + require.Nil(t, repo.created) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) { + tests := []struct { + name string + fallback *Group + wantMessage string + }{ + { + name: "openai_target", + fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard}, + wantMessage: "fallback group must be anthropic platform", + }, + { + name: "antigravity_target", + fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard}, + wantMessage: "fallback group must be anthropic platform", + }, + { + name: "subscription_group", + fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription}, + wantMessage: "fallback group cannot be subscription type", + }, + { + name: "nested_fallback", + fallback: &Group{ + ID: 10, + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(), + }, + wantMessage: "fallback group cannot have invalid request fallback configured", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fallbackID := tc.fallback.ID + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: tc.fallback, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), tc.wantMessage) + require.Nil(t, repo.created) + }) + } +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{} + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "fallback group not found") + require.Nil(t, repo.created) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAntigravity, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.created) + require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) { + zero := int64(0) + repo := &groupRepoStubForInvalidRequestFallback{} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &zero, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.created) + require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + FallbackGroupIDOnInvalidRequest: &fallbackID, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + Platform: PlatformOpenAI, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups") + require.Nil(t, repo.updated) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + FallbackGroupIDOnInvalidRequest: &fallbackID, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + SubscriptionType: SubscriptionTypeSubscription, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback") + require.Nil(t, repo.updated) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + FallbackGroupIDOnInvalidRequest: &fallbackID, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + clear := int64(0) + group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + Platform: PlatformOpenAI, + FallbackGroupIDOnInvalidRequest: &clear, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "fallback group cannot be subscription type") + require.Nil(t, repo.updated) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAntigravity, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest) +} diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go index 7506c6db..d661b710 100644 --- a/backend/internal/service/admin_service_search_test.go +++ b/backend/internal/service/admin_service_search_test.go @@ -152,6 +152,14 @@ func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params p return s.listWithFiltersCodes, result, nil } +func (s *redeemRepoStubForAdminList) ListByUserPaginated(_ context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListByUserPaginated call") +} + +func (s *redeemRepoStubForAdminList) SumPositiveBalanceByUser(_ context.Context, userID int64) (float64, error) { + panic("unexpected SumPositiveBalanceByUser call") +} + func TestAdminService_ListAccounts_WithSearch(t *testing.T) { t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { repo := &accountRepoStubForAdminList{ diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 9b8156e6..3d3c9cca 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -13,6 +13,7 @@ import ( "net" "net/http" "os" + "strconv" "strings" "sync/atomic" "time" @@ -27,24 +28,88 @@ const ( antigravityMaxRetries = 3 antigravityRetryBaseDelay = 1 * time.Second antigravityRetryMaxDelay = 16 * time.Second + + // 限流相关常量 + // antigravityRateLimitThreshold 限流等待/切换阈值 + // - 智能重试:retryDelay < 此阈值时等待后重试,>= 此阈值时直接限流模型 + // - 预检查:剩余限流时间 < 此阈值时等待,>= 此阈值时切换账号 + antigravityRateLimitThreshold = 7 * time.Second + antigravitySmartRetryMinWait = 1 * time.Second // 智能重试最小等待时间 + antigravitySmartRetryMaxAttempts = 3 // 智能重试最大次数 + antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用) + + // Google RPC 状态和类型常量 + googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED" + googleRPCStatusUnavailable = "UNAVAILABLE" + googleRPCTypeRetryInfo = "type.googleapis.com/google.rpc.RetryInfo" + googleRPCTypeErrorInfo = "type.googleapis.com/google.rpc.ErrorInfo" + googleRPCReasonModelCapacityExhausted = "MODEL_CAPACITY_EXHAUSTED" + googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED" ) -const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT" +// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写) +// 匹配时使用 strings.Contains,无需完全匹配 +var antigravityPassthroughErrorMessages = []string{ + "prompt is too long", +} + +const ( + antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL" + antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS" +) + +// AntigravityAccountSwitchError 账号切换信号 +// 当账号限流时间超过阈值时,通知上层切换账号 +type AntigravityAccountSwitchError struct { + OriginalAccountID int64 + RateLimitedModel string + IsStickySession bool // 是否为粘性会话切换(决定是否缓存计费) +} + +func (e *AntigravityAccountSwitchError) Error() string { + return fmt.Sprintf("account %d model %s rate limited, need switch", + e.OriginalAccountID, e.RateLimitedModel) +} + +// IsAntigravityAccountSwitchError 检查错误是否为账号切换信号 +func IsAntigravityAccountSwitchError(err error) (*AntigravityAccountSwitchError, bool) { + var switchErr *AntigravityAccountSwitchError + if errors.As(err, &switchErr) { + return switchErr, true + } + return nil, false +} + +// PromptTooLongError 表示上游明确返回 prompt too long +type PromptTooLongError struct { + StatusCode int + RequestID string + Body []byte +} + +func (e *PromptTooLongError) Error() string { + return fmt.Sprintf("prompt too long: status=%d", e.StatusCode) +} // antigravityRetryLoopParams 重试循环的参数 type antigravityRetryLoopParams struct { - ctx context.Context - prefix string - account *Account - proxyURL string - accessToken string - action string - body []byte - quotaScope AntigravityQuotaScope - c *gin.Context - httpUpstream HTTPUpstream - settingService *SettingService - handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) + ctx context.Context + prefix string + account *Account + proxyURL string + accessToken string + action string + body []byte + quotaScope AntigravityQuotaScope + c *gin.Context + httpUpstream HTTPUpstream + settingService *SettingService + accountRepo AccountRepository // 用于智能重试的模型级别限流 + handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult + requestedModel string // 用于限流检查的原始请求模型 + isStickySession bool // 是否为粘性会话(用于账号切换时的缓存计费判断) + groupID int64 // 用于模型级限流时清除粘性会话 + sessionHash string // 用于模型级限流时清除粘性会话 } // antigravityRetryLoopResult 重试循环的结果 @@ -52,8 +117,178 @@ type antigravityRetryLoopResult struct { resp *http.Response } +// smartRetryAction 智能重试的处理结果 +type smartRetryAction int + +const ( + smartRetryActionContinue smartRetryAction = iota // 继续默认重试逻辑 + smartRetryActionBreakWithResp // 结束循环并返回 resp + smartRetryActionContinueURL // 继续 URL fallback 循环 +) + +// smartRetryResult 智能重试的结果 +type smartRetryResult struct { + action smartRetryAction + resp *http.Response + err error + switchError *AntigravityAccountSwitchError // 模型限流时返回账号切换信号 +} + +// handleSmartRetry 处理 OAuth 账号的智能重试逻辑 +// 将 429/503 限流处理逻辑抽取为独立函数,减少 antigravityRetryLoop 的复杂度 +func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParams, resp *http.Response, respBody []byte, baseURL string, urlIdx int, availableURLs []string) *smartRetryResult { + // "Resource has been exhausted" 是 URL 级别限流,切换 URL(仅 429) + if resp.StatusCode == http.StatusTooManyRequests && isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 { + log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) + return &smartRetryResult{action: smartRetryActionContinueURL} + } + + // 判断是否触发智能重试 + shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName := shouldTriggerAntigravitySmartRetry(p.account, respBody) + + // 情况1: retryDelay >= 阈值,限流模型并切换账号 + if shouldRateLimitModel { + log.Printf("%s status=%d oauth_long_delay model=%s account=%d (model rate limit, switch account)", + p.prefix, resp.StatusCode, modelName, p.account.ID) + + resetAt := time.Now().Add(antigravityDefaultRateLimitDuration) + if !setModelRateLimitByModelName(p.ctx, p.accountRepo, p.account.ID, modelName, p.prefix, resp.StatusCode, resetAt, false) { + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) + log.Printf("%s status=%d rate_limited account=%d (no scope mapping)", p.prefix, resp.StatusCode, p.account.ID) + } else { + s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt) + } + + // 返回账号切换信号,让上层切换账号重试 + return &smartRetryResult{ + action: smartRetryActionBreakWithResp, + switchError: &AntigravityAccountSwitchError{ + OriginalAccountID: p.account.ID, + RateLimitedModel: modelName, + IsStickySession: p.isStickySession, + }, + } + } + + // 情况2: retryDelay < 阈值,智能重试(最多 antigravitySmartRetryMaxAttempts 次) + if shouldSmartRetry { + var lastRetryResp *http.Response + var lastRetryBody []byte + + for attempt := 1; attempt <= antigravitySmartRetryMaxAttempts; attempt++ { + log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d", + p.prefix, resp.StatusCode, attempt, antigravitySmartRetryMaxAttempts, waitDuration, modelName, p.account.ID) + + select { + case <-p.ctx.Done(): + log.Printf("%s status=context_canceled_during_smart_retry", p.prefix) + return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()} + case <-time.After(waitDuration): + } + + // 智能重试:创建新请求 + retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body) + if err != nil { + log.Printf("%s status=smart_retry_request_build_failed error=%v", p.prefix, err) + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) + return &smartRetryResult{ + action: smartRetryActionBreakWithResp, + resp: &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + }, + } + } + + retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency) + if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable { + log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts) + return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp} + } + + // 网络错误时,继续重试 + if retryErr != nil || retryResp == nil { + log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr) + continue + } + + // 重试失败,关闭之前的响应 + if lastRetryResp != nil { + _ = lastRetryResp.Body.Close() + } + lastRetryResp = retryResp + if retryResp != nil { + lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + _ = retryResp.Body.Close() + } + + // 解析新的重试信息,用于下次重试的等待时间 + if attempt < antigravitySmartRetryMaxAttempts && lastRetryBody != nil { + newShouldRetry, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody) + if newShouldRetry && newWaitDuration > 0 { + waitDuration = newWaitDuration + } + } + } + + // 所有重试都失败,限流当前模型并切换账号 + log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d (switch account)", + p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID) + + resetAt := time.Now().Add(antigravityDefaultRateLimitDuration) + if p.accountRepo != nil && modelName != "" { + if err := p.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, modelName, resetAt); err != nil { + log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err) + } else { + log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", + p.prefix, resp.StatusCode, modelName, p.account.ID, antigravityDefaultRateLimitDuration) + s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt) + } + } + + // 返回账号切换信号,让上层切换账号重试 + return &smartRetryResult{ + action: smartRetryActionBreakWithResp, + switchError: &AntigravityAccountSwitchError{ + OriginalAccountID: p.account.ID, + RateLimitedModel: modelName, + IsStickySession: p.isStickySession, + }, + } + } + + // 未触发智能重试,继续默认重试逻辑 + return &smartRetryResult{action: smartRetryActionContinue} +} + // antigravityRetryLoop 执行带 URL fallback 的重试循环 -func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { +func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { + // 预检查:如果账号已限流,根据剩余时间决定等待或切换 + if p.requestedModel != "" { + if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 { + if remaining < antigravityRateLimitThreshold { + // 限流剩余时间较短,等待后继续 + log.Printf("%s pre_check: rate_limit_wait remaining=%v model=%s account=%d", + p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) + select { + case <-p.ctx.Done(): + return nil, p.ctx.Err() + case <-time.After(remaining): + } + } else { + // 限流剩余时间较长,返回账号切换信号 + log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d", + p.prefix, remaining.Truncate(time.Second), p.requestedModel, p.account.ID) + return nil, &AntigravityAccountSwitchError{ + OriginalAccountID: p.account.ID, + RateLimitedModel: p.requestedModel, + IsStickySession: p.isStickySession, + } + } + } + } + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() if len(availableURLs) == 0 { availableURLs = antigravity.BaseURLs @@ -95,6 +330,9 @@ urlFallbackLoop: } resp, err = p.httpUpstream.Do(upstreamReq, p.proxyURL, p.account.ID, p.account.Concurrency) + if err == nil && resp == nil { + err = errors.New("upstream returned nil response") + } if err != nil { safeErr := sanitizeUpstreamErrorMessage(err.Error()) appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ @@ -122,18 +360,30 @@ urlFallbackLoop: return nil, fmt.Errorf("upstream request failed after retries: %w", err) } - // 429 限流处理:区分 URL 级别限流和账户配额限流 - if resp.StatusCode == http.StatusTooManyRequests { + // 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流 + if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() - // "Resource has been exhausted" 是 URL 级别限流,切换 URL - if isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 { - log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) + // 尝试智能重试处理(OAuth 账号专用) + smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs) + switch smartResult.action { + case smartRetryActionContinueURL: continue urlFallbackLoop + case smartRetryActionBreakWithResp: + if smartResult.err != nil { + return nil, smartResult.err + } + // 模型限流时返回切换账号信号 + if smartResult.switchError != nil { + return nil, smartResult.switchError + } + resp = smartResult.resp + break urlFallbackLoop } + // smartRetryActionContinue: 继续默认重试逻辑 - // 账户/模型配额限流,重试 3 次(指数退避) + // 账户/模型配额限流,重试 3 次(指数退避)- 默认逻辑(非 OAuth 账号或解析失败) if attempt < antigravityMaxRetries { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) @@ -147,7 +397,7 @@ urlFallbackLoop: Message: upstreamMsg, Detail: getUpstreamDetail(respBody), }) - log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { log.Printf("%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() @@ -156,8 +406,8 @@ urlFallbackLoop: } // 重试用尽,标记账户限流 - p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope) - log.Printf("%s status=429 rate_limited base_url=%s body=%s", p.prefix, baseURL, truncateForLog(respBody, 200)) + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) + log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200)) resp = &http.Response{ StatusCode: resp.StatusCode, Header: resp.Header.Clone(), @@ -166,7 +416,7 @@ urlFallbackLoop: break urlFallbackLoop } - // 其他可重试错误 + // 其他可重试错误(不包括 429 和 503,因为上面已处理) if resp.StatusCode >= 400 && shouldRetryAntigravityError(resp.StatusCode) { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() @@ -272,71 +522,34 @@ func logPrefix(sessionID, accountName string) string { return fmt.Sprintf("[antigravity-Forward] account=%s", accountName) } -// Antigravity 直接支持的模型(精确匹配透传) -// 注意:gemini-2.5 系列已移除,统一映射到 gemini-3 系列 -var antigravitySupportedModels = map[string]bool{ - "claude-opus-4-5-thinking": true, - "claude-sonnet-4-5": true, - "claude-sonnet-4-5-thinking": true, - "gemini-3-flash": true, - "gemini-3-pro-low": true, - "gemini-3-pro-high": true, - "gemini-3-pro-image": true, -} - -// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先) -// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀) -// gemini-2.5 系列统一映射到 gemini-3 系列(Antigravity 上游不再支持 2.5) -var antigravityPrefixMapping = []struct { - prefix string - target string -}{ - // gemini-2.5 → gemini-3 映射(长前缀优先) - {"gemini-2.5-flash-thinking", "gemini-3-flash"}, // gemini-2.5-flash-thinking → gemini-3-flash - {"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → gemini-3-pro-image - {"gemini-2.5-flash-lite", "gemini-3-flash"}, // gemini-2.5-flash-lite → gemini-3-flash - {"gemini-2.5-flash", "gemini-3-flash"}, // gemini-2.5-flash → gemini-3-flash - {"gemini-2.5-pro-preview", "gemini-3-pro-high"}, // gemini-2.5-pro-preview → gemini-3-pro-high - {"gemini-2.5-pro-exp", "gemini-3-pro-high"}, // gemini-2.5-pro-exp → gemini-3-pro-high - {"gemini-2.5-pro", "gemini-3-pro-high"}, // gemini-2.5-pro → gemini-3-pro-high - // gemini-3 前缀映射 - {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等 - {"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash - {"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等 - // Claude 映射 - {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx - {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx - {"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet - {"claude-opus-4-5", "claude-opus-4-5-thinking"}, - {"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet - {"claude-sonnet-4", "claude-sonnet-4-5"}, - {"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet - {"claude-opus-4", "claude-opus-4-5-thinking"}, -} - // AntigravityGatewayService 处理 Antigravity 平台的 API 转发 type AntigravityGatewayService struct { - accountRepo AccountRepository - tokenProvider *AntigravityTokenProvider - rateLimitService *RateLimitService - httpUpstream HTTPUpstream - settingService *SettingService + accountRepo AccountRepository + tokenProvider *AntigravityTokenProvider + rateLimitService *RateLimitService + httpUpstream HTTPUpstream + settingService *SettingService + cache GatewayCache // 用于模型级限流时清除粘性会话绑定 + schedulerSnapshot *SchedulerSnapshotService } func NewAntigravityGatewayService( accountRepo AccountRepository, - _ GatewayCache, + cache GatewayCache, + schedulerSnapshot *SchedulerSnapshotService, tokenProvider *AntigravityTokenProvider, rateLimitService *RateLimitService, httpUpstream HTTPUpstream, settingService *SettingService, ) *AntigravityGatewayService { return &AntigravityGatewayService{ - accountRepo: accountRepo, - tokenProvider: tokenProvider, - rateLimitService: rateLimitService, - httpUpstream: httpUpstream, - settingService: settingService, + accountRepo: accountRepo, + tokenProvider: tokenProvider, + rateLimitService: rateLimitService, + httpUpstream: httpUpstream, + settingService: settingService, + cache: cache, + schedulerSnapshot: schedulerSnapshot, } } @@ -345,33 +558,80 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider return s.tokenProvider } -// getMappedModel 获取映射后的模型名 -// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值 -func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string { - // 1. 账户级映射(用户自定义优先) - if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel { +// getLogConfig 获取上游错误日志配置 +// 返回是否记录日志体和最大字节数 +func (s *AntigravityGatewayService) getLogConfig() (logBody bool, maxBytes int) { + maxBytes = 2048 // 默认值 + if s.settingService == nil || s.settingService.cfg == nil { + return false, maxBytes + } + cfg := s.settingService.cfg.Gateway + if cfg.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = cfg.LogUpstreamErrorBodyMaxBytes + } + return cfg.LogUpstreamErrorBody, maxBytes +} + +// getUpstreamErrorDetail 获取上游错误详情(用于日志记录) +func (s *AntigravityGatewayService) getUpstreamErrorDetail(body []byte) string { + logBody, maxBytes := s.getLogConfig() + if !logBody { + return "" + } + return truncateString(string(body), maxBytes) +} + +// mapAntigravityModel 获取映射后的模型名 +// 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底(DefaultAntigravityModelMapping) +// 注意:返回空字符串表示模型不被支持,调度时会过滤掉该账号 +func mapAntigravityModel(account *Account, requestedModel string) string { + if account == nil { + return "" + } + + // 获取映射表(未配置时自动使用 DefaultAntigravityModelMapping) + mapping := account.GetModelMapping() + if len(mapping) == 0 { + return "" // 无映射配置(非 Antigravity 平台) + } + + // 通过映射表查询(支持精确匹配 + 通配符) + mapped := account.GetMappedModel(requestedModel) + + // 判断是否映射成功(mapped != requestedModel 说明找到了映射规则) + if mapped != requestedModel { return mapped } - // 2. 直接支持的模型透传 - if antigravitySupportedModels[requestedModel] { + // 如果 mapped == requestedModel,检查是否在映射表中配置(精确或通配符) + // 这区分两种情况: + // 1. 映射表中有 "model-a": "model-a"(显式透传)→ 返回 model-a + // 2. 通配符匹配 "claude-*": "claude-sonnet-4-5" 恰好目标等于请求名 → 返回 model-a + // 3. 映射表中没有 model-a 的配置 → 返回空(不支持) + if account.IsModelSupported(requestedModel) { return requestedModel } - // 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview) - for _, pm := range antigravityPrefixMapping { - if strings.HasPrefix(requestedModel, pm.prefix) { - return pm.target - } - } + // 未在映射表中配置的模型,返回空字符串(不支持) + return "" +} - // 4. Gemini 模型透传(未匹配到前缀的 gemini 模型) - if strings.HasPrefix(requestedModel, "gemini-") { - return requestedModel - } +// getMappedModel 获取映射后的模型名 +// 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底 +func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string { + return mapAntigravityModel(account, requestedModel) +} - // 5. 默认值 - return "claude-sonnet-4-5" +// applyThinkingModelSuffix 根据 thinking 配置调整模型名 +// 当映射结果是 claude-sonnet-4-5 且请求开启了 thinking 时,改为 claude-sonnet-4-5-thinking +func applyThinkingModelSuffix(mappedModel string, thinkingEnabled bool) string { + if !thinkingEnabled { + return mappedModel + } + if mappedModel == "claude-sonnet-4-5" { + return "claude-sonnet-4-5-thinking" + } + return mappedModel } // IsModelSupported 检查模型是否被支持 @@ -404,6 +664,9 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account // 模型映射 mappedModel := s.getMappedModel(account, modelID) + if mappedModel == "" { + return nil, fmt.Errorf("model %s not in whitelist", modelID) + } // 构建请求体 var requestBody []byte @@ -701,7 +964,7 @@ func isModelNotFoundError(statusCode int, body []byte) bool { } // Forward 转发 Claude 协议请求(Claude → Gemini 转换) -func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { +func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -709,23 +972,30 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 解析 Claude 请求 var claudeReq antigravity.ClaudeRequest if err := json.Unmarshal(body, &claudeReq); err != nil { - return nil, fmt.Errorf("parse claude request: %w", err) + return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request body") } if strings.TrimSpace(claudeReq.Model) == "" { - return nil, fmt.Errorf("missing model") + return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing model") } originalModel := claudeReq.Model mappedModel := s.getMappedModel(account, claudeReq.Model) + if mappedModel == "" { + return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) + } + loadModel := mappedModel + // 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本 + thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" + mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) quotaScope, _ := resolveAntigravityQuotaScope(originalModel) // 获取 access_token if s.tokenProvider == nil { - return nil, errors.New("antigravity token provider not configured") + return nil, s.writeClaudeError(c, http.StatusBadGateway, "api_error", "Antigravity token provider not configured") } accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) if err != nil { - return nil, fmt.Errorf("获取 access_token 失败: %w", err) + return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Failed to get upstream access token") } // 获取 project_id(部分账户类型可能没有) @@ -745,29 +1015,46 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 转换 Claude 请求为 Gemini 格式 geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts) if err != nil { - return nil, fmt.Errorf("transform request: %w", err) + return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request") } // Antigravity 上游只支持流式请求,统一使用 streamGenerateContent // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 action := "streamGenerateContent" + // 统计模型调用次数(包括粘性会话,用于负载均衡调度) + if s.cache != nil { + _, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel) + } + // 执行带重试的请求 - result, err := antigravityRetryLoop(antigravityRetryLoopParams{ - ctx: ctx, - prefix: prefix, - account: account, - proxyURL: proxyURL, - accessToken: accessToken, - action: action, - body: geminiBody, - quotaScope: quotaScope, - c: c, - httpUpstream: s.httpUpstream, - settingService: s.settingService, - handleError: s.handleUpstreamError, + result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: action, + body: geminiBody, + quotaScope: quotaScope, + c: c, + httpUpstream: s.httpUpstream, + settingService: s.settingService, + accountRepo: s.accountRepo, + handleError: s.handleUpstreamError, + requestedModel: originalModel, + isStickySession: isStickySession, // Forward 由上层判断粘性会话 + groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除 + sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除 }) if err != nil { + // 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号 + if switchErr, ok := IsAntigravityAccountSwitchError(err); ok { + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusServiceUnavailable, + ForceCacheBilling: switchErr.IsStickySession, + } + } return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") } resp := result.resp @@ -782,15 +1069,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody - maxBytes := 2048 - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { - maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - } - upstreamDetail := "" - if logBody { - upstreamDetail = truncateString(string(respBody), maxBytes) - } + logBody, maxBytes := s.getLogConfig() + upstreamDetail := s.getUpstreamErrorDetail(respBody) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, @@ -829,19 +1109,24 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if txErr != nil { continue } - retryResult, retryErr := antigravityRetryLoop(antigravityRetryLoopParams{ - ctx: ctx, - prefix: prefix, - account: account, - proxyURL: proxyURL, - accessToken: accessToken, - action: action, - body: retryGeminiBody, - quotaScope: quotaScope, - c: c, - httpUpstream: s.httpUpstream, - settingService: s.settingService, - handleError: s.handleUpstreamError, + retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: action, + body: retryGeminiBody, + quotaScope: quotaScope, + c: c, + httpUpstream: s.httpUpstream, + settingService: s.settingService, + accountRepo: s.accountRepo, + handleError: s.handleUpstreamError, + requestedModel: originalModel, + isStickySession: isStickySession, + groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除 + sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除 }) if retryErr != nil { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ @@ -917,20 +1202,38 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 处理错误响应(重试后仍失败或不触发重试) if resp.StatusCode >= 400 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) + // 检测 prompt too long 错误,返回特殊错误类型供上层 fallback + if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := s.getUpstreamErrorDetail(respBody) + logBody, maxBytes := s.getLogConfig() + if logBody { + log.Printf("%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes)) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "prompt_too_long", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &PromptTooLongError{ + StatusCode: resp.StatusCode, + RequestID: resp.Header.Get("x-request-id"), + Body: respBody, + } + } + + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) if s.shouldFailoverUpstreamError(resp.StatusCode) { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody - maxBytes := 2048 - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { - maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - } - upstreamDetail := "" - if logBody { - upstreamDetail = truncateString(string(respBody), maxBytes) - } + upstreamDetail := s.getUpstreamErrorDetail(respBody) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, @@ -941,7 +1244,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) @@ -1006,6 +1309,37 @@ func isSignatureRelatedError(respBody []byte) bool { return false } +// isPromptTooLongError 检测是否为 prompt too long 错误 +func isPromptTooLongError(respBody []byte) bool { + msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) + if msg == "" { + msg = strings.ToLower(string(respBody)) + } + return strings.Contains(msg, "prompt is too long") || + strings.Contains(msg, "request is too long") || + strings.Contains(msg, "context length exceeded") || + strings.Contains(msg, "max_tokens") +} + +// isPassthroughErrorMessage 检查错误消息是否在透传白名单中 +func isPassthroughErrorMessage(msg string) bool { + lower := strings.ToLower(msg) + for _, pattern := range antigravityPassthroughErrorMessages { + if strings.Contains(lower, pattern) { + return true + } + } + return false +} + +// getPassthroughOrDefault 若消息在白名单内则返回原始消息,否则返回默认消息 +func getPassthroughOrDefault(upstreamMsg, defaultMsg string) string { + if isPassthroughErrorMessage(upstreamMsg) { + return upstreamMsg + } + return defaultMsg +} + func extractAntigravityErrorMessage(body []byte) string { var payload map[string]any if err := json.Unmarshal(body, &payload); err != nil { @@ -1249,7 +1583,7 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque } // ForwardGemini 转发 Gemini 协议请求 -func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) { +func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -1287,14 +1621,17 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co } mappedModel := s.getMappedModel(account, originalModel) + if mappedModel == "" { + return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel)) + } // 获取 access_token if s.tokenProvider == nil { - return nil, errors.New("antigravity token provider not configured") + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Antigravity token provider not configured") } accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) if err != nil { - return nil, fmt.Errorf("获取 access_token 失败: %w", err) + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Failed to get upstream access token") } // 获取 project_id(部分账户类型可能没有) @@ -1309,7 +1646,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // Antigravity 上游要求必须包含身份提示词,注入到请求中 injectedBody, err := injectIdentityPatchToGeminiRequest(body) if err != nil { - return nil, err + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Invalid request body") } // 清理 Schema @@ -1323,29 +1660,46 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 包装请求 wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody) if err != nil { - return nil, err + return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build upstream request") } // Antigravity 上游只支持流式请求,统一使用 streamGenerateContent // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回 upstreamAction := "streamGenerateContent" + // 统计模型调用次数(包括粘性会话,用于负载均衡调度) + if s.cache != nil { + _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) + } + // 执行带重试的请求 - result, err := antigravityRetryLoop(antigravityRetryLoopParams{ - ctx: ctx, - prefix: prefix, - account: account, - proxyURL: proxyURL, - accessToken: accessToken, - action: upstreamAction, - body: wrappedBody, - quotaScope: quotaScope, - c: c, - httpUpstream: s.httpUpstream, - settingService: s.settingService, - handleError: s.handleUpstreamError, + result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: upstreamAction, + body: wrappedBody, + quotaScope: quotaScope, + c: c, + httpUpstream: s.httpUpstream, + settingService: s.settingService, + accountRepo: s.accountRepo, + handleError: s.handleUpstreamError, + requestedModel: originalModel, + isStickySession: isStickySession, // ForwardGemini 由上层判断粘性会话 + groupID: 0, // ForwardGemini 方法没有 groupID,由上层处理粘性会话清除 + sessionHash: "", // ForwardGemini 方法没有 sessionHash,由上层处理粘性会话清除 }) if err != nil { + // 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号 + if switchErr, ok := IsAntigravityAccountSwitchError(err); ok { + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusServiceUnavailable, + ForceCacheBilling: switchErr.IsStickySession, + } + } return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") } resp := result.resp @@ -1358,6 +1712,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 处理错误响应 if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + contentType := resp.Header.Get("Content-Type") // 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。 _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) @@ -1400,19 +1755,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if unwrapErr != nil || len(unwrappedForOps) == 0 { unwrappedForOps = respBody } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - - logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody - maxBytes := 2048 - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { - maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - } - upstreamDetail := "" - if logBody { - upstreamDetail = truncateString(string(unwrappedForOps), maxBytes) - } + upstreamDetail := s.getUpstreamErrorDetail(unwrappedForOps) // Always record upstream context for Ops error logs, even when we will failover. setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) @@ -1428,10 +1774,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps} } - - contentType := resp.Header.Get("Content-Type") if contentType == "" { contentType = "application/json" } @@ -1535,27 +1879,348 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { } } -func antigravityUseScopeRateLimit() bool { - v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv))) - // 默认开启按配额域限流,只有明确设置为禁用值时才关闭 - if v == "0" || v == "false" || v == "no" || v == "off" { +// setModelRateLimitByModelName 使用官方模型 ID 设置模型级限流 +// 直接使用上游返回的模型 ID(如 claude-sonnet-4-5)作为限流 key +// 返回是否已成功设置(若模型名为空或 repo 为 nil 将返回 false) +func setModelRateLimitByModelName(ctx context.Context, repo AccountRepository, accountID int64, modelName, prefix string, statusCode int, resetAt time.Time, afterSmartRetry bool) bool { + if repo == nil || modelName == "" { return false } + // 直接使用官方模型 ID 作为 key,不再转换为 scope + if err := repo.SetModelRateLimit(ctx, accountID, modelName, resetAt); err != nil { + log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err) + return false + } + if afterSmartRetry { + log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) + } else { + log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) + } return true } -func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { +func antigravityFallbackCooldownSeconds() (time.Duration, bool) { + raw := strings.TrimSpace(os.Getenv(antigravityFallbackSecondsEnv)) + if raw == "" { + return 0, false + } + seconds, err := strconv.Atoi(raw) + if err != nil || seconds <= 0 { + return 0, false + } + return time.Duration(seconds) * time.Second, true +} + +// antigravitySmartRetryInfo 智能重试所需的信息 +type antigravitySmartRetryInfo struct { + RetryDelay time.Duration // 重试延迟时间 + ModelName string // 限流的模型名称(如 "claude-sonnet-4-5") +} + +// parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息 +// 返回解析结果,如果解析失败或不满足条件返回 nil +// +// 支持两种情况: +// 1. 429 RESOURCE_EXHAUSTED + RATE_LIMIT_EXCEEDED: +// - error.status == "RESOURCE_EXHAUSTED" +// - error.details[].reason == "RATE_LIMIT_EXCEEDED" +// +// 2. 503 UNAVAILABLE + MODEL_CAPACITY_EXHAUSTED: +// - error.status == "UNAVAILABLE" +// - error.details[].reason == "MODEL_CAPACITY_EXHAUSTED" +// +// 必须满足以下条件才会返回有效值: +// - error.details[] 中存在 @type == "type.googleapis.com/google.rpc.RetryInfo" 的元素 +// - 该元素包含 retryDelay 字段,格式为 "数字s"(如 "0.201506475s") +func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo { + var parsed map[string]any + if err := json.Unmarshal(body, &parsed); err != nil { + return nil + } + + errObj, ok := parsed["error"].(map[string]any) + if !ok { + return nil + } + + // 检查 status 是否符合条件 + // 情况1: 429 RESOURCE_EXHAUSTED (需要进一步检查 reason == RATE_LIMIT_EXCEEDED) + // 情况2: 503 UNAVAILABLE (需要进一步检查 reason == MODEL_CAPACITY_EXHAUSTED) + status, _ := errObj["status"].(string) + isResourceExhausted := status == googleRPCStatusResourceExhausted + isUnavailable := status == googleRPCStatusUnavailable + + if !isResourceExhausted && !isUnavailable { + return nil + } + + details, ok := errObj["details"].([]any) + if !ok { + return nil + } + + var retryDelay time.Duration + var modelName string + var hasRateLimitExceeded bool // 429 需要此 reason + var hasModelCapacityExhausted bool // 503 需要此 reason + + for _, d := range details { + dm, ok := d.(map[string]any) + if !ok { + continue + } + + atType, _ := dm["@type"].(string) + + // 从 ErrorInfo 提取模型名称和 reason + if atType == googleRPCTypeErrorInfo { + if meta, ok := dm["metadata"].(map[string]any); ok { + if model, ok := meta["model"].(string); ok { + modelName = model + } + } + // 检查 reason + if reason, ok := dm["reason"].(string); ok { + if reason == googleRPCReasonModelCapacityExhausted { + hasModelCapacityExhausted = true + } + if reason == googleRPCReasonRateLimitExceeded { + hasRateLimitExceeded = true + } + } + continue + } + + // 从 RetryInfo 提取重试延迟 + if atType == googleRPCTypeRetryInfo { + delay, ok := dm["retryDelay"].(string) + if !ok || delay == "" { + continue + } + // 使用 time.ParseDuration 解析,支持所有 Go duration 格式 + // 例如: "0.5s", "10s", "4m50s", "1h30m", "200ms" 等 + dur, err := time.ParseDuration(delay) + if err != nil { + log.Printf("[Antigravity] failed to parse retryDelay: %s error=%v", delay, err) + continue + } + retryDelay = dur + } + } + + // 验证条件 + // 情况1: RESOURCE_EXHAUSTED 需要有 RATE_LIMIT_EXCEEDED reason + // 情况2: UNAVAILABLE 需要有 MODEL_CAPACITY_EXHAUSTED reason + if isResourceExhausted && !hasRateLimitExceeded { + return nil + } + if isUnavailable && !hasModelCapacityExhausted { + return nil + } + + // 必须有模型名才返回有效结果 + if modelName == "" { + return nil + } + + // 如果上游未提供 retryDelay,使用默认限流时间 + if retryDelay <= 0 { + retryDelay = antigravityDefaultRateLimitDuration + } + + return &antigravitySmartRetryInfo{ + RetryDelay: retryDelay, + ModelName: modelName, + } +} + +// shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试 +// 返回: +// - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold) +// - shouldRateLimitModel: 是否应该限流模型(retryDelay >= antigravityRateLimitThreshold) +// - waitDuration: 等待时间(智能重试时使用,shouldRateLimitModel=true 时为 0) +// - modelName: 限流的模型名称 +func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string) { + if account.Platform != PlatformAntigravity { + return false, false, 0, "" + } + + info := parseAntigravitySmartRetryInfo(respBody) + if info == nil { + return false, false, 0, "" + } + + // retryDelay >= 阈值:直接限流模型,不重试 + // 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 5 分钟 + if info.RetryDelay >= antigravityRateLimitThreshold { + return false, true, 0, info.ModelName + } + + // retryDelay < 阈值:智能重试 + waitDuration = info.RetryDelay + if waitDuration < antigravitySmartRetryMinWait { + waitDuration = antigravitySmartRetryMinWait + } + + return true, false, waitDuration, info.ModelName +} + +// handleModelRateLimitParams 模型级限流处理参数 +type handleModelRateLimitParams struct { + ctx context.Context + prefix string + account *Account + statusCode int + body []byte + cache GatewayCache + groupID int64 + sessionHash string + isStickySession bool +} + +// handleModelRateLimitResult 模型级限流处理结果 +type handleModelRateLimitResult struct { + Handled bool // 是否已处理 + ShouldRetry bool // 是否等待后重试 + WaitDuration time.Duration // 等待时间 + SwitchError *AntigravityAccountSwitchError // 账号切换错误 +} + +// handleModelRateLimit 处理模型级限流(在原有逻辑之前调用) +// 仅处理 429/503,解析模型名和 retryDelay +// - retryDelay < antigravityRateLimitThreshold: 返回 ShouldRetry=true,由调用方等待后重试 +// - retryDelay >= antigravityRateLimitThreshold: 设置模型限流 + 清除粘性会话 + 返回 SwitchError +func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult { + if p.statusCode != 429 && p.statusCode != 503 { + return &handleModelRateLimitResult{Handled: false} + } + + info := parseAntigravitySmartRetryInfo(p.body) + if info == nil || info.ModelName == "" { + return &handleModelRateLimitResult{Handled: false} + } + + // < antigravityRateLimitThreshold: 等待后重试 + if info.RetryDelay < antigravityRateLimitThreshold { + log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v", + p.prefix, p.statusCode, info.ModelName, info.RetryDelay) + return &handleModelRateLimitResult{ + Handled: true, + ShouldRetry: true, + WaitDuration: info.RetryDelay, + } + } + + // >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号 + s.setModelRateLimitAndClearSession(p, info) + + return &handleModelRateLimitResult{ + Handled: true, + SwitchError: &AntigravityAccountSwitchError{ + OriginalAccountID: p.account.ID, + RateLimitedModel: info.ModelName, + IsStickySession: p.isStickySession, + }, + } +} + +// setModelRateLimitAndClearSession 设置模型限流并清除粘性会话 +func (s *AntigravityGatewayService) setModelRateLimitAndClearSession(p *handleModelRateLimitParams, info *antigravitySmartRetryInfo) { + resetAt := time.Now().Add(info.RetryDelay) + log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", + p.prefix, p.statusCode, info.ModelName, p.account.ID, info.RetryDelay) + + // 设置模型限流状态(数据库) + if err := s.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, info.ModelName, resetAt); err != nil { + log.Printf("%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err) + } + + // 立即更新 Redis 快照中账号的限流状态,避免并发请求重复选中 + s.updateAccountModelRateLimitInCache(p.ctx, p.account, info.ModelName, resetAt) + + // 清除粘性会话绑定 + if p.cache != nil && p.sessionHash != "" { + _ = p.cache.DeleteSessionAccountID(p.ctx, p.groupID, p.sessionHash) + } +} + +// updateAccountModelRateLimitInCache 立即更新 Redis 中账号的模型限流状态 +func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx context.Context, account *Account, modelKey string, resetAt time.Time) { + if s.schedulerSnapshot == nil || account == nil || modelKey == "" { + return + } + + // 更新账号对象的 Extra 字段 + if account.Extra == nil { + account.Extra = make(map[string]any) + } + + limits, _ := account.Extra["model_rate_limits"].(map[string]any) + if limits == nil { + limits = make(map[string]any) + account.Extra["model_rate_limits"] = limits + } + + limits[modelKey] = map[string]any{ + "rate_limited_at": time.Now().UTC().Format(time.RFC3339), + "rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339), + } + + // 更新 Redis 快照 + if err := s.schedulerSnapshot.UpdateAccountInCache(ctx, account); err != nil { + log.Printf("[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err) + } +} + +func (s *AntigravityGatewayService) handleUpstreamError( + ctx context.Context, prefix string, account *Account, + statusCode int, headers http.Header, body []byte, + quotaScope AntigravityQuotaScope, + groupID int64, sessionHash string, isStickySession bool, +) *handleModelRateLimitResult { + // ✨ 模型级限流处理(在原有逻辑之前) + result := s.handleModelRateLimit(&handleModelRateLimitParams{ + ctx: ctx, + prefix: prefix, + account: account, + statusCode: statusCode, + body: body, + cache: s.cache, + groupID: groupID, + sessionHash: sessionHash, + isStickySession: isStickySession, + }) + if result.Handled { + return result + } + + // 503 仅处理模型限流(MODEL_CAPACITY_EXHAUSTED),非模型限流不做额外处理 + // 避免将普通的 503 错误误判为账号问题 + if statusCode == 503 { + return nil + } + + // ========== 原有逻辑,保持不变 ========== // 429 使用 Gemini 格式解析(从 body 解析重置时间) if statusCode == 429 { - useScopeLimit := antigravityUseScopeRateLimit() && quotaScope != "" + // 调试日志遵循统一日志开关与长度限制,避免无条件记录完整上游响应体。 + if logBody, maxBytes := s.getLogConfig(); logBody { + log.Printf("[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes)) + } + + useScopeLimit := quotaScope != "" resetAt := ParseGeminiRateLimitResetTime(body) if resetAt == nil { - // 解析失败:使用配置的 fallback 时间,直接限流整个账户 - fallbackMinutes := 5 + // 解析失败:使用默认限流时间(与临时限流保持一致) + // 可通过配置或环境变量覆盖 + defaultDur := antigravityDefaultRateLimitDuration if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 { - fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes + defaultDur = time.Duration(s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes) * time.Minute + } + // 秒级环境变量优先级最高 + if override, ok := antigravityFallbackCooldownSeconds(); ok { + defaultDur = override } - defaultDur := time.Duration(fallbackMinutes) * time.Minute ra := time.Now().Add(defaultDur) if useScopeLimit { log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) @@ -1568,7 +2233,7 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) } } - return + return nil } resetTime := time.Unix(*resetAt, 0) if useScopeLimit { @@ -1582,16 +2247,17 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) } } - return + return nil } // 其他错误码继续使用 rateLimitService if s.rateLimitService == nil { - return + return nil } shouldDisable := s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) if shouldDisable { log.Printf("%s status=%d marked_error", prefix, statusCode) } + return nil } type antigravityStreamResult struct { @@ -2122,20 +2788,16 @@ func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, return fmt.Errorf("%s", message) } +// WriteMappedClaudeError 导出版本,供 handler 层使用(如 fallback 错误处理) +func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error { + return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body) +} + func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error { upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - - logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody - maxBytes := 2048 - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { - maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - } - - upstreamDetail := "" - if logBody { - upstreamDetail = truncateString(string(body), maxBytes) - } + logBody, maxBytes := s.getLogConfig() + upstreamDetail := s.getUpstreamErrorDetail(body) setOpsUpstreamError(c, upstreamStatus, upstreamMsg, upstreamDetail) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, @@ -2160,7 +2822,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou case 400: statusCode = http.StatusBadRequest errType = "invalid_request_error" - errMsg = "Invalid request" + errMsg = getPassthroughOrDefault(upstreamMsg, "Invalid request") case 401: statusCode = http.StatusBadGateway errType = "authentication_error" @@ -2618,3 +3280,55 @@ func cleanGeminiRequest(body []byte) ([]byte, error) { return json.Marshal(payload) } + +// filterEmptyPartsFromGeminiRequest 过滤掉 parts 为空的消息 +// Gemini API 不接受空 parts,需要在请求前过滤 +func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return nil, err + } + + contents, ok := payload["contents"].([]any) + if !ok || len(contents) == 0 { + return body, nil + } + + filtered := make([]any, 0, len(contents)) + modified := false + + for _, c := range contents { + contentMap, ok := c.(map[string]any) + if !ok { + filtered = append(filtered, c) + continue + } + + parts, hasParts := contentMap["parts"] + if !hasParts { + filtered = append(filtered, c) + continue + } + + partsSlice, ok := parts.([]any) + if !ok { + filtered = append(filtered, c) + continue + } + + // 跳过 parts 为空数组的消息 + if len(partsSlice) == 0 { + modified = true + continue + } + + filtered = append(filtered, c) + } + + if !modified { + return body, nil + } + + payload["contents"] = filtered + return json.Marshal(payload) +} diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 05ad9bbd..91cefc28 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -1,10 +1,17 @@ package service import ( + "bytes" + "context" "encoding/json" + "io" + "net/http" + "net/http/httptest" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -81,3 +88,306 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) { require.Equal(t, "secret plan", blocks[0]["text"]) require.Equal(t, "tool_use", blocks[1]["type"]) } + +func TestIsPromptTooLongError(t *testing.T) { + require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`))) + require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`))) + require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`))) +} + +type httpUpstreamStub struct { + resp *http.Response + err error +} + +func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return s.resp, s.err +} + +func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { + return s.resp, s.err +} + +func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "model": "claude-opus-4-6", + "messages": []map[string]any{ + {"role": "user", "content": "hi"}, + }, + "max_tokens": 1, + "stream": false, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request = req + + respBody := []byte(`{"error":{"message":"Prompt is too long"}}`) + resp := &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{"X-Request-Id": []string{"req-1"}}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, + } + + account := &Account{ + ID: 1, + Name: "acc-1", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + }, + } + + result, err := svc.Forward(context.Background(), c, account, body, false) + require.Nil(t, result) + + var promptErr *PromptTooLongError + require.ErrorAs(t, err, &promptErr) + require.Equal(t, http.StatusBadRequest, promptErr.StatusCode) + require.Equal(t, "req-1", promptErr.RequestID) + require.NotEmpty(t, promptErr.Body) + + raw, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := raw.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, "prompt_too_long", events[0].Kind) +} + +// TestAntigravityGatewayService_Forward_ModelRateLimitTriggersFailover +// 验证:当账号存在模型限流且剩余时间 >= antigravityRateLimitThreshold 时, +// Forward 方法应返回 UpstreamFailoverError,触发 Handler 切换账号 +func TestAntigravityGatewayService_Forward_ModelRateLimitTriggersFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "model": "claude-opus-4-6", + "messages": []map[string]any{ + {"role": "user", "content": "hi"}, + }, + "max_tokens": 1, + "stream": false, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request = req + + // 不需要真正调用上游,因为预检查会直接返回切换信号 + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: nil, err: nil}, + } + + // 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s) + futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339) + account := &Account{ + ID: 1, + Name: "acc-rate-limited", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + }, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-opus-4-6-thinking": map[string]any{ + "rate_limit_reset_at": futureResetAt, + }, + }, + }, + } + + result, err := svc.Forward(context.Background(), c, account, body, false) + require.Nil(t, result, "Forward should not return result when model rate limited") + require.NotNil(t, err, "Forward should return error") + + // 核心验证:错误应该是 UpstreamFailoverError,而不是普通 502 错误 + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch") + require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode) + // 非粘性会话请求,ForceCacheBilling 应为 false + require.False(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be false for non-sticky session") +} + +// TestAntigravityGatewayService_ForwardGemini_ModelRateLimitTriggersFailover +// 验证:ForwardGemini 方法同样能正确将 AntigravityAccountSwitchError 转换为 UpstreamFailoverError +func TestAntigravityGatewayService_ForwardGemini_ModelRateLimitTriggersFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + c.Request = req + + // 不需要真正调用上游,因为预检查会直接返回切换信号 + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: nil, err: nil}, + } + + // 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s) + futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339) + account := &Account{ + ID: 2, + Name: "acc-gemini-rate-limited", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + }, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "gemini-2.5-flash": map[string]any{ + "rate_limit_reset_at": futureResetAt, + }, + }, + }, + } + + result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, false) + require.Nil(t, result, "ForwardGemini should not return result when model rate limited") + require.NotNil(t, err, "ForwardGemini should return error") + + // 核心验证:错误应该是 UpstreamFailoverError,而不是普通 502 错误 + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch") + require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode) + // 非粘性会话请求,ForceCacheBilling 应为 false + require.False(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be false for non-sticky session") +} + +// TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling +// 验证:粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true +func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "model": "claude-opus-4-6", + "messages": []map[string]string{{"role": "user", "content": "hello"}}, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request = req + + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: nil, err: nil}, + } + + // 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s) + futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339) + account := &Account{ + ID: 3, + Name: "acc-sticky-rate-limited", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + }, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-opus-4-6-thinking": map[string]any{ + "rate_limit_reset_at": futureResetAt, + }, + }, + }, + } + + // 传入 isStickySession = true + result, err := svc.Forward(context.Background(), c, account, body, true) + require.Nil(t, result, "Forward should not return result when model rate limited") + require.NotNil(t, err, "Forward should return error") + + // 核心验证:粘性会话切换时,ForceCacheBilling 应为 true + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch") + require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode) + require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") +} + +// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling +// 验证:ForwardGemini 粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true +func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + c.Request = req + + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: nil, err: nil}, + } + + // 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s) + futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339) + account := &Account{ + ID: 4, + Name: "acc-gemini-sticky-rate-limited", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + }, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "gemini-2.5-flash": map[string]any{ + "rate_limit_reset_at": futureResetAt, + }, + }, + }, + } + + // 传入 isStickySession = true + result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, true) + require.Nil(t, result, "ForwardGemini should not return result when model rate limited") + require.NotNil(t, err, "ForwardGemini should return error") + + // 核心验证:粘性会话切换时,ForceCacheBilling 应为 true + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch") + require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode) + require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") +} diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index e269103a..f3621555 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -8,53 +8,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestIsAntigravityModelSupported(t *testing.T) { - tests := []struct { - name string - model string - expected bool - }{ - // 直接支持的模型 - {"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true}, - {"直接支持 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true}, - {"直接支持 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true}, - {"直接支持 - gemini-2.5-flash", "gemini-2.5-flash", true}, - {"直接支持 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true}, - {"直接支持 - gemini-3-pro-high", "gemini-3-pro-high", true}, - - // 可映射的模型 - {"可映射 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", true}, - {"可映射 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", true}, - {"可映射 - claude-opus-4", "claude-opus-4", true}, - {"可映射 - claude-haiku-4", "claude-haiku-4", true}, - {"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true}, - - // Gemini 前缀透传 - {"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true}, - {"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true}, - {"Gemini前缀 - gemini-future-version", "gemini-future-version", true}, - - // Claude 前缀兜底 - {"Claude前缀 - claude-unknown-model", "claude-unknown-model", true}, - {"Claude前缀 - claude-3-opus-20240229", "claude-3-opus-20240229", true}, - {"Claude前缀 - claude-future-version", "claude-future-version", true}, - - // 不支持的模型 - {"不支持 - gpt-4", "gpt-4", false}, - {"不支持 - gpt-4o", "gpt-4o", false}, - {"不支持 - llama-3", "llama-3", false}, - {"不支持 - mistral-7b", "mistral-7b", false}, - {"不支持 - 空字符串", "", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := IsAntigravityModelSupported(tt.model) - require.Equal(t, tt.expected, got, "model: %s", tt.model) - }) - } -} - func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { svc := &AntigravityGatewayService{} @@ -64,7 +17,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { accountMapping map[string]string expected string }{ - // 1. 账户级映射优先(注意:model_mapping 在 credentials 中存储为 map[string]any) + // 1. 账户级映射优先 { name: "账户映射优先", requestedModel: "claude-3-5-sonnet-20241022", @@ -72,120 +25,124 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { expected: "custom-model", }, { - name: "账户映射覆盖系统映射", + name: "账户映射 - 可覆盖默认映射的模型", + requestedModel: "claude-sonnet-4-5", + accountMapping: map[string]string{"claude-sonnet-4-5": "my-custom-sonnet"}, + expected: "my-custom-sonnet", + }, + { + name: "账户映射 - 可覆盖未知模型", requestedModel: "claude-opus-4", accountMapping: map[string]string{"claude-opus-4": "my-opus"}, expected: "my-opus", }, - // 2. 系统默认映射 + // 2. 默认映射(DefaultAntigravityModelMapping) { - name: "系统映射 - claude-3-5-sonnet-20241022", - requestedModel: "claude-3-5-sonnet-20241022", + name: "默认映射 - claude-opus-4-6 → claude-opus-4-6-thinking", + requestedModel: "claude-opus-4-6", accountMapping: nil, - expected: "claude-sonnet-4-5", + expected: "claude-opus-4-6-thinking", }, { - name: "系统映射 - claude-3-5-sonnet-20240620", - requestedModel: "claude-3-5-sonnet-20240620", - accountMapping: nil, - expected: "claude-sonnet-4-5", - }, - { - name: "系统映射 - claude-opus-4", - requestedModel: "claude-opus-4", - accountMapping: nil, - expected: "claude-opus-4-5-thinking", - }, - { - name: "系统映射 - claude-opus-4-5-20251101", + name: "默认映射 - claude-opus-4-5-20251101 → claude-opus-4-6-thinking", requestedModel: "claude-opus-4-5-20251101", accountMapping: nil, - expected: "claude-opus-4-5-thinking", + expected: "claude-opus-4-6-thinking", }, { - name: "系统映射 - claude-haiku-4 → claude-sonnet-4-5", - requestedModel: "claude-haiku-4", + name: "默认映射 - claude-opus-4-5-thinking → claude-opus-4-6-thinking", + requestedModel: "claude-opus-4-5-thinking", accountMapping: nil, - expected: "claude-sonnet-4-5", + expected: "claude-opus-4-6-thinking", }, { - name: "系统映射 - claude-haiku-4-5 → claude-sonnet-4-5", + name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5", requestedModel: "claude-haiku-4-5", accountMapping: nil, expected: "claude-sonnet-4-5", }, { - name: "系统映射 - claude-3-haiku-20240307 → claude-sonnet-4-5", - requestedModel: "claude-3-haiku-20240307", - accountMapping: nil, - expected: "claude-sonnet-4-5", - }, - { - name: "系统映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5", + name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5", requestedModel: "claude-haiku-4-5-20251001", accountMapping: nil, expected: "claude-sonnet-4-5", }, { - name: "系统映射 - claude-sonnet-4-5-20250929", + name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5-20250929", accountMapping: nil, expected: "claude-sonnet-4-5", }, - // 3. Gemini 2.5 → 3 映射 + // 3. 默认映射中的透传(映射到自己) { - name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash", - requestedModel: "gemini-2.5-flash", - accountMapping: nil, - expected: "gemini-3-flash", - }, - { - name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high", - requestedModel: "gemini-2.5-pro", - accountMapping: nil, - expected: "gemini-3-pro-high", - }, - { - name: "Gemini透传 - gemini-future-model", - requestedModel: "gemini-future-model", - accountMapping: nil, - expected: "gemini-future-model", - }, - - // 4. 直接支持的模型 - { - name: "直接支持 - claude-sonnet-4-5", + name: "默认映射透传 - claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5", accountMapping: nil, expected: "claude-sonnet-4-5", }, { - name: "直接支持 - claude-opus-4-5-thinking", - requestedModel: "claude-opus-4-5-thinking", + name: "默认映射透传 - claude-opus-4-6-thinking", + requestedModel: "claude-opus-4-6-thinking", accountMapping: nil, - expected: "claude-opus-4-5-thinking", + expected: "claude-opus-4-6-thinking", }, { - name: "直接支持 - claude-sonnet-4-5-thinking", + name: "默认映射透传 - claude-sonnet-4-5-thinking", requestedModel: "claude-sonnet-4-5-thinking", accountMapping: nil, expected: "claude-sonnet-4-5-thinking", }, - - // 5. 默认值 fallback(未知 claude 模型) { - name: "默认值 - claude-unknown", - requestedModel: "claude-unknown", + name: "默认映射透传 - gemini-2.5-flash", + requestedModel: "gemini-2.5-flash", accountMapping: nil, - expected: "claude-sonnet-4-5", + expected: "gemini-2.5-flash", }, { - name: "默认值 - claude-3-opus-20240229", + name: "默认映射透传 - gemini-2.5-pro", + requestedModel: "gemini-2.5-pro", + accountMapping: nil, + expected: "gemini-2.5-pro", + }, + { + name: "默认映射透传 - gemini-3-flash", + requestedModel: "gemini-3-flash", + accountMapping: nil, + expected: "gemini-3-flash", + }, + + // 4. 未在默认映射中的模型返回空字符串(不支持) + { + name: "未知模型 - claude-unknown 返回空", + requestedModel: "claude-unknown", + accountMapping: nil, + expected: "", + }, + { + name: "未知模型 - claude-3-5-sonnet-20241022 返回空(未在默认映射)", + requestedModel: "claude-3-5-sonnet-20241022", + accountMapping: nil, + expected: "", + }, + { + name: "未知模型 - claude-3-opus-20240229 返回空", requestedModel: "claude-3-opus-20240229", accountMapping: nil, - expected: "claude-sonnet-4-5", + expected: "", + }, + { + name: "未知模型 - claude-opus-4 返回空", + requestedModel: "claude-opus-4", + accountMapping: nil, + expected: "", + }, + { + name: "未知模型 - gemini-future-model 返回空", + requestedModel: "gemini-future-model", + accountMapping: nil, + expected: "", }, } @@ -219,12 +176,10 @@ func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) { requestedModel string expected string }{ - // 空字符串回退到默认值 - {"空字符串", "", "claude-sonnet-4-5"}, - - // 非 claude/gemini 前缀回退到默认值 - {"非claude/gemini前缀 - gpt", "gpt-4", "claude-sonnet-4-5"}, - {"非claude/gemini前缀 - llama", "llama-3", "claude-sonnet-4-5"}, + // 空字符串和非 claude/gemini 前缀返回空字符串 + {"空字符串", "", ""}, + {"非claude/gemini前缀 - gpt", "gpt-4", ""}, + {"非claude/gemini前缀 - llama", "llama-3", ""}, } for _, tt := range tests { @@ -248,10 +203,10 @@ func TestAntigravityGatewayService_IsModelSupported(t *testing.T) { {"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true}, {"直接支持 - gemini-3-flash", "gemini-3-flash", true}, - // 可映射 - {"可映射 - claude-opus-4", "claude-opus-4", true}, + // 可映射(有明确前缀映射) + {"可映射 - claude-opus-4-6", "claude-opus-4-6", true}, - // 前缀透传 + // 前缀透传(claude 和 gemini 前缀) {"Gemini前缀", "gemini-unknown", true}, {"Claude前缀", "claude-unknown", true}, @@ -267,3 +222,58 @@ func TestAntigravityGatewayService_IsModelSupported(t *testing.T) { }) } } + +// TestMapAntigravityModel_WildcardTargetEqualsRequest 测试通配符映射目标恰好等于请求模型名的 edge case +// 例如 {"claude-*": "claude-sonnet-4-5"},请求 "claude-sonnet-4-5" 时应该通过 +func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) { + tests := []struct { + name string + modelMapping map[string]any + requestedModel string + expected string + }{ + { + name: "wildcard target equals request model", + modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"}, + requestedModel: "claude-sonnet-4-5", + expected: "claude-sonnet-4-5", + }, + { + name: "wildcard target differs from request model", + modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"}, + requestedModel: "claude-opus-4-6", + expected: "claude-sonnet-4-5", + }, + { + name: "wildcard no match", + modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"}, + requestedModel: "gpt-4o", + expected: "", + }, + { + name: "explicit passthrough same name", + modelMapping: map[string]any{"claude-sonnet-4-5": "claude-sonnet-4-5"}, + requestedModel: "claude-sonnet-4-5", + expected: "claude-sonnet-4-5", + }, + { + name: "multiple wildcards target equals one request", + modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5", "gemini-*": "gemini-2.5-flash"}, + requestedModel: "gemini-2.5-flash", + expected: "gemini-2.5-flash", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "model_mapping": tt.modelMapping, + }, + } + got := mapAntigravityModel(account, tt.requestedModel) + require.Equal(t, tt.expected, got, "mapAntigravityModel(%q) = %q, want %q", tt.requestedModel, got, tt.expected) + }) + } +} diff --git a/backend/internal/service/antigravity_quota_scope.go b/backend/internal/service/antigravity_quota_scope.go index 34cd9a4c..43ac6c2f 100644 --- a/backend/internal/service/antigravity_quota_scope.go +++ b/backend/internal/service/antigravity_quota_scope.go @@ -1,6 +1,8 @@ package service import ( + "context" + "slices" "strings" "time" ) @@ -16,6 +18,21 @@ const ( AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image" ) +// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中 +func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool { + if len(supportedScopes) == 0 { + // 未配置时默认全部支持 + return true + } + supported := slices.Contains(supportedScopes, string(scope)) + return supported +} + +// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本) +func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { + return resolveAntigravityQuotaScope(requestedModel) +} + // resolveAntigravityQuotaScope 根据模型名称解析配额域 func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { model := normalizeAntigravityModelName(requestedModel) @@ -41,15 +58,20 @@ func normalizeAntigravityModelName(model string) string { return normalized } -// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度 +// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。 +// 保持旧签名以兼容既有调用方;默认使用 context.Background()。 func (a *Account) IsSchedulableForModel(requestedModel string) bool { + return a.IsSchedulableForModelWithContext(context.Background(), requestedModel) +} + +func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requestedModel string) bool { if a == nil { return false } if !a.IsSchedulable() { return false } - if a.isModelRateLimited(requestedModel) { + if a.isModelRateLimitedWithContext(ctx, requestedModel) { return false } if a.Platform != PlatformAntigravity { @@ -116,3 +138,43 @@ func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 { } return result } + +// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间 +// 返回 0 表示未限流或已过期 +func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration { + if a == nil || a.Platform != PlatformAntigravity { + return 0 + } + scope, ok := resolveAntigravityQuotaScope(requestedModel) + if !ok { + return 0 + } + resetAt := a.antigravityQuotaScopeResetAt(scope) + if resetAt == nil { + return 0 + } + if remaining := time.Until(*resetAt); remaining > 0 { + return remaining + } + return 0 +} + +// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值) +// 返回 0 表示未限流或已过期 +func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration { + return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel) +} + +// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值) +// 返回 0 表示未限流或已过期 +func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration { + if a == nil { + return 0 + } + modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel) + scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel) + if modelRemaining > scopeRemaining { + return modelRemaining + } + return scopeRemaining +} diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index 9535948c..20936356 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -21,6 +21,23 @@ type stubAntigravityUpstream struct { calls []string } +type recordingOKUpstream struct { + calls int +} + +func (r *recordingOKUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + r.calls++ + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader("ok")), + }, nil +} + +func (r *recordingOKUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return r.Do(req, proxyURL, accountID, accountConcurrency) +} + func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { url := req.URL.String() s.calls = append(s.calls, url) @@ -53,10 +70,17 @@ type rateLimitCall struct { resetAt time.Time } +type modelRateLimitCall struct { + accountID int64 + modelKey string // 存储的 key(应该是官方模型 ID,如 "claude-sonnet-4-5") + resetAt time.Time +} + type stubAntigravityAccountRepo struct { AccountRepository - scopeCalls []scopeLimitCall - rateCalls []rateLimitCall + scopeCalls []scopeLimitCall + rateCalls []rateLimitCall + modelRateLimitCalls []modelRateLimitCall } func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { @@ -69,6 +93,11 @@ func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int6 return nil } +func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error { + s.modelRateLimitCalls = append(s.modelRateLimitCalls, modelRateLimitCall{accountID: id, modelKey: modelKey, resetAt: resetAt}) + return nil +} + func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) oldAvailability := antigravity.DefaultURLAvailability @@ -93,18 +122,21 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { } var handleErrorCalled bool - result, err := antigravityRetryLoop(antigravityRetryLoopParams{ - prefix: "[test]", - ctx: context.Background(), - account: account, - proxyURL: "", - accessToken: "token", - action: "generateContent", - body: []byte(`{"input":"test"}`), - quotaScope: AntigravityQuotaScopeClaude, - httpUpstream: upstream, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { + svc := &AntigravityGatewayService{} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + prefix: "[test]", + ctx: context.Background(), + account: account, + proxyURL: "", + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + quotaScope: AntigravityQuotaScopeClaude, + httpUpstream: upstream, + requestedModel: "claude-sonnet-4-5", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { handleErrorCalled = true + return nil }, }) @@ -123,14 +155,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { require.Equal(t, base2, available[0]) } -func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) { - t.Setenv(antigravityScopeRateLimitEnv, "true") +func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) { + // 分区限流始终开启,不再支持通过环境变量关闭 repo := &stubAntigravityAccountRepo{} svc := &AntigravityGatewayService{accountRepo: repo} account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity} body := buildGeminiRateLimitBody("3s") - svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude) + svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false) require.Len(t, repo.scopeCalls, 1) require.Empty(t, repo.rateCalls) @@ -140,20 +172,122 @@ func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second) } -func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) { - t.Setenv(antigravityScopeRateLimitEnv, "false") +// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景 +func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) { repo := &stubAntigravityAccountRepo{} svc := &AntigravityGatewayService{accountRepo: repo} - account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity} + account := &Account{ID: 1, Name: "acc-1", Platform: PlatformAntigravity} - body := buildGeminiRateLimitBody("2s") - svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude) + // 429 + RATE_LIMIT_EXCEEDED + 模型名 → 模型限流 + body := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"} + ] + } + }`) - require.Len(t, repo.rateCalls, 1) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false) + + // 应该触发模型限流 + require.NotNil(t, result) + require.True(t, result.Handled) + require.NotNil(t, result.SwitchError) + require.Equal(t, "claude-sonnet-4-5", result.SwitchError.RateLimitedModel) + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) +} + +// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流) +func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity} + + // 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ scope 限流 + body := buildGeminiRateLimitBody("5s") + + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false) + + // 不应该触发模型限流,应该走 scope 限流 + require.Nil(t, result) + require.Empty(t, repo.modelRateLimitCalls) + require.Len(t, repo.scopeCalls, 1) + require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope) +} + +// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景 +func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity} + + // 503 + MODEL_CAPACITY_EXHAUSTED → 模型限流 + body := []byte(`{ + "error": { + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "30s"} + ] + } + }`) + + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false) + + // 应该触发模型限流 + require.NotNil(t, result) + require.True(t, result.Handled) + require.NotNil(t, result.SwitchError) + require.Equal(t, "gemini-3-pro-high", result.SwitchError.RateLimitedModel) + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey) +} + +// TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理) +func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 4, Name: "acc-4", Platform: PlatformAntigravity} + + // 503 + 普通错误(非 MODEL_CAPACITY_EXHAUSTED)→ 不做任何处理 + body := []byte(`{ + "error": { + "status": "UNAVAILABLE", + "message": "Service temporarily unavailable", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "SERVICE_UNAVAILABLE"} + ] + } + }`) + + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false) + + // 503 非模型限流不应该做任何处理 + require.Nil(t, result) + require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit") + require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit") + require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit") +} + +// TestHandleUpstreamError_503_EmptyBody 测试 503 空响应体(不处理) +func TestHandleUpstreamError_503_EmptyBody(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 5, Name: "acc-5", Platform: PlatformAntigravity} + + // 503 + 空响应体 → 不做任何处理 + body := []byte(`{}`) + + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false) + + // 503 空响应不应该做任何处理 + require.Nil(t, result) + require.Empty(t, repo.modelRateLimitCalls) require.Empty(t, repo.scopeCalls) - call := repo.rateCalls[0] - require.Equal(t, account.ID, call.accountID) - require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second) + require.Empty(t, repo.rateCalls) } func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) { @@ -188,3 +322,771 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) { func buildGeminiRateLimitBody(delay string) []byte { return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay)) } + +func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) { + // Avoid flakiness around Unix second boundaries. + for { + now := time.Now() + if now.Nanosecond() < 800*1e6 { + break + } + time.Sleep(5 * time.Millisecond) + } + + baseUnix := time.Now().Unix() + ts := ParseGeminiRateLimitResetTime(buildGeminiRateLimitBody("0.1s")) + require.NotNil(t, ts) + require.Equal(t, baseUnix+1, *ts, "fractional seconds should be rounded up to the next second") +} + +func TestParseAntigravitySmartRetryInfo(t *testing.T) { + tests := []struct { + name string + body string + expectedDelay time.Duration + expectedModel string + expectedNil bool + }{ + { + name: "valid complete response with RATE_LIMIT_EXCEEDED", + body: `{ + "error": { + "code": 429, + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "domain": "cloudcode-pa.googleapis.com", + "metadata": { + "model": "claude-sonnet-4-5", + "quotaResetDelay": "201.506475ms" + }, + "reason": "RATE_LIMIT_EXCEEDED" + }, + { + "@type": "type.googleapis.com/google.rpc.RetryInfo", + "retryDelay": "0.201506475s" + } + ], + "message": "You have exhausted your capacity on this model.", + "status": "RESOURCE_EXHAUSTED" + } + }`, + expectedDelay: 201506475 * time.Nanosecond, + expectedModel: "claude-sonnet-4-5", + }, + { + name: "429 RESOURCE_EXHAUSTED without RATE_LIMIT_EXCEEDED - should return nil", + body: `{ + "error": { + "code": 429, + "status": "RESOURCE_EXHAUSTED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "metadata": {"model": "claude-sonnet-4-5"}, + "reason": "QUOTA_EXCEEDED" + }, + { + "@type": "type.googleapis.com/google.rpc.RetryInfo", + "retryDelay": "3s" + } + ] + } + }`, + expectedNil: true, + }, + { + name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - long delay", + body: `{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"} + ], + "message": "No capacity available for model gemini-3-pro-high on the server" + } + }`, + expectedDelay: 39 * time.Second, + expectedModel: "gemini-3-pro-high", + }, + { + name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil", + body: `{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-pro"}, "reason": "SERVICE_UNAVAILABLE"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "5s"} + ] + } + }`, + expectedNil: true, + }, + { + name: "wrong status - should return nil", + body: `{ + "error": { + "code": 429, + "status": "INVALID_ARGUMENT", + "details": [ + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"} + ] + } + }`, + expectedNil: true, + }, + { + name: "missing status - should return nil", + body: `{ + "error": { + "code": 429, + "details": [ + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"} + ] + } + }`, + expectedNil: true, + }, + { + name: "milliseconds format is now supported", + body: `{ + "error": { + "code": 429, + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "test-model"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "500ms"} + ] + } + }`, + expectedDelay: 500 * time.Millisecond, + expectedModel: "test-model", + }, + { + name: "minutes format is supported", + body: `{ + "error": { + "code": 429, + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "4m50s"} + ] + } + }`, + expectedDelay: 4*time.Minute + 50*time.Second, + expectedModel: "gemini-3-pro", + }, + { + name: "missing model name - should return nil", + body: `{ + "error": { + "code": 429, + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"} + ] + } + }`, + expectedNil: true, + }, + { + name: "invalid JSON", + body: `not json`, + expectedNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseAntigravitySmartRetryInfo([]byte(tt.body)) + if tt.expectedNil { + if result != nil { + t.Errorf("expected nil, got %+v", result) + } + return + } + if result == nil { + t.Errorf("expected non-nil result") + return + } + if result.RetryDelay != tt.expectedDelay { + t.Errorf("RetryDelay = %v, want %v", result.RetryDelay, tt.expectedDelay) + } + if result.ModelName != tt.expectedModel { + t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel) + } + }) + } +} + +func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { + oauthAccount := &Account{Type: AccountTypeOAuth, Platform: PlatformAntigravity} + setupTokenAccount := &Account{Type: AccountTypeSetupToken, Platform: PlatformAntigravity} + upstreamAccount := &Account{Type: AccountTypeUpstream, Platform: PlatformAntigravity} + apiKeyAccount := &Account{Type: AccountTypeAPIKey} + + tests := []struct { + name string + account *Account + body string + expectedShouldRetry bool + expectedShouldRateLimit bool + minWait time.Duration + modelName string + }{ + { + name: "OAuth account with short delay (< 7s) - smart retry", + account: oauthAccount, + body: `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} + ] + } + }`, + expectedShouldRetry: true, + expectedShouldRateLimit: false, + minWait: 1 * time.Second, // 0.5s < 1s, 使用最小等待时间 1s + modelName: "claude-opus-4", + }, + { + name: "SetupToken account with short delay - smart retry", + account: setupTokenAccount, + body: `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"} + ] + } + }`, + expectedShouldRetry: true, + expectedShouldRateLimit: false, + minWait: 3 * time.Second, + modelName: "gemini-3-flash", + }, + { + name: "OAuth account with long delay (>= 7s) - direct rate limit", + account: oauthAccount, + body: `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"} + ] + } + }`, + expectedShouldRetry: false, + expectedShouldRateLimit: true, + modelName: "claude-sonnet-4-5", + }, + { + name: "Upstream account with short delay - smart retry", + account: upstreamAccount, + body: `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "2s"} + ] + } + }`, + expectedShouldRetry: true, + expectedShouldRateLimit: false, + minWait: 2 * time.Second, + modelName: "claude-sonnet-4-5", + }, + { + name: "API Key account - should not trigger", + account: apiKeyAccount, + body: `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "test"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} + ] + } + }`, + expectedShouldRetry: false, + expectedShouldRateLimit: false, + }, + { + name: "OAuth account with exactly 7s delay - direct rate limit", + account: oauthAccount, + body: `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-pro"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "7s"} + ] + } + }`, + expectedShouldRetry: false, + expectedShouldRateLimit: true, + modelName: "gemini-pro", + }, + { + name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - long delay", + account: oauthAccount, + body: `{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"} + ] + } + }`, + expectedShouldRetry: false, + expectedShouldRateLimit: true, + modelName: "gemini-3-pro-high", + }, + { + name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use default rate limit", + account: oauthAccount, + body: `{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-2.5-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"} + ], + "message": "No capacity available for model gemini-2.5-flash on the server" + } + }`, + expectedShouldRetry: false, + expectedShouldRateLimit: true, + modelName: "gemini-2.5-flash", + }, + { + name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit", + account: oauthAccount, + body: `{ + "error": { + "code": 429, + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"} + ], + "message": "You have exhausted your capacity on this model." + } + }`, + expectedShouldRetry: false, + expectedShouldRateLimit: true, + modelName: "claude-sonnet-4-5", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldRetry, shouldRateLimit, wait, model := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body)) + if shouldRetry != tt.expectedShouldRetry { + t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry) + } + if shouldRateLimit != tt.expectedShouldRateLimit { + t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit) + } + if shouldRetry { + if wait < tt.minWait { + t.Errorf("wait = %v, want >= %v", wait, tt.minWait) + } + } + if (shouldRetry || shouldRateLimit) && model != tt.modelName { + t.Errorf("modelName = %q, want %q", model, tt.modelName) + } + }) + } +} + +// TestSetModelRateLimitByModelName_UsesOfficialModelID 验证写入端使用官方模型 ID +func TestSetModelRateLimitByModelName_UsesOfficialModelID(t *testing.T) { + tests := []struct { + name string + modelName string + expectedModelKey string + expectedSuccess bool + }{ + { + name: "claude-sonnet-4-5 should be stored as-is", + modelName: "claude-sonnet-4-5", + expectedModelKey: "claude-sonnet-4-5", + expectedSuccess: true, + }, + { + name: "gemini-3-pro-high should be stored as-is", + modelName: "gemini-3-pro-high", + expectedModelKey: "gemini-3-pro-high", + expectedSuccess: true, + }, + { + name: "gemini-3-flash should be stored as-is", + modelName: "gemini-3-flash", + expectedModelKey: "gemini-3-flash", + expectedSuccess: true, + }, + { + name: "empty model name should fail", + modelName: "", + expectedModelKey: "", + expectedSuccess: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + resetAt := time.Now().Add(30 * time.Second) + + success := setModelRateLimitByModelName( + context.Background(), + repo, + 123, // accountID + tt.modelName, + "[test]", + 429, + resetAt, + false, // afterSmartRetry + ) + + require.Equal(t, tt.expectedSuccess, success) + + if tt.expectedSuccess { + require.Len(t, repo.modelRateLimitCalls, 1) + call := repo.modelRateLimitCalls[0] + require.Equal(t, int64(123), call.accountID) + // 关键断言:存储的 key 应该是官方模型 ID,而不是 scope + require.Equal(t, tt.expectedModelKey, call.modelKey, "should store official model ID, not scope") + require.WithinDuration(t, resetAt, call.resetAt, time.Second) + } else { + require.Empty(t, repo.modelRateLimitCalls) + } + }) + } +} + +// TestSetModelRateLimitByModelName_NotConvertToScope 验证不会将模型名转换为 scope +func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + resetAt := time.Now().Add(30 * time.Second) + + // 调用 setModelRateLimitByModelName,传入官方模型 ID + success := setModelRateLimitByModelName( + context.Background(), + repo, + 456, + "claude-sonnet-4-5", // 官方模型 ID + "[test]", + 429, + resetAt, + true, // afterSmartRetry + ) + + require.True(t, success) + require.Len(t, repo.modelRateLimitCalls, 1) + + call := repo.modelRateLimitCalls[0] + // 关键断言:存储的应该是 "claude-sonnet-4-5",而不是 "claude_sonnet" + require.Equal(t, "claude-sonnet-4-5", call.modelKey, "should NOT convert to scope like claude_sonnet") + require.NotEqual(t, "claude_sonnet", call.modelKey, "should NOT be scope") +} + +func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testing.T) { + upstream := &recordingOKUpstream{} + account := &Account{ + ID: 1, + Name: "acc-1", + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + // RFC3339 here is second-precision; keep it safely in the future. + "rate_limit_reset_at": time.Now().Add(2 * time.Second).Format(time.RFC3339), + }, + }, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + + svc := &AntigravityGatewayService{} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + requestedModel: "claude-sonnet-4-5", + httpUpstream: upstream, + isStickySession: true, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.ErrorIs(t, err, context.DeadlineExceeded) + require.Nil(t, result) + require.Equal(t, 0, upstream.calls, "should not call upstream while waiting on pre-check") +} + +func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t *testing.T) { + upstream := &recordingOKUpstream{} + account := &Account{ + ID: 2, + Name: "acc-2", + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": time.Now().Add(11 * time.Second).Format(time.RFC3339), + }, + }, + }, + } + + svc := &AntigravityGatewayService{} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + requestedModel: "claude-sonnet-4-5", + httpUpstream: upstream, + isStickySession: true, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.Nil(t, result) + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr) + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel) + require.True(t, switchErr.IsStickySession) + require.Equal(t, 0, upstream.calls, "should not call upstream when switching on pre-check") +} + +func TestIsAntigravityAccountSwitchError(t *testing.T) { + tests := []struct { + name string + err error + expectedOK bool + expectedID int64 + expectedModel string + }{ + { + name: "nil error", + err: nil, + expectedOK: false, + }, + { + name: "generic error", + err: fmt.Errorf("some error"), + expectedOK: false, + }, + { + name: "account switch error", + err: &AntigravityAccountSwitchError{ + OriginalAccountID: 123, + RateLimitedModel: "claude-sonnet-4-5", + IsStickySession: true, + }, + expectedOK: true, + expectedID: 123, + expectedModel: "claude-sonnet-4-5", + }, + { + name: "wrapped account switch error", + err: fmt.Errorf("wrapped: %w", &AntigravityAccountSwitchError{ + OriginalAccountID: 456, + RateLimitedModel: "gemini-3-flash", + IsStickySession: false, + }), + expectedOK: true, + expectedID: 456, + expectedModel: "gemini-3-flash", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + switchErr, ok := IsAntigravityAccountSwitchError(tt.err) + require.Equal(t, tt.expectedOK, ok) + if tt.expectedOK { + require.NotNil(t, switchErr) + require.Equal(t, tt.expectedID, switchErr.OriginalAccountID) + require.Equal(t, tt.expectedModel, switchErr.RateLimitedModel) + } else { + require.Nil(t, switchErr) + } + }) + } +} + +func TestAntigravityAccountSwitchError_Error(t *testing.T) { + err := &AntigravityAccountSwitchError{ + OriginalAccountID: 789, + RateLimitedModel: "claude-opus-4-5", + IsStickySession: true, + } + msg := err.Error() + require.Contains(t, msg, "789") + require.Contains(t, msg, "claude-opus-4-5") +} + +// stubSchedulerCache 用于测试的 SchedulerCache 实现 +type stubSchedulerCache struct { + SchedulerCache + setAccountCalls []*Account + setAccountErr error +} + +func (s *stubSchedulerCache) SetAccount(ctx context.Context, account *Account) error { + s.setAccountCalls = append(s.setAccountCalls, account) + return s.setAccountErr +} + +// TestUpdateAccountModelRateLimitInCache_UpdatesExtraAndCallsCache 测试模型限流后更新缓存 +func TestUpdateAccountModelRateLimitInCache_UpdatesExtraAndCallsCache(t *testing.T) { + cache := &stubSchedulerCache{} + snapshotService := &SchedulerSnapshotService{cache: cache} + svc := &AntigravityGatewayService{ + schedulerSnapshot: snapshotService, + } + + account := &Account{ + ID: 100, + Name: "test-account", + Platform: PlatformAntigravity, + } + modelKey := "claude-sonnet-4-5" + resetAt := time.Now().Add(30 * time.Second) + + svc.updateAccountModelRateLimitInCache(context.Background(), account, modelKey, resetAt) + + // 验证 Extra 字段被正确更新 + require.NotNil(t, account.Extra) + limits, ok := account.Extra["model_rate_limits"].(map[string]any) + require.True(t, ok) + modelLimit, ok := limits[modelKey].(map[string]any) + require.True(t, ok) + require.NotEmpty(t, modelLimit["rate_limited_at"]) + require.NotEmpty(t, modelLimit["rate_limit_reset_at"]) + + // 验证 cache.SetAccount 被调用 + require.Len(t, cache.setAccountCalls, 1) + require.Equal(t, account.ID, cache.setAccountCalls[0].ID) +} + +// TestUpdateAccountModelRateLimitInCache_NilSchedulerSnapshot 测试 schedulerSnapshot 为 nil 时不 panic +func TestUpdateAccountModelRateLimitInCache_NilSchedulerSnapshot(t *testing.T) { + svc := &AntigravityGatewayService{ + schedulerSnapshot: nil, + } + + account := &Account{ID: 1, Name: "test"} + + // 不应 panic + svc.updateAccountModelRateLimitInCache(context.Background(), account, "claude-sonnet-4-5", time.Now().Add(30*time.Second)) + + // Extra 不应被更新(因为函数提前返回) + require.Nil(t, account.Extra) +} + +// TestUpdateAccountModelRateLimitInCache_PreservesExistingExtra 测试保留已有的 Extra 数据 +func TestUpdateAccountModelRateLimitInCache_PreservesExistingExtra(t *testing.T) { + cache := &stubSchedulerCache{} + snapshotService := &SchedulerSnapshotService{cache: cache} + svc := &AntigravityGatewayService{ + schedulerSnapshot: snapshotService, + } + + account := &Account{ + ID: 200, + Name: "test-account", + Platform: PlatformAntigravity, + Extra: map[string]any{ + "existing_key": "existing_value", + "model_rate_limits": map[string]any{ + "gemini-3-flash": map[string]any{ + "rate_limited_at": "2024-01-01T00:00:00Z", + "rate_limit_reset_at": "2024-01-01T00:05:00Z", + }, + }, + }, + } + + svc.updateAccountModelRateLimitInCache(context.Background(), account, "claude-sonnet-4-5", time.Now().Add(30*time.Second)) + + // 验证已有数据被保留 + require.Equal(t, "existing_value", account.Extra["existing_key"]) + limits := account.Extra["model_rate_limits"].(map[string]any) + require.NotNil(t, limits["gemini-3-flash"]) + require.NotNil(t, limits["claude-sonnet-4-5"]) +} + +// TestSchedulerSnapshotService_UpdateAccountInCache 测试 UpdateAccountInCache 方法 +func TestSchedulerSnapshotService_UpdateAccountInCache(t *testing.T) { + t.Run("calls cache.SetAccount", func(t *testing.T) { + cache := &stubSchedulerCache{} + svc := &SchedulerSnapshotService{cache: cache} + + account := &Account{ID: 123, Name: "test"} + err := svc.UpdateAccountInCache(context.Background(), account) + + require.NoError(t, err) + require.Len(t, cache.setAccountCalls, 1) + require.Equal(t, int64(123), cache.setAccountCalls[0].ID) + }) + + t.Run("returns nil when cache is nil", func(t *testing.T) { + svc := &SchedulerSnapshotService{cache: nil} + + err := svc.UpdateAccountInCache(context.Background(), &Account{ID: 1}) + + require.NoError(t, err) + }) + + t.Run("returns nil when account is nil", func(t *testing.T) { + cache := &stubSchedulerCache{} + svc := &SchedulerSnapshotService{cache: cache} + + err := svc.UpdateAccountInCache(context.Background(), nil) + + require.NoError(t, err) + require.Empty(t, cache.setAccountCalls) + }) + + t.Run("propagates cache error", func(t *testing.T) { + expectedErr := fmt.Errorf("cache error") + cache := &stubSchedulerCache{setAccountErr: expectedErr} + svc := &SchedulerSnapshotService{cache: cache} + + err := svc.UpdateAccountInCache(context.Background(), &Account{ID: 1}) + + require.ErrorIs(t, err, expectedErr) + }) +} diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go new file mode 100644 index 00000000..623dfec5 --- /dev/null +++ b/backend/internal/service/antigravity_smart_retry_test.go @@ -0,0 +1,676 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream +type mockSmartRetryUpstream struct { + responses []*http.Response + errors []error + callIdx int + calls []string +} + +func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + idx := m.callIdx + m.calls = append(m.calls, req.URL.String()) + m.callIdx++ + if idx < len(m.responses) { + return m.responses[idx], m.errors[idx] + } + return nil, nil +} + +func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return m.Do(req, proxyURL, accountID, accountConcurrency) +} + +// TestHandleSmartRetry_URLLevelRateLimit 测试 URL 级别限流切换 +func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) { + account := &Account{ + ID: 1, + Name: "acc-1", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{"error":{"message":"Resource has been exhausted"}}`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test", "https://ag-2.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionContinueURL, result.action) + require.Nil(t, result.resp) + require.Nil(t, result.err) + require.Nil(t, result.switchError) +} + +// TestHandleSmartRetry_LongDelay_ReturnsSwitchError 测试 retryDelay >= 阈值时返回 switchError +func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 1, + Name: "acc-1", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 15s >= 7s 阈值,应该返回 switchError + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + accountRepo: repo, + isStickySession: true, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.Nil(t, result.resp, "should not return resp when switchError is set") + require.Nil(t, result.err) + require.NotNil(t, result.switchError, "should return switchError for long delay") + require.Equal(t, account.ID, result.switchError.OriginalAccountID) + require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel) + require.True(t, result.switchError.IsStickySession) + + // 验证模型限流已设置 + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) +} + +// TestHandleSmartRetry_ShortDelay_SmartRetrySuccess 测试智能重试成功 +func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) { + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{successResp}, + errors: []error{nil}, + } + + account := &Account{ + ID: 1, + Name: "acc-1", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 0.5s < 7s 阈值,应该触发智能重试 + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp, "should return successful response") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.Nil(t, result.err) + require.Nil(t, result.switchError, "should not return switchError on success") + require.Len(t, upstream.calls, 1, "should have made one retry call") +} + +// TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError 测试智能重试失败后返回 switchError +func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *testing.T) { + // 智能重试后仍然返回 429(需要提供 3 个响应,因为智能重试最多 3 次) + failRespBody := `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp1 := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + failResp2 := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + failResp3 := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp1, failResp2, failResp3}, + errors: []error{nil, nil, nil}, + } + + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 2, + Name: "acc-2", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 3s < 7s 阈值,应该触发智能重试(最多 3 次) + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: false, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.Nil(t, result.resp, "should not return resp when switchError is set") + require.Nil(t, result.err) + require.NotNil(t, result.switchError, "should return switchError after smart retry failed") + require.Equal(t, account.ID, result.switchError.OriginalAccountID) + require.Equal(t, "gemini-3-flash", result.switchError.RateLimitedModel) + require.False(t, result.switchError.IsStickySession) + + // 验证模型限流已设置 + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "gemini-3-flash", repo.modelRateLimitCalls[0].modelKey) + require.Len(t, upstream.calls, 3, "should have made three retry calls (max attempts)") +} + +// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError +func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 3, + Name: "acc-3", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 503 + MODEL_CAPACITY_EXHAUSTED + 39s >= 7s 阈值 + respBody := []byte(`{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"} + ], + "message": "No capacity available for model gemini-3-pro-high on the server" + } + }`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + accountRepo: repo, + isStickySession: true, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.Nil(t, result.resp) + require.Nil(t, result.err) + require.NotNil(t, result.switchError, "should return switchError for 503 model capacity exhausted") + require.Equal(t, account.ID, result.switchError.OriginalAccountID) + require.Equal(t, "gemini-3-pro-high", result.switchError.RateLimitedModel) + require.True(t, result.switchError.IsStickySession) + + // 验证模型限流已设置 + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey) +} + +// TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑 +func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing.T) { + account := &Account{ + ID: 4, + Name: "acc-4", + Type: AccountTypeAPIKey, // 非 Antigravity 平台账号 + Platform: PlatformAnthropic, + } + + // 即使是模型限流响应,非 OAuth 账号也应该走默认逻辑 + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionContinue, result.action, "non-Antigravity platform account should continue default logic") + require.Nil(t, result.resp) + require.Nil(t, result.err) + require.Nil(t, result.switchError) +} + +// TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic 测试非模型限流响应走默认逻辑 +func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T) { + account := &Account{ + ID: 5, + Name: "acc-5", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 429 但没有 RATE_LIMIT_EXCEEDED 或 MODEL_CAPACITY_EXHAUSTED + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "5s"} + ], + "message": "Quota exceeded" + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionContinue, result.action, "non-model rate limit should continue default logic") + require.Nil(t, result.resp) + require.Nil(t, result.err) + require.Nil(t, result.switchError) +} + +// TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError 测试刚好等于阈值时返回 switchError +func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 6, + Name: "acc-6", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 刚好 7s = 7s 阈值,应该返回 switchError + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-pro"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "7s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.Nil(t, result.resp) + require.NotNil(t, result.switchError, "exactly at threshold should return switchError") + require.Equal(t, "gemini-pro", result.switchError.RateLimitedModel) +} + +// TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates 测试 switchError 正确传播到上层 +func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing.T) { + // 模拟 429 + 长延迟的响应 + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "30s"} + ] + } + }`) + rateLimitResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{rateLimitResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 7, + Name: "acc-7", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + svc := &AntigravityGatewayService{} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.Nil(t, result, "should not return result when switchError") + require.NotNil(t, err, "should return error") + + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr, "error should be AntigravityAccountSwitchError") + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, "claude-opus-4-6", switchErr.RateLimitedModel) + require.True(t, switchErr.IsStickySession) +} + +// TestHandleSmartRetry_NetworkError_ContinuesRetry 测试网络错误时继续重试 +func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) { + // 第一次网络错误,第二次成功 + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{nil, successResp}, // 第一次返回 nil(模拟网络错误) + errors: []error{nil, nil}, // mock 不返回 error,靠 nil response 触发 + } + + account := &Account{ + ID: 8, + Name: "acc-8", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 0.1s < 7s 阈值,应该触发智能重试 + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp, "should return successful response after network error recovery") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.Nil(t, result.switchError, "should not return switchError on success") + require.Len(t, upstream.calls, 2, "should have made two retry calls") +} + +// TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit 测试无 retryDelay 时使用默认 1 分钟限流 +func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 9, + Name: "acc-9", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 429 + RATE_LIMIT_EXCEEDED + 无 retryDelay → 使用默认 1 分钟限流 + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"} + ], + "message": "You have exhausted your capacity on this model." + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + accountRepo: repo, + isStickySession: true, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.Nil(t, result.resp, "should not return resp when switchError is set") + require.NotNil(t, result.switchError, "should return switchError for no retryDelay") + require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel) + require.True(t, result.switchError.IsStickySession) + + // 验证模型限流已设置 + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) +} diff --git a/backend/internal/service/antigravity_thinking_test.go b/backend/internal/service/antigravity_thinking_test.go new file mode 100644 index 00000000..b3952ee4 --- /dev/null +++ b/backend/internal/service/antigravity_thinking_test.go @@ -0,0 +1,68 @@ +//go:build unit + +package service + +import ( + "testing" +) + +func TestApplyThinkingModelSuffix(t *testing.T) { + tests := []struct { + name string + mappedModel string + thinkingEnabled bool + expected string + }{ + // Thinking 未开启:保持原样 + { + name: "thinking disabled - claude-sonnet-4-5 unchanged", + mappedModel: "claude-sonnet-4-5", + thinkingEnabled: false, + expected: "claude-sonnet-4-5", + }, + { + name: "thinking disabled - other model unchanged", + mappedModel: "claude-opus-4-6-thinking", + thinkingEnabled: false, + expected: "claude-opus-4-6-thinking", + }, + + // Thinking 开启 + claude-sonnet-4-5:自动添加后缀 + { + name: "thinking enabled - claude-sonnet-4-5 becomes thinking version", + mappedModel: "claude-sonnet-4-5", + thinkingEnabled: true, + expected: "claude-sonnet-4-5-thinking", + }, + + // Thinking 开启 + 其他模型:保持原样 + { + name: "thinking enabled - claude-sonnet-4-5-thinking unchanged", + mappedModel: "claude-sonnet-4-5-thinking", + thinkingEnabled: true, + expected: "claude-sonnet-4-5-thinking", + }, + { + name: "thinking enabled - claude-opus-4-6-thinking unchanged", + mappedModel: "claude-opus-4-6-thinking", + thinkingEnabled: true, + expected: "claude-opus-4-6-thinking", + }, + { + name: "thinking enabled - gemini model unchanged", + mappedModel: "gemini-3-flash", + thinkingEnabled: true, + expected: "gemini-3-flash", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := applyThinkingModelSuffix(tt.mappedModel, tt.thinkingEnabled) + if result != tt.expected { + t.Errorf("applyThinkingModelSuffix(%q, %v) = %q, want %q", + tt.mappedModel, tt.thinkingEnabled, result, tt.expected) + } + }) + } +} diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go index 94eca94d..1eb740f9 100644 --- a/backend/internal/service/antigravity_token_provider.go +++ b/backend/internal/service/antigravity_token_provider.go @@ -42,7 +42,18 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * if account == nil { return "", errors.New("account is nil") } - if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth { + if account.Platform != PlatformAntigravity { + return "", errors.New("not an antigravity account") + } + // upstream 类型:直接从 credentials 读取 api_key,不走 OAuth 刷新流程 + if account.Type == AccountTypeUpstream { + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return "", errors.New("upstream account missing api_key in credentials") + } + return apiKey, nil + } + if account.Type != AccountTypeOAuth { return "", errors.New("not an antigravity oauth account") } diff --git a/backend/internal/service/antigravity_token_provider_test.go b/backend/internal/service/antigravity_token_provider_test.go new file mode 100644 index 00000000..c9d38cf6 --- /dev/null +++ b/backend/internal/service/antigravity_token_provider_test.go @@ -0,0 +1,97 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAntigravityTokenProvider_GetAccessToken_Upstream(t *testing.T) { + provider := &AntigravityTokenProvider{} + + t.Run("upstream account with valid api_key", func(t *testing.T) { + account := &Account{ + Platform: PlatformAntigravity, + Type: AccountTypeUpstream, + Credentials: map[string]any{ + "api_key": "sk-test-key-12345", + }, + } + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "sk-test-key-12345", token) + }) + + t.Run("upstream account missing api_key", func(t *testing.T) { + account := &Account{ + Platform: PlatformAntigravity, + Type: AccountTypeUpstream, + Credentials: map[string]any{}, + } + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "upstream account missing api_key") + require.Empty(t, token) + }) + + t.Run("upstream account with empty api_key", func(t *testing.T) { + account := &Account{ + Platform: PlatformAntigravity, + Type: AccountTypeUpstream, + Credentials: map[string]any{ + "api_key": "", + }, + } + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "upstream account missing api_key") + require.Empty(t, token) + }) + + t.Run("upstream account with nil credentials", func(t *testing.T) { + account := &Account{ + Platform: PlatformAntigravity, + Type: AccountTypeUpstream, + } + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "upstream account missing api_key") + require.Empty(t, token) + }) +} + +func TestAntigravityTokenProvider_GetAccessToken_Guards(t *testing.T) { + provider := &AntigravityTokenProvider{} + + t.Run("nil account", func(t *testing.T) { + token, err := provider.GetAccessToken(context.Background(), nil) + require.Error(t, err) + require.Contains(t, err.Error(), "account is nil") + require.Empty(t, token) + }) + + t.Run("non-antigravity platform", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + } + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "not an antigravity account") + require.Empty(t, token) + }) + + t.Run("unsupported account type", func(t *testing.T) { + account := &Account{ + Platform: PlatformAntigravity, + Type: AccountTypeAPIKey, + } + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "not an antigravity oauth account") + require.Empty(t, token) + }) +} diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index 8c692d09..d66059dd 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -2,6 +2,14 @@ package service import "time" +// API Key status constants +const ( + StatusAPIKeyActive = "active" + StatusAPIKeyDisabled = "disabled" + StatusAPIKeyQuotaExhausted = "quota_exhausted" + StatusAPIKeyExpired = "expired" +) + type APIKey struct { ID int64 UserID int64 @@ -15,8 +23,53 @@ type APIKey struct { UpdatedAt time.Time User *User Group *Group + + // Quota fields + Quota float64 // Quota limit in USD (0 = unlimited) + QuotaUsed float64 // Used quota amount + ExpiresAt *time.Time // Expiration time (nil = never expires) } func (k *APIKey) IsActive() bool { return k.Status == StatusActive } + +// IsExpired checks if the API key has expired +func (k *APIKey) IsExpired() bool { + if k.ExpiresAt == nil { + return false + } + return time.Now().After(*k.ExpiresAt) +} + +// IsQuotaExhausted checks if the API key quota is exhausted +func (k *APIKey) IsQuotaExhausted() bool { + if k.Quota <= 0 { + return false // unlimited + } + return k.QuotaUsed >= k.Quota +} + +// GetQuotaRemaining returns remaining quota (-1 for unlimited) +func (k *APIKey) GetQuotaRemaining() float64 { + if k.Quota <= 0 { + return -1 // unlimited + } + remaining := k.Quota - k.QuotaUsed + if remaining < 0 { + return 0 + } + return remaining +} + +// GetDaysUntilExpiry returns days until expiry (-1 for never expires) +func (k *APIKey) GetDaysUntilExpiry() int { + if k.ExpiresAt == nil { + return -1 // never expires + } + duration := time.Until(*k.ExpiresAt) + if duration < 0 { + return 0 + } + return int(duration.Hours() / 24) +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 5b476dbc..d15b5817 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -1,5 +1,7 @@ package service +import "time" + // APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段) type APIKeyAuthSnapshot struct { APIKeyID int64 `json:"api_key_id"` @@ -10,6 +12,13 @@ type APIKeyAuthSnapshot struct { IPBlacklist []string `json:"ip_blacklist,omitempty"` User APIKeyAuthUserSnapshot `json:"user"` Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"` + + // Quota fields for API Key independent quota feature + Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) + QuotaUsed float64 `json:"quota_used"` // Used quota amount + + // Expiration field for API Key expiration feature + ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires) } // APIKeyAuthUserSnapshot 用户快照 @@ -23,25 +32,30 @@ type APIKeyAuthUserSnapshot struct { // APIKeyAuthGroupSnapshot 分组快照 type APIKeyAuthGroupSnapshot struct { - ID int64 `json:"id"` - Name string `json:"name"` - Platform string `json:"platform"` - Status string `json:"status"` - SubscriptionType string `json:"subscription_type"` - RateMultiplier float64 `json:"rate_multiplier"` - DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` - WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` - MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` - ImagePrice1K *float64 `json:"image_price_1k,omitempty"` - ImagePrice2K *float64 `json:"image_price_2k,omitempty"` - ImagePrice4K *float64 `json:"image_price_4k,omitempty"` - ClaudeCodeOnly bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` + Status string `json:"status"` + SubscriptionType string `json:"subscription_type"` + RateMultiplier float64 `json:"rate_multiplier"` + DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` + WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` + MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` + ImagePrice1K *float64 `json:"image_price_1k,omitempty"` + ImagePrice2K *float64 `json:"image_price_2k,omitempty"` + ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` // Model routing is used by gateway account selection, so it must be part of auth cache snapshot. // Only anthropic groups use these fields; others may leave them empty. ModelRouting map[string][]int64 `json:"model_routing,omitempty"` ModelRoutingEnabled bool `json:"model_routing_enabled"` + MCPXMLInject bool `json:"mcp_xml_inject"` + + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` } // APIKeyAuthCacheEntry 缓存条目,支持负缓存 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index eb5c7534..f5bba7d0 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -213,6 +213,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { Status: apiKey.Status, IPWhitelist: apiKey.IPWhitelist, IPBlacklist: apiKey.IPBlacklist, + Quota: apiKey.Quota, + QuotaUsed: apiKey.QuotaUsed, + ExpiresAt: apiKey.ExpiresAt, User: APIKeyAuthUserSnapshot{ ID: apiKey.User.ID, Status: apiKey.User.Status, @@ -223,22 +226,25 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { } if apiKey.Group != nil { snapshot.Group = &APIKeyAuthGroupSnapshot{ - ID: apiKey.Group.ID, - Name: apiKey.Group.Name, - Platform: apiKey.Group.Platform, - Status: apiKey.Group.Status, - SubscriptionType: apiKey.Group.SubscriptionType, - RateMultiplier: apiKey.Group.RateMultiplier, - DailyLimitUSD: apiKey.Group.DailyLimitUSD, - WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, - MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, - ImagePrice1K: apiKey.Group.ImagePrice1K, - ImagePrice2K: apiKey.Group.ImagePrice2K, - ImagePrice4K: apiKey.Group.ImagePrice4K, - ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, - FallbackGroupID: apiKey.Group.FallbackGroupID, - ModelRouting: apiKey.Group.ModelRouting, - ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, + ID: apiKey.Group.ID, + Name: apiKey.Group.Name, + Platform: apiKey.Group.Platform, + Status: apiKey.Group.Status, + SubscriptionType: apiKey.Group.SubscriptionType, + RateMultiplier: apiKey.Group.RateMultiplier, + DailyLimitUSD: apiKey.Group.DailyLimitUSD, + WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, + MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, + ImagePrice1K: apiKey.Group.ImagePrice1K, + ImagePrice2K: apiKey.Group.ImagePrice2K, + ImagePrice4K: apiKey.Group.ImagePrice4K, + ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, + FallbackGroupID: apiKey.Group.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest, + ModelRouting: apiKey.Group.ModelRouting, + ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, + MCPXMLInject: apiKey.Group.MCPXMLInject, + SupportedModelScopes: apiKey.Group.SupportedModelScopes, } } return snapshot @@ -256,6 +262,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho Status: snapshot.Status, IPWhitelist: snapshot.IPWhitelist, IPBlacklist: snapshot.IPBlacklist, + Quota: snapshot.Quota, + QuotaUsed: snapshot.QuotaUsed, + ExpiresAt: snapshot.ExpiresAt, User: &User{ ID: snapshot.User.ID, Status: snapshot.User.Status, @@ -266,23 +275,26 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho } if snapshot.Group != nil { apiKey.Group = &Group{ - ID: snapshot.Group.ID, - Name: snapshot.Group.Name, - Platform: snapshot.Group.Platform, - Status: snapshot.Group.Status, - Hydrated: true, - SubscriptionType: snapshot.Group.SubscriptionType, - RateMultiplier: snapshot.Group.RateMultiplier, - DailyLimitUSD: snapshot.Group.DailyLimitUSD, - WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, - MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, - ImagePrice1K: snapshot.Group.ImagePrice1K, - ImagePrice2K: snapshot.Group.ImagePrice2K, - ImagePrice4K: snapshot.Group.ImagePrice4K, - ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, - FallbackGroupID: snapshot.Group.FallbackGroupID, - ModelRouting: snapshot.Group.ModelRouting, - ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, + ID: snapshot.Group.ID, + Name: snapshot.Group.Name, + Platform: snapshot.Group.Platform, + Status: snapshot.Group.Status, + Hydrated: true, + SubscriptionType: snapshot.Group.SubscriptionType, + RateMultiplier: snapshot.Group.RateMultiplier, + DailyLimitUSD: snapshot.Group.DailyLimitUSD, + WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, + MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, + ImagePrice1K: snapshot.Group.ImagePrice1K, + ImagePrice2K: snapshot.Group.ImagePrice2K, + ImagePrice4K: snapshot.Group.ImagePrice4K, + ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, + FallbackGroupID: snapshot.Group.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest, + ModelRouting: snapshot.Group.ModelRouting, + ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, + MCPXMLInject: snapshot.Group.MCPXMLInject, + SupportedModelScopes: snapshot.Group.SupportedModelScopes, } } return apiKey diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index ef1ff990..cb1dd60a 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -24,6 +24,10 @@ var ( ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern") + // ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key has expired") + ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期") + // ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted") + ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完") ) const ( @@ -51,6 +55,9 @@ type APIKeyRepository interface { CountByGroupID(ctx context.Context, groupID int64) (int64, error) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) + + // Quota methods + IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) } // APIKeyCache defines cache operations for API key service @@ -85,6 +92,10 @@ type CreateAPIKeyRequest struct { CustomKey *string `json:"custom_key"` // 可选的自定义key IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 + + // Quota fields + Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) + ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires) } // UpdateAPIKeyRequest 更新API Key请求 @@ -94,19 +105,26 @@ type UpdateAPIKeyRequest struct { Status *string `json:"status"` IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空) IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空) + + // Quota fields + Quota *float64 `json:"quota"` // Quota limit in USD (nil = no change, 0 = unlimited) + ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change) + ClearExpiration bool `json:"-"` // Clear expiration (internal use) + ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0 } // APIKeyService API Key服务 type APIKeyService struct { - apiKeyRepo APIKeyRepository - userRepo UserRepository - groupRepo GroupRepository - userSubRepo UserSubscriptionRepository - cache APIKeyCache - cfg *config.Config - authCacheL1 *ristretto.Cache - authCfg apiKeyAuthCacheConfig - authGroup singleflight.Group + apiKeyRepo APIKeyRepository + userRepo UserRepository + groupRepo GroupRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache APIKeyCache + cfg *config.Config + authCacheL1 *ristretto.Cache + authCfg apiKeyAuthCacheConfig + authGroup singleflight.Group } // NewAPIKeyService 创建API Key服务实例 @@ -115,16 +133,18 @@ func NewAPIKeyService( userRepo UserRepository, groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, cache APIKeyCache, cfg *config.Config, ) *APIKeyService { svc := &APIKeyService{ - apiKeyRepo: apiKeyRepo, - userRepo: userRepo, - groupRepo: groupRepo, - userSubRepo: userSubRepo, - cache: cache, - cfg: cfg, + apiKeyRepo: apiKeyRepo, + userRepo: userRepo, + groupRepo: groupRepo, + userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, + cache: cache, + cfg: cfg, } svc.initAuthCache(cfg) return svc @@ -289,6 +309,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK Status: StatusActive, IPWhitelist: req.IPWhitelist, IPBlacklist: req.IPBlacklist, + Quota: req.Quota, + QuotaUsed: 0, + } + + // Set expiration time if specified + if req.ExpiresInDays != nil && *req.ExpiresInDays > 0 { + expiresAt := time.Now().AddDate(0, 0, *req.ExpiresInDays) + apiKey.ExpiresAt = &expiresAt } if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil { @@ -436,6 +464,35 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req } } + // Update quota fields + if req.Quota != nil { + apiKey.Quota = *req.Quota + // If quota is increased and status was quota_exhausted, reactivate + if apiKey.Status == StatusAPIKeyQuotaExhausted && *req.Quota > apiKey.QuotaUsed { + apiKey.Status = StatusActive + } + } + if req.ResetQuota != nil && *req.ResetQuota { + apiKey.QuotaUsed = 0 + // If resetting quota and status was quota_exhausted, reactivate + if apiKey.Status == StatusAPIKeyQuotaExhausted { + apiKey.Status = StatusActive + } + } + if req.ClearExpiration { + apiKey.ExpiresAt = nil + // If clearing expiry and status was expired, reactivate + if apiKey.Status == StatusAPIKeyExpired { + apiKey.Status = StatusActive + } + } else if req.ExpiresAt != nil { + apiKey.ExpiresAt = req.ExpiresAt + // If extending expiry and status was expired, reactivate + if apiKey.Status == StatusAPIKeyExpired && time.Now().Before(*req.ExpiresAt) { + apiKey.Status = StatusActive + } + } + // 更新 IP 限制(空数组会清空设置) apiKey.IPWhitelist = req.IPWhitelist apiKey.IPBlacklist = req.IPBlacklist @@ -572,3 +629,64 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword } return keys, nil } + +// GetUserGroupRates 获取用户的专属分组倍率配置 +// 返回 map[groupID]rateMultiplier +func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) { + if s.userGroupRateRepo == nil { + return nil, nil + } + rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get user group rates: %w", err) + } + return rates, nil +} + +// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted) +// Returns nil if valid, error if invalid +func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error { + // Check expiration + if apiKey.IsExpired() { + return ErrAPIKeyExpired + } + + // Check quota + if apiKey.IsQuotaExhausted() { + return ErrAPIKeyQuotaExhausted + } + + return nil +} + +// UpdateQuotaUsed updates the quota_used field after a request +// Also checks if quota is exhausted and updates status accordingly +func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { + if cost <= 0 { + return nil + } + + // Use repository to atomically increment quota_used + newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost) + if err != nil { + return fmt.Errorf("increment quota used: %w", err) + } + + // Check if quota is now exhausted and update status if needed + apiKey, err := s.apiKeyRepo.GetByID(ctx, apiKeyID) + if err != nil { + return nil // Don't fail the request, just log + } + + // If quota is set and now exhausted, update status + if apiKey.Quota > 0 && newQuotaUsed >= apiKey.Quota { + apiKey.Status = StatusAPIKeyQuotaExhausted + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { + return nil // Don't fail the request + } + // Invalidate cache so next request sees the new status + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + + return nil +} diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index c5e9cd47..14ecbf39 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -99,6 +99,10 @@ func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([] return s.listKeysByGroupID(ctx, groupID) } +func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + panic("unexpected IncrementQuotaUsed call") +} + type authCacheStub struct { getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) setAuthKeys []string @@ -163,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) groupID := int64(9) cacheEntry := &APIKeyAuthCacheEntry{ @@ -219,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { return &APIKeyAuthCacheEntry{NotFound: true}, nil } @@ -252,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { return nil, redis.Nil } @@ -289,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) { L1TTLSeconds: 60, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) require.NotNil(t, svc.authCacheL1) _, err := svc.GetByKey(context.Background(), "k-l1") @@ -316,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) svc.InvalidateAuthCacheByUserID(context.Background(), 7) require.Len(t, cache.deleteAuthKeys, 2) @@ -334,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) { L2TTLSeconds: 60, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) svc.InvalidateAuthCacheByGroupID(context.Background(), 9) require.Len(t, cache.deleteAuthKeys, 2) @@ -352,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) { L2TTLSeconds: 60, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) svc.InvalidateAuthCacheByKey(context.Background(), "k1") require.Len(t, cache.deleteAuthKeys, 1) @@ -371,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { return nil, redis.Nil } @@ -407,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) { Singleflight: true, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) start := make(chan struct{}) wg := sync.WaitGroup{} diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index 092b7fce..d4d12144 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -118,6 +118,10 @@ func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ( panic("unexpected ListKeysByGroupID call") } +func (s *apiKeyRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + panic("unexpected IncrementQuotaUsed call") +} + // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。 // diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index c824ec1e..fb8aaf9c 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -3,6 +3,7 @@ package service import ( "context" "crypto/rand" + "crypto/sha256" "encoding/hex" "errors" "fmt" @@ -25,8 +26,12 @@ var ( ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") + ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") + ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token") + ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired") + ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") @@ -37,6 +42,9 @@ var ( // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 const maxTokenLength = 8192 +// refreshTokenPrefix is the prefix for refresh tokens to distinguish them from access tokens. +const refreshTokenPrefix = "rt_" + // JWTClaims JWT载荷数据 type JWTClaims struct { UserID int64 `json:"user_id"` @@ -50,6 +58,7 @@ type JWTClaims struct { type AuthService struct { userRepo UserRepository redeemRepo RedeemCodeRepository + refreshTokenCache RefreshTokenCache cfg *config.Config settingService *SettingService emailService *EmailService @@ -62,6 +71,7 @@ type AuthService struct { func NewAuthService( userRepo UserRepository, redeemRepo RedeemCodeRepository, + refreshTokenCache RefreshTokenCache, cfg *config.Config, settingService *SettingService, emailService *EmailService, @@ -72,6 +82,7 @@ func NewAuthService( return &AuthService{ userRepo: userRepo, redeemRepo: redeemRepo, + refreshTokenCache: refreshTokenCache, cfg: cfg, settingService: settingService, emailService: emailService, @@ -185,7 +196,6 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err) } } - // 应用优惠码(如果提供且功能已启用) if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) { if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil { @@ -482,6 +492,100 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username return token, user, nil } +// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair +// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token +func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) { + // 检查 refreshTokenCache 是否可用 + if s.refreshTokenCache == nil { + return nil, nil, errors.New("refresh token cache not configured") + } + + email = strings.TrimSpace(email) + if email == "" || len(email) > 255 { + return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + if _, err := mail.ParseAddress(email); err != nil { + return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + + username = strings.TrimSpace(username) + if len([]rune(username)) > 100 { + username = string([]rune(username)[:100]) + } + + user, err := s.userRepo.GetByEmail(ctx, email) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + // OAuth 首次登录视为注册 + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { + return nil, nil, ErrRegDisabled + } + + randomPassword, err := randomHexString(32) + if err != nil { + log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + return nil, nil, ErrServiceUnavailable + } + hashedPassword, err := s.HashPassword(randomPassword) + if err != nil { + return nil, nil, fmt.Errorf("hash password: %w", err) + } + + defaultBalance := s.cfg.Default.UserBalance + defaultConcurrency := s.cfg.Default.UserConcurrency + if s.settingService != nil { + defaultBalance = s.settingService.GetDefaultBalance(ctx) + defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) + } + + newUser := &User{ + Email: email, + Username: username, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: defaultBalance, + Concurrency: defaultConcurrency, + Status: StatusActive, + } + + if err := s.userRepo.Create(ctx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + log.Printf("[Auth] Database error getting user after conflict: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + log.Printf("[Auth] Database error creating oauth user: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + user = newUser + } + } else { + log.Printf("[Auth] Database error during oauth login: %v", err) + return nil, nil, ErrServiceUnavailable + } + } + + if !user.IsActive() { + return nil, nil, ErrUserNotActive + } + + if user.Username == "" && username != "" { + user.Username = username + if err := s.userRepo.Update(ctx, user); err != nil { + log.Printf("[Auth] Failed to update username after oauth login: %v", err) + } + } + + tokenPair, err := s.GenerateTokenPair(ctx, user, "") + if err != nil { + return nil, nil, fmt.Errorf("generate token pair: %w", err) + } + return tokenPair, user, nil +} + // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 @@ -540,10 +644,17 @@ func isReservedEmail(email string) bool { return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) } -// GenerateToken 生成JWT token +// GenerateToken 生成JWT access token +// 使用新的access_token_expire_minutes配置项(如果配置了),否则回退到expire_hour func (s *AuthService) GenerateToken(user *User) (string, error) { now := time.Now() - expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour) + var expiresAt time.Time + if s.cfg.JWT.AccessTokenExpireMinutes > 0 { + expiresAt = now.Add(time.Duration(s.cfg.JWT.AccessTokenExpireMinutes) * time.Minute) + } else { + // 向后兼容:使用旧的expire_hour配置 + expiresAt = now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour) + } claims := &JWTClaims{ UserID: user.ID, @@ -566,6 +677,15 @@ func (s *AuthService) GenerateToken(user *User) (string, error) { return tokenString, nil } +// GetAccessTokenExpiresIn 返回Access Token的有效期(秒) +// 用于前端设置刷新定时器 +func (s *AuthService) GetAccessTokenExpiresIn() int { + if s.cfg.JWT.AccessTokenExpireMinutes > 0 { + return s.cfg.JWT.AccessTokenExpireMinutes * 60 + } + return s.cfg.JWT.ExpireHour * 3600 +} + // HashPassword 使用bcrypt加密密码 func (s *AuthService) HashPassword(password string) (string, error) { hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) @@ -756,6 +876,198 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo return ErrServiceUnavailable } + // Also revoke all refresh tokens for this user + if err := s.RevokeAllUserSessions(ctx, user.ID); err != nil { + log.Printf("[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err) + // Don't return error - password was already changed successfully + } + log.Printf("[Auth] Password reset successful for user: %s", email) return nil } + +// ==================== Refresh Token Methods ==================== + +// TokenPair 包含Access Token和Refresh Token +type TokenPair struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` // Access Token有效期(秒) +} + +// GenerateTokenPair 生成Access Token和Refresh Token对 +// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系 +func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) { + // 检查 refreshTokenCache 是否可用 + if s.refreshTokenCache == nil { + return nil, errors.New("refresh token cache not configured") + } + + // 生成Access Token + accessToken, err := s.GenerateToken(user) + if err != nil { + return nil, fmt.Errorf("generate access token: %w", err) + } + + // 生成Refresh Token + refreshToken, err := s.generateRefreshToken(ctx, user, familyID) + if err != nil { + return nil, fmt.Errorf("generate refresh token: %w", err) + } + + return &TokenPair{ + AccessToken: accessToken, + RefreshToken: refreshToken, + ExpiresIn: s.GetAccessTokenExpiresIn(), + }, nil +} + +// generateRefreshToken 生成并存储Refresh Token +func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, familyID string) (string, error) { + // 生成随机Token + tokenBytes := make([]byte, 32) + if _, err := rand.Read(tokenBytes); err != nil { + return "", fmt.Errorf("generate random bytes: %w", err) + } + rawToken := refreshTokenPrefix + hex.EncodeToString(tokenBytes) + + // 计算Token哈希(存储哈希而非原始Token) + tokenHash := hashToken(rawToken) + + // 如果没有提供familyID,生成新的 + if familyID == "" { + familyBytes := make([]byte, 16) + if _, err := rand.Read(familyBytes); err != nil { + return "", fmt.Errorf("generate family id: %w", err) + } + familyID = hex.EncodeToString(familyBytes) + } + + now := time.Now() + ttl := time.Duration(s.cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour + + data := &RefreshTokenData{ + UserID: user.ID, + TokenVersion: user.TokenVersion, + FamilyID: familyID, + CreatedAt: now, + ExpiresAt: now.Add(ttl), + } + + // 存储Token数据 + if err := s.refreshTokenCache.StoreRefreshToken(ctx, tokenHash, data, ttl); err != nil { + return "", fmt.Errorf("store refresh token: %w", err) + } + + // 添加到用户Token集合 + if err := s.refreshTokenCache.AddToUserTokenSet(ctx, user.ID, tokenHash, ttl); err != nil { + log.Printf("[Auth] Failed to add token to user set: %v", err) + // 不影响主流程 + } + + // 添加到家族Token集合 + if err := s.refreshTokenCache.AddToFamilyTokenSet(ctx, familyID, tokenHash, ttl); err != nil { + log.Printf("[Auth] Failed to add token to family set: %v", err) + // 不影响主流程 + } + + return rawToken, nil +} + +// RefreshTokenPair 使用Refresh Token刷新Token对 +// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效 +func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) { + // 检查 refreshTokenCache 是否可用 + if s.refreshTokenCache == nil { + return nil, ErrRefreshTokenInvalid + } + + // 验证Token格式 + if !strings.HasPrefix(refreshToken, refreshTokenPrefix) { + return nil, ErrRefreshTokenInvalid + } + + tokenHash := hashToken(refreshToken) + + // 获取Token数据 + data, err := s.refreshTokenCache.GetRefreshToken(ctx, tokenHash) + if err != nil { + if errors.Is(err, ErrRefreshTokenNotFound) { + // Token不存在,可能是已被使用(Token轮转)或已过期 + log.Printf("[Auth] Refresh token not found, possible reuse attack") + return nil, ErrRefreshTokenInvalid + } + log.Printf("[Auth] Error getting refresh token: %v", err) + return nil, ErrServiceUnavailable + } + + // 检查Token是否过期 + if time.Now().After(data.ExpiresAt) { + // 删除过期Token + _ = s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash) + return nil, ErrRefreshTokenExpired + } + + // 获取用户信息 + user, err := s.userRepo.GetByID(ctx, data.UserID) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + // 用户已删除,撤销整个Token家族 + _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) + return nil, ErrRefreshTokenInvalid + } + log.Printf("[Auth] Database error getting user for token refresh: %v", err) + return nil, ErrServiceUnavailable + } + + // 检查用户状态 + if !user.IsActive() { + // 用户被禁用,撤销整个Token家族 + _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) + return nil, ErrUserNotActive + } + + // 检查TokenVersion(密码更改后所有Token失效) + if data.TokenVersion != user.TokenVersion { + // TokenVersion不匹配,撤销整个Token家族 + _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) + return nil, ErrTokenRevoked + } + + // Token轮转:立即使旧Token失效 + if err := s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash); err != nil { + log.Printf("[Auth] Failed to delete old refresh token: %v", err) + // 继续处理,不影响主流程 + } + + // 生成新的Token对,保持同一个家族ID + return s.GenerateTokenPair(ctx, user, data.FamilyID) +} + +// RevokeRefreshToken 撤销单个Refresh Token +func (s *AuthService) RevokeRefreshToken(ctx context.Context, refreshToken string) error { + if s.refreshTokenCache == nil { + return nil // No-op if cache not configured + } + if !strings.HasPrefix(refreshToken, refreshTokenPrefix) { + return ErrRefreshTokenInvalid + } + + tokenHash := hashToken(refreshToken) + return s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash) +} + +// RevokeAllUserSessions 撤销用户的所有会话(所有Refresh Token) +// 用于密码更改或用户主动登出所有设备 +func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) error { + if s.refreshTokenCache == nil { + return nil // No-op if cache not configured + } + return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID) +} + +// hashToken 计算Token的SHA256哈希 +func hashToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return hex.EncodeToString(hash[:]) +} diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index aa3c769e..f1685be5 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -116,6 +116,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E return NewAuthService( repo, nil, // redeemRepo + nil, // refreshTokenCache cfg, settingService, emailService, diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go index ab86f1e8..6d06c83e 100644 --- a/backend/internal/service/claude_code_validator.go +++ b/backend/internal/service/claude_code_validator.go @@ -56,7 +56,8 @@ func NewClaudeCodeValidator() *ClaudeCodeValidator { // // Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x // Step 2: 对于非 messages 路径,只要 UA 匹配就通过 -// Step 3: 对于 messages 路径,进行严格验证: +// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过(UA 已验证) +// Step 4: 对于 messages 路径,进行严格验证: // - System prompt 相似度检查 // - X-App header 检查 // - anthropic-beta header 检查 @@ -75,14 +76,20 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo return true } - // Step 3: messages 路径,进行严格验证 + // Step 3: 检查 max_tokens=1 + haiku 探测请求绕过 + // 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt + if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku { + return true // 绕过 system prompt 检查,UA 已在 Step 1 验证 + } - // 3.1 检查 system prompt 相似度 + // Step 4: messages 路径,进行严格验证 + + // 4.1 检查 system prompt 相似度 if !v.hasClaudeCodeSystemPrompt(body) { return false } - // 3.2 检查必需的 headers(值不为空即可) + // 4.2 检查必需的 headers(值不为空即可) xApp := r.Header.Get("X-App") if xApp == "" { return false @@ -98,7 +105,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo return false } - // 3.3 验证 metadata.user_id + // 4.3 验证 metadata.user_id if body == nil { return false } diff --git a/backend/internal/service/claude_code_validator_test.go b/backend/internal/service/claude_code_validator_test.go new file mode 100644 index 00000000..a4cd1886 --- /dev/null +++ b/backend/internal/service/claude_code_validator_test.go @@ -0,0 +1,58 @@ +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestClaudeCodeValidator_ProbeBypass(t *testing.T) { + validator := NewClaudeCodeValidator() + req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)") + req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)) + + ok := validator.Validate(req, map[string]any{ + "model": "claude-haiku-4-5", + "max_tokens": 1, + }) + require.True(t, ok) +} + +func TestClaudeCodeValidator_ProbeBypassRequiresUA(t *testing.T) { + validator := NewClaudeCodeValidator() + req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil) + req.Header.Set("User-Agent", "curl/8.0.0") + req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)) + + ok := validator.Validate(req, map[string]any{ + "model": "claude-haiku-4-5", + "max_tokens": 1, + }) + require.False(t, ok) +} + +func TestClaudeCodeValidator_MessagesWithoutProbeStillNeedStrictValidation(t *testing.T) { + validator := NewClaudeCodeValidator() + req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)") + + ok := validator.Validate(req, map[string]any{ + "model": "claude-haiku-4-5", + "max_tokens": 1, + }) + require.False(t, ok) +} + +func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) { + validator := NewClaudeCodeValidator() + req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/models", nil) + req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)") + + ok := validator.Validate(req, nil) + require.True(t, ok) +} diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index 65ef16db..d5cb2025 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -35,6 +35,7 @@ type ConcurrencyCache interface { // 批量负载查询(只读) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) + GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) // 清理过期槽位(后台任务) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error @@ -77,6 +78,11 @@ type AccountWithConcurrency struct { MaxConcurrency int } +type UserWithConcurrency struct { + ID int64 + MaxConcurrency int +} + type AccountLoadInfo struct { AccountID int64 CurrentConcurrency int @@ -84,6 +90,13 @@ type AccountLoadInfo struct { LoadRate int // 0-100+ (percent) } +type UserLoadInfo struct { + UserID int64 + CurrentConcurrency int + WaitingCount int + LoadRate int // 0-100+ (percent) +} + // AcquireAccountSlot attempts to acquire a concurrency slot for an account. // If the account is at max concurrency, it waits until a slot is available or timeout. // Returns a release function that MUST be called when the request completes. @@ -253,6 +266,14 @@ func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts return s.cache.GetAccountsLoadBatch(ctx, accounts) } +// GetUsersLoadBatch returns load info for multiple users. +func (s *ConcurrencyService) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { + if s.cache == nil { + return map[int64]*UserLoadInfo{}, nil + } + return s.cache.GetUsersLoadBatch(ctx, users) +} + // CleanupExpiredAccountSlots removes expired slots for one account (background task). func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { if s.cache == nil { diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 2db72825..0295c23b 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -31,6 +31,7 @@ const ( AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 + AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) ) // Redeem type constants diff --git a/backend/internal/service/error_passthrough_runtime.go b/backend/internal/service/error_passthrough_runtime.go new file mode 100644 index 00000000..65085d6f --- /dev/null +++ b/backend/internal/service/error_passthrough_runtime.go @@ -0,0 +1,67 @@ +package service + +import "github.com/gin-gonic/gin" + +const errorPassthroughServiceContextKey = "error_passthrough_service" + +// BindErrorPassthroughService 将错误透传服务绑定到请求上下文,供 service 层在非 failover 场景下复用规则。 +func BindErrorPassthroughService(c *gin.Context, svc *ErrorPassthroughService) { + if c == nil || svc == nil { + return + } + c.Set(errorPassthroughServiceContextKey, svc) +} + +func getBoundErrorPassthroughService(c *gin.Context) *ErrorPassthroughService { + if c == nil { + return nil + } + v, ok := c.Get(errorPassthroughServiceContextKey) + if !ok { + return nil + } + svc, ok := v.(*ErrorPassthroughService) + if !ok { + return nil + } + return svc +} + +// applyErrorPassthroughRule 按规则改写错误响应;未命中时返回默认响应参数。 +func applyErrorPassthroughRule( + c *gin.Context, + platform string, + upstreamStatus int, + responseBody []byte, + defaultStatus int, + defaultErrType string, + defaultErrMsg string, +) (status int, errType string, errMsg string, matched bool) { + status = defaultStatus + errType = defaultErrType + errMsg = defaultErrMsg + + svc := getBoundErrorPassthroughService(c) + if svc == nil { + return status, errType, errMsg, false + } + + rule := svc.MatchRule(platform, upstreamStatus, responseBody) + if rule == nil { + return status, errType, errMsg, false + } + + status = upstreamStatus + if !rule.PassthroughCode && rule.ResponseCode != nil { + status = *rule.ResponseCode + } + + errMsg = ExtractUpstreamErrorMessage(responseBody) + if !rule.PassthroughBody && rule.CustomMessage != nil { + errMsg = *rule.CustomMessage + } + + // 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。 + errType = "upstream_error" + return status, errType, errMsg, true +} diff --git a/backend/internal/service/error_passthrough_runtime_test.go b/backend/internal/service/error_passthrough_runtime_test.go new file mode 100644 index 00000000..393e6e59 --- /dev/null +++ b/backend/internal/service/error_passthrough_runtime_test.go @@ -0,0 +1,211 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplyErrorPassthroughRule_NoBoundService(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformAnthropic, + http.StatusUnprocessableEntity, + []byte(`{"error":{"message":"invalid schema"}}`), + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ) + + assert.False(t, matched) + assert.Equal(t, http.StatusBadGateway, status) + assert.Equal(t, "upstream_error", errType) + assert.Equal(t, "Upstream request failed", errMsg) +} + +func TestGatewayHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + svc := &GatewayService{} + respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`) + resp := &http.Response{ + StatusCode: http.StatusUnprocessableEntity, + Body: io.NopCloser(bytes.NewReader(respBody)), + Header: http.Header{}, + } + account := &Account{ID: 11, Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + + _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + require.Error(t, err) + assert.Equal(t, http.StatusBadGateway, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "Upstream request failed", errField["message"]) +} + +func TestOpenAIHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + svc := &OpenAIGatewayService{} + respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`) + resp := &http.Response{ + StatusCode: http.StatusUnprocessableEntity, + Body: io.NopCloser(bytes.NewReader(respBody)), + Header: http.Header{}, + } + account := &Account{ID: 12, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + require.Error(t, err) + assert.Equal(t, http.StatusBadGateway, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "Upstream request failed", errField["message"]) +} + +func TestGeminiWriteGeminiMappedError_NoRuleKeepsDefault(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + svc := &GeminiMessagesCompatService{} + respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`) + account := &Account{ID: 13, Platform: PlatformGemini, Type: AccountTypeAPIKey} + + err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-2", respBody) + require.Error(t, err) + assert.Equal(t, http.StatusBadRequest, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "invalid_request_error", errField["type"]) + assert.Equal(t, "Upstream request failed", errField["message"]) +} + +func TestGatewayHandleErrorResponse_AppliesRuleFor422(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "上游请求失败")}) + BindErrorPassthroughService(c, ruleSvc) + + svc := &GatewayService{} + respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`) + resp := &http.Response{ + StatusCode: http.StatusUnprocessableEntity, + Body: io.NopCloser(bytes.NewReader(respBody)), + Header: http.Header{}, + } + account := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + + _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + require.Error(t, err) + assert.Equal(t, http.StatusTeapot, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "上游请求失败", errField["message"]) +} + +func TestOpenAIHandleErrorResponse_AppliesRuleFor422(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "OpenAI上游失败")}) + BindErrorPassthroughService(c, ruleSvc) + + svc := &OpenAIGatewayService{} + respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`) + resp := &http.Response{ + StatusCode: http.StatusUnprocessableEntity, + Body: io.NopCloser(bytes.NewReader(respBody)), + Header: http.Header{}, + } + account := &Account{ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + require.Error(t, err) + assert.Equal(t, http.StatusTeapot, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "OpenAI上游失败", errField["message"]) +} + +func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "Gemini上游失败")}) + BindErrorPassthroughService(c, ruleSvc) + + svc := &GeminiMessagesCompatService{} + respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`) + account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeAPIKey} + + err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-1", respBody) + require.Error(t, err) + assert.Equal(t, http.StatusTeapot, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "Gemini上游失败", errField["message"]) +} + +func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule { + return &model.ErrorPassthroughRule{ + ID: 1, + Name: "non-failover-rule", + Enabled: true, + Priority: 1, + ErrorCodes: []int{statusCode}, + Keywords: []string{keyword}, + MatchMode: model.MatchModeAll, + PassthroughCode: false, + ResponseCode: &respCode, + PassthroughBody: false, + CustomMessage: &customMessage, + } +} diff --git a/backend/internal/service/error_passthrough_service.go b/backend/internal/service/error_passthrough_service.go new file mode 100644 index 00000000..c3e0f630 --- /dev/null +++ b/backend/internal/service/error_passthrough_service.go @@ -0,0 +1,336 @@ +package service + +import ( + "context" + "log" + "sort" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" +) + +// ErrorPassthroughRepository 定义错误透传规则的数据访问接口 +type ErrorPassthroughRepository interface { + // List 获取所有规则 + List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) + // GetByID 根据 ID 获取规则 + GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) + // Create 创建规则 + Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) + // Update 更新规则 + Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) + // Delete 删除规则 + Delete(ctx context.Context, id int64) error +} + +// ErrorPassthroughCache 定义错误透传规则的缓存接口 +type ErrorPassthroughCache interface { + // Get 从缓存获取规则列表 + Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) + // Set 设置缓存 + Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error + // Invalidate 使缓存失效 + Invalidate(ctx context.Context) error + // NotifyUpdate 通知其他实例刷新缓存 + NotifyUpdate(ctx context.Context) error + // SubscribeUpdates 订阅缓存更新通知 + SubscribeUpdates(ctx context.Context, handler func()) +} + +// ErrorPassthroughService 错误透传规则服务 +type ErrorPassthroughService struct { + repo ErrorPassthroughRepository + cache ErrorPassthroughCache + + // 本地内存缓存,用于快速匹配 + localCache []*model.ErrorPassthroughRule + localCacheMu sync.RWMutex +} + +// NewErrorPassthroughService 创建错误透传规则服务 +func NewErrorPassthroughService( + repo ErrorPassthroughRepository, + cache ErrorPassthroughCache, +) *ErrorPassthroughService { + svc := &ErrorPassthroughService{ + repo: repo, + cache: cache, + } + + // 启动时加载规则到本地缓存 + ctx := context.Background() + if err := svc.reloadRulesFromDB(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err) + if fallbackErr := svc.refreshLocalCache(ctx); fallbackErr != nil { + log.Printf("[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr) + } + } + + // 订阅缓存更新通知 + if cache != nil { + cache.SubscribeUpdates(ctx, func() { + if err := svc.refreshLocalCache(context.Background()); err != nil { + log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err) + } + }) + } + + return svc +} + +// List 获取所有规则 +func (s *ErrorPassthroughService) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { + return s.repo.List(ctx) +} + +// GetByID 根据 ID 获取规则 +func (s *ErrorPassthroughService) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { + return s.repo.GetByID(ctx, id) +} + +// Create 创建规则 +func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + if err := rule.Validate(); err != nil { + return nil, err + } + + created, err := s.repo.Create(ctx, rule) + if err != nil { + return nil, err + } + + // 刷新缓存 + refreshCtx, cancel := s.newCacheRefreshContext() + defer cancel() + s.invalidateAndNotify(refreshCtx) + + return created, nil +} + +// Update 更新规则 +func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + if err := rule.Validate(); err != nil { + return nil, err + } + + updated, err := s.repo.Update(ctx, rule) + if err != nil { + return nil, err + } + + // 刷新缓存 + refreshCtx, cancel := s.newCacheRefreshContext() + defer cancel() + s.invalidateAndNotify(refreshCtx) + + return updated, nil +} + +// Delete 删除规则 +func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error { + if err := s.repo.Delete(ctx, id); err != nil { + return err + } + + // 刷新缓存 + refreshCtx, cancel := s.newCacheRefreshContext() + defer cancel() + s.invalidateAndNotify(refreshCtx) + + return nil +} + +// MatchRule 匹配透传规则 +// 返回第一个匹配的规则,如果没有匹配则返回 nil +func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, body []byte) *model.ErrorPassthroughRule { + rules := s.getCachedRules() + if len(rules) == 0 { + return nil + } + + bodyStr := strings.ToLower(string(body)) + + for _, rule := range rules { + if !rule.Enabled { + continue + } + if !s.platformMatches(rule, platform) { + continue + } + if s.ruleMatches(rule, statusCode, bodyStr) { + return rule + } + } + + return nil +} + +// getCachedRules 获取缓存的规则列表(按优先级排序) +func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule { + s.localCacheMu.RLock() + rules := s.localCache + s.localCacheMu.RUnlock() + + if rules != nil { + return rules + } + + // 如果本地缓存为空,尝试刷新 + ctx := context.Background() + if err := s.refreshLocalCache(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err) + return nil + } + + s.localCacheMu.RLock() + defer s.localCacheMu.RUnlock() + return s.localCache +} + +// refreshLocalCache 刷新本地缓存 +func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error { + // 先尝试从 Redis 缓存获取 + if s.cache != nil { + if rules, ok := s.cache.Get(ctx); ok { + s.setLocalCache(rules) + return nil + } + } + + return s.reloadRulesFromDB(ctx) +} + +// 从数据库加载(repo.List 已按 priority 排序) +// 注意:该方法会绕过 cache.Get,确保拿到数据库最新值。 +func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error { + rules, err := s.repo.List(ctx) + if err != nil { + return err + } + + // 更新 Redis 缓存 + if s.cache != nil { + if err := s.cache.Set(ctx, rules); err != nil { + log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err) + } + } + + // 更新本地缓存(setLocalCache 内部会确保排序) + s.setLocalCache(rules) + + return nil +} + +// setLocalCache 设置本地缓存 +func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) { + // 按优先级排序 + sorted := make([]*model.ErrorPassthroughRule, len(rules)) + copy(sorted, rules) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].Priority < sorted[j].Priority + }) + + s.localCacheMu.Lock() + s.localCache = sorted + s.localCacheMu.Unlock() +} + +// clearLocalCache 清空本地缓存,避免刷新失败时继续命中陈旧规则。 +func (s *ErrorPassthroughService) clearLocalCache() { + s.localCacheMu.Lock() + s.localCache = nil + s.localCacheMu.Unlock() +} + +// newCacheRefreshContext 为写路径缓存同步创建独立上下文,避免受请求取消影响。 +func (s *ErrorPassthroughService) newCacheRefreshContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 3*time.Second) +} + +// invalidateAndNotify 使缓存失效并通知其他实例 +func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { + // 先失效缓存,避免后续刷新读到陈旧规则。 + if s.cache != nil { + if err := s.cache.Invalidate(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to invalidate cache: %v", err) + } + } + + // 刷新本地缓存 + if err := s.reloadRulesFromDB(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err) + // 刷新失败时清空本地缓存,避免继续使用陈旧规则。 + s.clearLocalCache() + } + + // 通知其他实例 + if s.cache != nil { + if err := s.cache.NotifyUpdate(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err) + } + } +} + +// platformMatches 检查平台是否匹配 +func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool { + // 如果没有配置平台限制,则匹配所有平台 + if len(rule.Platforms) == 0 { + return true + } + + platform = strings.ToLower(platform) + for _, p := range rule.Platforms { + if strings.ToLower(p) == platform { + return true + } + } + + return false +} + +// ruleMatches 检查规则是否匹配 +func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool { + hasErrorCodes := len(rule.ErrorCodes) > 0 + hasKeywords := len(rule.Keywords) > 0 + + // 如果没有配置任何条件,不匹配 + if !hasErrorCodes && !hasKeywords { + return false + } + + codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode) + keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords) + + if rule.MatchMode == model.MatchModeAll { + // "all" 模式:所有配置的条件都必须满足 + return codeMatch && keywordMatch + } + + // "any" 模式:任一条件满足即可 + if hasErrorCodes && hasKeywords { + return codeMatch || keywordMatch + } + return codeMatch && keywordMatch +} + +// containsInt 检查切片是否包含指定整数 +func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool { + for _, v := range slice { + if v == val { + return true + } + } + return false +} + +// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写) +func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool { + for _, kw := range keywords { + if strings.Contains(bodyLower, strings.ToLower(kw)) { + return true + } + } + return false +} diff --git a/backend/internal/service/error_passthrough_service_test.go b/backend/internal/service/error_passthrough_service_test.go new file mode 100644 index 00000000..74c98d86 --- /dev/null +++ b/backend/internal/service/error_passthrough_service_test.go @@ -0,0 +1,984 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockErrorPassthroughRepo 用于测试的 mock repository +type mockErrorPassthroughRepo struct { + rules []*model.ErrorPassthroughRule + listErr error + getErr error + createErr error + updateErr error + deleteErr error +} + +type mockErrorPassthroughCache struct { + rules []*model.ErrorPassthroughRule + hasData bool + getCalled int + setCalled int + invalidateCalled int + notifyCalled int +} + +func newMockErrorPassthroughCache(rules []*model.ErrorPassthroughRule, hasData bool) *mockErrorPassthroughCache { + return &mockErrorPassthroughCache{ + rules: cloneRules(rules), + hasData: hasData, + } +} + +func (m *mockErrorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) { + m.getCalled++ + if !m.hasData { + return nil, false + } + return cloneRules(m.rules), true +} + +func (m *mockErrorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error { + m.setCalled++ + m.rules = cloneRules(rules) + m.hasData = true + return nil +} + +func (m *mockErrorPassthroughCache) Invalidate(ctx context.Context) error { + m.invalidateCalled++ + m.rules = nil + m.hasData = false + return nil +} + +func (m *mockErrorPassthroughCache) NotifyUpdate(ctx context.Context) error { + m.notifyCalled++ + return nil +} + +func (m *mockErrorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) { + // 单测中无需订阅行为 +} + +func cloneRules(rules []*model.ErrorPassthroughRule) []*model.ErrorPassthroughRule { + if rules == nil { + return nil + } + out := make([]*model.ErrorPassthroughRule, len(rules)) + copy(out, rules) + return out +} + +func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { + if m.listErr != nil { + return nil, m.listErr + } + return m.rules, nil +} + +func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { + if m.getErr != nil { + return nil, m.getErr + } + for _, r := range m.rules { + if r.ID == id { + return r, nil + } + } + return nil, nil +} + +func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + if m.createErr != nil { + return nil, m.createErr + } + rule.ID = int64(len(m.rules) + 1) + m.rules = append(m.rules, rule) + return rule, nil +} + +func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + if m.updateErr != nil { + return nil, m.updateErr + } + for i, r := range m.rules { + if r.ID == rule.ID { + m.rules[i] = rule + return rule, nil + } + } + return rule, nil +} + +func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error { + if m.deleteErr != nil { + return m.deleteErr + } + for i, r := range m.rules { + if r.ID == id { + m.rules = append(m.rules[:i], m.rules[i+1:]...) + return nil + } + } + return nil +} + +// newTestService 创建测试用的服务实例 +func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughService { + repo := &mockErrorPassthroughRepo{rules: rules} + svc := &ErrorPassthroughService{ + repo: repo, + cache: nil, // 不使用缓存 + } + // 直接设置本地缓存,避免调用 refreshLocalCache + svc.setLocalCache(rules) + return svc +} + +// ============================================================================= +// 测试 ruleMatches 核心匹配逻辑 +// ============================================================================= + +func TestRuleMatches_NoConditions(t *testing.T) { + // 没有配置任何条件时,不应该匹配 + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{}, + Keywords: []string{}, + MatchMode: model.MatchModeAny, + } + + assert.False(t, svc.ruleMatches(rule, 422, "some error message"), + "没有配置条件时不应该匹配") +} + +func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{422, 400}, + Keywords: []string{}, + MatchMode: model.MatchModeAny, + } + + tests := []struct { + name string + statusCode int + body string + expected bool + }{ + {"状态码匹配 422", 422, "any message", true}, + {"状态码匹配 400", 400, "any message", true}, + {"状态码不匹配 500", 500, "any message", false}, + {"状态码不匹配 429", 429, "any message", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := svc.ruleMatches(rule, tt.statusCode, tt.body) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{}, + Keywords: []string{"context limit", "model not supported"}, + MatchMode: model.MatchModeAny, + } + + tests := []struct { + name string + statusCode int + body string + expected bool + }{ + {"关键词匹配 context limit", 500, "error: context limit reached", true}, + {"关键词匹配 model not supported", 400, "the model not supported here", true}, + {"关键词不匹配", 422, "some other error", false}, + // 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的 + // 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches + {"关键词大小写 - 输入已小写", 500, "context limit exceeded", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 模拟 MatchRule 的行为:先转换为小写 + bodyLower := strings.ToLower(tt.body) + result := svc.ruleMatches(rule, tt.statusCode, bodyLower) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { + // any 模式:错误码 OR 关键词 + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{422, 400}, + Keywords: []string{"context limit"}, + MatchMode: model.MatchModeAny, + } + + tests := []struct { + name string + statusCode int + body string + expected bool + reason string + }{ + { + name: "状态码和关键词都匹配", + statusCode: 422, + body: "context limit reached", + expected: true, + reason: "both match", + }, + { + name: "只有状态码匹配", + statusCode: 422, + body: "some other error", + expected: true, + reason: "code matches, keyword doesn't - OR mode should match", + }, + { + name: "只有关键词匹配", + statusCode: 500, + body: "context limit exceeded", + expected: true, + reason: "keyword matches, code doesn't - OR mode should match", + }, + { + name: "都不匹配", + statusCode: 500, + body: "some other error", + expected: false, + reason: "neither matches", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := svc.ruleMatches(rule, tt.statusCode, tt.body) + assert.Equal(t, tt.expected, result, tt.reason) + }) + } +} + +func TestRuleMatches_BothConditions_AllMode(t *testing.T) { + // all 模式:错误码 AND 关键词 + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{422, 400}, + Keywords: []string{"context limit"}, + MatchMode: model.MatchModeAll, + } + + tests := []struct { + name string + statusCode int + body string + expected bool + reason string + }{ + { + name: "状态码和关键词都匹配", + statusCode: 422, + body: "context limit reached", + expected: true, + reason: "both match - AND mode should match", + }, + { + name: "只有状态码匹配", + statusCode: 422, + body: "some other error", + expected: false, + reason: "code matches but keyword doesn't - AND mode should NOT match", + }, + { + name: "只有关键词匹配", + statusCode: 500, + body: "context limit exceeded", + expected: false, + reason: "keyword matches but code doesn't - AND mode should NOT match", + }, + { + name: "都不匹配", + statusCode: 500, + body: "some other error", + expected: false, + reason: "neither matches", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := svc.ruleMatches(rule, tt.statusCode, tt.body) + assert.Equal(t, tt.expected, result, tt.reason) + }) + } +} + +// ============================================================================= +// 测试 platformMatches 平台匹配逻辑 +// ============================================================================= + +func TestPlatformMatches(t *testing.T) { + svc := newTestService(nil) + + tests := []struct { + name string + rulePlatforms []string + requestPlatform string + expected bool + }{ + { + name: "空平台列表匹配所有", + rulePlatforms: []string{}, + requestPlatform: "anthropic", + expected: true, + }, + { + name: "nil平台列表匹配所有", + rulePlatforms: nil, + requestPlatform: "openai", + expected: true, + }, + { + name: "精确匹配 anthropic", + rulePlatforms: []string{"anthropic", "openai"}, + requestPlatform: "anthropic", + expected: true, + }, + { + name: "精确匹配 openai", + rulePlatforms: []string{"anthropic", "openai"}, + requestPlatform: "openai", + expected: true, + }, + { + name: "不匹配 gemini", + rulePlatforms: []string{"anthropic", "openai"}, + requestPlatform: "gemini", + expected: false, + }, + { + name: "大小写不敏感", + rulePlatforms: []string{"Anthropic", "OpenAI"}, + requestPlatform: "anthropic", + expected: true, + }, + { + name: "匹配 antigravity", + rulePlatforms: []string{"antigravity"}, + requestPlatform: "antigravity", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := &model.ErrorPassthroughRule{ + Platforms: tt.rulePlatforms, + } + result := svc.platformMatches(rule, tt.requestPlatform) + assert.Equal(t, tt.expected, result) + }) + } +} + +// ============================================================================= +// 测试 MatchRule 完整匹配流程 +// ============================================================================= + +func TestMatchRule_Priority(t *testing.T) { + // 测试规则按优先级排序,优先级小的先匹配 + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Low Priority", + Enabled: true, + Priority: 10, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + { + ID: 2, + Name: "High Priority", + Enabled: true, + Priority: 1, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + matched := svc.MatchRule("anthropic", 422, []byte("error")) + + require.NotNil(t, matched) + assert.Equal(t, int64(2), matched.ID, "应该匹配优先级更高(数值更小)的规则") + assert.Equal(t, "High Priority", matched.Name) +} + +func TestMatchRule_DisabledRule(t *testing.T) { + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Disabled Rule", + Enabled: false, + Priority: 1, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + { + ID: 2, + Name: "Enabled Rule", + Enabled: true, + Priority: 10, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + matched := svc.MatchRule("anthropic", 422, []byte("error")) + + require.NotNil(t, matched) + assert.Equal(t, int64(2), matched.ID, "应该跳过禁用的规则") +} + +func TestMatchRule_PlatformFilter(t *testing.T) { + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Anthropic Only", + Enabled: true, + Priority: 1, + ErrorCodes: []int{422}, + Platforms: []string{"anthropic"}, + MatchMode: model.MatchModeAny, + }, + { + ID: 2, + Name: "OpenAI Only", + Enabled: true, + Priority: 2, + ErrorCodes: []int{422}, + Platforms: []string{"openai"}, + MatchMode: model.MatchModeAny, + }, + { + ID: 3, + Name: "All Platforms", + Enabled: true, + Priority: 3, + ErrorCodes: []int{422}, + Platforms: []string{}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + + t.Run("Anthropic 请求匹配 Anthropic 规则", func(t *testing.T) { + matched := svc.MatchRule("anthropic", 422, []byte("error")) + require.NotNil(t, matched) + assert.Equal(t, int64(1), matched.ID) + }) + + t.Run("OpenAI 请求匹配 OpenAI 规则", func(t *testing.T) { + matched := svc.MatchRule("openai", 422, []byte("error")) + require.NotNil(t, matched) + assert.Equal(t, int64(2), matched.ID) + }) + + t.Run("Gemini 请求匹配全平台规则", func(t *testing.T) { + matched := svc.MatchRule("gemini", 422, []byte("error")) + require.NotNil(t, matched) + assert.Equal(t, int64(3), matched.ID) + }) + + t.Run("Antigravity 请求匹配全平台规则", func(t *testing.T) { + matched := svc.MatchRule("antigravity", 422, []byte("error")) + require.NotNil(t, matched) + assert.Equal(t, int64(3), matched.ID) + }) +} + +func TestMatchRule_NoMatch(t *testing.T) { + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Rule for 422", + Enabled: true, + Priority: 1, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + matched := svc.MatchRule("anthropic", 500, []byte("error")) + + assert.Nil(t, matched, "不匹配任何规则时应返回 nil") +} + +func TestMatchRule_EmptyRules(t *testing.T) { + svc := newTestService([]*model.ErrorPassthroughRule{}) + matched := svc.MatchRule("anthropic", 422, []byte("error")) + + assert.Nil(t, matched, "没有规则时应返回 nil") +} + +func TestMatchRule_CaseInsensitiveKeyword(t *testing.T) { + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Context Limit", + Enabled: true, + Priority: 1, + Keywords: []string{"Context Limit"}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + + tests := []struct { + name string + body string + expected bool + }{ + {"完全匹配", "Context Limit reached", true}, + {"小写匹配", "context limit reached", true}, + {"大写匹配", "CONTEXT LIMIT REACHED", true}, + {"混合大小写", "ConTeXt LiMiT error", true}, + {"不匹配", "some other error", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matched := svc.MatchRule("anthropic", 500, []byte(tt.body)) + if tt.expected { + assert.NotNil(t, matched) + } else { + assert.Nil(t, matched) + } + }) + } +} + +// ============================================================================= +// 测试真实场景 +// ============================================================================= + +func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) { + // 场景:上游返回 422 + "context limit has been reached",需要透传给客户端 + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Context Limit Passthrough", + Enabled: true, + Priority: 1, + ErrorCodes: []int{422}, + Keywords: []string{"context limit"}, + MatchMode: model.MatchModeAll, // 必须同时满足 + Platforms: []string{"anthropic", "antigravity"}, + PassthroughCode: true, + PassthroughBody: true, + }, + } + + svc := newTestService(rules) + + // 测试 Anthropic 平台 + t.Run("Anthropic 422 with context limit", func(t *testing.T) { + body := []byte(`{"type":"error","error":{"type":"invalid_request","message":"The context limit has been reached"}}`) + matched := svc.MatchRule("anthropic", 422, body) + require.NotNil(t, matched) + assert.True(t, matched.PassthroughCode) + assert.True(t, matched.PassthroughBody) + }) + + // 测试 Antigravity 平台 + t.Run("Antigravity 422 with context limit", func(t *testing.T) { + body := []byte(`{"error":"context limit exceeded"}`) + matched := svc.MatchRule("antigravity", 422, body) + require.NotNil(t, matched) + }) + + // 测试 OpenAI 平台(不在规则的平台列表中) + t.Run("OpenAI should not match", func(t *testing.T) { + body := []byte(`{"error":"context limit exceeded"}`) + matched := svc.MatchRule("openai", 422, body) + assert.Nil(t, matched, "OpenAI 不在规则的平台列表中") + }) + + // 测试状态码不匹配 + t.Run("Wrong status code", func(t *testing.T) { + body := []byte(`{"error":"context limit exceeded"}`) + matched := svc.MatchRule("anthropic", 400, body) + assert.Nil(t, matched, "状态码不匹配") + }) + + // 测试关键词不匹配 + t.Run("Wrong keyword", func(t *testing.T) { + body := []byte(`{"error":"rate limit exceeded"}`) + matched := svc.MatchRule("anthropic", 422, body) + assert.Nil(t, matched, "关键词不匹配") + }) +} + +func TestMatchRule_RealWorldScenario_CustomErrorMessage(t *testing.T) { + // 场景:某些错误需要返回自定义消息,隐藏上游详细信息 + customMsg := "Service temporarily unavailable, please try again later" + responseCode := 503 + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Hide Internal Errors", + Enabled: true, + Priority: 1, + ErrorCodes: []int{500, 502, 503}, + MatchMode: model.MatchModeAny, + PassthroughCode: false, + ResponseCode: &responseCode, + PassthroughBody: false, + CustomMessage: &customMsg, + }, + } + + svc := newTestService(rules) + + matched := svc.MatchRule("anthropic", 500, []byte("internal server error")) + require.NotNil(t, matched) + assert.False(t, matched.PassthroughCode) + assert.Equal(t, 503, *matched.ResponseCode) + assert.False(t, matched.PassthroughBody) + assert.Equal(t, customMsg, *matched.CustomMessage) +} + +// ============================================================================= +// 测试 model.Validate +// ============================================================================= + +func TestErrorPassthroughRule_Validate(t *testing.T) { + tests := []struct { + name string + rule *model.ErrorPassthroughRule + expectError bool + errorField string + }{ + { + name: "有效规则 - 透传模式(含错误码)", + rule: &model.ErrorPassthroughRule{ + Name: "Valid Rule", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: false, + }, + { + name: "有效规则 - 透传模式(含关键词)", + rule: &model.ErrorPassthroughRule{ + Name: "Valid Rule", + MatchMode: model.MatchModeAny, + Keywords: []string{"context limit"}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: false, + }, + { + name: "有效规则 - 自定义响应", + rule: &model.ErrorPassthroughRule{ + Name: "Valid Rule", + MatchMode: model.MatchModeAll, + ErrorCodes: []int{500}, + Keywords: []string{"internal error"}, + PassthroughCode: false, + ResponseCode: testIntPtr(503), + PassthroughBody: false, + CustomMessage: testStrPtr("Custom error"), + }, + expectError: false, + }, + { + name: "缺少名称", + rule: &model.ErrorPassthroughRule{ + Name: "", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: true, + errorField: "name", + }, + { + name: "无效的匹配模式", + rule: &model.ErrorPassthroughRule{ + Name: "Invalid Mode", + MatchMode: "invalid", + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: true, + errorField: "match_mode", + }, + { + name: "缺少匹配条件(错误码和关键词都为空)", + rule: &model.ErrorPassthroughRule{ + Name: "No Conditions", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{}, + Keywords: []string{}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: true, + errorField: "conditions", + }, + { + name: "缺少匹配条件(nil切片)", + rule: &model.ErrorPassthroughRule{ + Name: "Nil Conditions", + MatchMode: model.MatchModeAny, + ErrorCodes: nil, + Keywords: nil, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: true, + errorField: "conditions", + }, + { + name: "自定义状态码但未提供值", + rule: &model.ErrorPassthroughRule{ + Name: "Missing Code", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: false, + ResponseCode: nil, + PassthroughBody: true, + }, + expectError: true, + errorField: "response_code", + }, + { + name: "自定义消息但未提供值", + rule: &model.ErrorPassthroughRule{ + Name: "Missing Message", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: false, + CustomMessage: nil, + }, + expectError: true, + errorField: "custom_message", + }, + { + name: "自定义消息为空字符串", + rule: &model.ErrorPassthroughRule{ + Name: "Empty Message", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: false, + CustomMessage: testStrPtr(""), + }, + expectError: true, + errorField: "custom_message", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.rule.Validate() + if tt.expectError { + require.Error(t, err) + validationErr, ok := err.(*model.ValidationError) + require.True(t, ok, "应该返回 ValidationError") + assert.Equal(t, tt.errorField, validationErr.Field) + } else { + assert.NoError(t, err) + } + }) + } +} + +// ============================================================================= +// 测试写路径缓存刷新(Create/Update/Delete) +// ============================================================================= + +func TestCreate_ForceRefreshCacheAfterWrite(t *testing.T) { + ctx := context.Background() + + staleRule := newPassthroughRuleForWritePathTest(99, "service temporarily unavailable after multiple", "旧缓存消息") + repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{}} + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true) + + svc := &ErrorPassthroughService{repo: repo, cache: cache} + svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule}) + + newRule := newPassthroughRuleForWritePathTest(0, "service temporarily unavailable after multiple", "上游请求失败") + created, err := svc.Create(ctx, newRule) + require.NoError(t, err) + require.NotNil(t, created) + + body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`) + matched := svc.MatchRule("anthropic", 503, body) + require.NotNil(t, matched) + assert.Equal(t, created.ID, matched.ID) + if assert.NotNil(t, matched.CustomMessage) { + assert.Equal(t, "上游请求失败", *matched.CustomMessage) + } + + assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get") + assert.Equal(t, 1, cache.invalidateCalled) + assert.Equal(t, 1, cache.setCalled) + assert.Equal(t, 1, cache.notifyCalled) +} + +func TestUpdate_ForceRefreshCacheAfterWrite(t *testing.T) { + ctx := context.Background() + + originalRule := newPassthroughRuleForWritePathTest(1, "old keyword", "旧消息") + repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{originalRule}} + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{originalRule}, true) + + svc := &ErrorPassthroughService{repo: repo, cache: cache} + svc.setLocalCache([]*model.ErrorPassthroughRule{originalRule}) + + updatedRule := newPassthroughRuleForWritePathTest(1, "new keyword", "新消息") + _, err := svc.Update(ctx, updatedRule) + require.NoError(t, err) + + oldBody := []byte(`{"message":"old keyword"}`) + oldMatched := svc.MatchRule("anthropic", 503, oldBody) + assert.Nil(t, oldMatched, "更新后旧关键词不应继续命中") + + newBody := []byte(`{"message":"new keyword"}`) + newMatched := svc.MatchRule("anthropic", 503, newBody) + require.NotNil(t, newMatched) + if assert.NotNil(t, newMatched.CustomMessage) { + assert.Equal(t, "新消息", *newMatched.CustomMessage) + } + + assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get") + assert.Equal(t, 1, cache.invalidateCalled) + assert.Equal(t, 1, cache.setCalled) + assert.Equal(t, 1, cache.notifyCalled) +} + +func TestDelete_ForceRefreshCacheAfterWrite(t *testing.T) { + ctx := context.Background() + + rule := newPassthroughRuleForWritePathTest(1, "to be deleted", "删除前消息") + repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{rule}} + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{rule}, true) + + svc := &ErrorPassthroughService{repo: repo, cache: cache} + svc.setLocalCache([]*model.ErrorPassthroughRule{rule}) + + err := svc.Delete(ctx, 1) + require.NoError(t, err) + + body := []byte(`{"message":"to be deleted"}`) + matched := svc.MatchRule("anthropic", 503, body) + assert.Nil(t, matched, "删除后规则不应再命中") + + assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get") + assert.Equal(t, 1, cache.invalidateCalled) + assert.Equal(t, 1, cache.setCalled) + assert.Equal(t, 1, cache.notifyCalled) +} + +func TestNewService_StartupReloadFromDBToHealStaleCache(t *testing.T) { + staleRule := newPassthroughRuleForWritePathTest(99, "stale keyword", "旧缓存消息") + latestRule := newPassthroughRuleForWritePathTest(1, "fresh keyword", "最新消息") + + repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{latestRule}} + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true) + + svc := NewErrorPassthroughService(repo, cache) + + matchedFresh := svc.MatchRule("anthropic", 503, []byte(`{"message":"fresh keyword"}`)) + require.NotNil(t, matchedFresh) + assert.Equal(t, int64(1), matchedFresh.ID) + + matchedStale := svc.MatchRule("anthropic", 503, []byte(`{"message":"stale keyword"}`)) + assert.Nil(t, matchedStale, "启动后应以 DB 最新规则覆盖旧缓存") + + assert.Equal(t, 0, cache.getCalled, "启动强制 DB 刷新不应依赖 cache.Get") + assert.Equal(t, 1, cache.setCalled, "启动后应回写缓存,覆盖陈旧缓存") +} + +func TestUpdate_RefreshFailureShouldNotKeepStaleEnabledRule(t *testing.T) { + ctx := context.Background() + + staleRule := newPassthroughRuleForWritePathTest(1, "service temporarily unavailable after multiple", "旧缓存消息") + repo := &mockErrorPassthroughRepo{ + rules: []*model.ErrorPassthroughRule{staleRule}, + listErr: errors.New("db list failed"), + } + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true) + + svc := &ErrorPassthroughService{repo: repo, cache: cache} + svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule}) + + disabledRule := *staleRule + disabledRule.Enabled = false + _, err := svc.Update(ctx, &disabledRule) + require.NoError(t, err) + + body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`) + matched := svc.MatchRule("anthropic", 503, body) + assert.Nil(t, matched, "刷新失败时不应继续命中旧的启用规则") + + svc.localCacheMu.RLock() + assert.Nil(t, svc.localCache, "刷新失败后应清空本地缓存,避免误命中") + svc.localCacheMu.RUnlock() +} + +func newPassthroughRuleForWritePathTest(id int64, keyword, customMsg string) *model.ErrorPassthroughRule { + responseCode := 503 + rule := &model.ErrorPassthroughRule{ + ID: id, + Name: "write-path-cache-refresh", + Enabled: true, + Priority: 1, + ErrorCodes: []int{503}, + Keywords: []string{keyword}, + MatchMode: model.MatchModeAll, + PassthroughCode: false, + ResponseCode: &responseCode, + PassthroughBody: false, + CustomMessage: &customMsg, + } + return rule +} + +// Helper functions +func testIntPtr(i int) *int { return &i } +func testStrPtr(s string) *string { return &s } diff --git a/backend/internal/service/force_cache_billing_test.go b/backend/internal/service/force_cache_billing_test.go new file mode 100644 index 00000000..073b1345 --- /dev/null +++ b/backend/internal/service/force_cache_billing_test.go @@ -0,0 +1,133 @@ +//go:build unit + +package service + +import ( + "context" + "testing" +) + +func TestIsForceCacheBilling(t *testing.T) { + tests := []struct { + name string + ctx context.Context + expected bool + }{ + { + name: "context without force cache billing", + ctx: context.Background(), + expected: false, + }, + { + name: "context with force cache billing set to true", + ctx: context.WithValue(context.Background(), ForceCacheBillingContextKey, true), + expected: true, + }, + { + name: "context with force cache billing set to false", + ctx: context.WithValue(context.Background(), ForceCacheBillingContextKey, false), + expected: false, + }, + { + name: "context with wrong type value", + ctx: context.WithValue(context.Background(), ForceCacheBillingContextKey, "true"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsForceCacheBilling(tt.ctx) + if result != tt.expected { + t.Errorf("IsForceCacheBilling() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestWithForceCacheBilling(t *testing.T) { + ctx := context.Background() + + // 原始上下文没有标记 + if IsForceCacheBilling(ctx) { + t.Error("original context should not have force cache billing") + } + + // 使用 WithForceCacheBilling 后应该有标记 + newCtx := WithForceCacheBilling(ctx) + if !IsForceCacheBilling(newCtx) { + t.Error("new context should have force cache billing") + } + + // 原始上下文应该不受影响 + if IsForceCacheBilling(ctx) { + t.Error("original context should still not have force cache billing") + } +} + +func TestForceCacheBilling_TokenConversion(t *testing.T) { + tests := []struct { + name string + forceCacheBilling bool + inputTokens int + cacheReadInputTokens int + expectedInputTokens int + expectedCacheReadTokens int + }{ + { + name: "force cache billing converts input to cache_read", + forceCacheBilling: true, + inputTokens: 1000, + cacheReadInputTokens: 500, + expectedInputTokens: 0, + expectedCacheReadTokens: 1500, // 500 + 1000 + }, + { + name: "no force cache billing keeps tokens unchanged", + forceCacheBilling: false, + inputTokens: 1000, + cacheReadInputTokens: 500, + expectedInputTokens: 1000, + expectedCacheReadTokens: 500, + }, + { + name: "force cache billing with zero input tokens does nothing", + forceCacheBilling: true, + inputTokens: 0, + cacheReadInputTokens: 500, + expectedInputTokens: 0, + expectedCacheReadTokens: 500, + }, + { + name: "force cache billing with zero cache_read tokens", + forceCacheBilling: true, + inputTokens: 1000, + cacheReadInputTokens: 0, + expectedInputTokens: 0, + expectedCacheReadTokens: 1000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 模拟 RecordUsage 中的 ForceCacheBilling 逻辑 + usage := ClaudeUsage{ + InputTokens: tt.inputTokens, + CacheReadInputTokens: tt.cacheReadInputTokens, + } + + // 这是 RecordUsage 中的实际逻辑 + if tt.forceCacheBilling && usage.InputTokens > 0 { + usage.CacheReadInputTokens += usage.InputTokens + usage.InputTokens = 0 + } + + if usage.InputTokens != tt.expectedInputTokens { + t.Errorf("InputTokens = %d, want %d", usage.InputTokens, tt.expectedInputTokens) + } + if usage.CacheReadInputTokens != tt.expectedCacheReadTokens { + t.Errorf("CacheReadInputTokens = %d, want %d", usage.CacheReadInputTokens, tt.expectedCacheReadTokens) + } + }) + } +} diff --git a/backend/internal/service/gateway_cached_tokens_test.go b/backend/internal/service/gateway_cached_tokens_test.go new file mode 100644 index 00000000..f886c855 --- /dev/null +++ b/backend/internal/service/gateway_cached_tokens_test.go @@ -0,0 +1,288 @@ +package service + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ---------- reconcileCachedTokens 单元测试 ---------- + +func TestReconcileCachedTokens_NilUsage(t *testing.T) { + assert.False(t, reconcileCachedTokens(nil)) +} + +func TestReconcileCachedTokens_AlreadyHasCacheRead(t *testing.T) { + // 已有标准字段,不应覆盖 + usage := map[string]any{ + "cache_read_input_tokens": float64(100), + "cached_tokens": float64(50), + } + assert.False(t, reconcileCachedTokens(usage)) + assert.Equal(t, float64(100), usage["cache_read_input_tokens"]) +} + +func TestReconcileCachedTokens_KimiStyle(t *testing.T) { + // Kimi 风格:cache_read_input_tokens=0,cached_tokens>0 + usage := map[string]any{ + "input_tokens": float64(23), + "cache_creation_input_tokens": float64(0), + "cache_read_input_tokens": float64(0), + "cached_tokens": float64(23), + } + assert.True(t, reconcileCachedTokens(usage)) + assert.Equal(t, float64(23), usage["cache_read_input_tokens"]) +} + +func TestReconcileCachedTokens_NoCachedTokens(t *testing.T) { + // 无 cached_tokens 字段(原生 Claude) + usage := map[string]any{ + "input_tokens": float64(100), + "cache_read_input_tokens": float64(0), + "cache_creation_input_tokens": float64(0), + } + assert.False(t, reconcileCachedTokens(usage)) + assert.Equal(t, float64(0), usage["cache_read_input_tokens"]) +} + +func TestReconcileCachedTokens_CachedTokensZero(t *testing.T) { + // cached_tokens 为 0,不应覆盖 + usage := map[string]any{ + "cache_read_input_tokens": float64(0), + "cached_tokens": float64(0), + } + assert.False(t, reconcileCachedTokens(usage)) + assert.Equal(t, float64(0), usage["cache_read_input_tokens"]) +} + +func TestReconcileCachedTokens_MissingCacheReadField(t *testing.T) { + // cache_read_input_tokens 字段完全不存在,cached_tokens > 0 + usage := map[string]any{ + "cached_tokens": float64(42), + } + assert.True(t, reconcileCachedTokens(usage)) + assert.Equal(t, float64(42), usage["cache_read_input_tokens"]) +} + +// ---------- 流式 message_start 事件 reconcile 测试 ---------- + +func TestStreamingReconcile_MessageStart(t *testing.T) { + // 模拟 Kimi 返回的 message_start SSE 事件 + eventJSON := `{ + "type": "message_start", + "message": { + "id": "msg_123", + "type": "message", + "role": "assistant", + "model": "kimi", + "usage": { + "input_tokens": 23, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "cached_tokens": 23 + } + } + }` + + var event map[string]any + require.NoError(t, json.Unmarshal([]byte(eventJSON), &event)) + + eventType, _ := event["type"].(string) + require.Equal(t, "message_start", eventType) + + // 模拟 processSSEEvent 中的 reconcile 逻辑 + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + reconcileCachedTokens(u) + } + } + + // 验证 cache_read_input_tokens 已被填充 + msg, ok := event["message"].(map[string]any) + require.True(t, ok) + usage, ok := msg["usage"].(map[string]any) + require.True(t, ok) + assert.Equal(t, float64(23), usage["cache_read_input_tokens"]) + + // 验证重新序列化后 JSON 也包含正确值 + data, err := json.Marshal(event) + require.NoError(t, err) + assert.Equal(t, int64(23), gjson.GetBytes(data, "message.usage.cache_read_input_tokens").Int()) +} + +func TestStreamingReconcile_MessageStart_NativeClaude(t *testing.T) { + // 原生 Claude 不返回 cached_tokens,reconcile 不应改变任何值 + eventJSON := `{ + "type": "message_start", + "message": { + "usage": { + "input_tokens": 100, + "cache_creation_input_tokens": 50, + "cache_read_input_tokens": 30 + } + } + }` + + var event map[string]any + require.NoError(t, json.Unmarshal([]byte(eventJSON), &event)) + + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + reconcileCachedTokens(u) + } + } + + msg, ok := event["message"].(map[string]any) + require.True(t, ok) + usage, ok := msg["usage"].(map[string]any) + require.True(t, ok) + assert.Equal(t, float64(30), usage["cache_read_input_tokens"]) +} + +// ---------- 流式 message_delta 事件 reconcile 测试 ---------- + +func TestStreamingReconcile_MessageDelta(t *testing.T) { + // 模拟 Kimi 返回的 message_delta SSE 事件 + eventJSON := `{ + "type": "message_delta", + "usage": { + "output_tokens": 7, + "cache_read_input_tokens": 0, + "cached_tokens": 15 + } + }` + + var event map[string]any + require.NoError(t, json.Unmarshal([]byte(eventJSON), &event)) + + eventType, _ := event["type"].(string) + require.Equal(t, "message_delta", eventType) + + // 模拟 processSSEEvent 中的 reconcile 逻辑 + usage, ok := event["usage"].(map[string]any) + require.True(t, ok) + reconcileCachedTokens(usage) + assert.Equal(t, float64(15), usage["cache_read_input_tokens"]) +} + +func TestStreamingReconcile_MessageDelta_NativeClaude(t *testing.T) { + // 原生 Claude 的 message_delta 通常没有 cached_tokens + eventJSON := `{ + "type": "message_delta", + "usage": { + "output_tokens": 50 + } + }` + + var event map[string]any + require.NoError(t, json.Unmarshal([]byte(eventJSON), &event)) + + usage, ok := event["usage"].(map[string]any) + require.True(t, ok) + reconcileCachedTokens(usage) + _, hasCacheRead := usage["cache_read_input_tokens"] + assert.False(t, hasCacheRead, "不应为原生 Claude 响应注入 cache_read_input_tokens") +} + +// ---------- 非流式响应 reconcile 测试 ---------- + +func TestNonStreamingReconcile_KimiResponse(t *testing.T) { + // 模拟 Kimi 非流式响应 + body := []byte(`{ + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hello"}], + "model": "kimi", + "usage": { + "input_tokens": 23, + "output_tokens": 7, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "cached_tokens": 23, + "prompt_tokens": 23, + "completion_tokens": 7 + } + }`) + + // 模拟 handleNonStreamingResponse 中的逻辑 + var response struct { + Usage ClaudeUsage `json:"usage"` + } + require.NoError(t, json.Unmarshal(body, &response)) + + // reconcile + if response.Usage.CacheReadInputTokens == 0 { + cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() + if cachedTokens > 0 { + response.Usage.CacheReadInputTokens = int(cachedTokens) + if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil { + body = newBody + } + } + } + + // 验证内部 usage(计费用) + assert.Equal(t, 23, response.Usage.CacheReadInputTokens) + assert.Equal(t, 23, response.Usage.InputTokens) + assert.Equal(t, 7, response.Usage.OutputTokens) + + // 验证返回给客户端的 JSON body + assert.Equal(t, int64(23), gjson.GetBytes(body, "usage.cache_read_input_tokens").Int()) +} + +func TestNonStreamingReconcile_NativeClaude(t *testing.T) { + // 原生 Claude 响应:cache_read_input_tokens 已有值 + body := []byte(`{ + "usage": { + "input_tokens": 100, + "output_tokens": 50, + "cache_creation_input_tokens": 20, + "cache_read_input_tokens": 30 + } + }`) + + var response struct { + Usage ClaudeUsage `json:"usage"` + } + require.NoError(t, json.Unmarshal(body, &response)) + + // CacheReadInputTokens == 30,条件不成立,整个 reconcile 分支不会执行 + assert.NotZero(t, response.Usage.CacheReadInputTokens) + assert.Equal(t, 30, response.Usage.CacheReadInputTokens) +} + +func TestNonStreamingReconcile_NoCachedTokens(t *testing.T) { + // 没有 cached_tokens 字段 + body := []byte(`{ + "usage": { + "input_tokens": 100, + "output_tokens": 50, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0 + } + }`) + + var response struct { + Usage ClaudeUsage `json:"usage"` + } + require.NoError(t, json.Unmarshal(body, &response)) + + if response.Usage.CacheReadInputTokens == 0 { + cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() + if cachedTokens > 0 { + response.Usage.CacheReadInputTokens = int(cachedTokens) + if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil { + body = newBody + } + } + } + + // cache_read_input_tokens 应保持为 0 + assert.Equal(t, 0, response.Usage.CacheReadInputTokens) + assert.Equal(t, int64(0), gjson.GetBytes(body, "usage.cache_read_input_tokens").Int()) +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 4bfa23d1..b3e60c21 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -216,6 +216,22 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context return nil } +func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { + return 0, nil +} + +func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) { + return nil, nil +} + +func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { + return "", 0, false +} + +func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { + return nil +} + type mockGroupRepoForGateway struct { groups map[int64]*Group getByIDCalls int @@ -332,7 +348,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Antigravity(t *testing cfg: testConfig(), } - acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAntigravity) + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-sonnet-4-5", nil, PlatformAntigravity) require.NoError(t, err) require.NotNil(t, acc) require.Equal(t, int64(2), acc.ID) @@ -670,7 +686,7 @@ func TestGatewayService_SelectAccountForModelWithExclusions_ForcePlatform(t *tes cfg: testConfig(), } - acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "claude-sonnet-4-5", nil) require.NoError(t, err) require.NotNil(t, acc) require.Equal(t, int64(2), acc.ID) @@ -1014,10 +1030,16 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) { expected bool }{ { - name: "Antigravity平台-支持claude模型", + name: "Antigravity平台-支持默认映射中的claude模型", + account: &Account{Platform: PlatformAntigravity}, + model: "claude-sonnet-4-5", + expected: true, + }, + { + name: "Antigravity平台-不支持非默认映射中的claude模型", account: &Account{Platform: PlatformAntigravity}, model: "claude-3-5-sonnet-20241022", - expected: true, + expected: false, }, { name: "Antigravity平台-支持gemini模型", @@ -1115,7 +1137,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { cfg: testConfig(), } - acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) + acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-sonnet-4-5", nil, PlatformAnthropic) require.NoError(t, err) require.NotNil(t, acc) require.Equal(t, int64(2), acc.ID, "应选择优先级最高的账户(包含启用混合调度的antigravity)") @@ -1123,7 +1145,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { t.Run("混合调度-路由优先选择路由账号", func(t *testing.T) { groupID := int64(30) - requestedModel := "claude-3-5-sonnet-20241022" + requestedModel := "claude-sonnet-4-5" repo := &mockAccountRepoForPlatform{ accounts: []Account{ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, @@ -1168,7 +1190,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { t.Run("混合调度-路由粘性命中", func(t *testing.T) { groupID := int64(31) - requestedModel := "claude-3-5-sonnet-20241022" + requestedModel := "claude-sonnet-4-5" repo := &mockAccountRepoForPlatform{ accounts: []Account{ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, @@ -1320,7 +1342,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { Schedulable: true, Extra: map[string]any{ "model_rate_limits": map[string]any{ - "claude_sonnet": map[string]any{ + "claude-3-5-sonnet-20241022": map[string]any{ "rate_limit_reset_at": resetAt.Format(time.RFC3339), }, }, @@ -1465,7 +1487,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { cfg: testConfig(), } - acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) + acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-sonnet-4-5", nil, PlatformAnthropic) require.NoError(t, err) require.NotNil(t, acc) require.Equal(t, int64(2), acc.ID, "应返回粘性会话绑定的启用mixed_scheduling的antigravity账户") @@ -1597,7 +1619,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { cfg: testConfig(), } - acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) + acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-sonnet-4-5", nil, PlatformAnthropic) require.NoError(t, err) require.NotNil(t, acc) require.Equal(t, int64(1), acc.ID) @@ -1870,6 +1892,19 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a return nil } +func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { + result := make(map[int64]*UserLoadInfo, len(users)) + for _, user := range users { + result[user.ID] = &UserLoadInfo{ + UserID: user.ID, + CurrentConcurrency: 0, + WaitingCount: 0, + LoadRate: 0, + } + } + return result, nil +} + // TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ctx := context.Background() @@ -2747,7 +2782,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { Concurrency: 5, Extra: map[string]any{ "model_rate_limits": map[string]any{ - "claude_sonnet": map[string]any{ + "claude-3-5-sonnet-20241022": map[string]any{ "rate_limit_reset_at": now.Format(time.RFC3339), }, }, diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index aa48d880..0ecd18aa 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -4,6 +4,9 @@ import ( "bytes" "encoding/json" "fmt" + "math" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) // ParsedRequest 保存网关请求的预解析结果 @@ -19,13 +22,15 @@ import ( // 2. 将解析结果 ParsedRequest 传递给 Service 层 // 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销 type ParsedRequest struct { - Body []byte // 原始请求体(保留用于转发) - Model string // 请求的模型名称 - Stream bool // 是否为流式请求 - MetadataUserID string // metadata.user_id(用于会话亲和) - System any // system 字段内容 - Messages []any // messages 数组 - HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) + Body []byte // 原始请求体(保留用于转发) + Model string // 请求的模型名称 + Stream bool // 是否为流式请求 + MetadataUserID string // metadata.user_id(用于会话亲和) + System any // system 字段内容 + Messages []any // messages 数组 + HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) + ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) + MaxTokens int // max_tokens 值(用于探测请求拦截) } // ParseGatewayRequest 解析网关请求体并返回结构化结果 @@ -69,9 +74,62 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { parsed.Messages = messages } + // thinking: {type: "enabled"} + if rawThinking, ok := req["thinking"].(map[string]any); ok { + if t, ok := rawThinking["type"].(string); ok && t == "enabled" { + parsed.ThinkingEnabled = true + } + } + + // max_tokens + if rawMaxTokens, exists := req["max_tokens"]; exists { + if maxTokens, ok := parseIntegralNumber(rawMaxTokens); ok { + parsed.MaxTokens = maxTokens + } + } + return parsed, nil } +// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。 +// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。 +func parseIntegralNumber(raw any) (int, bool) { + switch v := raw.(type) { + case float64: + if math.IsNaN(v) || math.IsInf(v, 0) || v != math.Trunc(v) { + return 0, false + } + if v > float64(math.MaxInt) || v < float64(math.MinInt) { + return 0, false + } + return int(v), true + case int: + return v, true + case int8: + return int(v), true + case int16: + return int(v), true + case int32: + return int(v), true + case int64: + if v > int64(math.MaxInt) || v < int64(math.MinInt) { + return 0, false + } + return int(v), true + case json.Number: + i64, err := v.Int64() + if err != nil { + return 0, false + } + if i64 > int64(math.MaxInt) || i64 < int64(math.MinInt) { + return 0, false + } + return int(i64), true + default: + return 0, false + } +} + // FilterThinkingBlocks removes thinking blocks from request body // Returns filtered body or original body if filtering fails (fail-safe) // This prevents 400 errors from invalid thinking block signatures @@ -466,7 +524,7 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte { // only keep thinking blocks with valid signatures if thinkingEnabled && role == "assistant" { signature, _ := blockMap["signature"].(string) - if signature != "" && signature != "skip_thought_signature_validator" { + if signature != "" && signature != antigravity.DummyThoughtSignature { newContent = append(newContent, block) continue } diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index f92496fb..4e390b0a 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -17,6 +17,29 @@ func TestParseGatewayRequest(t *testing.T) { require.True(t, parsed.HasSystem) require.NotNil(t, parsed.System) require.Len(t, parsed.Messages, 1) + require.False(t, parsed.ThinkingEnabled) +} + +func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`) + parsed, err := ParseGatewayRequest(body) + require.NoError(t, err) + require.Equal(t, "claude-sonnet-4-5", parsed.Model) + require.True(t, parsed.ThinkingEnabled) +} + +func TestParseGatewayRequest_MaxTokens(t *testing.T) { + body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`) + parsed, err := ParseGatewayRequest(body) + require.NoError(t, err) + require.Equal(t, 1, parsed.MaxTokens) +} + +func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) { + body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`) + parsed, err := ParseGatewayRequest(body) + require.NoError(t, err) + require.Equal(t, 0, parsed.MaxTokens) } func TestParseGatewayRequest_SystemNull(t *testing.T) { diff --git a/backend/internal/service/gateway_sanitize_test.go b/backend/internal/service/gateway_sanitize_test.go index 8fa971ca..a62bc8c7 100644 --- a/backend/internal/service/gateway_sanitize_test.go +++ b/backend/internal/service/gateway_sanitize_test.go @@ -12,10 +12,3 @@ func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) { got := sanitizeSystemText(in) require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got) } - -func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) { - in := "OpenCode and opencode are mentioned." - got := sanitizeToolDescription(in) - // We no longer rewrite tool descriptions; only redact obvious path leaks. - require.Equal(t, in, got) -} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index f52cd2d8..32646b11 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -20,7 +20,6 @@ import ( "strings" "sync/atomic" "time" - "unicode" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" @@ -50,6 +49,29 @@ const ( claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) +// ForceCacheBillingContextKey 强制缓存计费上下文键 +// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 +type forceCacheBillingKeyType struct{} + +// accountWithLoad 账号与负载信息的组合,用于负载感知调度 +type accountWithLoad struct { + account *Account + loadInfo *AccountLoadInfo +} + +var ForceCacheBillingContextKey = forceCacheBillingKeyType{} + +// IsForceCacheBilling 检查是否启用强制缓存计费 +func IsForceCacheBilling(ctx context.Context) bool { + v, _ := ctx.Value(ForceCacheBillingContextKey).(bool) + return v +} + +// WithForceCacheBilling 返回带有强制缓存计费标记的上下文 +func WithForceCacheBilling(ctx context.Context) context.Context { + return context.WithValue(ctx, ForceCacheBillingContextKey, true) +} + func (s *GatewayService) debugModelRoutingEnabled() bool { v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) return v == "1" || v == "true" || v == "yes" || v == "on" @@ -208,40 +230,6 @@ var ( sseDataRe = regexp.MustCompile(`^data:\s*`) sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) - toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`) - toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`) - toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`) - toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`) - modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`) - toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`) - toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`) - - claudeToolNameOverrides = map[string]string{ - "bash": "Bash", - "read": "Read", - "edit": "Edit", - "write": "Write", - "task": "Task", - "glob": "Glob", - "grep": "Grep", - "webfetch": "WebFetch", - "websearch": "WebSearch", - "todowrite": "TodoWrite", - "question": "AskUserQuestion", - } - openCodeToolOverrides = map[string]string{ - "Bash": "bash", - "Read": "read", - "Edit": "edit", - "Write": "write", - "Task": "task", - "Glob": "glob", - "Grep": "grep", - "WebFetch": "webfetch", - "WebSearch": "websearch", - "TodoWrite": "todowrite", - "AskUserQuestion": "question", - } // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 @@ -257,6 +245,9 @@ var ( // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") +// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内 +var ErrModelScopeNotSupported = errors.New("model scope not supported by this group") + // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, @@ -282,6 +273,13 @@ var allowedHeaders = map[string]bool{ // GatewayCache 定义网关服务的缓存操作接口。 // 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。 // +// ModelLoadInfo 模型负载信息(用于 Antigravity 调度) +// Model load info for Antigravity scheduling +type ModelLoadInfo struct { + CallCount int64 // 当前分钟调用次数 / Call count in current minute + LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled) +} + // GatewayCache defines cache operations for gateway service. // Provides sticky session storage, retrieval, refresh and deletion capabilities. type GatewayCache interface { @@ -297,6 +295,24 @@ type GatewayCache interface { // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 // Delete sticky session binding, used to proactively clean up when account becomes unavailable DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error + + // IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用) + // Increment model call count and update last scheduling time (Antigravity only) + // 返回更新后的调用次数 + IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) + + // GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用) + // Batch get model load info for accounts (Antigravity only) + GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) + + // FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配) + // Find Gemini session using MGET reverse order matching + // 返回最长匹配的会话信息(uuid, accountID) + FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) + + // SaveGeminiSession 保存 Gemini 会话 + // Save Gemini session binding + SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error } // derefGroupID safely dereferences *int64 to int64, returning 0 if nil @@ -307,16 +323,23 @@ func derefGroupID(groupID *int64) int64 { return *groupID } +// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。 +// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。 +// 低于此阈值时保持粘性会话,等待短暂限流结束。 +const stickySessionRateLimitThreshold = 10 * time.Second + // shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 -// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。 +// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间, +// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。 // 这确保后续请求不会继续使用不可用的账号。 // // shouldClearStickySession checks if an account is in an unschedulable state // and the sticky session binding should be cleared. // Returns true when account status is error/disabled, schedulable is false, -// or within temporary unschedulable period. +// within temporary unschedulable period, or model rate limit remaining time +// exceeds stickySessionRateLimitThreshold. // This ensures subsequent requests won't continue using unavailable accounts. -func shouldClearStickySession(account *Account) bool { +func shouldClearStickySession(account *Account, requestedModel string) bool { if account == nil { return false } @@ -326,6 +349,10 @@ func shouldClearStickySession(account *Account) bool { if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { return true } + // 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话 + if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold { + return true + } return false } @@ -368,7 +395,9 @@ type ForwardResult struct { // UpstreamFailoverError indicates an upstream error that should trigger account failover. type UpstreamFailoverError struct { - StatusCode int + StatusCode int + ResponseBody []byte // 上游响应体,用于错误透传规则匹配 + ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true } func (e *UpstreamFailoverError) Error() string { @@ -382,6 +411,7 @@ type GatewayService struct { usageLogRepo UsageLogRepository userRepo UserRepository userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository cache GatewayCache cfg *config.Config schedulerSnapshot *SchedulerSnapshotService @@ -403,6 +433,7 @@ func NewGatewayService( usageLogRepo UsageLogRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, cache GatewayCache, cfg *config.Config, schedulerSnapshot *SchedulerSnapshotService, @@ -422,6 +453,7 @@ func NewGatewayService( usageLogRepo: usageLogRepo, userRepo: userRepo, userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, cache: cache, cfg: cfg, schedulerSnapshot: schedulerSnapshot, @@ -498,6 +530,23 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID return accountID, nil } +// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配) +// 返回最长匹配的会话信息(uuid, accountID) +func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { + if digestChain == "" || s.cache == nil { + return "", 0, false + } + return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain) +} + +// SaveGeminiSession 保存 Gemini 会话 +func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { + if digestChain == "" || s.cache == nil { + return nil + } + return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID) +} + func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { if parsed == nil { return "" @@ -585,12 +634,18 @@ func (s *GatewayService) hashContent(content string) string { } // replaceModelInBody 替换请求体中的model字段 +// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改 func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { - var req map[string]any + var req map[string]json.RawMessage if err := json.Unmarshal(body, &req); err != nil { return body } - req["model"] = newModel + // 只序列化 model 字段 + modelBytes, err := json.Marshal(newModel) + if err != nil { + return body + } + req["model"] = modelBytes newBody, err := json.Marshal(req) if err != nil { return body @@ -604,98 +659,6 @@ type claudeOAuthNormalizeOptions struct { stripSystemCacheControl bool } -func stripToolPrefix(value string) string { - if value == "" { - return value - } - return toolPrefixRe.ReplaceAllString(value, "") -} - -func toPascalCase(value string) string { - if value == "" { - return value - } - normalized := toolNameBoundaryRe.ReplaceAllString(value, " ") - tokens := make([]string, 0) - for _, token := range strings.Fields(normalized) { - expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2") - parts := strings.Fields(expanded) - if len(parts) > 0 { - tokens = append(tokens, parts...) - } - } - if len(tokens) == 0 { - return value - } - var builder strings.Builder - for _, token := range tokens { - lower := strings.ToLower(token) - if lower == "" { - continue - } - runes := []rune(lower) - runes[0] = unicode.ToUpper(runes[0]) - _, _ = builder.WriteString(string(runes)) - } - return builder.String() -} - -func toSnakeCase(value string) string { - if value == "" { - return value - } - output := toolNameCamelRe.ReplaceAllString(value, "$1_$2") - output = toolNameBoundaryRe.ReplaceAllString(output, "_") - output = strings.Trim(output, "_") - return strings.ToLower(output) -} - -func normalizeToolNameForClaude(name string, cache map[string]string) string { - if name == "" { - return name - } - stripped := stripToolPrefix(name) - mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)] - if !ok { - mapped = toPascalCase(stripped) - } - if mapped != "" && cache != nil && mapped != stripped { - cache[mapped] = stripped - } - if mapped == "" { - return stripped - } - return mapped -} - -func normalizeToolNameForOpenCode(name string, cache map[string]string) string { - if name == "" { - return name - } - stripped := stripToolPrefix(name) - if cache != nil { - if mapped, ok := cache[stripped]; ok { - return mapped - } - } - if mapped, ok := openCodeToolOverrides[stripped]; ok { - return mapped - } - return toSnakeCase(stripped) -} - -func normalizeParamNameForOpenCode(name string, cache map[string]string) string { - if name == "" { - return name - } - if cache != nil { - if mapped, ok := cache[name]; ok { - return mapped - } - } - return name -} - // sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present). // We intentionally avoid broad keyword replacement in system prompts to prevent // accidentally changing user-provided instructions. @@ -714,55 +677,6 @@ func sanitizeSystemText(text string) string { return text } -func sanitizeToolDescription(description string) string { - if description == "" { - return description - } - description = toolDescAbsPathRe.ReplaceAllString(description, "[path]") - description = toolDescWinPathRe.ReplaceAllString(description, "[path]") - // Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings). - // Tool names/skill names may rely on exact wording, and rewriting can be misleading. - return description -} - -func normalizeToolInputSchema(inputSchema any, cache map[string]string) { - schema, ok := inputSchema.(map[string]any) - if !ok { - return - } - properties, ok := schema["properties"].(map[string]any) - if !ok { - return - } - - newProperties := make(map[string]any, len(properties)) - for key, value := range properties { - snakeKey := toSnakeCase(key) - newProperties[snakeKey] = value - if snakeKey != key && cache != nil { - cache[snakeKey] = key - } - } - schema["properties"] = newProperties - - if required, ok := schema["required"].([]any); ok { - newRequired := make([]any, 0, len(required)) - for _, item := range required { - name, ok := item.(string) - if !ok { - newRequired = append(newRequired, item) - continue - } - snakeName := toSnakeCase(name) - newRequired = append(newRequired, snakeName) - if snakeName != name && cache != nil { - cache[snakeName] = name - } - } - schema["required"] = newRequired - } -} - func stripCacheControlFromSystemBlocks(system any) bool { blocks, ok := system.([]any) if !ok { @@ -783,16 +697,18 @@ func stripCacheControlFromSystemBlocks(system any) bool { return changed } -func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) { +func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) { if len(body) == 0 { - return body, modelID, nil - } - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { - return body, modelID, nil + return body, modelID } - toolNameMap := make(map[string]string) + // 解析为 map[string]any 用于修改字段 + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + return body, modelID + } + + modified := false if system, ok := req["system"]; ok { switch v := system.(type) { @@ -800,6 +716,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu sanitized := sanitizeSystemText(v) if sanitized != v { req["system"] = sanitized + modified = true } case []any: for _, item := range v { @@ -817,6 +734,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu sanitized := sanitizeSystemText(text) if sanitized != text { block["text"] = sanitized + modified = true } } } @@ -827,95 +745,20 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu if normalized != rawModel { req["model"] = normalized modelID = normalized + modified = true } } - if rawTools, exists := req["tools"]; exists { - switch tools := rawTools.(type) { - case []any: - for idx, tool := range tools { - toolMap, ok := tool.(map[string]any) - if !ok { - continue - } - if name, ok := toolMap["name"].(string); ok { - normalized := normalizeToolNameForClaude(name, toolNameMap) - if normalized != "" && normalized != name { - toolMap["name"] = normalized - } - } - if desc, ok := toolMap["description"].(string); ok { - sanitized := sanitizeToolDescription(desc) - if sanitized != desc { - toolMap["description"] = sanitized - } - } - if schema, ok := toolMap["input_schema"]; ok { - normalizeToolInputSchema(schema, toolNameMap) - } - tools[idx] = toolMap - } - req["tools"] = tools - case map[string]any: - normalizedTools := make(map[string]any, len(tools)) - for name, value := range tools { - normalized := normalizeToolNameForClaude(name, toolNameMap) - if normalized == "" { - normalized = name - } - if toolMap, ok := value.(map[string]any); ok { - toolMap["name"] = normalized - if desc, ok := toolMap["description"].(string); ok { - sanitized := sanitizeToolDescription(desc) - if sanitized != desc { - toolMap["description"] = sanitized - } - } - if schema, ok := toolMap["input_schema"]; ok { - normalizeToolInputSchema(schema, toolNameMap) - } - normalizedTools[normalized] = toolMap - continue - } - normalizedTools[normalized] = value - } - req["tools"] = normalizedTools - } - } else { + // 确保 tools 字段存在(即使为空数组) + if _, exists := req["tools"]; !exists { req["tools"] = []any{} - } - - if messages, ok := req["messages"].([]any); ok { - for _, msg := range messages { - msgMap, ok := msg.(map[string]any) - if !ok { - continue - } - content, ok := msgMap["content"].([]any) - if !ok { - continue - } - for _, block := range content { - blockMap, ok := block.(map[string]any) - if !ok { - continue - } - if blockType, _ := blockMap["type"].(string); blockType != "tool_use" { - continue - } - if name, ok := blockMap["name"].(string); ok { - normalized := normalizeToolNameForClaude(name, toolNameMap) - if normalized != "" && normalized != name { - blockMap["name"] = normalized - } - } - } - } + modified = true } if opts.stripSystemCacheControl { if system, ok := req["system"]; ok { _ = stripCacheControlFromSystemBlocks(system) + modified = true } } @@ -927,17 +770,28 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu } if existing, ok := metadata["user_id"].(string); !ok || existing == "" { metadata["user_id"] = opts.metadataUserID + modified = true } } - delete(req, "temperature") - delete(req, "tool_choice") + if _, hasTemp := req["temperature"]; hasTemp { + delete(req, "temperature") + modified = true + } + if _, hasChoice := req["tool_choice"]; hasChoice { + delete(req, "tool_choice") + modified = true + } + + if !modified { + return body, modelID + } newBody, err := json.Marshal(req) if err != nil { - return body, modelID, toolNameMap + return body, modelID } - return newBody, modelID, toolNameMap + return newBody, modelID } func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string { @@ -1135,6 +989,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) } + // Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查) + if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { + if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { + return nil, err + } + } + accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err != nil { return nil, err @@ -1184,6 +1045,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // 1. 过滤出路由列表中可调度的账号 var routingCandidates []*Account var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int + var modelScopeSkippedIDs []int64 // 记录因模型限流被跳过的账号 ID for _, routingAccountID := range routingAccountIDs { if isExcluded(routingAccountID) { filteredExcluded++ @@ -1202,12 +1064,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro filteredPlatform++ continue } - if !account.IsSchedulableForModel(requestedModel) { - filteredModelScope++ + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, account, requestedModel) { + filteredModelMapping++ continue } - if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) { - filteredModelMapping++ + if !account.IsSchedulableForModelWithContext(ctx, requestedModel) { + filteredModelScope++ + modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID) continue } // 窗口费用检查(非粘性会话路径) @@ -1222,6 +1085,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)", derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates), filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost) + if len(modelScopeSkippedIDs) > 0 { + log.Printf("[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v", + derefGroupID(groupID), requestedModel, modelScopeSkippedIDs) + } } if len(routingCandidates) > 0 { @@ -1233,8 +1100,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if stickyAccount, ok := accountByID[stickyAccountID]; ok { if stickyAccount.IsSchedulable() && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && - stickyAccount.IsSchedulableForModel(requestedModel) && - (requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) && + (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && + stickyAccount.IsSchedulableForModelWithContext(ctx, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { @@ -1291,10 +1158,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads) // 3. 按负载感知排序 - type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo - } var routingAvailable []accountWithLoad for _, acc := range routingCandidates { loadInfo := routingLoadMap[acc.ID] @@ -1385,14 +1248,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if ok { // 检查账户是否需要清理粘性会话绑定 // Check if the account needs sticky session cleanup - clearSticky := shouldClearStickySession(account) + clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } if !clearSticky && s.isAccountInGroup(account, groupID) && s.isAccountAllowedForPlatform(account, platform, useMixed) && - account.IsSchedulableForModel(requestedModel) && - (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) && + (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && + account.IsSchedulableForModelWithContext(ctx, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查 result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { @@ -1450,10 +1313,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { continue } - if !acc.IsSchedulableForModel(requestedModel) { + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { + if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { continue } // 窗口费用检查(非粘性会话路径) @@ -1481,10 +1344,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return result, nil } } else { - type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo - } + // Antigravity 平台:获取模型负载信息 + var modelLoadMap map[int64]*ModelLoadInfo + isAntigravity := platform == PlatformAntigravity + var available []accountWithLoad for _, acc := range candidates { loadInfo := loadMap[acc.ID] @@ -1499,47 +1362,108 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - if len(available) > 0 { - sort.SliceStable(available, func(i, j int) bool { - a, b := available[i], available[j] - if a.account.Priority != b.account.Priority { - return a.account.Priority < b.account.Priority - } - if a.loadInfo.LoadRate != b.loadInfo.LoadRate { - return a.loadInfo.LoadRate < b.loadInfo.LoadRate - } - switch { - case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: - return true - case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: - return false - case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: - if preferOAuth && a.account.Type != b.account.Type { - return a.account.Type == AccountTypeOAuth - } - return false - default: - return a.account.LastUsedAt.Before(*b.account.LastUsedAt) - } - }) - + // Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致) + if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 { + modelLoadMap = make(map[int64]*ModelLoadInfo, len(available)) + modelToAccountIDs := make(map[string][]int64) for _, item := range available { - result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) + mappedModel := mapAntigravityModel(item.account, requestedModel) + if mappedModel == "" { + continue + } + modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID) + } + for model, ids := range modelToAccountIDs { + batch, err := s.cache.GetModelLoadBatch(ctx, ids, model) + if err != nil { + continue + } + for id, info := range batch { + modelLoadMap[id] = info + } + } + if len(modelLoadMap) == 0 { + modelLoadMap = nil + } + } + + // Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值) + // 其他平台:分层过滤选择:优先级 → 负载率 → LRU + if isAntigravity { + for len(available) > 0 { + // 1. 取优先级最小的集合(硬过滤) + candidates := filterByMinPriority(available) + // 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值) + selected := selectByCallCount(candidates, modelLoadMap, preferOAuth) + if selected == nil { + break + } + + result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { + if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 - continue + } else { + if sessionHash != "" && s.cache != nil { + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) + } + return &AccountSelectionResult{ + Account: selected.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } - if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) - } - return &AccountSelectionResult{ - Account: item.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil } + + // 移除已尝试的账号,重新选择 + selectedID := selected.account.ID + newAvailable := make([]accountWithLoad, 0, len(available)-1) + for _, acc := range available { + if acc.account.ID != selectedID { + newAvailable = append(newAvailable, acc) + } + } + available = newAvailable + } + } else { + for len(available) > 0 { + // 1. 取优先级最小的集合 + candidates := filterByMinPriority(available) + // 2. 取负载率最低的集合 + candidates = filterByMinLoadRate(candidates) + // 3. LRU 选择最久未用的账号 + selected := selectByLRU(candidates, preferOAuth) + if selected == nil { + break + } + + result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) + if err == nil && result.Acquired { + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { + result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 + } else { + if sessionHash != "" && s.cache != nil { + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) + } + return &AccountSelectionResult{ + Account: selected.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + + // 移除已尝试的账号,重新进行分层过滤 + selectedID := selected.account.ID + newAvailable := make([]accountWithLoad, 0, len(available)-1) + for _, acc := range available { + if acc.account.ID != selectedID { + newAvailable = append(newAvailable, acc) + } + } + available = newAvailable } } } @@ -1632,6 +1556,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (* return group, nil } +func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) { + return s.resolveGroupByID(ctx, groupID) +} + func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 { if groupID == nil || requestedModel == "" || platform != PlatformAnthropic { return nil @@ -1697,7 +1625,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID } // 强制平台模式不检查 Claude Code 限制 - if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform { + if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" { return nil, groupID, nil } @@ -1952,6 +1880,106 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in return s.accountRepo.GetByID(ctx, accountID) } +// filterByMinPriority 过滤出优先级最小的账号集合 +func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { + if len(accounts) == 0 { + return accounts + } + minPriority := accounts[0].account.Priority + for _, acc := range accounts[1:] { + if acc.account.Priority < minPriority { + minPriority = acc.account.Priority + } + } + result := make([]accountWithLoad, 0, len(accounts)) + for _, acc := range accounts { + if acc.account.Priority == minPriority { + result = append(result, acc) + } + } + return result +} + +// filterByMinLoadRate 过滤出负载率最低的账号集合 +func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad { + if len(accounts) == 0 { + return accounts + } + minLoadRate := accounts[0].loadInfo.LoadRate + for _, acc := range accounts[1:] { + if acc.loadInfo.LoadRate < minLoadRate { + minLoadRate = acc.loadInfo.LoadRate + } + } + result := make([]accountWithLoad, 0, len(accounts)) + for _, acc := range accounts { + if acc.loadInfo.LoadRate == minLoadRate { + result = append(result, acc) + } + } + return result +} + +// selectByLRU 从集合中选择最久未用的账号 +// 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个 +func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad { + if len(accounts) == 0 { + return nil + } + if len(accounts) == 1 { + return &accounts[0] + } + + // 1. 找到最小的 LastUsedAt(nil 被视为最小) + var minTime *time.Time + hasNil := false + for _, acc := range accounts { + if acc.account.LastUsedAt == nil { + hasNil = true + break + } + if minTime == nil || acc.account.LastUsedAt.Before(*minTime) { + minTime = acc.account.LastUsedAt + } + } + + // 2. 收集所有具有最小 LastUsedAt 的账号索引 + var candidateIdxs []int + for i, acc := range accounts { + if hasNil { + if acc.account.LastUsedAt == nil { + candidateIdxs = append(candidateIdxs, i) + } + } else { + if acc.account.LastUsedAt != nil && acc.account.LastUsedAt.Equal(*minTime) { + candidateIdxs = append(candidateIdxs, i) + } + } + } + + // 3. 如果只有一个候选,直接返回 + if len(candidateIdxs) == 1 { + return &accounts[candidateIdxs[0]] + } + + // 4. 如果有多个候选且 preferOAuth,优先选择 OAuth 类型 + if preferOAuth { + var oauthIdxs []int + for _, idx := range candidateIdxs { + if accounts[idx].account.Type == AccountTypeOAuth { + oauthIdxs = append(oauthIdxs, idx) + } + } + if len(oauthIdxs) > 0 { + candidateIdxs = oauthIdxs + } + } + + // 5. 随机选择一个 + selectedIdx := candidateIdxs[mathrand.Intn(len(candidateIdxs))] + return &accounts[selectedIdx] +} + func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { sort.SliceStable(accounts, func(i, j int) bool { a, b := accounts[i], accounts[j] @@ -1974,6 +2002,87 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { }) } +// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用) +// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调 +// 如果有多个账号具有相同的最小调用次数,则随机选择一个 +func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad { + if len(accounts) == 0 { + return nil + } + if len(accounts) == 1 { + return &accounts[0] + } + + // 如果没有负载信息,回退到 LRU + if modelLoadMap == nil { + return selectByLRU(accounts, preferOAuth) + } + + // 1. 计算平均调用次数(用于新账号冷启动) + var totalCallCount int64 + var countWithCalls int + for _, acc := range accounts { + if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 { + totalCallCount += info.CallCount + countWithCalls++ + } + } + + var avgCallCount int64 + if countWithCalls > 0 { + avgCallCount = totalCallCount / int64(countWithCalls) + } + + // 2. 获取每个账号的有效调用次数 + getEffectiveCallCount := func(acc accountWithLoad) int64 { + if acc.account == nil { + return 0 + } + info := modelLoadMap[acc.account.ID] + if info == nil || info.CallCount == 0 { + return avgCallCount // 新账号使用平均值 + } + return info.CallCount + } + + // 3. 找到最小调用次数 + minCount := getEffectiveCallCount(accounts[0]) + for _, acc := range accounts[1:] { + if c := getEffectiveCallCount(acc); c < minCount { + minCount = c + } + } + + // 4. 收集所有具有最小调用次数的账号 + var candidateIdxs []int + for i, acc := range accounts { + if getEffectiveCallCount(acc) == minCount { + candidateIdxs = append(candidateIdxs, i) + } + } + + // 5. 如果只有一个候选,直接返回 + if len(candidateIdxs) == 1 { + return &accounts[candidateIdxs[0]] + } + + // 6. preferOAuth 处理 + if preferOAuth { + var oauthIdxs []int + for _, idx := range candidateIdxs { + if accounts[idx].account.Type == AccountTypeOAuth { + oauthIdxs = append(oauthIdxs, idx) + } + } + if len(oauthIdxs) > 0 { + candidateIdxs = oauthIdxs + } + } + + // 7. 随机选择 + return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]] +} + // sortCandidatesForFallback 根据配置选择排序策略 // mode: "last_used"(按最后使用时间) 或 "random"(随机) func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) { @@ -2026,6 +2135,13 @@ func shuffleWithinPriority(accounts []*Account) { // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { + // 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内 + if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { + if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { + return nil, err + } + } + preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) @@ -2048,11 +2164,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, account, err := s.getSchedulableAccount(ctx, accountID) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) if err == nil { - clearSticky := shouldClearStickySession(account) + clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) } @@ -2099,10 +2215,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !acc.IsSchedulable() { continue } - if !acc.IsSchedulableForModel(requestedModel) { + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { + if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { continue } if selected == nil { @@ -2151,11 +2267,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, account, err := s.getSchedulableAccount(ctx, accountID) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) if err == nil { - clearSticky := shouldClearStickySession(account) + clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) } @@ -2191,10 +2307,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !acc.IsSchedulable() { continue } - if !acc.IsSchedulableForModel(requestedModel) { + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { + if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { continue } if selected == nil { @@ -2261,11 +2377,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g account, err := s.getSchedulableAccount(ctx, accountID) // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 if err == nil { - clearSticky := shouldClearStickySession(account) + clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) @@ -2314,10 +2430,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue } - if !acc.IsSchedulableForModel(requestedModel) { + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { + if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { continue } if selected == nil { @@ -2366,11 +2482,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g account, err := s.getSchedulableAccount(ctx, accountID) // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 if err == nil { - clearSticky := shouldClearStickySession(account) + clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) @@ -2408,10 +2524,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue } - if !acc.IsSchedulableForModel(requestedModel) { + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { + if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { continue } if selected == nil { @@ -2455,11 +2571,42 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g return selected, nil } -// isModelSupportedByAccount 根据账户平台检查模型支持 +// isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context) +// 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持 +func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool { + if account.Platform == PlatformAntigravity { + if strings.TrimSpace(requestedModel) == "" { + return true + } + // 使用与转发阶段一致的映射逻辑:自定义映射优先 → 默认映射兜底 + mapped := mapAntigravityModel(account, requestedModel) + if mapped == "" { + return false + } + // 应用 thinking 后缀后检查最终模型是否在账号映射中 + if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + finalModel := applyThinkingModelSuffix(mapped, enabled) + if finalModel == mapped { + return true // thinking 后缀未改变模型名,映射已通过 + } + return account.IsModelSupported(finalModel) + } + return true + } + return s.isModelSupportedByAccount(account, requestedModel) +} + +// isModelSupportedByAccount 根据账户平台检查模型支持(无 context,用于非 Antigravity 平台) func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool { if account.Platform == PlatformAntigravity { - // Antigravity 平台使用专门的模型支持检查 - return IsAntigravityModelSupported(requestedModel) + if strings.TrimSpace(requestedModel) == "" { + return true + } + return mapAntigravityModel(account, requestedModel) != "" + } + // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) + if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + requestedModel = claude.NormalizeModelID(requestedModel) } // Gemini API Key 账户直接透传,由上游判断模型是否支持 if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey { @@ -2469,13 +2616,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo return account.IsModelSupported(requestedModel) } -// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型 -// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持 -func IsAntigravityModelSupported(requestedModel string) bool { - return strings.HasPrefix(requestedModel, "claude-") || - strings.HasPrefix(requestedModel, "gemini-") -} - // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { @@ -2880,7 +3020,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A reqModel := parsed.Model reqStream := parsed.Stream originalModel := reqModel - var toolNameMap map[string]string isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode @@ -2904,22 +3043,36 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } } - body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } // 强制执行 cache_control 块数量限制(最多 4 个) body = enforceCacheControlLimit(body) - // 应用模型映射(仅对apikey类型账号) + // 应用模型映射: + // - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名 + // - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID) + mappedModel := reqModel + mappingSource := "" if account.Type == AccountTypeAPIKey { - mappedModel := account.GetMappedModel(reqModel) + mappedModel = account.GetMappedModel(reqModel) if mappedModel != reqModel { - // 替换请求体中的模型名 - body = s.replaceModelInBody(body, mappedModel) - reqModel = mappedModel - log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name) + mappingSource = "account" } } + if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(reqModel) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "prefix" + } + } + if mappedModel != reqModel { + // 替换请求体中的模型名 + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) + } // 获取凭证 token, tokenType, err := s.GetAccessToken(ctx, account) @@ -3191,7 +3344,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } return s.handleRetryExhaustedError(ctx, resp, c, account) } @@ -3221,10 +3374,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } - - // 处理错误响应(不可重试的错误) if resp.StatusCode >= 400 { // 可选:对部分 400 触发 failover(默认关闭以保持语义) if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { @@ -3268,7 +3419,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A log.Printf("Account %d: 400 error, attempting failover", account.ID) } s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } } return s.handleErrorResponse(ctx, resp, c, account) @@ -3279,7 +3430,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A var firstTokenMs *int var clientDisconnect bool if reqStream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode) + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, shouldMimicClaudeCode) if err != nil { if err.Error() == "have error in stream" { return nil, &UpstreamFailoverError{ @@ -3292,7 +3443,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A firstTokenMs = streamResult.firstTokenMs clientDisconnect = streamResult.clientDisconnect } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode) + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) if err != nil { return nil, err } @@ -3621,6 +3772,13 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { return true } + // 检测 thinking block 被修改的错误 + // 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified" + if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { + log.Printf("[SignatureCheck] Detected thinking block modification error") + return true + } + // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的) // 例如: "all messages must have non-empty content" if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") { @@ -3658,6 +3816,12 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { return false } +// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息 +// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}} +func ExtractUpstreamErrorMessage(body []byte) string { + return extractUpstreamErrorMessage(body) +} + func extractUpstreamErrorMessage(body []byte) string { // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { @@ -3725,7 +3889,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) } if shouldDisable { - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} } // 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端) @@ -3740,6 +3904,34 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res ) } + // 非 failover 错误也支持错误透传规则匹配。 + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + summary := upstreamMsg + if summary == "" { + summary = errMsg + } + if summary == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, summary) + } + // 根据状态码返回适当的自定义错误响应(不透传上游详细信息) var errType, errMsg string var statusCode int @@ -3871,6 +4063,33 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht ) } + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + respBody, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed after retries", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + summary := upstreamMsg + if summary == "" { + summary = errMsg + } + if summary == "" { + return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched) message=%s", resp.StatusCode, summary) + } + // 返回统一的重试耗尽错误响应 c.JSON(http.StatusBadGateway, gin.H{ "type": "error", @@ -3893,7 +4112,7 @@ type streamingResult struct { clientDisconnect bool // 客户端是否在流式传输过程中断开 } -func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) { +func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, mimicClaudeCode bool) (*streamingResult, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) @@ -3989,33 +4208,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage pendingEventLines := make([]string, 0, 4) - var toolInputBuffers map[int]string - if mimicClaudeCode { - toolInputBuffers = make(map[int]string) - } - - transformToolInputJSON := func(raw string) string { - if !mimicClaudeCode { - return raw - } - raw = strings.TrimSpace(raw) - if raw == "" { - return raw - } - - var parsed any - if err := json.Unmarshal([]byte(raw), &parsed); err != nil { - return replaceToolNamesInText(raw, toolNameMap) - } - - rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap) - if changed { - if bytes, err := json.Marshal(rewritten); err == nil { - return string(bytes) - } - } - return raw - } processSSEEvent := func(lines []string) ([]string, string, error) { if len(lines) == 0 { @@ -4054,16 +4246,13 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http var event map[string]any if err := json.Unmarshal([]byte(dataLine), &event); err != nil { - replaced := dataLine - if mimicClaudeCode { - replaced = replaceToolNamesInText(dataLine, toolNameMap) - } + // JSON 解析失败,直接透传原始数据 block := "" if eventName != "" { block = "event: " + eventName + "\n" } - block += "data: " + replaced + "\n\n" - return []string{block}, replaced, nil + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, nil } eventType, _ := event["type"].(string) @@ -4071,6 +4260,20 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http eventName = eventType } + // 兼容 Kimi cached_tokens → cache_read_input_tokens + if eventType == "message_start" { + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + reconcileCachedTokens(u) + } + } + } + if eventType == "message_delta" { + if u, ok := event["usage"].(map[string]any); ok { + reconcileCachedTokens(u) + } + } + if needModelReplace { if msg, ok := event["message"].(map[string]any); ok { if model, ok := msg["model"].(string); ok && model == mappedModel { @@ -4079,70 +4282,15 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } } - if mimicClaudeCode && eventType == "content_block_delta" { - if delta, ok := event["delta"].(map[string]any); ok { - if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" { - if indexVal, ok := event["index"].(float64); ok { - index := int(indexVal) - if partial, ok := delta["partial_json"].(string); ok { - toolInputBuffers[index] += partial - } - } - return nil, dataLine, nil - } - } - } - - if mimicClaudeCode && eventType == "content_block_stop" { - if indexVal, ok := event["index"].(float64); ok { - index := int(indexVal) - if buffered := toolInputBuffers[index]; buffered != "" { - delete(toolInputBuffers, index) - - transformed := transformToolInputJSON(buffered) - synthetic := map[string]any{ - "type": "content_block_delta", - "index": index, - "delta": map[string]any{ - "type": "input_json_delta", - "partial_json": transformed, - }, - } - - synthBytes, synthErr := json.Marshal(synthetic) - if synthErr == nil { - synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n" - - rewriteToolNamesInValue(event, toolNameMap) - stopBytes, stopErr := json.Marshal(event) - if stopErr == nil { - stopBlock := "" - if eventName != "" { - stopBlock = "event: " + eventName + "\n" - } - stopBlock += "data: " + string(stopBytes) + "\n\n" - return []string{synthBlock, stopBlock}, string(stopBytes), nil - } - } - } - } - } - - if mimicClaudeCode { - rewriteToolNamesInValue(event, toolNameMap) - } newData, err := json.Marshal(event) if err != nil { - replaced := dataLine - if mimicClaudeCode { - replaced = replaceToolNamesInText(dataLine, toolNameMap) - } + // 序列化失败,直接透传原始数据 block := "" if eventName != "" { block = "event: " + eventName + "\n" } - block += "data: " + replaced + "\n\n" - return []string{block}, replaced, nil + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, nil } block := "" @@ -4241,126 +4389,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } -func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) { - switch v := value.(type) { - case map[string]any: - changed := false - rewritten := make(map[string]any, len(v)) - for key, item := range v { - newKey := normalizeParamNameForOpenCode(key, cache) - newItem, childChanged := rewriteParamKeysInValue(item, cache) - if childChanged { - changed = true - } - if newKey != key { - changed = true - } - rewritten[newKey] = newItem - } - if !changed { - return value, false - } - return rewritten, true - case []any: - changed := false - rewritten := make([]any, len(v)) - for idx, item := range v { - newItem, childChanged := rewriteParamKeysInValue(item, cache) - if childChanged { - changed = true - } - rewritten[idx] = newItem - } - if !changed { - return value, false - } - return rewritten, true - default: - return value, false - } -} - -func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool { - switch v := value.(type) { - case map[string]any: - changed := false - if blockType, _ := v["type"].(string); blockType == "tool_use" { - if name, ok := v["name"].(string); ok { - mapped := normalizeToolNameForOpenCode(name, toolNameMap) - if mapped != name { - v["name"] = mapped - changed = true - } - } - if input, ok := v["input"].(map[string]any); ok { - rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap) - if inputChanged { - if m, ok := rewrittenInput.(map[string]any); ok { - v["input"] = m - changed = true - } - } - } - } - for _, item := range v { - if rewriteToolNamesInValue(item, toolNameMap) { - changed = true - } - } - return changed - case []any: - changed := false - for _, item := range v { - if rewriteToolNamesInValue(item, toolNameMap) { - changed = true - } - } - return changed - default: - return false - } -} - -func replaceToolNamesInText(text string, toolNameMap map[string]string) string { - if text == "" { - return text - } - output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string { - submatches := toolNameFieldRe.FindStringSubmatch(match) - if len(submatches) < 2 { - return match - } - name := submatches[1] - mapped := normalizeToolNameForOpenCode(name, toolNameMap) - if mapped == name { - return match - } - return strings.Replace(match, name, mapped, 1) - }) - output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string { - submatches := modelFieldRe.FindStringSubmatch(match) - if len(submatches) < 2 { - return match - } - model := submatches[1] - mapped := claude.DenormalizeModelID(model) - if mapped == model { - return match - } - return strings.Replace(match, model, mapped, 1) - }) - - for mapped, original := range toolNameMap { - if mapped == "" || original == "" || mapped == original { - continue - } - output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":") - output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":") - } - - return output -} - func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { // 解析message_start获取input tokens(标准Claude API格式) var msgStart struct { @@ -4404,7 +4432,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { } } -func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) { +func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) @@ -4421,13 +4449,21 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h return nil, fmt.Errorf("parse response: %w", err) } + // 兼容 Kimi cached_tokens → cache_read_input_tokens + if response.Usage.CacheReadInputTokens == 0 { + cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() + if cachedTokens > 0 { + response.Usage.CacheReadInputTokens = int(cachedTokens) + if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil { + body = newBody + } + } + } + // 如果有模型映射,替换响应中的model字段 if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } - if mimicClaudeCode { - body = s.replaceToolNamesInResponseBody(body, toolNameMap) - } responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) @@ -4465,37 +4501,22 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo return newBody } -func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte { - if len(body) == 0 { - return body - } - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - replaced := replaceToolNamesInText(string(body), toolNameMap) - if replaced == string(body) { - return body - } - return []byte(replaced) - } - if !rewriteToolNamesInValue(resp, toolNameMap) { - return body - } - newBody, err := json.Marshal(resp) - if err != nil { - return body - } - return newBody -} - // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { - Result *ForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription // 可选:订阅信息 - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 +} + +// APIKeyQuotaUpdater defines the interface for updating API Key quota +type APIKeyQuotaUpdater interface { + UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error } // RecordUsage 记录使用量并扣费(或更新订阅用量) @@ -4506,10 +4527,26 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu account := input.Account subscription := input.Subscription - // 获取费率倍数 + // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens + // 用于粘性会话切换时的特殊计费处理 + if input.ForceCacheBilling && result.Usage.InputTokens > 0 { + log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", + result.Usage.InputTokens, account.ID) + result.Usage.CacheReadInputTokens += result.Usage.InputTokens + result.Usage.InputTokens = 0 + } + + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { multiplier = apiKey.Group.RateMultiplier + + // 检查用户专属倍率 + if s.userGroupRateRepo != nil { + if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { + multiplier = *userRate + } + } } var cost *CostBreakdown @@ -4635,6 +4672,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } } + // 更新 API Key 配额(如果设置了配额限制) + if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { + if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { + log.Printf("Update API key quota failed: %v", err) + } + } + // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) @@ -4652,6 +4696,8 @@ type RecordUsageLongContextInput struct { IPAddress string // 请求的客户端 IP 地址 LongContextThreshold int // 长上下文阈值(如 200000) LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService *APIKeyService // API Key 配额服务(可选) } // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) @@ -4662,10 +4708,26 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * account := input.Account subscription := input.Subscription - // 获取费率倍数 + // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens + // 用于粘性会话切换时的特殊计费处理 + if input.ForceCacheBilling && result.Usage.InputTokens > 0 { + log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", + result.Usage.InputTokens, account.ID) + result.Usage.CacheReadInputTokens += result.Usage.InputTokens + result.Usage.InputTokens = 0 + } + + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { multiplier = apiKey.Group.RateMultiplier + + // 检查用户专属倍率 + if s.userGroupRateRepo != nil { + if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { + multiplier = *userRate + } + } } var cost *CostBreakdown @@ -4788,6 +4850,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } // 异步更新余额缓存 s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) + // API Key 独立配额扣费 + if input.APIKeyService != nil && apiKey.Quota > 0 { + if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { + log.Printf("Add API key quota used failed: %v", err) + } + } } } @@ -4813,7 +4881,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, if shouldMimicClaudeCode { normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} - body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } // Antigravity 账户不支持 count_tokens 转发,直接返回空值 @@ -4822,16 +4890,30 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return nil } - // 应用模型映射(仅对 apikey 类型账号) - if account.Type == AccountTypeAPIKey { - if reqModel != "" { - mappedModel := account.GetMappedModel(reqModel) + // 应用模型映射: + // - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名 + // - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID) + if reqModel != "" { + mappedModel := reqModel + mappingSource := "" + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(reqModel) if mappedModel != reqModel { - body = s.replaceModelInBody(body, mappedModel) - reqModel = mappedModel - log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) + mappingSource = "account" } } + if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(reqModel) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "prefix" + } + } + if mappedModel != reqModel { + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) + } } // 获取凭证 @@ -5083,6 +5165,27 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { return normalized, nil } +// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内 +func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error { + scope, ok := ResolveAntigravityQuotaScope(requestedModel) + if !ok { + return nil // 无法解析 scope,跳过检查 + } + + group, err := s.resolveGroupByID(ctx, groupID) + if err != nil { + return nil // 查询失败时放行 + } + if group == nil { + return nil // 分组不存在时放行 + } + + if !IsScopeSupported(group.SupportedModelScopes, scope) { + return ErrModelScopeNotSupported + } + return nil +} + // GetAvailableModels returns the list of models available for a group // It aggregates model_mapping keys from all schedulable accounts in the group func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { @@ -5137,3 +5240,21 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, return models } + +// reconcileCachedTokens 兼容 Kimi 等上游: +// 将 OpenAI 风格的 cached_tokens 映射到 Claude 标准的 cache_read_input_tokens +func reconcileCachedTokens(usage map[string]any) bool { + if usage == nil { + return false + } + cacheRead, _ := usage["cache_read_input_tokens"].(float64) + if cacheRead > 0 { + return false // 已有标准字段,无需处理 + } + cached, _ := usage["cached_tokens"].(float64) + if cached <= 0 { + return false + } + usage["cache_read_input_tokens"] = cached + return true +} diff --git a/backend/internal/service/gateway_service_antigravity_whitelist_test.go b/backend/internal/service/gateway_service_antigravity_whitelist_test.go new file mode 100644 index 00000000..c078be32 --- /dev/null +++ b/backend/internal/service/gateway_service_antigravity_whitelist_test.go @@ -0,0 +1,240 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestGatewayService_isModelSupportedByAccount_AntigravityModelMapping(t *testing.T) { + svc := &GatewayService{} + + // 使用 model_mapping 作为白名单(通配符匹配) + account := &Account{ + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-*": "claude-sonnet-4-5", + "gemini-3-*": "gemini-3-flash", + }, + }, + } + + // claude-* 通配符匹配 + require.True(t, svc.isModelSupportedByAccount(account, "claude-sonnet-4-5")) + require.True(t, svc.isModelSupportedByAccount(account, "claude-haiku-4-5")) + require.True(t, svc.isModelSupportedByAccount(account, "claude-opus-4-6")) + + // gemini-3-* 通配符匹配 + require.True(t, svc.isModelSupportedByAccount(account, "gemini-3-flash")) + require.True(t, svc.isModelSupportedByAccount(account, "gemini-3-pro-high")) + + // gemini-2.5-* 不匹配(不在 model_mapping 中) + require.False(t, svc.isModelSupportedByAccount(account, "gemini-2.5-flash")) + require.False(t, svc.isModelSupportedByAccount(account, "gemini-2.5-pro")) + + // 其他平台模型不支持 + require.False(t, svc.isModelSupportedByAccount(account, "gpt-4")) + + // 空模型允许 + require.True(t, svc.isModelSupportedByAccount(account, "")) +} + +func TestGatewayService_isModelSupportedByAccount_AntigravityNoMapping(t *testing.T) { + svc := &GatewayService{} + + // 未配置 model_mapping 时,使用默认映射(domain.DefaultAntigravityModelMapping) + // 只有默认映射中的模型才被支持 + account := &Account{ + Platform: PlatformAntigravity, + Credentials: map[string]any{}, + } + + // 默认映射中的模型应该被支持 + require.True(t, svc.isModelSupportedByAccount(account, "claude-sonnet-4-5")) + require.True(t, svc.isModelSupportedByAccount(account, "gemini-3-flash")) + require.True(t, svc.isModelSupportedByAccount(account, "gemini-2.5-pro")) + require.True(t, svc.isModelSupportedByAccount(account, "claude-haiku-4-5")) + + // 不在默认映射中的模型不被支持 + require.False(t, svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022")) + require.False(t, svc.isModelSupportedByAccount(account, "claude-unknown-model")) + + // 非 claude-/gemini- 前缀仍然不支持 + require.False(t, svc.isModelSupportedByAccount(account, "gpt-4")) +} + +// TestGatewayService_isModelSupportedByAccountWithContext_ThinkingMode 测试 thinking 模式下的模型支持检查 +// 验证调度时使用映射后的最终模型名(包括 thinking 后缀)来检查 model_mapping 支持 +func TestGatewayService_isModelSupportedByAccountWithContext_ThinkingMode(t *testing.T) { + svc := &GatewayService{} + + tests := []struct { + name string + modelMapping map[string]any + requestedModel string + thinkingEnabled bool + expected bool + }{ + // 场景 1: 只配置 claude-sonnet-4-5-thinking,请求 claude-sonnet-4-5 + thinking=true + // mapAntigravityModel 找不到 claude-sonnet-4-5 的映射 → 返回 false + { + name: "thinking_enabled_no_base_mapping_returns_false", + modelMapping: map[string]any{ + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + }, + requestedModel: "claude-sonnet-4-5", + thinkingEnabled: true, + expected: false, + }, + // 场景 2: 只配置 claude-sonnet-4-5-thinking,请求 claude-sonnet-4-5 + thinking=false + // mapAntigravityModel 找不到 claude-sonnet-4-5 的映射 → 返回 false + { + name: "thinking_disabled_no_base_mapping_returns_false", + modelMapping: map[string]any{ + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + }, + requestedModel: "claude-sonnet-4-5", + thinkingEnabled: false, + expected: false, + }, + // 场景 3: 配置 claude-sonnet-4-5(非 thinking),请求 claude-sonnet-4-5 + thinking=true + // 最终模型名 = claude-sonnet-4-5-thinking,不在 mapping 中,应该不匹配 + { + name: "thinking_enabled_no_match_non_thinking_mapping", + modelMapping: map[string]any{ + "claude-sonnet-4-5": "claude-sonnet-4-5", + }, + requestedModel: "claude-sonnet-4-5", + thinkingEnabled: true, + expected: false, + }, + // 场景 4: 配置两种模型,请求 claude-sonnet-4-5 + thinking=true,应该匹配 thinking 版本 + { + name: "both_models_thinking_enabled_matches_thinking", + modelMapping: map[string]any{ + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + }, + requestedModel: "claude-sonnet-4-5", + thinkingEnabled: true, + expected: true, + }, + // 场景 5: 配置两种模型,请求 claude-sonnet-4-5 + thinking=false,应该匹配非 thinking 版本 + { + name: "both_models_thinking_disabled_matches_non_thinking", + modelMapping: map[string]any{ + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + }, + requestedModel: "claude-sonnet-4-5", + thinkingEnabled: false, + expected: true, + }, + // 场景 6: 通配符 claude-* 应该同时匹配 thinking 和非 thinking + { + name: "wildcard_matches_thinking", + modelMapping: map[string]any{ + "claude-*": "claude-sonnet-4-5", + }, + requestedModel: "claude-sonnet-4-5", + thinkingEnabled: true, + expected: true, // claude-sonnet-4-5-thinking 匹配 claude-* + }, + // 场景 7: 只配置 thinking 变体但没有基础模型映射 → 返回 false + // mapAntigravityModel 找不到 claude-opus-4-6 的映射 + { + name: "opus_thinking_no_base_mapping_returns_false", + modelMapping: map[string]any{ + "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", + }, + requestedModel: "claude-opus-4-6", + thinkingEnabled: true, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "model_mapping": tt.modelMapping, + }, + } + + ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, tt.thinkingEnabled) + result := svc.isModelSupportedByAccountWithContext(ctx, account, tt.requestedModel) + + require.Equal(t, tt.expected, result, + "isModelSupportedByAccountWithContext(ctx[thinking=%v], account, %q) = %v, want %v", + tt.thinkingEnabled, tt.requestedModel, result, tt.expected) + }) + } +} + +// TestGatewayService_isModelSupportedByAccount_CustomMappingNotInDefault 测试自定义模型映射中 +// 不在 DefaultAntigravityModelMapping 中的模型能通过调度 +func TestGatewayService_isModelSupportedByAccount_CustomMappingNotInDefault(t *testing.T) { + svc := &GatewayService{} + + // 自定义映射中包含不在默认映射中的模型 + account := &Account{ + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "my-custom-model": "actual-upstream-model", + "gpt-4o": "some-upstream-model", + "llama-3-70b": "llama-3-70b-upstream", + "claude-sonnet-4-5": "claude-sonnet-4-5", + }, + }, + } + + // 自定义模型应该通过(不在 DefaultAntigravityModelMapping 中也可以) + require.True(t, svc.isModelSupportedByAccount(account, "my-custom-model")) + require.True(t, svc.isModelSupportedByAccount(account, "gpt-4o")) + require.True(t, svc.isModelSupportedByAccount(account, "llama-3-70b")) + require.True(t, svc.isModelSupportedByAccount(account, "claude-sonnet-4-5")) + + // 不在自定义映射中的模型不通过 + require.False(t, svc.isModelSupportedByAccount(account, "gpt-3.5-turbo")) + require.False(t, svc.isModelSupportedByAccount(account, "unknown-model")) + + // 空模型允许 + require.True(t, svc.isModelSupportedByAccount(account, "")) +} + +// TestGatewayService_isModelSupportedByAccountWithContext_CustomMappingThinking +// 测试自定义映射 + thinking 模式的交互 +func TestGatewayService_isModelSupportedByAccountWithContext_CustomMappingThinking(t *testing.T) { + svc := &GatewayService{} + + // 自定义映射同时配置基础模型和 thinking 变体 + account := &Account{ + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + "my-custom-model": "upstream-model", + }, + }, + } + + // thinking=true: claude-sonnet-4-5 → mapped=claude-sonnet-4-5 → +thinking → check IsModelSupported(claude-sonnet-4-5-thinking)=true + ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true) + require.True(t, svc.isModelSupportedByAccountWithContext(ctx, account, "claude-sonnet-4-5")) + + // thinking=false: claude-sonnet-4-5 → mapped=claude-sonnet-4-5 → check IsModelSupported(claude-sonnet-4-5)=true + ctx = context.WithValue(context.Background(), ctxkey.ThinkingEnabled, false) + require.True(t, svc.isModelSupportedByAccountWithContext(ctx, account, "claude-sonnet-4-5")) + + // 自定义模型(非 claude)不受 thinking 后缀影响,mapped 成功即通过 + ctx = context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true) + require.True(t, svc.isModelSupportedByAccountWithContext(ctx, account, "my-custom-model")) +} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 2d2e86d5..0f156c2e 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -200,7 +200,7 @@ func (s *GeminiMessagesCompatService) tryStickySessionHit( // 检查账号是否需要清理粘性会话 // Check if sticky session should be cleared - if shouldClearStickySession(account) { + if shouldClearStickySession(account, requestedModel) { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey) return nil } @@ -230,7 +230,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest( ) bool { // 检查模型调度能力 // Check model scheduling capability - if !account.IsSchedulableForModel(requestedModel) { + if !account.IsSchedulableForModelWithContext(ctx, requestedModel) { return false } @@ -362,7 +362,10 @@ func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current * // isModelSupportedByAccount 根据账户平台检查模型支持 func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool { if account.Platform == PlatformAntigravity { - return IsAntigravityModelSupported(requestedModel) + if strings.TrimSpace(requestedModel) == "" { + return true + } + return mapAntigravityModel(account, requestedModel) != "" } return account.IsModelSupported(requestedModel) } @@ -864,7 +867,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { upstreamReqID := resp.Header.Get(requestIDHeader) @@ -891,7 +894,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } upstreamReqID := resp.Header.Get(requestIDHeader) if upstreamReqID == "" { @@ -977,6 +980,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") } + // 过滤掉 parts 为空的消息(Gemini API 不接受空 parts) + if filteredBody, err := filterEmptyPartsFromGeminiRequest(body); err == nil { + body = filteredBody + } + switch action { case "generateContent", "streamGenerateContent", "countTokens": // ok @@ -1296,7 +1304,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { evBody := unwrapIfNeeded(isOAuth, respBody) @@ -1320,7 +1328,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody} } respBody = unwrapIfNeeded(isOAuth, respBody) @@ -1493,6 +1501,28 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) } + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformGemini, + upstreamStatus, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{"type": errType, "message": errMsg}, + }) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d (passthrough rule matched)", upstreamStatus) + } + return fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", upstreamStatus, upstreamMsg) + } + var statusCode int var errType, errMsg string @@ -2631,7 +2661,9 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 { if meta, ok := dm["metadata"].(map[string]any); ok { if v, ok := meta["quotaResetDelay"].(string); ok { if dur, err := time.ParseDuration(v); err == nil { - ts := time.Now().Unix() + int64(dur.Seconds()) + // Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s), + // which can affect scheduling decisions around thresholds (like 10s). + ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds())) return &ts } } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index e7ed80fd..601e7e2c 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -265,6 +265,22 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, return nil } +func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { + return 0, nil +} + +func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) { + return nil, nil +} + +func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { + return "", 0, false +} + +func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { + return nil +} + // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择 func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) { ctx := context.Background() @@ -880,7 +896,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) { { name: "Antigravity平台-支持claude模型", account: &Account{Platform: PlatformAntigravity}, - model: "claude-3-5-sonnet-20241022", + model: "claude-sonnet-4-5", expected: true, }, { @@ -889,6 +905,39 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) { model: "gpt-4", expected: false, }, + { + name: "Antigravity平台-空模型允许", + account: &Account{Platform: PlatformAntigravity}, + model: "", + expected: true, + }, + { + name: "Antigravity平台-自定义映射-支持自定义模型", + account: &Account{ + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "my-custom-model": "upstream-model", + "gpt-4o": "some-model", + }, + }, + }, + model: "my-custom-model", + expected: true, + }, + { + name: "Antigravity平台-自定义映射-不在映射中的模型不支持", + account: &Account{ + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "my-custom-model": "upstream-model", + }, + }, + }, + model: "claude-sonnet-4-5", + expected: false, + }, { name: "Gemini平台-无映射配置-支持所有模型", account: &Account{Platform: PlatformGemini}, diff --git a/backend/internal/service/gemini_native_signature_cleaner.go b/backend/internal/service/gemini_native_signature_cleaner.go index b3352fb0..d43fb445 100644 --- a/backend/internal/service/gemini_native_signature_cleaner.go +++ b/backend/internal/service/gemini_native_signature_cleaner.go @@ -2,20 +2,22 @@ package service import ( "encoding/json" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) -// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段, +// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中替换 thoughtSignature 字段为 dummy 签名, // 以避免跨账号签名验证错误。 // // 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature -// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。 +// 会导致新账号的签名验证失败。通过替换为 dummy 签名,跳过签名验证。 // -// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests -// to avoid cross-account signature validation errors. +// CleanGeminiNativeThoughtSignatures replaces thoughtSignature fields with dummy signature +// in Gemini native API requests to avoid cross-account signature validation errors. // // When sticky session switches accounts (e.g., original account becomes unavailable), // thoughtSignatures from the old account will cause validation failures on the new account. -// By removing these signatures, we allow the new account to generate valid signatures. +// By replacing with dummy signature, we skip signature validation. func CleanGeminiNativeThoughtSignatures(body []byte) []byte { if len(body) == 0 { return body @@ -28,11 +30,11 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte { return body } - // 递归清理 thoughtSignature - cleaned := cleanThoughtSignaturesRecursive(data) + // 递归替换 thoughtSignature 为 dummy 签名 + replaced := replaceThoughtSignaturesRecursive(data) // 重新序列化 - result, err := json.Marshal(cleaned) + result, err := json.Marshal(replaced) if err != nil { // 如果序列化失败,返回原始 body return body @@ -41,19 +43,20 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte { return result } -// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段 -func cleanThoughtSignaturesRecursive(data any) any { +// replaceThoughtSignaturesRecursive 递归遍历数据结构,将所有 thoughtSignature 字段替换为 dummy 签名 +func replaceThoughtSignaturesRecursive(data any) any { switch v := data.(type) { case map[string]any: - // 创建新的 map,移除 thoughtSignature + // 创建新的 map,替换 thoughtSignature 为 dummy 签名 result := make(map[string]any, len(v)) for key, value := range v { - // 跳过 thoughtSignature 字段 + // 替换 thoughtSignature 字段为 dummy 签名 if key == "thoughtSignature" { + result[key] = antigravity.DummyThoughtSignature continue } // 递归处理嵌套结构 - result[key] = cleanThoughtSignaturesRecursive(value) + result[key] = replaceThoughtSignaturesRecursive(value) } return result @@ -61,7 +64,7 @@ func cleanThoughtSignaturesRecursive(data any) any { // 递归处理数组中的每个元素 result := make([]any, len(v)) for i, item := range v { - result[i] = cleanThoughtSignaturesRecursive(item) + result[i] = replaceThoughtSignaturesRecursive(item) } return result diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index bc84baeb..fd2932e6 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -944,6 +944,32 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil } + // 关键逻辑:对齐 Gemini CLI 对“已注册用户”的处理方式。 + // 当 LoadCodeAssist 返回了 currentTier / paidTier(表示账号已注册)但没有返回 cloudaicompanionProject 时: + // - 不要再调用 onboardUser(通常不会再分配 project_id,且可能触发 INVALID_ARGUMENT) + // - 先尝试从 Cloud Resource Manager 获取可用项目;仍失败则提示用户手动填写 project_id + if loadResp != nil { + registeredTierID := strings.TrimSpace(loadResp.GetTier()) + if registeredTierID != "" { + // 已注册但未返回 cloudaicompanionProject,这在 Google One 用户中较常见:需要用户自行提供 project_id。 + log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID) + + // Try to get project from Cloud Resource Manager + fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) + if fbErr == nil && strings.TrimSpace(fallback) != "" { + log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback) + return strings.TrimSpace(fallback), tierID, nil + } + + // No project found - user must provide project_id manually + log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually") + return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID) + } + } + + // 未检测到 currentTier/paidTier,视为新用户,继续调用 onboardUser + log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID) + req := &geminicli.OnboardUserRequest{ TierID: tierID, Metadata: geminicli.LoadCodeAssistMetadata{ diff --git a/backend/internal/service/gemini_session.go b/backend/internal/service/gemini_session.go new file mode 100644 index 00000000..859ae9f3 --- /dev/null +++ b/backend/internal/service/gemini_session.go @@ -0,0 +1,164 @@ +package service + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/json" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/cespare/xxhash/v2" +) + +// Gemini 会话 ID Fallback 相关常量 +const ( + // geminiSessionTTLSeconds Gemini 会话缓存 TTL(5 分钟) + geminiSessionTTLSeconds = 300 + + // geminiSessionKeyPrefix Gemini 会话 Redis key 前缀 + geminiSessionKeyPrefix = "gemini:sess:" +) + +// GeminiSessionTTL 返回 Gemini 会话缓存 TTL +func GeminiSessionTTL() time.Duration { + return geminiSessionTTLSeconds * time.Second +} + +// shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符) +// XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20% +func shortHash(data []byte) string { + h := xxhash.Sum64(data) + return strconv.FormatUint(h, 36) +} + +// BuildGeminiDigestChain 根据 Gemini 请求生成摘要链 +// 格式: s:-u:-m:-u:-... +// s = systemInstruction, u = user, m = model +func BuildGeminiDigestChain(req *antigravity.GeminiRequest) string { + if req == nil { + return "" + } + + var parts []string + + // 1. system instruction + if req.SystemInstruction != nil && len(req.SystemInstruction.Parts) > 0 { + partsData, _ := json.Marshal(req.SystemInstruction.Parts) + parts = append(parts, "s:"+shortHash(partsData)) + } + + // 2. contents + for _, c := range req.Contents { + prefix := "u" // user + if c.Role == "model" { + prefix = "m" + } + partsData, _ := json.Marshal(c.Parts) + parts = append(parts, prefix+":"+shortHash(partsData)) + } + + return strings.Join(parts, "-") +} + +// GenerateGeminiPrefixHash 生成前缀 hash(用于分区隔离) +// 组合: userID + apiKeyID + ip + userAgent + platform + model +// 返回 16 字符的 Base64 编码的 SHA256 前缀 +func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, model string) string { + // 组合所有标识符 + combined := strconv.FormatInt(userID, 10) + ":" + + strconv.FormatInt(apiKeyID, 10) + ":" + + ip + ":" + + userAgent + ":" + + platform + ":" + + model + + hash := sha256.Sum256([]byte(combined)) + // 取前 12 字节,Base64 编码后正好 16 字符 + return base64.RawURLEncoding.EncodeToString(hash[:12]) +} + +// BuildGeminiSessionKey 构建 Gemini 会话 Redis key +// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain} +func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string { + return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain +} + +// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短) +// 用于 MGET 批量查询最长匹配 +func GenerateDigestChainPrefixes(chain string) []string { + if chain == "" { + return nil + } + + var prefixes []string + c := chain + + for c != "" { + prefixes = append(prefixes, c) + // 找到最后一个 "-" 的位置 + if i := strings.LastIndex(c, "-"); i > 0 { + c = c[:i] + } else { + break + } + } + + return prefixes +} + +// ParseGeminiSessionValue 解析 Gemini 会话缓存值 +// 格式: {uuid}:{accountID} +func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) { + if value == "" { + return "", 0, false + } + + // 找到最后一个 ":" 的位置(因为 uuid 可能包含 ":") + i := strings.LastIndex(value, ":") + if i <= 0 || i >= len(value)-1 { + return "", 0, false + } + + uuid = value[:i] + accountID, err := strconv.ParseInt(value[i+1:], 10, 64) + if err != nil { + return "", 0, false + } + + return uuid, accountID, true +} + +// FormatGeminiSessionValue 格式化 Gemini 会话缓存值 +// 格式: {uuid}:{accountID} +func FormatGeminiSessionValue(uuid string, accountID int64) string { + return uuid + ":" + strconv.FormatInt(accountID, 10) +} + +// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀 +const geminiDigestSessionKeyPrefix = "gemini:digest:" + +// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀 +const geminiTrieKeyPrefix = "gemini:trie:" + +// BuildGeminiTrieKey 构建 Gemini Trie Redis key +// 格式: gemini:trie:{groupID}:{prefixHash} +func BuildGeminiTrieKey(groupID int64, prefixHash string) string { + return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash +} + +// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey +// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey +// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话 +func GenerateGeminiDigestSessionKey(prefixHash, uuid string) string { + prefix := prefixHash + if len(prefixHash) >= 8 { + prefix = prefixHash[:8] + } + uuidPart := uuid + if len(uuid) >= 8 { + uuidPart = uuid[:8] + } + return geminiDigestSessionKeyPrefix + prefix + ":" + uuidPart +} diff --git a/backend/internal/service/gemini_session_integration_test.go b/backend/internal/service/gemini_session_integration_test.go new file mode 100644 index 00000000..928c62cf --- /dev/null +++ b/backend/internal/service/gemini_session_integration_test.go @@ -0,0 +1,206 @@ +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" +) + +// mockGeminiSessionCache 模拟 Redis 缓存 +type mockGeminiSessionCache struct { + sessions map[string]string // key -> value +} + +func newMockGeminiSessionCache() *mockGeminiSessionCache { + return &mockGeminiSessionCache{sessions: make(map[string]string)} +} + +func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) { + key := BuildGeminiSessionKey(groupID, prefixHash, digestChain) + value := FormatGeminiSessionValue(uuid, accountID) + m.sessions[key] = value +} + +func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { + prefixes := GenerateDigestChainPrefixes(digestChain) + for _, p := range prefixes { + key := BuildGeminiSessionKey(groupID, prefixHash, p) + if val, ok := m.sessions[key]; ok { + return ParseGeminiSessionValue(val) + } + } + return "", 0, false +} + +// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配 +func TestGeminiSessionContinuousConversation(t *testing.T) { + cache := newMockGeminiSessionCache() + groupID := int64(1) + prefixHash := "test_prefix_hash" + sessionUUID := "session-uuid-12345" + accountID := int64(100) + + // 模拟第一轮对话 + req1 := &antigravity.GeminiRequest{ + SystemInstruction: &antigravity.GeminiContent{ + Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}}, + }, + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Hello, what's your name?"}}}, + }, + } + chain1 := BuildGeminiDigestChain(req1) + t.Logf("Round 1 chain: %s", chain1) + + // 第一轮:没有找到会话,创建新会话 + _, _, found := cache.Find(groupID, prefixHash, chain1) + if found { + t.Error("Round 1: should not find existing session") + } + + // 保存第一轮会话 + cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID) + + // 模拟第二轮对话(用户继续对话) + req2 := &antigravity.GeminiRequest{ + SystemInstruction: &antigravity.GeminiContent{ + Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}}, + }, + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Hello, what's your name?"}}}, + {Role: "model", Parts: []antigravity.GeminiPart{{Text: "I'm Claude, nice to meet you!"}}}, + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "What can you do?"}}}, + }, + } + chain2 := BuildGeminiDigestChain(req2) + t.Logf("Round 2 chain: %s", chain2) + + // 第二轮:应该能找到会话(通过前缀匹配) + foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2) + if !found { + t.Error("Round 2: should find session via prefix matching") + } + if foundUUID != sessionUUID { + t.Errorf("Round 2: expected UUID %s, got %s", sessionUUID, foundUUID) + } + if foundAccID != accountID { + t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID) + } + + // 保存第二轮会话 + cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID) + + // 模拟第三轮对话 + req3 := &antigravity.GeminiRequest{ + SystemInstruction: &antigravity.GeminiContent{ + Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}}, + }, + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Hello, what's your name?"}}}, + {Role: "model", Parts: []antigravity.GeminiPart{{Text: "I'm Claude, nice to meet you!"}}}, + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "What can you do?"}}}, + {Role: "model", Parts: []antigravity.GeminiPart{{Text: "I can help with coding, writing, and more!"}}}, + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Great, help me write some Go code"}}}, + }, + } + chain3 := BuildGeminiDigestChain(req3) + t.Logf("Round 3 chain: %s", chain3) + + // 第三轮:应该能找到会话(通过第二轮的前缀匹配) + foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3) + if !found { + t.Error("Round 3: should find session via prefix matching") + } + if foundUUID != sessionUUID { + t.Errorf("Round 3: expected UUID %s, got %s", sessionUUID, foundUUID) + } + if foundAccID != accountID { + t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID) + } + + t.Log("✓ Continuous conversation session matching works correctly!") +} + +// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配 +func TestGeminiSessionDifferentConversations(t *testing.T) { + cache := newMockGeminiSessionCache() + groupID := int64(1) + prefixHash := "test_prefix_hash" + + // 第一个会话 + req1 := &antigravity.GeminiRequest{ + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Tell me about Go programming"}}}, + }, + } + chain1 := BuildGeminiDigestChain(req1) + cache.Save(groupID, prefixHash, chain1, "session-1", 100) + + // 第二个完全不同的会话 + req2 := &antigravity.GeminiRequest{ + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "What's the weather today?"}}}, + }, + } + chain2 := BuildGeminiDigestChain(req2) + + // 不同会话不应该匹配 + _, _, found := cache.Find(groupID, prefixHash, chain2) + if found { + t.Error("Different conversations should not match") + } + + t.Log("✓ Different conversations are correctly isolated!") +} + +// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先) +func TestGeminiSessionPrefixMatchingOrder(t *testing.T) { + cache := newMockGeminiSessionCache() + groupID := int64(1) + prefixHash := "test_prefix_hash" + + // 创建一个三轮对话 + req := &antigravity.GeminiRequest{ + SystemInstruction: &antigravity.GeminiContent{ + Parts: []antigravity.GeminiPart{{Text: "System prompt"}}, + }, + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}}, + {Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}}, + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}}, + }, + } + fullChain := BuildGeminiDigestChain(req) + prefixes := GenerateDigestChainPrefixes(fullChain) + + t.Logf("Full chain: %s", fullChain) + t.Logf("Prefixes (longest first): %v", prefixes) + + // 验证前缀生成顺序(从长到短) + if len(prefixes) != 4 { + t.Errorf("Expected 4 prefixes, got %d", len(prefixes)) + } + + // 保存不同轮次的会话到不同账号 + // 第一轮(最短前缀)-> 账号 1 + cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1) + // 第二轮 -> 账号 2 + cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2) + // 第三轮(最长前缀,完整链)-> 账号 3 + cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3) + + // 查找应该返回最长匹配(账号 3) + _, accID, found := cache.Find(groupID, prefixHash, fullChain) + if !found { + t.Error("Should find session") + } + if accID != 3 { + t.Errorf("Should match longest prefix (account 3), got account %d", accID) + } + + t.Log("✓ Longest prefix matching works correctly!") +} + +// 确保 context 包被使用(避免未使用的导入警告) +var _ = context.Background diff --git a/backend/internal/service/gemini_session_test.go b/backend/internal/service/gemini_session_test.go new file mode 100644 index 00000000..8c1908f7 --- /dev/null +++ b/backend/internal/service/gemini_session_test.go @@ -0,0 +1,481 @@ +package service + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" +) + +func TestShortHash(t *testing.T) { + tests := []struct { + name string + input []byte + }{ + {"empty", []byte{}}, + {"simple", []byte("hello world")}, + {"json", []byte(`{"role":"user","parts":[{"text":"hello"}]}`)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shortHash(tt.input) + // Base36 编码的 uint64 最长 13 个字符 + if len(result) > 13 { + t.Errorf("shortHash result too long: %d characters", len(result)) + } + // 相同输入应该产生相同输出 + result2 := shortHash(tt.input) + if result != result2 { + t.Errorf("shortHash not deterministic: %s vs %s", result, result2) + } + }) + } +} + +func TestBuildGeminiDigestChain(t *testing.T) { + tests := []struct { + name string + req *antigravity.GeminiRequest + wantLen int // 预期的分段数量 + hasEmpty bool // 是否应该是空字符串 + }{ + { + name: "nil request", + req: nil, + hasEmpty: true, + }, + { + name: "empty contents", + req: &antigravity.GeminiRequest{ + Contents: []antigravity.GeminiContent{}, + }, + hasEmpty: true, + }, + { + name: "single user message", + req: &antigravity.GeminiRequest{ + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}}, + }, + }, + wantLen: 1, // u: + }, + { + name: "user and model messages", + req: &antigravity.GeminiRequest{ + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}}, + {Role: "model", Parts: []antigravity.GeminiPart{{Text: "hi there"}}}, + }, + }, + wantLen: 2, // u:-m: + }, + { + name: "with system instruction", + req: &antigravity.GeminiRequest{ + SystemInstruction: &antigravity.GeminiContent{ + Role: "user", + Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}}, + }, + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}}, + }, + }, + wantLen: 2, // s:-u: + }, + { + name: "conversation with system", + req: &antigravity.GeminiRequest{ + SystemInstruction: &antigravity.GeminiContent{ + Role: "user", + Parts: []antigravity.GeminiPart{{Text: "System prompt"}}, + }, + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}}, + {Role: "model", Parts: []antigravity.GeminiPart{{Text: "hi"}}}, + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "how are you?"}}}, + }, + }, + wantLen: 4, // s:-u:-m:-u: + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := BuildGeminiDigestChain(tt.req) + + if tt.hasEmpty { + if result != "" { + t.Errorf("expected empty string, got: %s", result) + } + return + } + + // 检查分段数量 + parts := splitChain(result) + if len(parts) != tt.wantLen { + t.Errorf("expected %d parts, got %d: %s", tt.wantLen, len(parts), result) + } + + // 验证每个分段的格式 + for _, part := range parts { + if len(part) < 3 || part[1] != ':' { + t.Errorf("invalid part format: %s", part) + } + prefix := part[0] + if prefix != 's' && prefix != 'u' && prefix != 'm' { + t.Errorf("invalid prefix: %c", prefix) + } + } + }) + } +} + +func TestGenerateGeminiPrefixHash(t *testing.T) { + hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0", "antigravity", "gemini-2.5-pro") + hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0", "antigravity", "gemini-2.5-pro") + hash3 := GenerateGeminiPrefixHash(2, 100, "192.168.1.1", "Mozilla/5.0", "antigravity", "gemini-2.5-pro") + + // 相同输入应该产生相同输出 + if hash1 != hash2 { + t.Errorf("GenerateGeminiPrefixHash not deterministic: %s vs %s", hash1, hash2) + } + + // 不同输入应该产生不同输出 + if hash1 == hash3 { + t.Errorf("GenerateGeminiPrefixHash collision for different inputs") + } + + // Base64 URL 编码的 12 字节正好是 16 字符 + if len(hash1) != 16 { + t.Errorf("expected 16 characters, got %d: %s", len(hash1), hash1) + } +} + +func TestGenerateDigestChainPrefixes(t *testing.T) { + tests := []struct { + name string + chain string + want []string + wantLen int + }{ + { + name: "empty", + chain: "", + wantLen: 0, + }, + { + name: "single part", + chain: "u:abc123", + want: []string{"u:abc123"}, + wantLen: 1, + }, + { + name: "two parts", + chain: "s:xyz-u:abc", + want: []string{"s:xyz-u:abc", "s:xyz"}, + wantLen: 2, + }, + { + name: "four parts", + chain: "s:a-u:b-m:c-u:d", + want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"}, + wantLen: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GenerateDigestChainPrefixes(tt.chain) + + if len(result) != tt.wantLen { + t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result) + } + + if tt.want != nil { + for i, want := range tt.want { + if i >= len(result) { + t.Errorf("missing prefix at index %d", i) + continue + } + if result[i] != want { + t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i]) + } + } + } + }) + } +} + +func TestParseGeminiSessionValue(t *testing.T) { + tests := []struct { + name string + value string + wantUUID string + wantAccID int64 + wantOK bool + }{ + { + name: "empty", + value: "", + wantOK: false, + }, + { + name: "no colon", + value: "abc123", + wantOK: false, + }, + { + name: "valid", + value: "uuid-1234:100", + wantUUID: "uuid-1234", + wantAccID: 100, + wantOK: true, + }, + { + name: "uuid with colon", + value: "a:b:c:123", + wantUUID: "a:b:c", + wantAccID: 123, + wantOK: true, + }, + { + name: "invalid account id", + value: "uuid:abc", + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + uuid, accID, ok := ParseGeminiSessionValue(tt.value) + + if ok != tt.wantOK { + t.Errorf("ok: expected %v, got %v", tt.wantOK, ok) + } + + if tt.wantOK { + if uuid != tt.wantUUID { + t.Errorf("uuid: expected %s, got %s", tt.wantUUID, uuid) + } + if accID != tt.wantAccID { + t.Errorf("accountID: expected %d, got %d", tt.wantAccID, accID) + } + } + }) + } +} + +func TestFormatGeminiSessionValue(t *testing.T) { + result := FormatGeminiSessionValue("test-uuid", 123) + expected := "test-uuid:123" + if result != expected { + t.Errorf("expected %s, got %s", expected, result) + } + + // 验证往返一致性 + uuid, accID, ok := ParseGeminiSessionValue(result) + if !ok { + t.Error("ParseGeminiSessionValue failed on formatted value") + } + if uuid != "test-uuid" || accID != 123 { + t.Errorf("round-trip failed: uuid=%s, accID=%d", uuid, accID) + } +} + +// splitChain 辅助函数:按 "-" 分割摘要链 +func splitChain(chain string) []string { + if chain == "" { + return nil + } + var parts []string + start := 0 + for i := 0; i < len(chain); i++ { + if chain[i] == '-' { + parts = append(parts, chain[start:i]) + start = i + 1 + } + } + if start < len(chain) { + parts = append(parts, chain[start:]) + } + return parts +} + +func TestDigestChainDifferentSysInstruction(t *testing.T) { + req1 := &antigravity.GeminiRequest{ + SystemInstruction: &antigravity.GeminiContent{ + Parts: []antigravity.GeminiPart{{Text: "SYS_ORIGINAL"}}, + }, + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}}, + }, + } + + req2 := &antigravity.GeminiRequest{ + SystemInstruction: &antigravity.GeminiContent{ + Parts: []antigravity.GeminiPart{{Text: "SYS_MODIFIED"}}, + }, + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}}, + }, + } + + chain1 := BuildGeminiDigestChain(req1) + chain2 := BuildGeminiDigestChain(req2) + + t.Logf("Chain1: %s", chain1) + t.Logf("Chain2: %s", chain2) + + if chain1 == chain2 { + t.Error("Different systemInstruction should produce different chains") + } +} + +func TestDigestChainTamperedMiddleContent(t *testing.T) { + req1 := &antigravity.GeminiRequest{ + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}}, + {Role: "model", Parts: []antigravity.GeminiPart{{Text: "ORIGINAL_REPLY"}}}, + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "next"}}}, + }, + } + + req2 := &antigravity.GeminiRequest{ + Contents: []antigravity.GeminiContent{ + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}}, + {Role: "model", Parts: []antigravity.GeminiPart{{Text: "TAMPERED_REPLY"}}}, + {Role: "user", Parts: []antigravity.GeminiPart{{Text: "next"}}}, + }, + } + + chain1 := BuildGeminiDigestChain(req1) + chain2 := BuildGeminiDigestChain(req2) + + t.Logf("Chain1: %s", chain1) + t.Logf("Chain2: %s", chain2) + + if chain1 == chain2 { + t.Error("Tampered middle content should produce different chains") + } + + // 验证第一个 user 的 hash 相同 + parts1 := splitChain(chain1) + parts2 := splitChain(chain2) + + if parts1[0] != parts2[0] { + t.Error("First user message hash should be the same") + } + if parts1[1] == parts2[1] { + t.Error("Model reply hash should be different") + } +} + +func TestGenerateGeminiDigestSessionKey(t *testing.T) { + tests := []struct { + name string + prefixHash string + uuid string + want string + }{ + { + name: "normal 16 char hash with uuid", + prefixHash: "abcdefgh12345678", + uuid: "550e8400-e29b-41d4-a716-446655440000", + want: "gemini:digest:abcdefgh:550e8400", + }, + { + name: "exactly 8 chars prefix and uuid", + prefixHash: "12345678", + uuid: "abcdefgh", + want: "gemini:digest:12345678:abcdefgh", + }, + { + name: "short hash and short uuid (less than 8)", + prefixHash: "abc", + uuid: "xyz", + want: "gemini:digest:abc:xyz", + }, + { + name: "empty hash and uuid", + prefixHash: "", + uuid: "", + want: "gemini:digest::", + }, + { + name: "normal prefix with short uuid", + prefixHash: "abcdefgh12345678", + uuid: "short", + want: "gemini:digest:abcdefgh:short", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GenerateGeminiDigestSessionKey(tt.prefixHash, tt.uuid) + if got != tt.want { + t.Errorf("GenerateGeminiDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want) + } + }) + } + + // 验证确定性:相同输入产生相同输出 + t.Run("deterministic", func(t *testing.T) { + hash := "testprefix123456" + uuid := "test-uuid-12345" + result1 := GenerateGeminiDigestSessionKey(hash, uuid) + result2 := GenerateGeminiDigestSessionKey(hash, uuid) + if result1 != result2 { + t.Errorf("GenerateGeminiDigestSessionKey not deterministic: %s vs %s", result1, result2) + } + }) + + // 验证不同 uuid 产生不同 sessionKey(负载均衡核心逻辑) + t.Run("different uuid different key", func(t *testing.T) { + hash := "sameprefix123456" + uuid1 := "uuid0001-session-a" + uuid2 := "uuid0002-session-b" + result1 := GenerateGeminiDigestSessionKey(hash, uuid1) + result2 := GenerateGeminiDigestSessionKey(hash, uuid2) + if result1 == result2 { + t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2) + } + }) +} + +func TestBuildGeminiTrieKey(t *testing.T) { + tests := []struct { + name string + groupID int64 + prefixHash string + want string + }{ + { + name: "normal", + groupID: 123, + prefixHash: "abcdef12", + want: "gemini:trie:123:abcdef12", + }, + { + name: "zero group", + groupID: 0, + prefixHash: "xyz", + want: "gemini:trie:0:xyz", + }, + { + name: "empty prefix", + groupID: 1, + prefixHash: "", + want: "gemini:trie:1:", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash) + if got != tt.want { + t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want) + } + }) + } +} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index d6d1269b..1302047a 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -29,6 +29,8 @@ type Group struct { // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 + // 无效请求兜底分组(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置 // key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*") @@ -36,6 +38,13 @@ type Group struct { ModelRouting map[string][]int64 ModelRoutingEnabled bool + // MCP XML 协议注入开关(仅 antigravity 平台使用) + MCPXMLInject bool + + // 支持的模型系列(仅 antigravity 平台使用) + // 可选值: claude, gemini_text, gemini_image + SupportedModelScopes []string + CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index a620ac4d..261da0ef 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -169,22 +169,31 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) { // RewriteUserID 重写body中的metadata.user_id // 输入格式:user_{clientId}_account__session_{sessionUUID} // 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash} +// +// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, +// 避免重新序列化导致 thinking 块等内容被修改。 func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) { if len(body) == 0 || accountUUID == "" || cachedClientID == "" { return body, nil } - // 解析JSON - var reqMap map[string]any + // 使用 RawMessage 保留其他字段的原始字节 + var reqMap map[string]json.RawMessage if err := json.Unmarshal(body, &reqMap); err != nil { return body, nil } - metadata, ok := reqMap["metadata"].(map[string]any) + // 解析 metadata 字段 + metadataRaw, ok := reqMap["metadata"] if !ok { return body, nil } + var metadata map[string]any + if err := json.Unmarshal(metadataRaw, &metadata); err != nil { + return body, nil + } + userID, ok := metadata["user_id"].(string) if !ok || userID == "" { return body, nil @@ -207,7 +216,13 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash) metadata["user_id"] = newUserID - reqMap["metadata"] = metadata + + // 只重新序列化 metadata 字段 + newMetadataRaw, err := json.Marshal(metadata) + if err != nil { + return body, nil + } + reqMap["metadata"] = newMetadataRaw return json.Marshal(reqMap) } @@ -215,6 +230,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI // RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装 // 如果账号启用了会话ID伪装(session_id_masking_enabled), // 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变) +// +// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, +// 避免重新序列化导致 thinking 块等内容被修改。 func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) { // 先执行常规的 RewriteUserID 逻辑 newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID) @@ -227,17 +245,23 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b return newBody, nil } - // 解析重写后的 body,提取 user_id - var reqMap map[string]any + // 使用 RawMessage 保留其他字段的原始字节 + var reqMap map[string]json.RawMessage if err := json.Unmarshal(newBody, &reqMap); err != nil { return newBody, nil } - metadata, ok := reqMap["metadata"].(map[string]any) + // 解析 metadata 字段 + metadataRaw, ok := reqMap["metadata"] if !ok { return newBody, nil } + var metadata map[string]any + if err := json.Unmarshal(metadataRaw, &metadata); err != nil { + return newBody, nil + } + userID, ok := metadata["user_id"].(string) if !ok || userID == "" { return newBody, nil @@ -278,7 +302,13 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b ) metadata["user_id"] = newUserID - reqMap["metadata"] = metadata + + // 只重新序列化 metadata 字段 + newMetadataRaw, marshalErr := json.Marshal(metadata) + if marshalErr != nil { + return newBody, nil + } + reqMap["metadata"] = newMetadataRaw return json.Marshal(reqMap) } diff --git a/backend/internal/service/model_rate_limit.go b/backend/internal/service/model_rate_limit.go index 49354a7f..ff4b5977 100644 --- a/backend/internal/service/model_rate_limit.go +++ b/backend/internal/service/model_rate_limit.go @@ -1,35 +1,82 @@ package service import ( + "context" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" ) const modelRateLimitsKey = "model_rate_limits" -const modelRateLimitScopeClaudeSonnet = "claude_sonnet" -func resolveModelRateLimitScope(requestedModel string) (string, bool) { - model := strings.ToLower(strings.TrimSpace(requestedModel)) - if model == "" { - return "", false - } - model = strings.TrimPrefix(model, "models/") - if strings.Contains(model, "sonnet") { - return modelRateLimitScopeClaudeSonnet, true - } - return "", false +// isRateLimitActiveForKey 检查指定 key 的限流是否生效 +func (a *Account) isRateLimitActiveForKey(key string) bool { + resetAt := a.modelRateLimitResetAt(key) + return resetAt != nil && time.Now().Before(*resetAt) } -func (a *Account) isModelRateLimited(requestedModel string) bool { - scope, ok := resolveModelRateLimitScope(requestedModel) - if !ok { - return false - } - resetAt := a.modelRateLimitResetAt(scope) +// getRateLimitRemainingForKey 获取指定 key 的限流剩余时间,0 表示未限流或已过期 +func (a *Account) getRateLimitRemainingForKey(key string) time.Duration { + resetAt := a.modelRateLimitResetAt(key) if resetAt == nil { + return 0 + } + remaining := time.Until(*resetAt) + if remaining > 0 { + return remaining + } + return 0 +} + +func (a *Account) isModelRateLimitedWithContext(ctx context.Context, requestedModel string) bool { + if a == nil { return false } - return time.Now().Before(*resetAt) + + modelKey := a.GetMappedModel(requestedModel) + if a.Platform == PlatformAntigravity { + modelKey = resolveFinalAntigravityModelKey(ctx, a, requestedModel) + } + modelKey = strings.TrimSpace(modelKey) + if modelKey == "" { + return false + } + return a.isRateLimitActiveForKey(modelKey) +} + +// GetModelRateLimitRemainingTime 获取模型限流剩余时间 +// 返回 0 表示未限流或已过期 +func (a *Account) GetModelRateLimitRemainingTime(requestedModel string) time.Duration { + return a.GetModelRateLimitRemainingTimeWithContext(context.Background(), requestedModel) +} + +func (a *Account) GetModelRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration { + if a == nil { + return 0 + } + + modelKey := a.GetMappedModel(requestedModel) + if a.Platform == PlatformAntigravity { + modelKey = resolveFinalAntigravityModelKey(ctx, a, requestedModel) + } + modelKey = strings.TrimSpace(modelKey) + if modelKey == "" { + return 0 + } + return a.getRateLimitRemainingForKey(modelKey) +} + +func resolveFinalAntigravityModelKey(ctx context.Context, account *Account, requestedModel string) string { + modelKey := mapAntigravityModel(account, requestedModel) + if modelKey == "" { + return "" + } + // thinking 会影响 Antigravity 最终模型名(例如 claude-sonnet-4-5 -> claude-sonnet-4-5-thinking) + if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + modelKey = applyThinkingModelSuffix(modelKey, enabled) + } + return modelKey } func (a *Account) modelRateLimitResetAt(scope string) *time.Time { diff --git a/backend/internal/service/model_rate_limit_test.go b/backend/internal/service/model_rate_limit_test.go new file mode 100644 index 00000000..a51e6909 --- /dev/null +++ b/backend/internal/service/model_rate_limit_test.go @@ -0,0 +1,537 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" +) + +func TestIsModelRateLimited(t *testing.T) { + now := time.Now() + future := now.Add(10 * time.Minute).Format(time.RFC3339) + past := now.Add(-10 * time.Minute).Format(time.RFC3339) + + tests := []struct { + name string + account *Account + requestedModel string + expected bool + }{ + { + name: "official model ID hit - claude-sonnet-4-5", + account: &Account{ + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": future, + }, + }, + }, + }, + requestedModel: "claude-sonnet-4-5", + expected: true, + }, + { + name: "official model ID hit via mapping - request claude-3-5-sonnet, mapped to claude-sonnet-4-5", + account: &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "claude-sonnet-4-5", + }, + }, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": future, + }, + }, + }, + }, + requestedModel: "claude-3-5-sonnet", + expected: true, + }, + { + name: "no rate limit - expired", + account: &Account{ + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": past, + }, + }, + }, + }, + requestedModel: "claude-sonnet-4-5", + expected: false, + }, + { + name: "no rate limit - no matching key", + account: &Account{ + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "gemini-3-flash": map[string]any{ + "rate_limit_reset_at": future, + }, + }, + }, + }, + requestedModel: "claude-sonnet-4-5", + expected: false, + }, + { + name: "no rate limit - unsupported model", + account: &Account{}, + requestedModel: "gpt-4", + expected: false, + }, + { + name: "no rate limit - empty model", + account: &Account{}, + requestedModel: "", + expected: false, + }, + { + name: "gemini model hit", + account: &Account{ + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "gemini-3-pro-high": map[string]any{ + "rate_limit_reset_at": future, + }, + }, + }, + }, + requestedModel: "gemini-3-pro-high", + expected: true, + }, + { + name: "antigravity platform - gemini-3-pro-preview mapped to gemini-3-pro-high", + account: &Account{ + Platform: PlatformAntigravity, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "gemini-3-pro-high": map[string]any{ + "rate_limit_reset_at": future, + }, + }, + }, + }, + requestedModel: "gemini-3-pro-preview", + expected: true, + }, + { + name: "non-antigravity platform - gemini-3-pro-preview NOT mapped", + account: &Account{ + Platform: PlatformGemini, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "gemini-3-pro-high": map[string]any{ + "rate_limit_reset_at": future, + }, + }, + }, + }, + requestedModel: "gemini-3-pro-preview", + expected: false, // gemini 平台不走 antigravity 映射 + }, + { + name: "antigravity platform - claude-opus-4-5-thinking mapped to opus-4-6-thinking", + account: &Account{ + Platform: PlatformAntigravity, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-opus-4-6-thinking": map[string]any{ + "rate_limit_reset_at": future, + }, + }, + }, + }, + requestedModel: "claude-opus-4-5-thinking", + expected: true, + }, + { + name: "no scope fallback - claude_sonnet should not match", + account: &Account{ + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude_sonnet": map[string]any{ + "rate_limit_reset_at": future, + }, + }, + }, + }, + requestedModel: "claude-3-5-sonnet-20241022", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.isModelRateLimitedWithContext(context.Background(), tt.requestedModel) + if result != tt.expected { + t.Errorf("isModelRateLimited(%q) = %v, want %v", tt.requestedModel, result, tt.expected) + } + }) + } +} + +func TestIsModelRateLimited_Antigravity_ThinkingAffectsModelKey(t *testing.T) { + now := time.Now() + future := now.Add(10 * time.Minute).Format(time.RFC3339) + + account := &Account{ + Platform: PlatformAntigravity, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5-thinking": map[string]any{ + "rate_limit_reset_at": future, + }, + }, + }, + } + + ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true) + if !account.isModelRateLimitedWithContext(ctx, "claude-sonnet-4-5") { + t.Errorf("expected model to be rate limited") + } +} + +func TestGetModelRateLimitRemainingTime(t *testing.T) { + now := time.Now() + future10m := now.Add(10 * time.Minute).Format(time.RFC3339) + future5m := now.Add(5 * time.Minute).Format(time.RFC3339) + past := now.Add(-10 * time.Minute).Format(time.RFC3339) + + tests := []struct { + name string + account *Account + requestedModel string + minExpected time.Duration + maxExpected time.Duration + }{ + { + name: "nil account", + account: nil, + requestedModel: "claude-sonnet-4-5", + minExpected: 0, + maxExpected: 0, + }, + { + name: "model rate limited - direct hit", + account: &Account{ + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": future10m, + }, + }, + }, + }, + requestedModel: "claude-sonnet-4-5", + minExpected: 9 * time.Minute, + maxExpected: 11 * time.Minute, + }, + { + name: "model rate limited - via mapping", + account: &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "claude-sonnet-4-5", + }, + }, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": future5m, + }, + }, + }, + }, + requestedModel: "claude-3-5-sonnet", + minExpected: 4 * time.Minute, + maxExpected: 6 * time.Minute, + }, + { + name: "expired rate limit", + account: &Account{ + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": past, + }, + }, + }, + }, + requestedModel: "claude-sonnet-4-5", + minExpected: 0, + maxExpected: 0, + }, + { + name: "no rate limit data", + account: &Account{}, + requestedModel: "claude-sonnet-4-5", + minExpected: 0, + maxExpected: 0, + }, + { + name: "no scope fallback", + account: &Account{ + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude_sonnet": map[string]any{ + "rate_limit_reset_at": future5m, + }, + }, + }, + }, + requestedModel: "claude-3-5-sonnet-20241022", + minExpected: 0, + maxExpected: 0, + }, + { + name: "antigravity platform - claude-opus-4-5-thinking mapped to opus-4-6-thinking", + account: &Account{ + Platform: PlatformAntigravity, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-opus-4-6-thinking": map[string]any{ + "rate_limit_reset_at": future5m, + }, + }, + }, + }, + requestedModel: "claude-opus-4-5-thinking", + minExpected: 4 * time.Minute, + maxExpected: 6 * time.Minute, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetModelRateLimitRemainingTimeWithContext(context.Background(), tt.requestedModel) + if result < tt.minExpected || result > tt.maxExpected { + t.Errorf("GetModelRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected) + } + }) + } +} + +func TestGetQuotaScopeRateLimitRemainingTime(t *testing.T) { + now := time.Now() + future10m := now.Add(10 * time.Minute).Format(time.RFC3339) + past := now.Add(-10 * time.Minute).Format(time.RFC3339) + + tests := []struct { + name string + account *Account + requestedModel string + minExpected time.Duration + maxExpected time.Duration + }{ + { + name: "nil account", + account: nil, + requestedModel: "claude-sonnet-4-5", + minExpected: 0, + maxExpected: 0, + }, + { + name: "non-antigravity platform", + account: &Account{ + Platform: PlatformAnthropic, + Extra: map[string]any{ + antigravityQuotaScopesKey: map[string]any{ + "claude": map[string]any{ + "rate_limit_reset_at": future10m, + }, + }, + }, + }, + requestedModel: "claude-sonnet-4-5", + minExpected: 0, + maxExpected: 0, + }, + { + name: "claude scope rate limited", + account: &Account{ + Platform: PlatformAntigravity, + Extra: map[string]any{ + antigravityQuotaScopesKey: map[string]any{ + "claude": map[string]any{ + "rate_limit_reset_at": future10m, + }, + }, + }, + }, + requestedModel: "claude-sonnet-4-5", + minExpected: 9 * time.Minute, + maxExpected: 11 * time.Minute, + }, + { + name: "gemini_text scope rate limited", + account: &Account{ + Platform: PlatformAntigravity, + Extra: map[string]any{ + antigravityQuotaScopesKey: map[string]any{ + "gemini_text": map[string]any{ + "rate_limit_reset_at": future10m, + }, + }, + }, + }, + requestedModel: "gemini-3-flash", + minExpected: 9 * time.Minute, + maxExpected: 11 * time.Minute, + }, + { + name: "expired scope rate limit", + account: &Account{ + Platform: PlatformAntigravity, + Extra: map[string]any{ + antigravityQuotaScopesKey: map[string]any{ + "claude": map[string]any{ + "rate_limit_reset_at": past, + }, + }, + }, + }, + requestedModel: "claude-sonnet-4-5", + minExpected: 0, + maxExpected: 0, + }, + { + name: "unsupported model", + account: &Account{ + Platform: PlatformAntigravity, + }, + requestedModel: "gpt-4", + minExpected: 0, + maxExpected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetQuotaScopeRateLimitRemainingTime(tt.requestedModel) + if result < tt.minExpected || result > tt.maxExpected { + t.Errorf("GetQuotaScopeRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected) + } + }) + } +} + +func TestGetRateLimitRemainingTime(t *testing.T) { + now := time.Now() + future15m := now.Add(15 * time.Minute).Format(time.RFC3339) + future5m := now.Add(5 * time.Minute).Format(time.RFC3339) + + tests := []struct { + name string + account *Account + requestedModel string + minExpected time.Duration + maxExpected time.Duration + }{ + { + name: "nil account", + account: nil, + requestedModel: "claude-sonnet-4-5", + minExpected: 0, + maxExpected: 0, + }, + { + name: "model remaining > scope remaining - returns model", + account: &Account{ + Platform: PlatformAntigravity, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": future15m, // 15 分钟 + }, + }, + antigravityQuotaScopesKey: map[string]any{ + "claude": map[string]any{ + "rate_limit_reset_at": future5m, // 5 分钟 + }, + }, + }, + }, + requestedModel: "claude-sonnet-4-5", + minExpected: 14 * time.Minute, // 应返回较大的 15 分钟 + maxExpected: 16 * time.Minute, + }, + { + name: "scope remaining > model remaining - returns scope", + account: &Account{ + Platform: PlatformAntigravity, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": future5m, // 5 分钟 + }, + }, + antigravityQuotaScopesKey: map[string]any{ + "claude": map[string]any{ + "rate_limit_reset_at": future15m, // 15 分钟 + }, + }, + }, + }, + requestedModel: "claude-sonnet-4-5", + minExpected: 14 * time.Minute, // 应返回较大的 15 分钟 + maxExpected: 16 * time.Minute, + }, + { + name: "only model rate limited", + account: &Account{ + Platform: PlatformAntigravity, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": future5m, + }, + }, + }, + }, + requestedModel: "claude-sonnet-4-5", + minExpected: 4 * time.Minute, + maxExpected: 6 * time.Minute, + }, + { + name: "only scope rate limited", + account: &Account{ + Platform: PlatformAntigravity, + Extra: map[string]any{ + antigravityQuotaScopesKey: map[string]any{ + "claude": map[string]any{ + "rate_limit_reset_at": future5m, + }, + }, + }, + }, + requestedModel: "claude-sonnet-4-5", + minExpected: 4 * time.Minute, + maxExpected: 6 * time.Minute, + }, + { + name: "neither rate limited", + account: &Account{ + Platform: PlatformAntigravity, + }, + requestedModel: "claude-sonnet-4-5", + minExpected: 0, + maxExpected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetRateLimitRemainingTimeWithContext(context.Background(), tt.requestedModel) + if result < tt.minExpected || result > tt.maxExpected { + t.Errorf("GetRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected) + } + }) + } +} diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 48c72593..cea81693 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -21,6 +21,17 @@ const ( var codexCLIInstructions string var codexModelMap = map[string]string{ + "gpt-5.3": "gpt-5.3", + "gpt-5.3-none": "gpt-5.3", + "gpt-5.3-low": "gpt-5.3", + "gpt-5.3-medium": "gpt-5.3", + "gpt-5.3-high": "gpt-5.3", + "gpt-5.3-xhigh": "gpt-5.3", + "gpt-5.3-codex": "gpt-5.3-codex", + "gpt-5.3-codex-low": "gpt-5.3-codex", + "gpt-5.3-codex-medium": "gpt-5.3-codex", + "gpt-5.3-codex-high": "gpt-5.3-codex", + "gpt-5.3-codex-xhigh": "gpt-5.3-codex", "gpt-5.1-codex": "gpt-5.1-codex", "gpt-5.1-codex-low": "gpt-5.1-codex", "gpt-5.1-codex-medium": "gpt-5.1-codex", @@ -72,7 +83,7 @@ type opencodeCacheMetadata struct { LastChecked int64 `json:"lastChecked"` } -func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { +func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult { result := codexTransformResult{} // 工具续链需求会影响存储策略与 input 过滤逻辑。 needsToolContinuation := NeedsToolContinuation(reqBody) @@ -118,22 +129,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { result.PromptCacheKey = strings.TrimSpace(v) } - instructions := strings.TrimSpace(getOpenCodeCodexHeader()) - existingInstructions, _ := reqBody["instructions"].(string) - existingInstructions = strings.TrimSpace(existingInstructions) - - if instructions != "" { - if existingInstructions != instructions { - reqBody["instructions"] = instructions - result.Modified = true - } - } else if existingInstructions == "" { - // 未获取到 opencode 指令时,回退使用 Codex CLI 指令。 - codexInstructions := strings.TrimSpace(getCodexCLIInstructions()) - if codexInstructions != "" { - reqBody["instructions"] = codexInstructions - result.Modified = true - } + // instructions 处理逻辑:根据是否是 Codex CLI 分别调用不同方法 + if applyInstructions(reqBody, isCodexCLI) { + result.Modified = true } // 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。 @@ -169,6 +167,12 @@ func normalizeCodexModel(model string) string { if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") { return "gpt-5.2" } + if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") { + return "gpt-5.3-codex" + } + if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") { + return "gpt-5.3" + } if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") { return "gpt-5.1-codex-max" } @@ -276,40 +280,48 @@ func GetCodexCLIInstructions() string { return getCodexCLIInstructions() } -// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。 -func ReplaceWithCodexInstructions(reqBody map[string]any) bool { - codexInstructions := strings.TrimSpace(getCodexCLIInstructions()) - if codexInstructions == "" { - return false +// applyInstructions 处理 instructions 字段 +// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令) +// isCodexCLI=false: 优先使用 opencode 指令覆盖 +func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { + if isCodexCLI { + return applyCodexCLIInstructions(reqBody) + } + return applyOpenCodeInstructions(reqBody) +} + +// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions +// 仅在 instructions 为空时添加 opencode 指令 +func applyCodexCLIInstructions(reqBody map[string]any) bool { + if !isInstructionsEmpty(reqBody) { + return false // 已有有效 instructions,不修改 } - existingInstructions, _ := reqBody["instructions"].(string) - if strings.TrimSpace(existingInstructions) != codexInstructions { - reqBody["instructions"] = codexInstructions + instructions := strings.TrimSpace(getOpenCodeCodexHeader()) + if instructions != "" { + reqBody["instructions"] = instructions return true } return false } -// IsInstructionError 判断错误信息是否与指令格式/系统提示相关。 -func IsInstructionError(errorMessage string) bool { - if errorMessage == "" { - return false - } +// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令 +// 优先使用 opencode 指令覆盖 +func applyOpenCodeInstructions(reqBody map[string]any) bool { + instructions := strings.TrimSpace(getOpenCodeCodexHeader()) + existingInstructions, _ := reqBody["instructions"].(string) + existingInstructions = strings.TrimSpace(existingInstructions) - lowerMsg := strings.ToLower(errorMessage) - instructionKeywords := []string{ - "instruction", - "instructions", - "system prompt", - "system message", - "invalid prompt", - "prompt format", - } - - for _, keyword := range instructionKeywords { - if strings.Contains(lowerMsg, keyword) { + if instructions != "" { + if existingInstructions != instructions { + reqBody["instructions"] = instructions + return true + } + } else if existingInstructions == "" { + codexInstructions := strings.TrimSpace(getCodexCLIInstructions()) + if codexInstructions != "" { + reqBody["instructions"] = codexInstructions return true } } @@ -317,6 +329,23 @@ func IsInstructionError(errorMessage string) bool { return false } +// isInstructionsEmpty 检查 instructions 字段是否为空 +// 处理以下情况:字段不存在、nil、空字符串、纯空白字符串 +func isInstructionsEmpty(reqBody map[string]any) bool { + val, exists := reqBody["instructions"] + if !exists { + return true + } + if val == nil { + return true + } + str, ok := val.(string) + if !ok { + return true + } + return strings.TrimSpace(str) == "" +} + // filterCodexInput 按需过滤 item_reference 与 id。 // preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。 func filterCodexInput(input []any, preserveReferences bool) []any { diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 4cd72ab6..cc0acafc 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -23,7 +23,7 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { "tool_choice": "auto", } - applyCodexOAuthTransform(reqBody) + applyCodexOAuthTransform(reqBody, false) // 未显式设置 store=true,默认为 false。 store, ok := reqBody["store"].(bool) @@ -59,7 +59,7 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { "tool_choice": "auto", } - applyCodexOAuthTransform(reqBody) + applyCodexOAuthTransform(reqBody, false) store, ok := reqBody["store"].(bool) require.True(t, ok) @@ -79,7 +79,7 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { "tool_choice": "auto", } - applyCodexOAuthTransform(reqBody) + applyCodexOAuthTransform(reqBody, false) store, ok := reqBody["store"].(bool) require.True(t, ok) @@ -97,7 +97,7 @@ func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs( }, } - applyCodexOAuthTransform(reqBody) + applyCodexOAuthTransform(reqBody, false) store, ok := reqBody["store"].(bool) require.True(t, ok) @@ -148,7 +148,7 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction }, } - applyCodexOAuthTransform(reqBody) + applyCodexOAuthTransform(reqBody, false) tools, ok := reqBody["tools"].([]any) require.True(t, ok) @@ -169,19 +169,88 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { "input": []any{}, } - applyCodexOAuthTransform(reqBody) + applyCodexOAuthTransform(reqBody, false) input, ok := reqBody["input"].([]any) require.True(t, ok) require.Len(t, input, 0) } +func TestNormalizeCodexModel_Gpt53(t *testing.T) { + cases := map[string]string{ + "gpt-5.3": "gpt-5.3", + "gpt-5.3-codex": "gpt-5.3-codex", + "gpt-5.3-codex-xhigh": "gpt-5.3-codex", + "gpt 5.3 codex": "gpt-5.3-codex", + } + + for input, expected := range cases { + require.Equal(t, expected, normalizeCodexModel(input)) + } + +} + +func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) { + // Codex CLI 场景:已有 instructions 时保持不变 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "instructions": "user custom instructions", + "input": []any{}, + } + + result := applyCodexOAuthTransform(reqBody, true) + + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.Equal(t, "user custom instructions", instructions) + // instructions 未变,但其他字段(如 store、stream)可能被修改 + require.True(t, result.Modified) +} + +func TestApplyCodexOAuthTransform_CodexCLI_AddsInstructionsWhenEmpty(t *testing.T) { + // Codex CLI 场景:无 instructions 时补充内置指令 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "input": []any{}, + } + + result := applyCodexOAuthTransform(reqBody, true) + + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.NotEmpty(t, instructions) + require.True(t, result.Modified) +} + +func TestApplyCodexOAuthTransform_NonCodexCLI_UsesOpenCodeInstructions(t *testing.T) { + // 非 Codex CLI 场景:使用 opencode 指令(缓存中有 header) + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "input": []any{}, + } + + result := applyCodexOAuthTransform(reqBody, false) + + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.Equal(t, "header", instructions) // setupCodexCache 设置的缓存内容 + require.True(t, result.Modified) +} + func setupCodexCache(t *testing.T) { t.Helper() // 使用临时 HOME 避免触发网络拉取 header。 + // Windows 使用 USERPROFILE,Unix 使用 HOME。 tempDir := t.TempDir() t.Setenv("HOME", tempDir) + t.Setenv("USERPROFILE", tempDir) cacheDir := filepath.Join(tempDir, ".opencode", "cache") require.NoError(t, os.MkdirAll(cacheDir, 0o755)) @@ -196,3 +265,59 @@ func setupCodexCache(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644)) } + +func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) { + // Codex CLI 场景:无 instructions 时补充默认值 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + // 没有 instructions 字段 + } + + result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true + + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.NotEmpty(t, instructions) + require.True(t, result.Modified) +} + +func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) { + // 非 Codex CLI 场景:使用 opencode 指令覆盖 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "instructions": "old instructions", + } + + result := applyCodexOAuthTransform(reqBody, false) // isCodexCLI=false + + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.NotEqual(t, "old instructions", instructions) + require.True(t, result.Modified) +} + +func TestIsInstructionsEmpty(t *testing.T) { + tests := []struct { + name string + reqBody map[string]any + expected bool + }{ + {"missing field", map[string]any{}, true}, + {"nil value", map[string]any{"instructions": nil}, true}, + {"empty string", map[string]any{"instructions": ""}, true}, + {"whitespace only", map[string]any{"instructions": " "}, true}, + {"non-string", map[string]any{"instructions": 123}, true}, + {"valid string", map[string]any{"instructions": "hello"}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isInstructionsEmpty(tt.reqBody) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 6d93e92d..fbe81cb4 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -332,7 +332,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 检查账号是否需要清理粘性会话 // Check if sticky session should be cleared - if shouldClearStickySession(account) { + if shouldClearStickySession(account, requestedModel) { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey) return nil } @@ -498,7 +498,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if err == nil && accountID > 0 && !isExcluded(accountID) { account, err := s.getSchedulableAccount(ctx, accountID) if err == nil { - clearSticky := shouldClearStickySession(account) + clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) } @@ -796,8 +796,8 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } - if account.Type == AccountTypeOAuth && !isCodexCLI { - codexResult := applyCodexOAuthTransform(reqBody) + if account.Type == AccountTypeOAuth { + codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI) if codexResult.Modified { bodyModified = true } @@ -846,10 +846,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } - // Remove prompt_cache_retention (not supported by upstream OpenAI API) - if _, has := reqBody["prompt_cache_retention"]; has { - delete(reqBody, "prompt_cache_retention") - bodyModified = true + // Remove unsupported fields (not supported by upstream OpenAI API) + for _, unsupportedField := range []string{"prompt_cache_retention", "safety_identifier", "previous_response_id"} { + if _, has := reqBody[unsupportedField]; has { + delete(reqBody, unsupportedField) + bodyModified = true + } } } @@ -938,7 +940,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco }) s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } return s.handleErrorResponse(ctx, resp, c, account) } @@ -1085,6 +1087,30 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht ) } + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformOpenAI, + resp.StatusCode, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) + } + // Check custom error codes if !account.ShouldHandleErrorCode(resp.StatusCode) { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ @@ -1129,7 +1155,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht Detail: upstreamDetail, }) if shouldDisable { - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} } // Return appropriate error response @@ -1681,13 +1707,14 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { - Result *OpenAIForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 + Result *OpenAIForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + APIKeyService APIKeyQuotaUpdater } // RecordUsage records usage and deducts balance @@ -1799,6 +1826,13 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec } } + // Update API key quota if applicable (only for balance mode with quota set) + if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { + if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { + log.Printf("Update API key quota failed: %v", err) + } + } + // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index ae69a986..1c2c81ca 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -204,6 +204,22 @@ func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID i return nil } +func (c *stubGatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { + return 0, nil +} + +func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) { + return nil, nil +} + +func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { + return "", 0, false +} + +func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { + return nil +} + func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) { now := time.Now() resetAt := now.Add(10 * time.Minute) diff --git a/backend/internal/service/ops_account_availability.go b/backend/internal/service/ops_account_availability.go index 9be06c15..a649e7b5 100644 --- a/backend/internal/service/ops_account_availability.go +++ b/backend/internal/service/ops_account_availability.go @@ -66,7 +66,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi } isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched - scopeRateLimits := acc.GetAntigravityScopeRateLimits() if acc.Platform != "" { diff --git a/backend/internal/service/ops_concurrency.go b/backend/internal/service/ops_concurrency.go index c3b7b853..f6541d08 100644 --- a/backend/internal/service/ops_concurrency.go +++ b/backend/internal/service/ops_concurrency.go @@ -255,3 +255,142 @@ func (s *OpsService) GetConcurrencyStats( return platform, group, account, &collectedAt, nil } + +// listAllActiveUsersForOps returns all active users with their concurrency settings. +func (s *OpsService) listAllActiveUsersForOps(ctx context.Context) ([]User, error) { + if s == nil || s.userRepo == nil { + return []User{}, nil + } + + out := make([]User, 0, 128) + page := 1 + for { + users, pageInfo, err := s.userRepo.ListWithFilters(ctx, pagination.PaginationParams{ + Page: page, + PageSize: opsAccountsPageSize, + }, UserListFilters{ + Status: StatusActive, + }) + if err != nil { + return nil, err + } + if len(users) == 0 { + break + } + + out = append(out, users...) + if pageInfo != nil && int64(len(out)) >= pageInfo.Total { + break + } + if len(users) < opsAccountsPageSize { + break + } + + page++ + if page > 10_000 { + log.Printf("[Ops] listAllActiveUsersForOps: aborting after too many pages") + break + } + } + + return out, nil +} + +// getUsersLoadMapBestEffort returns user load info for the given users. +func (s *OpsService) getUsersLoadMapBestEffort(ctx context.Context, users []User) map[int64]*UserLoadInfo { + if s == nil || s.concurrencyService == nil { + return map[int64]*UserLoadInfo{} + } + if len(users) == 0 { + return map[int64]*UserLoadInfo{} + } + + // De-duplicate IDs (and keep the max concurrency to avoid under-reporting). + unique := make(map[int64]int, len(users)) + for _, u := range users { + if u.ID <= 0 { + continue + } + if prev, ok := unique[u.ID]; !ok || u.Concurrency > prev { + unique[u.ID] = u.Concurrency + } + } + + batch := make([]UserWithConcurrency, 0, len(unique)) + for id, maxConc := range unique { + batch = append(batch, UserWithConcurrency{ + ID: id, + MaxConcurrency: maxConc, + }) + } + + out := make(map[int64]*UserLoadInfo, len(batch)) + for i := 0; i < len(batch); i += opsConcurrencyBatchChunkSize { + end := i + opsConcurrencyBatchChunkSize + if end > len(batch) { + end = len(batch) + } + part, err := s.concurrencyService.GetUsersLoadBatch(ctx, batch[i:end]) + if err != nil { + // Best-effort: return zeros rather than failing the ops UI. + log.Printf("[Ops] GetUsersLoadBatch failed: %v", err) + continue + } + for k, v := range part { + out[k] = v + } + } + + return out +} + +// GetUserConcurrencyStats returns real-time concurrency usage for all active users. +func (s *OpsService) GetUserConcurrencyStats(ctx context.Context) (map[int64]*UserConcurrencyInfo, *time.Time, error) { + if err := s.RequireMonitoringEnabled(ctx); err != nil { + return nil, nil, err + } + + users, err := s.listAllActiveUsersForOps(ctx) + if err != nil { + return nil, nil, err + } + + collectedAt := time.Now() + loadMap := s.getUsersLoadMapBestEffort(ctx, users) + + result := make(map[int64]*UserConcurrencyInfo) + + for _, u := range users { + if u.ID <= 0 { + continue + } + + load := loadMap[u.ID] + currentInUse := int64(0) + waiting := int64(0) + if load != nil { + currentInUse = int64(load.CurrentConcurrency) + waiting = int64(load.WaitingCount) + } + + // Skip users with no concurrency activity + if currentInUse == 0 && waiting == 0 { + continue + } + + info := &UserConcurrencyInfo{ + UserID: u.ID, + UserEmail: u.Email, + Username: u.Username, + CurrentInUse: currentInUse, + MaxCapacity: int64(u.Concurrency), + WaitingInQueue: waiting, + } + if info.MaxCapacity > 0 { + info.LoadPercentage = float64(info.CurrentInUse) / float64(info.MaxCapacity) * 100 + } + result[u.ID] = info + } + + return result, &collectedAt, nil +} diff --git a/backend/internal/service/ops_metrics_collector.go b/backend/internal/service/ops_metrics_collector.go index edf32cf2..30adaae0 100644 --- a/backend/internal/service/ops_metrics_collector.go +++ b/backend/internal/service/ops_metrics_collector.go @@ -285,6 +285,11 @@ func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error { return fmt.Errorf("query error counts: %w", err) } + accountSwitchCount, err := c.queryAccountSwitchCount(ctx, windowStart, windowEnd) + if err != nil { + return fmt.Errorf("query account switch counts: %w", err) + } + windowSeconds := windowEnd.Sub(windowStart).Seconds() if windowSeconds <= 0 { windowSeconds = 60 @@ -309,9 +314,10 @@ func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error { Upstream429Count: upstream429, Upstream529Count: upstream529, - TokenConsumed: tokenConsumed, - QPS: float64Ptr(roundTo1DP(qps)), - TPS: float64Ptr(roundTo1DP(tps)), + TokenConsumed: tokenConsumed, + AccountSwitchCount: accountSwitchCount, + QPS: float64Ptr(roundTo1DP(qps)), + TPS: float64Ptr(roundTo1DP(tps)), DurationP50Ms: duration.p50, DurationP90Ms: duration.p90, @@ -551,6 +557,27 @@ WHERE created_at >= $1 AND created_at < $2` return errorTotal, businessLimited, errorSLA, upstreamExcl429529, upstream429, upstream529, nil } +func (c *OpsMetricsCollector) queryAccountSwitchCount(ctx context.Context, start, end time.Time) (int64, error) { + q := ` +SELECT + COALESCE(SUM(CASE + WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1 + ELSE 0 + END), 0) AS switch_count +FROM ops_error_logs o +CROSS JOIN LATERAL jsonb_array_elements( + COALESCE(NULLIF(o.upstream_errors, 'null'::jsonb), '[]'::jsonb) +) AS ev +WHERE o.created_at >= $1 AND o.created_at < $2 + AND o.is_count_tokens = FALSE` + + var count int64 + if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&count); err != nil { + return 0, err + } + return count, nil +} + type opsCollectedSystemStats struct { cpuUsagePercent *float64 memoryUsedMB *int64 diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go index 515b47bb..347b06b5 100644 --- a/backend/internal/service/ops_port.go +++ b/backend/internal/service/ops_port.go @@ -161,7 +161,8 @@ type OpsInsertSystemMetricsInput struct { Upstream429Count int64 Upstream529Count int64 - TokenConsumed int64 + TokenConsumed int64 + AccountSwitchCount int64 QPS *float64 TPS *float64 @@ -223,8 +224,9 @@ type OpsSystemMetricsSnapshot struct { DBConnIdle *int `json:"db_conn_idle"` DBConnWaiting *int `json:"db_conn_waiting"` - GoroutineCount *int `json:"goroutine_count"` - ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"` + GoroutineCount *int `json:"goroutine_count"` + ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"` + AccountSwitchCount *int64 `json:"account_switch_count"` } type OpsUpsertJobHeartbeatInput struct { diff --git a/backend/internal/service/ops_realtime_models.go b/backend/internal/service/ops_realtime_models.go index c7e5715b..33029f59 100644 --- a/backend/internal/service/ops_realtime_models.go +++ b/backend/internal/service/ops_realtime_models.go @@ -37,6 +37,17 @@ type AccountConcurrencyInfo struct { WaitingInQueue int64 `json:"waiting_in_queue"` } +// UserConcurrencyInfo represents real-time concurrency usage for a single user. +type UserConcurrencyInfo struct { + UserID int64 `json:"user_id"` + UserEmail string `json:"user_email"` + Username string `json:"username"` + CurrentInUse int64 `json:"current_in_use"` + MaxCapacity int64 `json:"max_capacity"` + LoadPercentage float64 `json:"load_percentage"` + WaitingInQueue int64 `json:"waiting_in_queue"` +} + // PlatformAvailability aggregates account availability by platform. type PlatformAvailability struct { Platform string `json:"platform"` diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index 8d98e43f..fbc800f2 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" "github.com/lib/pq" @@ -476,9 +477,13 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq continue } + attemptCtx := ctx + if switches > 0 { + attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches) + } exec := func() *opsRetryExecution { defer selection.ReleaseFunc() - return s.executeWithAccount(ctx, reqType, errorLog, body, account) + return s.executeWithAccount(attemptCtx, reqType, errorLog, body, account) }() if exec != nil { @@ -571,7 +576,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq action = "streamGenerateContent" } if account.Platform == PlatformAntigravity { - _, err = s.antigravityGatewayService.ForwardGemini(ctx, c, account, modelName, action, errorLog.Stream, body) + _, err = s.antigravityGatewayService.ForwardGemini(ctx, c, account, modelName, action, errorLog.Stream, body, false) } else { _, err = s.geminiCompatService.ForwardNative(ctx, c, account, modelName, action, errorLog.Stream, body) } @@ -581,7 +586,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq if s.antigravityGatewayService == nil { return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "antigravity gateway service not available"} } - _, err = s.antigravityGatewayService.Forward(ctx, c, account, body) + _, err = s.antigravityGatewayService.Forward(ctx, c, account, body, false) case PlatformGemini: if s.geminiCompatService == nil { return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gemini gateway service not available"} diff --git a/backend/internal/service/ops_service.go b/backend/internal/service/ops_service.go index abb8ae12..9c121b8b 100644 --- a/backend/internal/service/ops_service.go +++ b/backend/internal/service/ops_service.go @@ -27,6 +27,7 @@ type OpsService struct { cfg *config.Config accountRepo AccountRepository + userRepo UserRepository // getAccountAvailability is a unit-test hook for overriding account availability lookup. getAccountAvailability func(ctx context.Context, platformFilter string, groupIDFilter *int64) (*OpsAccountAvailability, error) @@ -43,6 +44,7 @@ func NewOpsService( settingRepo SettingRepository, cfg *config.Config, accountRepo AccountRepository, + userRepo UserRepository, concurrencyService *ConcurrencyService, gatewayService *GatewayService, openAIGatewayService *OpenAIGatewayService, @@ -55,6 +57,7 @@ func NewOpsService( cfg: cfg, accountRepo: accountRepo, + userRepo: userRepo, concurrencyService: concurrencyService, gatewayService: gatewayService, @@ -424,6 +427,26 @@ func isSensitiveKey(key string) bool { return false } + // Token 计数 / 预算字段不是凭据,应保留用于排错。 + // 白名单保持尽量窄,避免误把真实敏感信息"反脱敏"。 + switch k { + case "max_tokens", + "max_output_tokens", + "max_input_tokens", + "max_completion_tokens", + "max_tokens_to_sample", + "budget_tokens", + "prompt_tokens", + "completion_tokens", + "input_tokens", + "output_tokens", + "total_tokens", + "token_count", + "cache_creation_input_tokens", + "cache_read_input_tokens": + return false + } + // Exact matches (common credential fields). switch k { case "authorization", @@ -566,7 +589,18 @@ func trimArrayField(root map[string]any, field string, maxBytes int) (map[string func shrinkToEssentials(root map[string]any) map[string]any { out := make(map[string]any) - for _, key := range []string{"model", "stream", "max_tokens", "temperature", "top_p", "top_k"} { + for _, key := range []string{ + "model", + "stream", + "max_tokens", + "max_output_tokens", + "max_input_tokens", + "max_completion_tokens", + "thinking", + "temperature", + "top_p", + "top_k", + } { if v, ok := root[key]; ok { out[key] = v } diff --git a/backend/internal/service/ops_service_redaction_test.go b/backend/internal/service/ops_service_redaction_test.go new file mode 100644 index 00000000..e0aeafa5 --- /dev/null +++ b/backend/internal/service/ops_service_redaction_test.go @@ -0,0 +1,99 @@ +package service + +import ( + "encoding/json" + "testing" +) + +func TestIsSensitiveKey_TokenBudgetKeysNotRedacted(t *testing.T) { + t.Parallel() + + for _, key := range []string{ + "max_tokens", + "max_output_tokens", + "max_input_tokens", + "max_completion_tokens", + "max_tokens_to_sample", + "budget_tokens", + "prompt_tokens", + "completion_tokens", + "input_tokens", + "output_tokens", + "total_tokens", + "token_count", + } { + if isSensitiveKey(key) { + t.Fatalf("expected key %q to NOT be treated as sensitive", key) + } + } + + for _, key := range []string{ + "authorization", + "Authorization", + "access_token", + "refresh_token", + "id_token", + "session_token", + "token", + "client_secret", + "private_key", + "signature", + } { + if !isSensitiveKey(key) { + t.Fatalf("expected key %q to be treated as sensitive", key) + } + } +} + +func TestSanitizeAndTrimRequestBody_PreservesTokenBudgetFields(t *testing.T) { + t.Parallel() + + raw := []byte(`{"model":"claude-3","max_tokens":123,"thinking":{"type":"enabled","budget_tokens":456},"access_token":"abc","messages":[{"role":"user","content":"hi"}]}`) + out, _, _ := sanitizeAndTrimRequestBody(raw, 10*1024) + if out == "" { + t.Fatalf("expected non-empty sanitized output") + } + + var decoded map[string]any + if err := json.Unmarshal([]byte(out), &decoded); err != nil { + t.Fatalf("unmarshal sanitized output: %v", err) + } + + if got, ok := decoded["max_tokens"].(float64); !ok || got != 123 { + t.Fatalf("expected max_tokens=123, got %#v", decoded["max_tokens"]) + } + + thinking, ok := decoded["thinking"].(map[string]any) + if !ok || thinking == nil { + t.Fatalf("expected thinking object to be preserved, got %#v", decoded["thinking"]) + } + if got, ok := thinking["budget_tokens"].(float64); !ok || got != 456 { + t.Fatalf("expected thinking.budget_tokens=456, got %#v", thinking["budget_tokens"]) + } + + if got := decoded["access_token"]; got != "[REDACTED]" { + t.Fatalf("expected access_token to be redacted, got %#v", got) + } +} + +func TestShrinkToEssentials_IncludesThinking(t *testing.T) { + t.Parallel() + + root := map[string]any{ + "model": "claude-3", + "max_tokens": 100, + "thinking": map[string]any{ + "type": "enabled", + "budget_tokens": 200, + }, + "messages": []any{ + map[string]any{"role": "user", "content": "first"}, + map[string]any{"role": "user", "content": "last"}, + }, + } + + out := shrinkToEssentials(root) + if _, ok := out["thinking"]; !ok { + t.Fatalf("expected thinking to be included in essentials: %#v", out) + } +} diff --git a/backend/internal/service/ops_trend_models.go b/backend/internal/service/ops_trend_models.go index f6d07c14..97bbfebe 100644 --- a/backend/internal/service/ops_trend_models.go +++ b/backend/internal/service/ops_trend_models.go @@ -6,6 +6,7 @@ type OpsThroughputTrendPoint struct { BucketStart time.Time `json:"bucket_start"` RequestCount int64 `json:"request_count"` TokenConsumed int64 `json:"token_consumed"` + SwitchCount int64 `json:"switch_count"` QPS float64 `json:"qps"` TPS float64 `json:"tps"` } diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 0ade72cd..d8db0d67 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -579,6 +579,7 @@ func (s *PricingService) extractBaseName(model string) string { func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { // Claude模型系列匹配规则 familyPatterns := map[string][]string{ + "opus-4.6": {"claude-opus-4.6", "claude-opus-4-6"}, "opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"}, "opus-4": {"claude-opus-4", "claude-3-opus"}, "sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"}, @@ -651,7 +652,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { // 回退顺序: // 1. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等) // 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) -// 3. 最终回退到 DefaultTestModel (gpt-5.1-codex) +// 3. gpt-5.3-codex -> gpt-5.2-codex +// 4. 最终回退到 DefaultTestModel (gpt-5.1-codex) func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { // 尝试的回退变体 variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern) @@ -663,6 +665,13 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { } } + if strings.HasPrefix(model, "gpt-5.3-codex") { + if pricing, ok := s.pricingData["gpt-5.2-codex"]; ok { + log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex") + return pricing + } + } + // 最终回退到 DefaultTestModel defaultModel := strings.ToLower(openai.DefaultTestModel) if pricing, ok := s.pricingData[defaultModel]; ok { diff --git a/backend/internal/service/proxy_service.go b/backend/internal/service/proxy_service.go index a5d897f6..80045187 100644 --- a/backend/internal/service/proxy_service.go +++ b/backend/internal/service/proxy_service.go @@ -16,6 +16,7 @@ var ( type ProxyRepository interface { Create(ctx context.Context, proxy *Proxy) error GetByID(ctx context.Context, id int64) (*Proxy, error) + ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) Update(ctx context.Context, proxy *Proxy) error Delete(ctx context.Context, id int64) error diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 6b7ebb07..47286deb 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -387,14 +387,6 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head // 没有重置时间,使用默认5分钟 resetAt := time.Now().Add(5 * time.Minute) - if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) { - if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil { - slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err) - } else { - slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt) - } - return - } slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m") if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) @@ -407,14 +399,6 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head if err != nil { slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err) resetAt := time.Now().Add(5 * time.Minute) - if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) { - if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil { - slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err) - } else { - slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt) - } - return - } if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) } @@ -423,15 +407,6 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head resetAt := time.Unix(ts, 0) - if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) { - if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil { - slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err) - return - } - slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt) - return - } - // 标记限流状态 if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) @@ -448,17 +423,6 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt) } -func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, responseBody []byte) bool { - if account == nil || account.Platform != PlatformAnthropic { - return false - } - msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(responseBody))) - if msg == "" { - return false - } - return strings.Contains(msg, "sonnet") -} - // calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间 // 返回 nil 表示无法从响应头中确定重置时间 func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time { diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index adcafb3f..ad277ca0 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -49,6 +49,11 @@ type RedeemCodeRepository interface { List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) + // ListByUserPaginated returns paginated balance/concurrency history for a specific user. + // codeType filter is optional - pass empty string to return all types. + ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) + // SumPositiveBalanceByUser returns the total recharged amount (sum of positive balance values) for a user. + SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) } // GenerateCodesRequest 生成兑换码请求 diff --git a/backend/internal/service/refresh_token_cache.go b/backend/internal/service/refresh_token_cache.go new file mode 100644 index 00000000..91b3924f --- /dev/null +++ b/backend/internal/service/refresh_token_cache.go @@ -0,0 +1,73 @@ +package service + +import ( + "context" + "errors" + "time" +) + +// ErrRefreshTokenNotFound is returned when a refresh token is not found in cache. +// This is used to abstract away the underlying cache implementation (e.g., redis.Nil). +var ErrRefreshTokenNotFound = errors.New("refresh token not found") + +// RefreshTokenData 存储在Redis中的Refresh Token数据 +type RefreshTokenData struct { + UserID int64 `json:"user_id"` + TokenVersion int64 `json:"token_version"` // 用于检测密码更改后的Token失效 + FamilyID string `json:"family_id"` // Token家族ID,用于防重放攻击 + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` +} + +// RefreshTokenCache 管理Refresh Token的Redis缓存 +// 用于JWT Token刷新机制,支持Token轮转和防重放攻击 +// +// Key 格式: +// - refresh_token:{token_hash} -> RefreshTokenData (JSON) +// - user_refresh_tokens:{user_id} -> Set +// - token_family:{family_id} -> Set +type RefreshTokenCache interface { + // StoreRefreshToken 存储Refresh Token + // tokenHash: Token的SHA256哈希值(不存储原始Token) + // data: Token关联的数据 + // ttl: Token过期时间 + StoreRefreshToken(ctx context.Context, tokenHash string, data *RefreshTokenData, ttl time.Duration) error + + // GetRefreshToken 获取Refresh Token数据 + // 返回 (data, nil) 如果Token存在 + // 返回 (nil, ErrRefreshTokenNotFound) 如果Token不存在 + // 返回 (nil, err) 如果发生其他错误 + GetRefreshToken(ctx context.Context, tokenHash string) (*RefreshTokenData, error) + + // DeleteRefreshToken 删除单个Refresh Token + // 用于Token轮转时使旧Token失效 + DeleteRefreshToken(ctx context.Context, tokenHash string) error + + // DeleteUserRefreshTokens 删除用户的所有Refresh Token + // 用于密码更改或用户主动登出所有设备 + DeleteUserRefreshTokens(ctx context.Context, userID int64) error + + // DeleteTokenFamily 删除整个Token家族 + // 用于检测到Token重放攻击时,撤销整个会话链 + DeleteTokenFamily(ctx context.Context, familyID string) error + + // AddToUserTokenSet 将Token添加到用户的Token集合 + // 用于跟踪用户的所有活跃Refresh Token + AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error + + // AddToFamilyTokenSet 将Token添加到家族Token集合 + // 用于跟踪同一登录会话的所有Token + AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error + + // GetUserTokenHashes 获取用户的所有Token哈希 + // 用于批量删除用户Token + GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error) + + // GetFamilyTokenHashes 获取家族的所有Token哈希 + // 用于批量删除家族Token + GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error) + + // IsTokenInFamily 检查Token是否属于指定家族 + // 用于验证Token家族关系 + IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error) +} diff --git a/backend/internal/service/scheduler_layered_filter_test.go b/backend/internal/service/scheduler_layered_filter_test.go new file mode 100644 index 00000000..d012cf09 --- /dev/null +++ b/backend/internal/service/scheduler_layered_filter_test.go @@ -0,0 +1,264 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestFilterByMinPriority(t *testing.T) { + t.Run("empty slice", func(t *testing.T) { + result := filterByMinPriority(nil) + require.Empty(t, result) + }) + + t.Run("single account", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 5}, loadInfo: &AccountLoadInfo{}}, + } + result := filterByMinPriority(accounts) + require.Len(t, result, 1) + require.Equal(t, int64(1), result[0].account.ID) + }) + + t.Run("multiple accounts same priority", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 3}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 2, Priority: 3}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 3, Priority: 3}, loadInfo: &AccountLoadInfo{}}, + } + result := filterByMinPriority(accounts) + require.Len(t, result, 3) + }) + + t.Run("filters to min priority only", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 5}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 2, Priority: 1}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 3, Priority: 3}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 4, Priority: 1}, loadInfo: &AccountLoadInfo{}}, + } + result := filterByMinPriority(accounts) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID) + require.Equal(t, int64(4), result[1].account.ID) + }) +} + +func TestFilterByMinLoadRate(t *testing.T) { + t.Run("empty slice", func(t *testing.T) { + result := filterByMinLoadRate(nil) + require.Empty(t, result) + }) + + t.Run("single account", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1}, loadInfo: &AccountLoadInfo{LoadRate: 50}}, + } + result := filterByMinLoadRate(accounts) + require.Len(t, result, 1) + require.Equal(t, int64(1), result[0].account.ID) + }) + + t.Run("multiple accounts same load rate", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + {account: &Account{ID: 2}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + {account: &Account{ID: 3}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + } + result := filterByMinLoadRate(accounts) + require.Len(t, result, 3) + }) + + t.Run("filters to min load rate only", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1}, loadInfo: &AccountLoadInfo{LoadRate: 80}}, + {account: &Account{ID: 2}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + {account: &Account{ID: 3}, loadInfo: &AccountLoadInfo{LoadRate: 50}}, + {account: &Account{ID: 4}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + } + result := filterByMinLoadRate(accounts) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID) + require.Equal(t, int64(4), result[1].account.ID) + }) + + t.Run("zero load rate", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1}, loadInfo: &AccountLoadInfo{LoadRate: 0}}, + {account: &Account{ID: 2}, loadInfo: &AccountLoadInfo{LoadRate: 50}}, + {account: &Account{ID: 3}, loadInfo: &AccountLoadInfo{LoadRate: 0}}, + } + result := filterByMinLoadRate(accounts) + require.Len(t, result, 2) + require.Equal(t, int64(1), result[0].account.ID) + require.Equal(t, int64(3), result[1].account.ID) + }) +} + +func TestSelectByLRU(t *testing.T) { + now := time.Now() + earlier := now.Add(-1 * time.Hour) + muchEarlier := now.Add(-2 * time.Hour) + + t.Run("empty slice", func(t *testing.T) { + result := selectByLRU(nil, false) + require.Nil(t, result) + }) + + t.Run("single account", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{}}, + } + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.Equal(t, int64(1), result.account.ID) + }) + + t.Run("selects least recently used", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 2, LastUsedAt: &muchEarlier}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 3, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{}}, + } + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.Equal(t, int64(2), result.account.ID) + }) + + t.Run("nil LastUsedAt preferred over non-nil", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 2, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 3, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{}}, + } + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.Equal(t, int64(2), result.account.ID) + }) + + t.Run("multiple nil LastUsedAt random selection", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 2, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 3, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}}, + } + // 多次调用应该随机选择,验证结果都在候选范围内 + validIDs := map[int64]bool{1: true, 2: true, 3: true} + for i := 0; i < 10; i++ { + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.True(t, validIDs[result.account.ID], "selected ID should be one of the candidates") + } + }) + + t.Run("multiple same LastUsedAt random selection", func(t *testing.T) { + sameTime := now + accounts := []accountWithLoad{ + {account: &Account{ID: 1, LastUsedAt: &sameTime}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 2, LastUsedAt: &sameTime}, loadInfo: &AccountLoadInfo{}}, + } + // 多次调用应该随机选择 + validIDs := map[int64]bool{1: true, 2: true} + for i := 0; i < 10; i++ { + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.True(t, validIDs[result.account.ID], "selected ID should be one of the candidates") + } + }) + + t.Run("preferOAuth selects from OAuth accounts when multiple nil", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 2, LastUsedAt: nil, Type: AccountTypeOAuth}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 3, LastUsedAt: nil, Type: AccountTypeOAuth}, loadInfo: &AccountLoadInfo{}}, + } + // preferOAuth 时,应该从 OAuth 类型中选择 + oauthIDs := map[int64]bool{2: true, 3: true} + for i := 0; i < 10; i++ { + result := selectByLRU(accounts, true) + require.NotNil(t, result) + require.True(t, oauthIDs[result.account.ID], "should select from OAuth accounts") + } + }) + + t.Run("preferOAuth falls back to all when no OAuth", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 2, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}}, + } + // 没有 OAuth 时,从所有候选中选择 + validIDs := map[int64]bool{1: true, 2: true} + for i := 0; i < 10; i++ { + result := selectByLRU(accounts, true) + require.NotNil(t, result) + require.True(t, validIDs[result.account.ID]) + } + }) + + t.Run("preferOAuth only affects same LastUsedAt accounts", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, LastUsedAt: &earlier, Type: "session"}, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 2, LastUsedAt: &now, Type: AccountTypeOAuth}, loadInfo: &AccountLoadInfo{}}, + } + result := selectByLRU(accounts, true) + require.NotNil(t, result) + // 有不同 LastUsedAt 时,按时间选择最早的,不受 preferOAuth 影响 + require.Equal(t, int64(1), result.account.ID) + }) +} + +func TestLayeredFilterIntegration(t *testing.T) { + now := time.Now() + earlier := now.Add(-1 * time.Hour) + muchEarlier := now.Add(-2 * time.Hour) + + t.Run("full layered selection", func(t *testing.T) { + // 模拟真实场景:多个账号,不同优先级、负载率、最后使用时间 + accounts := []accountWithLoad{ + // 优先级 1,负载 50% + {account: &Account{ID: 1, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 50}}, + // 优先级 1,负载 20%(最低) + {account: &Account{ID: 2, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + // 优先级 1,负载 20%(最低),更早使用 + {account: &Account{ID: 3, Priority: 1, LastUsedAt: &muchEarlier}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + // 优先级 2(较低优先) + {account: &Account{ID: 4, Priority: 2, LastUsedAt: &muchEarlier}, loadInfo: &AccountLoadInfo{LoadRate: 0}}, + } + + // 1. 取优先级最小的集合 → ID: 1, 2, 3 + step1 := filterByMinPriority(accounts) + require.Len(t, step1, 3) + + // 2. 取负载率最低的集合 → ID: 2, 3 + step2 := filterByMinLoadRate(step1) + require.Len(t, step2, 2) + + // 3. LRU 选择 → ID: 3(muchEarlier 最早) + selected := selectByLRU(step2, false) + require.NotNil(t, selected) + require.Equal(t, int64(3), selected.account.ID) + }) + + t.Run("all same priority and load rate", func(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 50}}, + {account: &Account{ID: 2, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 50}}, + {account: &Account{ID: 3, Priority: 1, LastUsedAt: &muchEarlier}, loadInfo: &AccountLoadInfo{LoadRate: 50}}, + } + + step1 := filterByMinPriority(accounts) + require.Len(t, step1, 3) + + step2 := filterByMinLoadRate(step1) + require.Len(t, step2, 3) + + // LRU 选择最早的 + selected := selectByLRU(step2, false) + require.NotNil(t, selected) + require.Equal(t, int64(3), selected.account.ID) + }) +} diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index b3714ed1..52d455b8 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -151,6 +151,14 @@ func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int return s.accountRepo.GetByID(fallbackCtx, accountID) } +// UpdateAccountInCache 立即更新 Redis 中单个账号的数据(用于模型限流后立即生效) +func (s *SchedulerSnapshotService) UpdateAccountInCache(ctx context.Context, account *Account) error { + if s.cache == nil || account == nil { + return nil + } + return s.cache.SetAccount(ctx, account) +} + func (s *SchedulerSnapshotService) runInitialRebuild() { if s.cache == nil { return diff --git a/backend/internal/service/sticky_session_test.go b/backend/internal/service/sticky_session_test.go index 4bd06b7b..c70f12fe 100644 --- a/backend/internal/service/sticky_session_test.go +++ b/backend/internal/service/sticky_session_test.go @@ -23,32 +23,90 @@ import ( // - 临时不可调度且未过期:清理 // - 临时不可调度已过期:不清理 // - 正常可调度状态:不清理 +// - 模型限流超过阈值:清理 +// - 模型限流未超过阈值:不清理 // // TestShouldClearStickySession tests the sticky session clearing logic. // Verifies correct behavior for various account states including: -// nil account, error/disabled status, unschedulable, temporary unschedulable. +// nil account, error/disabled status, unschedulable, temporary unschedulable, +// and model rate limiting scenarios. func TestShouldClearStickySession(t *testing.T) { now := time.Now() future := now.Add(1 * time.Hour) past := now.Add(-1 * time.Hour) + // 短限流时间(低于阈值,不应清除粘性会话) + shortRateLimitReset := now.Add(5 * time.Second).Format(time.RFC3339) + // 长限流时间(超过阈值,应清除粘性会话) + longRateLimitReset := now.Add(30 * time.Second).Format(time.RFC3339) + tests := []struct { - name string - account *Account - want bool + name string + account *Account + requestedModel string + want bool }{ - {name: "nil account", account: nil, want: false}, - {name: "status error", account: &Account{Status: StatusError, Schedulable: true}, want: true}, - {name: "status disabled", account: &Account{Status: StatusDisabled, Schedulable: true}, want: true}, - {name: "schedulable false", account: &Account{Status: StatusActive, Schedulable: false}, want: true}, - {name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, want: true}, - {name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, want: false}, - {name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, want: false}, + {name: "nil account", account: nil, requestedModel: "", want: false}, + {name: "status error", account: &Account{Status: StatusError, Schedulable: true}, requestedModel: "", want: true}, + {name: "status disabled", account: &Account{Status: StatusDisabled, Schedulable: true}, requestedModel: "", want: true}, + {name: "schedulable false", account: &Account{Status: StatusActive, Schedulable: false}, requestedModel: "", want: true}, + {name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, requestedModel: "", want: true}, + {name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, requestedModel: "", want: false}, + {name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, requestedModel: "", want: false}, + // 模型限流测试 + { + name: "model rate limited short duration", + account: &Account{ + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + "claude-sonnet-4": map[string]any{ + "rate_limit_reset_at": shortRateLimitReset, + }, + }, + }, + }, + requestedModel: "claude-sonnet-4", + want: false, // 低于阈值,不清除 + }, + { + name: "model rate limited long duration", + account: &Account{ + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + "claude-sonnet-4": map[string]any{ + "rate_limit_reset_at": longRateLimitReset, + }, + }, + }, + }, + requestedModel: "claude-sonnet-4", + want: true, // 超过阈值,清除 + }, + { + name: "model rate limited different model", + account: &Account{ + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + "claude-sonnet-4": map[string]any{ + "rate_limit_reset_at": longRateLimitReset, + }, + }, + }, + }, + requestedModel: "claude-opus-4", // 请求不同模型 + want: false, // 不同模型不受影响 + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - require.Equal(t, tt.want, shouldClearStickySession(tt.account)) + require.Equal(t, tt.want, shouldClearStickySession(tt.account, tt.requestedModel)) }) } } diff --git a/backend/internal/service/temp_unsched_test.go b/backend/internal/service/temp_unsched_test.go new file mode 100644 index 00000000..d132c2bc --- /dev/null +++ b/backend/internal/service/temp_unsched_test.go @@ -0,0 +1,378 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// ============ 临时限流单元测试 ============ + +// TestMatchTempUnschedKeyword 测试关键词匹配函数 +func TestMatchTempUnschedKeyword(t *testing.T) { + tests := []struct { + name string + body string + keywords []string + want string + }{ + { + name: "match_first", + body: "server is overloaded", + keywords: []string{"overloaded", "capacity"}, + want: "overloaded", + }, + { + name: "match_second", + body: "no capacity available", + keywords: []string{"overloaded", "capacity"}, + want: "capacity", + }, + { + name: "no_match", + body: "internal error", + keywords: []string{"overloaded", "capacity"}, + want: "", + }, + { + name: "empty_body", + body: "", + keywords: []string{"overloaded"}, + want: "", + }, + { + name: "empty_keywords", + body: "server is overloaded", + keywords: []string{}, + want: "", + }, + { + name: "whitespace_keyword", + body: "server is overloaded", + keywords: []string{" ", "overloaded"}, + want: "overloaded", + }, + { + // matchTempUnschedKeyword 期望 body 已经是小写的 + // 所以要测试大小写不敏感匹配,需要传入小写的 body + name: "case_insensitive_body_lowered", + body: "server is overloaded", // body 已经是小写 + keywords: []string{"OVERLOADED"}, // keyword 会被转为小写比较 + want: "OVERLOADED", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := matchTempUnschedKeyword(tt.body, tt.keywords) + require.Equal(t, tt.want, got) + }) + } +} + +// TestAccountIsSchedulable_TempUnschedulable 测试临时限流账号不可调度 +func TestAccountIsSchedulable_TempUnschedulable(t *testing.T) { + future := time.Now().Add(10 * time.Minute) + past := time.Now().Add(-10 * time.Minute) + + tests := []struct { + name string + account *Account + want bool + }{ + { + name: "temp_unschedulable_active", + account: &Account{ + Status: StatusActive, + Schedulable: true, + TempUnschedulableUntil: &future, + }, + want: false, + }, + { + name: "temp_unschedulable_expired", + account: &Account{ + Status: StatusActive, + Schedulable: true, + TempUnschedulableUntil: &past, + }, + want: true, + }, + { + name: "no_temp_unschedulable", + account: &Account{ + Status: StatusActive, + Schedulable: true, + TempUnschedulableUntil: nil, + }, + want: true, + }, + { + name: "temp_unschedulable_with_rate_limit", + account: &Account{ + Status: StatusActive, + Schedulable: true, + TempUnschedulableUntil: &future, + RateLimitResetAt: &past, // 过期的限流不影响 + }, + want: false, // 临时限流生效 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.account.IsSchedulable() + require.Equal(t, tt.want, got) + }) + } +} + +// TestAccount_IsTempUnschedulableEnabled 测试临时限流开关 +func TestAccount_IsTempUnschedulableEnabled(t *testing.T) { + tests := []struct { + name string + account *Account + want bool + }{ + { + name: "enabled", + account: &Account{ + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + }, + }, + want: true, + }, + { + name: "disabled", + account: &Account{ + Credentials: map[string]any{ + "temp_unschedulable_enabled": false, + }, + }, + want: false, + }, + { + name: "not_set", + account: &Account{ + Credentials: map[string]any{}, + }, + want: false, + }, + { + name: "nil_credentials", + account: &Account{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.account.IsTempUnschedulableEnabled() + require.Equal(t, tt.want, got) + }) + } +} + +// TestAccount_GetTempUnschedulableRules 测试获取临时限流规则 +func TestAccount_GetTempUnschedulableRules(t *testing.T) { + tests := []struct { + name string + account *Account + wantCount int + }{ + { + name: "has_rules", + account: &Account{ + Credentials: map[string]any{ + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(5), + }, + map[string]any{ + "error_code": float64(500), + "keywords": []any{"internal"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + wantCount: 2, + }, + { + name: "empty_rules", + account: &Account{ + Credentials: map[string]any{ + "temp_unschedulable_rules": []any{}, + }, + }, + wantCount: 0, + }, + { + name: "no_rules", + account: &Account{ + Credentials: map[string]any{}, + }, + wantCount: 0, + }, + { + name: "nil_credentials", + account: &Account{}, + wantCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rules := tt.account.GetTempUnschedulableRules() + require.Len(t, rules, tt.wantCount) + }) + } +} + +// TestTempUnschedulableRule_Parse 测试规则解析 +func TestTempUnschedulableRule_Parse(t *testing.T) { + account := &Account{ + Credentials: map[string]any{ + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded", "capacity"}, + "duration_minutes": float64(5), + }, + }, + }, + } + + rules := account.GetTempUnschedulableRules() + require.Len(t, rules, 1) + + rule := rules[0] + require.Equal(t, 503, rule.ErrorCode) + require.Equal(t, []string{"overloaded", "capacity"}, rule.Keywords) + require.Equal(t, 5, rule.DurationMinutes) +} + +// TestTruncateTempUnschedMessage 测试消息截断 +func TestTruncateTempUnschedMessage(t *testing.T) { + tests := []struct { + name string + body []byte + maxBytes int + want string + }{ + { + name: "short_message", + body: []byte("short"), + maxBytes: 100, + want: "short", + }, + { + // 截断后会 TrimSpace,所以末尾的空格会被移除 + name: "truncate_long_message", + body: []byte("this is a very long message that needs to be truncated"), + maxBytes: 20, + want: "this is a very long", // 截断后 TrimSpace + }, + { + name: "empty_body", + body: []byte{}, + maxBytes: 100, + want: "", + }, + { + name: "zero_max_bytes", + body: []byte("test"), + maxBytes: 0, + want: "", + }, + { + name: "whitespace_trimmed", + body: []byte(" test "), + maxBytes: 100, + want: "test", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := truncateTempUnschedMessage(tt.body, tt.maxBytes) + require.Equal(t, tt.want, got) + }) + } +} + +// TestTempUnschedState 测试临时限流状态结构 +func TestTempUnschedState(t *testing.T) { + now := time.Now() + until := now.Add(5 * time.Minute) + + state := &TempUnschedState{ + UntilUnix: until.Unix(), + TriggeredAtUnix: now.Unix(), + StatusCode: 503, + MatchedKeyword: "overloaded", + RuleIndex: 0, + ErrorMessage: "Server is overloaded", + } + + require.Equal(t, 503, state.StatusCode) + require.Equal(t, "overloaded", state.MatchedKeyword) + require.Equal(t, 0, state.RuleIndex) + + // 验证时间戳 + require.Equal(t, until.Unix(), state.UntilUnix) + require.Equal(t, now.Unix(), state.TriggeredAtUnix) +} + +// TestAccount_TempUnschedulableUntil 测试临时限流时间字段 +func TestAccount_TempUnschedulableUntil(t *testing.T) { + future := time.Now().Add(10 * time.Minute) + past := time.Now().Add(-10 * time.Minute) + + tests := []struct { + name string + account *Account + schedulable bool + }{ + { + name: "active_temp_unsched_not_schedulable", + account: &Account{ + Status: StatusActive, + Schedulable: true, + TempUnschedulableUntil: &future, + }, + schedulable: false, + }, + { + name: "expired_temp_unsched_is_schedulable", + account: &Account{ + Status: StatusActive, + Schedulable: true, + TempUnschedulableUntil: &past, + }, + schedulable: true, + }, + { + name: "nil_temp_unsched_is_schedulable", + account: &Account{ + Status: StatusActive, + Schedulable: true, + TempUnschedulableUntil: nil, + }, + schedulable: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.account.IsSchedulable() + require.Equal(t, tt.schedulable, got) + }) + } +} diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index aa0a5b87..5594e53f 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -288,6 +288,15 @@ func (s *UsageService) GetUserDashboardStats(ctx context.Context, userID int64) return stats, nil } +// GetAPIKeyDashboardStats returns dashboard summary stats filtered by API Key. +func (s *UsageService) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) { + stats, err := s.usageRepo.GetAPIKeyDashboardStats(ctx, apiKeyID) + if err != nil { + return nil, fmt.Errorf("get api key dashboard stats: %w", err) + } + return stats, nil +} + // GetUserUsageTrendByUserID returns per-user usage trend. func (s *UsageService) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) { trend, err := s.usageRepo.GetUserUsageTrendByUserID(ctx, userID, startTime, endTime, granularity) diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index 0f589eb3..e56d83bf 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -21,6 +21,10 @@ type User struct { CreatedAt time.Time UpdatedAt time.Time + // GroupRates 用户专属分组倍率配置 + // map[groupID]rateMultiplier + GroupRates map[int64]float64 + // TOTP 双因素认证字段 TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥 TotpEnabled bool // 是否启用 TOTP @@ -40,18 +44,20 @@ func (u *User) IsActive() bool { // CanBindGroup checks whether a user can bind to a given group. // For standard groups: -// - If AllowedGroups is non-empty, only allow binding to IDs in that list. -// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group. +// - Public groups (non-exclusive): all users can bind +// - Exclusive groups: only users with the group in AllowedGroups can bind func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool { - if len(u.AllowedGroups) > 0 { - for _, id := range u.AllowedGroups { - if id == groupID { - return true - } - } - return false + // 公开分组(非专属):所有用户都可以绑定 + if !isExclusive { + return true } - return !isExclusive + // 专属分组:需要在 AllowedGroups 中 + for _, id := range u.AllowedGroups { + if id == groupID { + return true + } + } + return false } func (u *User) SetPassword(password string) error { diff --git a/backend/internal/service/user_group_rate.go b/backend/internal/service/user_group_rate.go new file mode 100644 index 00000000..9eb5f067 --- /dev/null +++ b/backend/internal/service/user_group_rate.go @@ -0,0 +1,25 @@ +package service + +import "context" + +// UserGroupRateRepository 用户专属分组倍率仓储接口 +// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率 +type UserGroupRateRepository interface { + // GetByUserID 获取用户的所有专属分组倍率 + // 返回 map[groupID]rateMultiplier + GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) + + // GetByUserAndGroup 获取用户在特定分组的专属倍率 + // 如果未设置专属倍率,返回 nil + GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) + + // SyncUserGroupRates 同步用户的分组专属倍率 + // rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率 + SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error + + // DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用) + DeleteByGroupID(ctx context.Context, groupID int64) error + + // DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用) + DeleteByUserID(ctx context.Context, userID int64) error +} diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 99bf7fd0..1bfb392e 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -39,7 +39,7 @@ type UserRepository interface { ExistsByEmail(ctx context.Context, email string) (bool, error) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) - // TOTP 相关方法 + // TOTP 双因素认证 UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error EnableTotp(ctx context.Context, userID int64) error DisableTotp(ctx context.Context, userID int64) error diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 4b721bb6..05371022 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -274,4 +274,5 @@ var ProviderSet = wire.NewSet( NewUserAttributeService, NewUsageCache, NewTotpService, + NewErrorPassthroughService, ) diff --git a/backend/migrations/042b_add_ops_system_metrics_switch_count.sql b/backend/migrations/042b_add_ops_system_metrics_switch_count.sql new file mode 100644 index 00000000..6d9f48e5 --- /dev/null +++ b/backend/migrations/042b_add_ops_system_metrics_switch_count.sql @@ -0,0 +1,3 @@ +-- ops_system_metrics 增加账号切换次数统计(按分钟窗口) +ALTER TABLE ops_system_metrics + ADD COLUMN IF NOT EXISTS account_switch_count BIGINT NOT NULL DEFAULT 0; diff --git a/backend/migrations/043b_add_group_invalid_request_fallback.sql b/backend/migrations/043b_add_group_invalid_request_fallback.sql new file mode 100644 index 00000000..1c792704 --- /dev/null +++ b/backend/migrations/043b_add_group_invalid_request_fallback.sql @@ -0,0 +1,13 @@ +-- 043_add_group_invalid_request_fallback.sql +-- 添加无效请求兜底分组配置 + +-- 添加 fallback_group_id_on_invalid_request 字段:无效请求兜底使用的分组 +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS fallback_group_id_on_invalid_request BIGINT REFERENCES groups(id) ON DELETE SET NULL; + +-- 添加索引优化查询 +CREATE INDEX IF NOT EXISTS idx_groups_fallback_group_id_on_invalid_request +ON groups(fallback_group_id_on_invalid_request) WHERE deleted_at IS NULL AND fallback_group_id_on_invalid_request IS NOT NULL; + +-- 添加字段注释 +COMMENT ON COLUMN groups.fallback_group_id_on_invalid_request IS '无效请求兜底使用的分组 ID'; diff --git a/backend/migrations/044b_add_group_mcp_xml_inject.sql b/backend/migrations/044b_add_group_mcp_xml_inject.sql new file mode 100644 index 00000000..7db71dd8 --- /dev/null +++ b/backend/migrations/044b_add_group_mcp_xml_inject.sql @@ -0,0 +1,2 @@ +-- Add mcp_xml_inject field to groups table (for antigravity platform) +ALTER TABLE groups ADD COLUMN mcp_xml_inject BOOLEAN NOT NULL DEFAULT true; diff --git a/backend/migrations/045_add_api_key_quota.sql b/backend/migrations/045_add_api_key_quota.sql new file mode 100644 index 00000000..b3c42d2c --- /dev/null +++ b/backend/migrations/045_add_api_key_quota.sql @@ -0,0 +1,20 @@ +-- Migration: Add quota fields to api_keys table +-- This migration adds independent quota and expiration support for API keys + +-- Add quota limit field (0 = unlimited) +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS quota DECIMAL(20, 8) NOT NULL DEFAULT 0; + +-- Add used quota amount field +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS quota_used DECIMAL(20, 8) NOT NULL DEFAULT 0; + +-- Add expiration time field (NULL = never expires) +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS expires_at TIMESTAMPTZ; + +-- Add indexes for efficient quota queries +CREATE INDEX IF NOT EXISTS idx_api_keys_quota_quota_used ON api_keys(quota, quota_used) WHERE deleted_at IS NULL; +CREATE INDEX IF NOT EXISTS idx_api_keys_expires_at ON api_keys(expires_at) WHERE deleted_at IS NULL; + +-- Comment on columns for documentation +COMMENT ON COLUMN api_keys.quota IS 'Quota limit in USD for this API key (0 = unlimited)'; +COMMENT ON COLUMN api_keys.quota_used IS 'Used quota amount in USD'; +COMMENT ON COLUMN api_keys.expires_at IS 'Expiration time for this API key (null = never expires)'; diff --git a/backend/migrations/046b_add_group_supported_model_scopes.sql b/backend/migrations/046b_add_group_supported_model_scopes.sql new file mode 100644 index 00000000..0b2b3968 --- /dev/null +++ b/backend/migrations/046b_add_group_supported_model_scopes.sql @@ -0,0 +1,6 @@ +-- 添加分组支持的模型系列字段 +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS supported_model_scopes JSONB NOT NULL +DEFAULT '["claude", "gemini_text", "gemini_image"]'::jsonb; + +COMMENT ON COLUMN groups.supported_model_scopes IS '支持的模型系列:claude, gemini_text, gemini_image'; diff --git a/backend/migrations/047_add_user_group_rate_multipliers.sql b/backend/migrations/047_add_user_group_rate_multipliers.sql new file mode 100644 index 00000000..a37d5bcd --- /dev/null +++ b/backend/migrations/047_add_user_group_rate_multipliers.sql @@ -0,0 +1,19 @@ +-- 用户专属分组倍率表 +-- 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率 +CREATE TABLE IF NOT EXISTS user_group_rate_multipliers ( + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE, + rate_multiplier DECIMAL(10,4) NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (user_id, group_id) +); + +-- 按 group_id 查询索引(删除分组时清理关联记录) +CREATE INDEX IF NOT EXISTS idx_user_group_rate_multipliers_group_id + ON user_group_rate_multipliers(group_id); + +COMMENT ON TABLE user_group_rate_multipliers IS '用户专属分组倍率配置'; +COMMENT ON COLUMN user_group_rate_multipliers.user_id IS '用户ID'; +COMMENT ON COLUMN user_group_rate_multipliers.group_id IS '分组ID'; +COMMENT ON COLUMN user_group_rate_multipliers.rate_multiplier IS '专属计费倍率(覆盖分组默认倍率)'; diff --git a/backend/migrations/048_add_error_passthrough_rules.sql b/backend/migrations/048_add_error_passthrough_rules.sql new file mode 100644 index 00000000..bf2a9117 --- /dev/null +++ b/backend/migrations/048_add_error_passthrough_rules.sql @@ -0,0 +1,24 @@ +-- Error Passthrough Rules table +-- Allows administrators to configure how upstream errors are passed through to clients + +CREATE TABLE IF NOT EXISTS error_passthrough_rules ( + id BIGSERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + enabled BOOLEAN NOT NULL DEFAULT true, + priority INTEGER NOT NULL DEFAULT 0, + error_codes JSONB DEFAULT '[]', + keywords JSONB DEFAULT '[]', + match_mode VARCHAR(10) NOT NULL DEFAULT 'any', + platforms JSONB DEFAULT '[]', + passthrough_code BOOLEAN NOT NULL DEFAULT true, + response_code INTEGER, + passthrough_body BOOLEAN NOT NULL DEFAULT true, + custom_message TEXT, + description TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- Indexes for efficient queries +CREATE INDEX IF NOT EXISTS idx_error_passthrough_rules_enabled ON error_passthrough_rules (enabled); +CREATE INDEX IF NOT EXISTS idx_error_passthrough_rules_priority ON error_passthrough_rules (priority); diff --git a/backend/migrations/049_unify_antigravity_model_mapping.sql b/backend/migrations/049_unify_antigravity_model_mapping.sql new file mode 100644 index 00000000..a1e2bb99 --- /dev/null +++ b/backend/migrations/049_unify_antigravity_model_mapping.sql @@ -0,0 +1,36 @@ +-- Force set default Antigravity model_mapping. +-- +-- Notes: +-- - Applies to both Antigravity OAuth and Upstream accounts. +-- - Overwrites existing credentials.model_mapping. +-- - Removes legacy credentials.model_whitelist. + +UPDATE accounts +SET credentials = (COALESCE(credentials, '{}'::jsonb) - 'model_whitelist' - 'model_mapping') || '{ + "model_mapping": { + "claude-opus-4-6": "claude-opus-4-6", + "claude-opus-4-5-thinking": "claude-opus-4-5-thinking", + "claude-opus-4-5-20251101": "claude-opus-4-5-thinking", + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", + "claude-haiku-4-5": "claude-sonnet-4-5", + "claude-haiku-4-5-20251001": "claude-sonnet-4-5", + "gemini-2.5-flash": "gemini-2.5-flash", + "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", + "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", + "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-3-flash": "gemini-3-flash", + "gemini-3-flash-preview": "gemini-3-flash", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", + "gemini-3-pro-image": "gemini-3-pro-image", + "gemini-3-pro-preview": "gemini-3-pro-high", + "gemini-3-pro-image-preview": "gemini-3-pro-image", + "gpt-oss-120b-medium": "gpt-oss-120b-medium", + "tab_flash_lite_preview": "tab_flash_lite_preview" + } +}'::jsonb +WHERE platform = 'antigravity' + AND deleted_at IS NULL; + diff --git a/backend/migrations/050_map_opus46_to_opus45.sql b/backend/migrations/050_map_opus46_to_opus45.sql new file mode 100644 index 00000000..db8bf8fc --- /dev/null +++ b/backend/migrations/050_map_opus46_to_opus45.sql @@ -0,0 +1,17 @@ +-- Map claude-opus-4-6 to claude-opus-4-5-thinking +-- +-- Notes: +-- - Updates existing Antigravity accounts' model_mapping +-- - Changes claude-opus-4-6 target from claude-opus-4-6 to claude-opus-4-5-thinking +-- - This is needed because previous versions didn't have this mapping + +UPDATE accounts +SET credentials = jsonb_set( + credentials, + '{model_mapping,claude-opus-4-6}', + '"claude-opus-4-5-thinking"'::jsonb +) +WHERE platform = 'antigravity' + AND deleted_at IS NULL + AND credentials->'model_mapping' IS NOT NULL + AND credentials->'model_mapping'->>'claude-opus-4-6' IS NOT NULL; diff --git a/backend/migrations/051_migrate_opus45_to_opus46_thinking.sql b/backend/migrations/051_migrate_opus45_to_opus46_thinking.sql new file mode 100644 index 00000000..6cabc176 --- /dev/null +++ b/backend/migrations/051_migrate_opus45_to_opus46_thinking.sql @@ -0,0 +1,41 @@ +-- Migrate all Opus 4.5 models to Opus 4.6-thinking +-- +-- Background: +-- Antigravity now supports claude-opus-4-6-thinking and no longer supports opus-4-5 +-- +-- Strategy: +-- Directly overwrite the entire model_mapping with updated mappings +-- This ensures consistency with DefaultAntigravityModelMapping in constants.go + +UPDATE accounts +SET credentials = jsonb_set( + credentials, + '{model_mapping}', + '{ + "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-6": "claude-opus-4-6-thinking", + "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", + "claude-haiku-4-5": "claude-sonnet-4-5", + "claude-haiku-4-5-20251001": "claude-sonnet-4-5", + "gemini-2.5-flash": "gemini-2.5-flash", + "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", + "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", + "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-3-flash": "gemini-3-flash", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", + "gemini-3-pro-image": "gemini-3-pro-image", + "gemini-3-flash-preview": "gemini-3-flash", + "gemini-3-pro-preview": "gemini-3-pro-high", + "gemini-3-pro-image-preview": "gemini-3-pro-image", + "gpt-oss-120b-medium": "gpt-oss-120b-medium", + "tab_flash_lite_preview": "tab_flash_lite_preview" + }'::jsonb +) +WHERE platform = 'antigravity' + AND deleted_at IS NULL + AND credentials->'model_mapping' IS NOT NULL; diff --git a/backend/resources/model-pricing/model_prices_and_context_window.json b/backend/resources/model-pricing/model_prices_and_context_window.json index ad2861df..c5aa8870 100644 --- a/backend/resources/model-pricing/model_prices_and_context_window.json +++ b/backend/resources/model-pricing/model_prices_and_context_window.json @@ -1605,7 +1605,7 @@ "cache_read_input_token_cost": 1.4e-07, "input_cost_per_token": 1.38e-06, "litellm_provider": "azure", - "max_input_tokens": 272000, + "max_input_tokens": 400000, "max_output_tokens": 128000, "max_tokens": 128000, "mode": "responses", @@ -16951,6 +16951,209 @@ "supports_tool_choice": false, "supports_vision": true }, + "gpt-5.3": { + "cache_read_input_token_cost": 1.75e-07, + "cache_read_input_token_cost_priority": 3.5e-07, + "input_cost_per_token": 1.75e-06, + "input_cost_per_token_priority": 3.5e-06, + "litellm_provider": "openai", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 1.4e-05, + "output_cost_per_token_priority": 2.8e-05, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text", + "image" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_service_tier": true, + "supports_vision": true + }, + "gpt-5.3-2025-12-11": { + "cache_read_input_token_cost": 1.75e-07, + "cache_read_input_token_cost_priority": 3.5e-07, + "input_cost_per_token": 1.75e-06, + "input_cost_per_token_priority": 3.5e-06, + "litellm_provider": "openai", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 1.4e-05, + "output_cost_per_token_priority": 2.8e-05, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text", + "image" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_service_tier": true, + "supports_vision": true + }, + "gpt-5.3-chat-latest": { + "cache_read_input_token_cost": 1.75e-07, + "cache_read_input_token_cost_priority": 3.5e-07, + "input_cost_per_token": 1.75e-06, + "input_cost_per_token_priority": 3.5e-06, + "litellm_provider": "openai", + "max_input_tokens": 128000, + "max_output_tokens": 16384, + "max_tokens": 16384, + "mode": "chat", + "output_cost_per_token": 1.4e-05, + "output_cost_per_token_priority": 2.8e-05, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true + }, + "gpt-5.3-pro": { + "input_cost_per_token": 2.1e-05, + "litellm_provider": "openai", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "responses", + "output_cost_per_token": 1.68e-04, + "supported_endpoints": [ + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true, + "supports_web_search": true + }, + "gpt-5.3-pro-2025-12-11": { + "input_cost_per_token": 2.1e-05, + "litellm_provider": "openai", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "responses", + "output_cost_per_token": 1.68e-04, + "supported_endpoints": [ + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true, + "supports_web_search": true + }, + "gpt-5.3-codex": { + "cache_read_input_token_cost": 1.75e-07, + "cache_read_input_token_cost_priority": 3.5e-07, + "input_cost_per_token": 1.75e-06, + "input_cost_per_token_priority": 3.5e-06, + "litellm_provider": "openai", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "responses", + "output_cost_per_token": 1.4e-05, + "output_cost_per_token_priority": 2.8e-05, + "supported_endpoints": [ + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": false, + "supports_tool_choice": true, + "supports_vision": true + }, "gpt-5.2": { "cache_read_input_token_cost": 1.75e-07, "cache_read_input_token_cost_priority": 3.5e-07, @@ -16988,6 +17191,39 @@ "supports_service_tier": true, "supports_vision": true }, + "gpt-5.2-codex": { + "cache_read_input_token_cost": 1.75e-07, + "cache_read_input_token_cost_priority": 3.5e-07, + "input_cost_per_token": 1.75e-06, + "input_cost_per_token_priority": 3.5e-06, + "litellm_provider": "openai", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "responses", + "output_cost_per_token": 1.4e-05, + "output_cost_per_token_priority": 2.8e-05, + "supported_endpoints": [ + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": false, + "supports_tool_choice": true, + "supports_vision": true + }, "gpt-5.2-2025-12-11": { "cache_read_input_token_cost": 1.75e-07, "cache_read_input_token_cost_priority": 3.5e-07, diff --git a/backend/tools.go b/backend/tools.go deleted file mode 100644 index f06d2c78..00000000 --- a/backend/tools.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build tools -// +build tools - -package tools - -import ( - _ "entgo.io/ent/cmd/ent" - _ "github.com/google/wire/cmd/wire" -) diff --git a/config.yaml b/config.yaml deleted file mode 100644 index 19f77221..00000000 --- a/config.yaml +++ /dev/null @@ -1,530 +0,0 @@ -# Sub2API Configuration File -# Sub2API 配置文件 -# -# Copy this file to /etc/sub2api/config.yaml and modify as needed -# 复制此文件到 /etc/sub2api/config.yaml 并根据需要修改 -# -# Documentation / 文档: https://github.com/Wei-Shaw/sub2api - -# ============================================================================= -# Server Configuration -# 服务器配置 -# ============================================================================= -server: - # Bind address (0.0.0.0 for all interfaces) - # 绑定地址(0.0.0.0 表示监听所有网络接口) - host: "0.0.0.0" - # Port to listen on - # 监听端口 - port: 8080 - # Mode: "debug" for development, "release" for production - # 运行模式:"debug" 用于开发,"release" 用于生产环境 - mode: "release" - # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. - # 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。 - trusted_proxies: [] - -# ============================================================================= -# Run Mode Configuration -# 运行模式配置 -# ============================================================================= -# Run mode: "standard" (default) or "simple" (for internal use) -# 运行模式:"standard"(默认)或 "simple"(内部使用) -# - standard: Full SaaS features with billing/balance checks -# - standard: 完整 SaaS 功能,包含计费和余额校验 -# - simple: Hides SaaS features and skips billing/balance checks -# - simple: 隐藏 SaaS 功能,跳过计费和余额校验 -run_mode: "standard" - -# ============================================================================= -# CORS Configuration -# 跨域资源共享 (CORS) 配置 -# ============================================================================= -cors: - # Allowed origins list. Leave empty to disable cross-origin requests. - # 允许的来源列表。留空则禁用跨域请求。 - allowed_origins: [] - # Allow credentials (cookies/authorization headers). Cannot be used with "*". - # 允许携带凭证(cookies/授权头)。不能与 "*" 通配符同时使用。 - allow_credentials: true - -# ============================================================================= -# Security Configuration -# 安全配置 -# ============================================================================= -security: - url_allowlist: - # Enable URL allowlist validation (disable to skip all URL checks) - # 启用 URL 白名单验证(禁用则跳过所有 URL 检查) - enabled: false - # Allowed upstream hosts for API proxying - # 允许代理的上游 API 主机列表 - upstream_hosts: - - "api.openai.com" - - "api.anthropic.com" - - "api.kimi.com" - - "open.bigmodel.cn" - - "api.minimaxi.com" - - "generativelanguage.googleapis.com" - - "cloudcode-pa.googleapis.com" - - "*.openai.azure.com" - # Allowed hosts for pricing data download - # 允许下载定价数据的主机列表 - pricing_hosts: - - "raw.githubusercontent.com" - # Allowed hosts for CRS sync (required when using CRS sync) - # 允许 CRS 同步的主机列表(使用 CRS 同步功能时必须配置) - crs_hosts: [] - # Allow localhost/private IPs for upstream/pricing/CRS (use only in trusted networks) - # 允许本地/私有 IP 地址用于上游/定价/CRS(仅在可信网络中使用) - allow_private_hosts: true - # Allow http:// URLs when allowlist is disabled (default: false, require https) - # 白名单禁用时是否允许 http:// URL(默认: false,要求 https) - allow_insecure_http: true - response_headers: - # Enable configurable response header filtering (disable to use default allowlist) - # 启用可配置的响应头过滤(禁用则使用默认白名单) - enabled: false - # Extra allowed response headers from upstream - # 额外允许的上游响应头 - additional_allowed: [] - # Force-remove response headers from upstream - # 强制移除的上游响应头 - force_remove: [] - csp: - # Enable Content-Security-Policy header - # 启用内容安全策略 (CSP) 响应头 - enabled: true - # Default CSP policy (override if you host assets on other domains) - # 默认 CSP 策略(如果静态资源托管在其他域名,请自行覆盖) - policy: "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" - proxy_probe: - # Allow skipping TLS verification for proxy probe (debug only) - # 允许代理探测时跳过 TLS 证书验证(仅用于调试) - insecure_skip_verify: false - -# ============================================================================= -# Gateway Configuration -# 网关配置 -# ============================================================================= -gateway: - # Timeout for waiting upstream response headers (seconds) - # 等待上游响应头超时时间(秒) - response_header_timeout: 600 - # Max request body size in bytes (default: 100MB) - # 请求体最大字节数(默认 100MB) - max_body_size: 104857600 - # Connection pool isolation strategy: - # 连接池隔离策略: - # - proxy: Isolate by proxy, same proxy shares connection pool (suitable for few proxies, many accounts) - # - proxy: 按代理隔离,同一代理共享连接池(适合代理少、账户多) - # - account: Isolate by account, same account shares connection pool (suitable for few accounts, strict isolation) - # - account: 按账户隔离,同一账户共享连接池(适合账户少、需严格隔离) - # - account_proxy: Isolate by account+proxy combination (default, finest granularity) - # - account_proxy: 按账户+代理组合隔离(默认,最细粒度) - connection_pool_isolation: "account_proxy" - # HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults) - # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值) - # Max idle connections across all hosts - # 所有主机的最大空闲连接数 - max_idle_conns: 240 - # Max idle connections per host - # 每个主机的最大空闲连接数 - max_idle_conns_per_host: 120 - # Max connections per host - # 每个主机的最大连接数 - max_conns_per_host: 240 - # Idle connection timeout (seconds) - # 空闲连接超时时间(秒) - idle_conn_timeout_seconds: 90 - # Upstream client cache settings - # 上游连接池客户端缓存配置 - # max_upstream_clients: Max cached clients, evicts least recently used when exceeded - # max_upstream_clients: 最大缓存客户端数量,超出后淘汰最久未使用的 - max_upstream_clients: 5000 - # client_idle_ttl_seconds: Client idle reclaim threshold (seconds), reclaimed when idle and no active requests - # client_idle_ttl_seconds: 客户端空闲回收阈值(秒),超时且无活跃请求时回收 - client_idle_ttl_seconds: 900 - # Concurrency slot expiration time (minutes) - # 并发槽位过期时间(分钟) - concurrency_slot_ttl_minutes: 30 - # Stream data interval timeout (seconds), 0=disable - # 流数据间隔超时(秒),0=禁用 - stream_data_interval_timeout: 180 - # Stream keepalive interval (seconds), 0=disable - # 流式 keepalive 间隔(秒),0=禁用 - stream_keepalive_interval: 10 - # SSE max line size in bytes (default: 40MB) - # SSE 单行最大字节数(默认 40MB) - max_line_size: 41943040 - # Log upstream error response body summary (safe/truncated; does not log request content) - # 记录上游错误响应体摘要(安全/截断;不记录请求内容) - log_upstream_error_body: true - # Max bytes to log from upstream error body - # 记录上游错误响应体的最大字节数 - log_upstream_error_body_max_bytes: 2048 - # Auto inject anthropic-beta header for API-key accounts when needed (default: off) - # 需要时自动为 API-key 账户注入 anthropic-beta 头(默认:关闭) - inject_beta_for_apikey: false - # Allow failover on selected 400 errors (default: off) - # 允许在特定 400 错误时进行故障转移(默认:关闭) - failover_on_400: false - -# ============================================================================= -# API Key Auth Cache Configuration -# API Key 认证缓存配置 -# ============================================================================= -api_key_auth_cache: - # L1 cache size (entries), in-process LRU/TTL cache - # L1 缓存容量(条目数),进程内 LRU/TTL 缓存 - l1_size: 65535 - # L1 cache TTL (seconds) - # L1 缓存 TTL(秒) - l1_ttl_seconds: 15 - # L2 cache TTL (seconds), stored in Redis - # L2 缓存 TTL(秒),Redis 中存储 - l2_ttl_seconds: 300 - # Negative cache TTL (seconds) - # 负缓存 TTL(秒) - negative_ttl_seconds: 30 - # TTL jitter percent (0-100) - # TTL 抖动百分比(0-100) - jitter_percent: 10 - # Enable singleflight for cache misses - # 缓存未命中时启用 singleflight 合并回源 - singleflight: true - -# ============================================================================= -# Dashboard Cache Configuration -# 仪表盘缓存配置 -# ============================================================================= -dashboard_cache: - # Enable dashboard cache - # 启用仪表盘缓存 - enabled: true - # Redis key prefix for multi-environment isolation - # Redis key 前缀,用于多环境隔离 - key_prefix: "sub2api:" - # Fresh TTL (seconds); within this window cached stats are considered fresh - # 新鲜阈值(秒);命中后处于该窗口视为新鲜数据 - stats_fresh_ttl_seconds: 15 - # Cache TTL (seconds) stored in Redis - # Redis 缓存 TTL(秒) - stats_ttl_seconds: 30 - # Async refresh timeout (seconds) - # 异步刷新超时(秒) - stats_refresh_timeout_seconds: 30 - -# ============================================================================= -# Dashboard Aggregation Configuration -# 仪表盘预聚合配置(重启生效) -# ============================================================================= -dashboard_aggregation: - # Enable aggregation job - # 启用聚合作业 - enabled: true - # Refresh interval (seconds) - # 刷新间隔(秒) - interval_seconds: 60 - # Lookback window (seconds) for late-arriving data - # 回看窗口(秒),处理迟到数据 - lookback_seconds: 120 - # Allow manual backfill - # 允许手动回填 - backfill_enabled: false - # Backfill max range (days) - # 回填最大跨度(天) - backfill_max_days: 31 - # Recompute recent N days on startup - # 启动时重算最近 N 天 - recompute_days: 2 - # Retention windows (days) - # 保留窗口(天) - retention: - # Raw usage_logs retention - # 原始 usage_logs 保留天数 - usage_logs_days: 90 - # Hourly aggregation retention - # 小时聚合保留天数 - hourly_days: 180 - # Daily aggregation retention - # 日聚合保留天数 - daily_days: 730 - -# ============================================================================= -# Usage Cleanup Task Configuration -# 使用记录清理任务配置(重启生效) -# ============================================================================= -usage_cleanup: - # Enable cleanup task worker - # 启用清理任务执行器 - enabled: true - # Max date range (days) per task - # 单次任务最大时间跨度(天) - max_range_days: 31 - # Batch delete size - # 单批删除数量 - batch_size: 5000 - # Worker interval (seconds) - # 执行器轮询间隔(秒) - worker_interval_seconds: 10 - # Task execution timeout (seconds) - # 单次任务最大执行时长(秒) - task_timeout_seconds: 1800 - -# ============================================================================= -# Concurrency Wait Configuration -# 并发等待配置 -# ============================================================================= -concurrency: - # SSE ping interval during concurrency wait (seconds) - # 并发等待期间的 SSE ping 间隔(秒) - ping_interval: 10 - -# ============================================================================= -# Database Configuration (PostgreSQL) -# 数据库配置 (PostgreSQL) -# ============================================================================= -database: - # Database host address - # 数据库主机地址 - host: "localhost" - # Database port - # 数据库端口 - port: 5432 - # Database username - # 数据库用户名 - user: "postgres" - # Database password - # 数据库密码 - password: "your_secure_password_here" - # Database name - # 数据库名称 - dbname: "sub2api" - # SSL mode: disable, require, verify-ca, verify-full - # SSL 模式:disable(禁用), require(要求), verify-ca(验证CA), verify-full(完全验证) - sslmode: "disable" - -# ============================================================================= -# Redis Configuration -# Redis 配置 -# ============================================================================= -redis: - # Redis host address - # Redis 主机地址 - host: "localhost" - # Redis port - # Redis 端口 - port: 6379 - # Redis password (leave empty if no password is set) - # Redis 密码(如果未设置密码则留空) - password: "" - # Database number (0-15) - # 数据库编号(0-15) - db: 0 - # Enable TLS/SSL connection - # 是否启用 TLS/SSL 连接 - enable_tls: false - -# ============================================================================= -# Ops Monitoring (Optional) -# 运维监控 (可选) -# ============================================================================= -ops: - # Hard switch: disable all ops background jobs and APIs when false - # 硬开关:为 false 时禁用所有 Ops 后台任务与接口 - enabled: true - - # Prefer pre-aggregated tables (ops_metrics_hourly/ops_metrics_daily) for long-window dashboard queries. - # 优先使用预聚合表(用于长时间窗口查询性能) - use_preaggregated_tables: false - - # Data cleanup configuration - # 数据清理配置(vNext 默认统一保留 30 天) - cleanup: - enabled: true - # Cron expression (minute hour dom month dow), e.g. "0 2 * * *" = daily at 2 AM - # Cron 表达式(分 时 日 月 周),例如 "0 2 * * *" = 每天凌晨 2 点 - schedule: "0 2 * * *" - error_log_retention_days: 30 - minute_metrics_retention_days: 30 - hourly_metrics_retention_days: 30 - - # Pre-aggregation configuration - # 预聚合任务配置 - aggregation: - enabled: true - - # OpsMetricsCollector Redis cache (reduces duplicate expensive window aggregation in multi-replica deployments) - # 指标采集 Redis 缓存(多副本部署时减少重复计算) - metrics_collector_cache: - enabled: true - ttl: 65s - -# ============================================================================= -# JWT Configuration -# JWT 配置 -# ============================================================================= -jwt: - # IMPORTANT: Change this to a random string in production! - # 重要:生产环境中请更改为随机字符串! - # Generate with / 生成命令: openssl rand -hex 32 - secret: "change-this-to-a-secure-random-string" - # Token expiration time in hours (max 24) - # 令牌过期时间(小时,最大 24) - expire_hour: 24 - -# ============================================================================= -# Default Settings -# 默认设置 -# ============================================================================= -default: - # Initial admin account (created on first run) - # 初始管理员账户(首次运行时创建) - admin_email: "admin@example.com" - admin_password: "admin123" - - # Default settings for new users - # 新用户默认设置 - # Max concurrent requests per user - # 每用户最大并发请求数 - user_concurrency: 5 - # Initial balance for new users - # 新用户初始余额 - user_balance: 0 - - # API key settings - # API 密钥设置 - # Prefix for generated API keys - # 生成的 API 密钥前缀 - api_key_prefix: "sk-" - - # Rate multiplier (affects billing calculation) - # 费率倍数(影响计费计算) - rate_multiplier: 1.0 - -# ============================================================================= -# Rate Limiting -# 速率限制 -# ============================================================================= -rate_limit: - # Cooldown time (in minutes) when upstream returns 529 (overloaded) - # 上游返回 529(过载)时的冷却时间(分钟) - overload_cooldown_minutes: 10 - -# ============================================================================= -# Pricing Data Source (Optional) -# 定价数据源(可选) -# ============================================================================= -pricing: - # URL to fetch model pricing data (default: LiteLLM) - # 获取模型定价数据的 URL(默认:LiteLLM) - remote_url: "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" - # Hash verification URL (optional) - # 哈希校验 URL(可选) - hash_url: "" - # Local data directory for caching - # 本地数据缓存目录 - data_dir: "./data" - # Fallback pricing file - # 备用定价文件 - fallback_file: "./resources/model-pricing/model_prices_and_context_window.json" - # Update interval in hours - # 更新间隔(小时) - update_interval_hours: 24 - # Hash check interval in minutes - # 哈希检查间隔(分钟) - hash_check_interval_minutes: 10 - -# ============================================================================= -# Billing Configuration -# 计费配置 -# ============================================================================= -billing: - circuit_breaker: - # Enable circuit breaker for billing service - # 启用计费服务熔断器 - enabled: true - # Number of failures before opening circuit - # 触发熔断的失败次数阈值 - failure_threshold: 5 - # Time to wait before attempting reset (seconds) - # 熔断后重试等待时间(秒) - reset_timeout_seconds: 30 - # Number of requests to allow in half-open state - # 半开状态允许通过的请求数 - half_open_requests: 3 - -# ============================================================================= -# Turnstile Configuration -# Turnstile 人机验证配置 -# ============================================================================= -turnstile: - # Require Turnstile in release mode (when enabled, login/register will fail if not configured) - # 在 release 模式下要求 Turnstile 验证(启用后,若未配置则登录/注册会失败) - required: false - -# ============================================================================= -# Gemini OAuth (Required for Gemini accounts) -# Gemini OAuth 配置(Gemini 账户必需) -# ============================================================================= -# Sub2API supports TWO Gemini OAuth modes: -# Sub2API 支持两种 Gemini OAuth 模式: -# -# 1. Code Assist OAuth (requires GCP project_id) -# 1. Code Assist OAuth(需要 GCP project_id) -# - Uses: cloudcode-pa.googleapis.com (Code Assist API) -# - 使用:cloudcode-pa.googleapis.com(Code Assist API) -# -# 2. AI Studio OAuth (no project_id needed) -# 2. AI Studio OAuth(不需要 project_id) -# - Uses: generativelanguage.googleapis.com (AI Studio API) -# - 使用:generativelanguage.googleapis.com(AI Studio API) -# -# Default: Uses Gemini CLI's public OAuth credentials (same as Google's official CLI tool) -# 默认:使用 Gemini CLI 的公开 OAuth 凭证(与 Google 官方 CLI 工具相同) -gemini: - oauth: - # Gemini CLI public OAuth credentials (works for both Code Assist and AI Studio) - # Gemini CLI 公开 OAuth 凭证(适用于 Code Assist 和 AI Studio) - client_id: "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - client_secret: "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" - # Optional scopes (space-separated). Leave empty to auto-select based on oauth_type. - # 可选的权限范围(空格分隔)。留空则根据 oauth_type 自动选择。 - scopes: "" - quota: - # Optional: local quota simulation for Gemini Code Assist (local billing). - # 可选:Gemini Code Assist 本地配额模拟(本地计费)。 - # These values are used for UI progress + precheck scheduling, not official Google quotas. - # 这些值用于 UI 进度显示和预检调度,并非 Google 官方配额。 - tiers: - LEGACY: - # Pro model requests per day - # Pro 模型每日请求数 - pro_rpd: 50 - # Flash model requests per day - # Flash 模型每日请求数 - flash_rpd: 1500 - # Cooldown time (minutes) after hitting quota - # 达到配额后的冷却时间(分钟) - cooldown_minutes: 30 - PRO: - # Pro model requests per day - # Pro 模型每日请求数 - pro_rpd: 1500 - # Flash model requests per day - # Flash 模型每日请求数 - flash_rpd: 4000 - # Cooldown time (minutes) after hitting quota - # 达到配额后的冷却时间(分钟) - cooldown_minutes: 5 - ULTRA: - # Pro model requests per day - # Pro 模型每日请求数 - pro_rpd: 2000 - # Flash model requests per day (0 = unlimited) - # Flash 模型每日请求数(0 = 无限制) - flash_rpd: 0 - # Cooldown time (minutes) after hitting quota - # 达到配额后的冷却时间(分钟) - cooldown_minutes: 5 diff --git a/deploy/.env.example b/deploy/.env.example index 25096c3d..c5e850ae 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -20,6 +20,31 @@ SERVER_PORT=8080 # Server mode: release or debug SERVER_MODE=release +# Global max request body size in bytes (default: 100MB) +# 全局最大请求体大小(字节,默认 100MB) +# Applies to all requests, especially important for h2c first request memory protection +# 适用于所有请求,对 h2c 第一请求的内存保护尤为重要 +SERVER_MAX_REQUEST_BODY_SIZE=104857600 + +# Enable HTTP/2 Cleartext (h2c) for client connections +# 启用 HTTP/2 Cleartext (h2c) 客户端连接 +SERVER_H2C_ENABLED=true +# H2C max concurrent streams (default: 50) +# H2C 最大并发流数量(默认 50) +SERVER_H2C_MAX_CONCURRENT_STREAMS=50 +# H2C idle timeout in seconds (default: 75) +# H2C 空闲超时时间(秒,默认 75) +SERVER_H2C_IDLE_TIMEOUT=75 +# H2C max read frame size in bytes (default: 1048576 = 1MB) +# H2C 最大帧大小(字节,默认 1048576 = 1MB) +SERVER_H2C_MAX_READ_FRAME_SIZE=1048576 +# H2C max upload buffer per connection in bytes (default: 2097152 = 2MB) +# H2C 每个连接的最大上传缓冲区(字节,默认 2097152 = 2MB) +SERVER_H2C_MAX_UPLOAD_BUFFER_PER_CONNECTION=2097152 +# H2C max upload buffer per stream in bytes (default: 524288 = 512KB) +# H2C 每个流的最大上传缓冲区(字节,默认 524288 = 512KB) +SERVER_H2C_MAX_UPLOAD_BUFFER_PER_STREAM=524288 + # 运行模式: standard (默认) 或 simple (内部自用) # standard: 完整 SaaS 功能,包含计费/余额校验;simple: 隐藏 SaaS 功能并跳过计费/余额校验 RUN_MODE=standard diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 6f5e9744..d9f5f2ab 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -23,6 +23,32 @@ server: # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. # 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。 trusted_proxies: [] + # Global max request body size in bytes (default: 100MB) + # 全局最大请求体大小(字节,默认 100MB) + # Applies to all requests, especially important for h2c first request memory protection + # 适用于所有请求,对 h2c 第一请求的内存保护尤为重要 + max_request_body_size: 104857600 + # HTTP/2 Cleartext (h2c) configuration + # HTTP/2 Cleartext (h2c) 配置 + h2c: + # Enable HTTP/2 Cleartext for client connections + # 启用 HTTP/2 Cleartext 客户端连接 + enabled: true + # Max concurrent streams per connection + # 每个连接的最大并发流数量 + max_concurrent_streams: 50 + # Idle timeout for connections (seconds) + # 连接空闲超时时间(秒) + idle_timeout: 75 + # Max frame size in bytes (default: 1MB) + # 最大帧大小(字节,默认 1MB) + max_read_frame_size: 1048576 + # Max upload buffer per connection in bytes (default: 2MB) + # 每个连接的最大上传缓冲区(字节,默认 2MB) + max_upload_buffer_per_connection: 2097152 + # Max upload buffer per stream in bytes (default: 512KB) + # 每个流的最大上传缓冲区(字节,默认 512KB) + max_upload_buffer_per_stream: 524288 # ============================================================================= # Run Mode Configuration diff --git a/docs/rename_local_migrations_20260202.sql b/docs/rename_local_migrations_20260202.sql new file mode 100644 index 00000000..911ed17d --- /dev/null +++ b/docs/rename_local_migrations_20260202.sql @@ -0,0 +1,34 @@ +-- 修正 schema_migrations 中“本地改名”的迁移文件名 +-- 适用场景:你已执行过旧文件名的迁移,合并后仅改了自己这边的文件名 + +BEGIN; + +UPDATE schema_migrations +SET filename = '042b_add_ops_system_metrics_switch_count.sql' +WHERE filename = '042_add_ops_system_metrics_switch_count.sql' + AND NOT EXISTS ( + SELECT 1 FROM schema_migrations WHERE filename = '042b_add_ops_system_metrics_switch_count.sql' + ); + +UPDATE schema_migrations +SET filename = '043b_add_group_invalid_request_fallback.sql' +WHERE filename = '043_add_group_invalid_request_fallback.sql' + AND NOT EXISTS ( + SELECT 1 FROM schema_migrations WHERE filename = '043b_add_group_invalid_request_fallback.sql' + ); + +UPDATE schema_migrations +SET filename = '044b_add_group_mcp_xml_inject.sql' +WHERE filename = '044_add_group_mcp_xml_inject.sql' + AND NOT EXISTS ( + SELECT 1 FROM schema_migrations WHERE filename = '044b_add_group_mcp_xml_inject.sql' + ); + +UPDATE schema_migrations +SET filename = '046b_add_group_supported_model_scopes.sql' +WHERE filename = '046_add_group_supported_model_scopes.sql' + AND NOT EXISTS ( + SELECT 1 FROM schema_migrations WHERE filename = '046b_add_group_supported_model_scopes.sql' + ); + +COMMIT; diff --git a/frontend/src/__tests__/integration/data-import.spec.ts b/frontend/src/__tests__/integration/data-import.spec.ts new file mode 100644 index 00000000..1fe870ab --- /dev/null +++ b/frontend/src/__tests__/integration/data-import.spec.ts @@ -0,0 +1,70 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { mount } from '@vue/test-utils' +import ImportDataModal from '@/components/admin/account/ImportDataModal.vue' + +const showError = vi.fn() +const showSuccess = vi.fn() + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showError, + showSuccess + }) +})) + +vi.mock('@/api/admin', () => ({ + adminAPI: { + accounts: { + importData: vi.fn() + } + } +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => key + }) +})) + +describe('ImportDataModal', () => { + beforeEach(() => { + showError.mockReset() + showSuccess.mockReset() + }) + + it('未选择文件时提示错误', async () => { + const wrapper = mount(ImportDataModal, { + props: { show: true }, + global: { + stubs: { + BaseDialog: { template: '
' } + } + } + }) + + await wrapper.find('form').trigger('submit') + expect(showError).toHaveBeenCalledWith('admin.accounts.dataImportSelectFile') + }) + + it('无效 JSON 时提示解析失败', async () => { + const wrapper = mount(ImportDataModal, { + props: { show: true }, + global: { + stubs: { + BaseDialog: { template: '
' } + } + } + }) + + const input = wrapper.find('input[type="file"]') + const file = new File(['invalid json'], 'data.json', { type: 'application/json' }) + Object.defineProperty(input.element, 'files', { + value: [file] + }) + + await input.trigger('change') + await wrapper.find('form').trigger('submit') + + expect(showError).toHaveBeenCalledWith('admin.accounts.dataImportParseFailed') + }) +}) diff --git a/frontend/src/__tests__/integration/proxy-data-import.spec.ts b/frontend/src/__tests__/integration/proxy-data-import.spec.ts new file mode 100644 index 00000000..f0433898 --- /dev/null +++ b/frontend/src/__tests__/integration/proxy-data-import.spec.ts @@ -0,0 +1,70 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { mount } from '@vue/test-utils' +import ImportDataModal from '@/components/admin/proxy/ImportDataModal.vue' + +const showError = vi.fn() +const showSuccess = vi.fn() + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showError, + showSuccess + }) +})) + +vi.mock('@/api/admin', () => ({ + adminAPI: { + proxies: { + importData: vi.fn() + } + } +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => key + }) +})) + +describe('Proxy ImportDataModal', () => { + beforeEach(() => { + showError.mockReset() + showSuccess.mockReset() + }) + + it('未选择文件时提示错误', async () => { + const wrapper = mount(ImportDataModal, { + props: { show: true }, + global: { + stubs: { + BaseDialog: { template: '
' } + } + } + }) + + await wrapper.find('form').trigger('submit') + expect(showError).toHaveBeenCalledWith('admin.proxies.dataImportSelectFile') + }) + + it('无效 JSON 时提示解析失败', async () => { + const wrapper = mount(ImportDataModal, { + props: { show: true }, + global: { + stubs: { + BaseDialog: { template: '
' } + } + } + }) + + const input = wrapper.find('input[type="file"]') + const file = new File(['invalid json'], 'data.json', { type: 'application/json' }) + Object.defineProperty(input.element, 'files', { + value: [file] + }) + + await input.trigger('change') + await wrapper.find('form').trigger('submit') + + expect(showError).toHaveBeenCalledWith('admin.proxies.dataImportParseFailed') + }) +}) diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 54d0ad94..6df93498 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -13,7 +13,9 @@ import type { WindowStats, ClaudeModel, AccountUsageStatsResponse, - TempUnschedulableStatus + TempUnschedulableStatus, + AdminDataPayload, + AdminDataImportResult } from '@/types' /** @@ -347,6 +349,55 @@ export async function syncFromCrs(params: { return data } +export async function exportData(options?: { + ids?: number[] + filters?: { + platform?: string + type?: string + status?: string + search?: string + } + includeProxies?: boolean +}): Promise { + const params: Record = {} + if (options?.ids && options.ids.length > 0) { + params.ids = options.ids.join(',') + } else if (options?.filters) { + const { platform, type, status, search } = options.filters + if (platform) params.platform = platform + if (type) params.type = type + if (status) params.status = status + if (search) params.search = search + } + if (options?.includeProxies === false) { + params.include_proxies = 'false' + } + const { data } = await apiClient.get('/admin/accounts/data', { params }) + return data +} + +export async function importData(payload: { + data: AdminDataPayload + skip_default_group_bind?: boolean +}): Promise { + const { data } = await apiClient.post('/admin/accounts/data', { + data: payload.data, + skip_default_group_bind: payload.skip_default_group_bind + }) + return data +} + +/** + * Get Antigravity default model mapping from backend + * @returns Default model mapping (from -> to) + */ +export async function getAntigravityDefaultModelMapping(): Promise> { + const { data } = await apiClient.get>( + '/admin/accounts/antigravity/default-model-mapping' + ) + return data +} + export const accountsAPI = { list, getById, @@ -370,7 +421,10 @@ export const accountsAPI = { batchCreate, batchUpdateCredentials, bulkUpdate, - syncFromCrs + syncFromCrs, + exportData, + importData, + getAntigravityDefaultModelMapping } export default accountsAPI diff --git a/frontend/src/api/admin/errorPassthrough.ts b/frontend/src/api/admin/errorPassthrough.ts new file mode 100644 index 00000000..4c545ad5 --- /dev/null +++ b/frontend/src/api/admin/errorPassthrough.ts @@ -0,0 +1,134 @@ +/** + * Admin Error Passthrough Rules API endpoints + * Handles error passthrough rule management for administrators + */ + +import { apiClient } from '../client' + +/** + * Error passthrough rule interface + */ +export interface ErrorPassthroughRule { + id: number + name: string + enabled: boolean + priority: number + error_codes: number[] + keywords: string[] + match_mode: 'any' | 'all' + platforms: string[] + passthrough_code: boolean + response_code: number | null + passthrough_body: boolean + custom_message: string | null + description: string | null + created_at: string + updated_at: string +} + +/** + * Create rule request + */ +export interface CreateRuleRequest { + name: string + enabled?: boolean + priority?: number + error_codes?: number[] + keywords?: string[] + match_mode?: 'any' | 'all' + platforms?: string[] + passthrough_code?: boolean + response_code?: number | null + passthrough_body?: boolean + custom_message?: string | null + description?: string | null +} + +/** + * Update rule request + */ +export interface UpdateRuleRequest { + name?: string + enabled?: boolean + priority?: number + error_codes?: number[] + keywords?: string[] + match_mode?: 'any' | 'all' + platforms?: string[] + passthrough_code?: boolean + response_code?: number | null + passthrough_body?: boolean + custom_message?: string | null + description?: string | null +} + +/** + * List all error passthrough rules + * @returns List of all rules sorted by priority + */ +export async function list(): Promise { + const { data } = await apiClient.get('/admin/error-passthrough-rules') + return data +} + +/** + * Get rule by ID + * @param id - Rule ID + * @returns Rule details + */ +export async function getById(id: number): Promise { + const { data } = await apiClient.get(`/admin/error-passthrough-rules/${id}`) + return data +} + +/** + * Create new rule + * @param ruleData - Rule data + * @returns Created rule + */ +export async function create(ruleData: CreateRuleRequest): Promise { + const { data } = await apiClient.post('/admin/error-passthrough-rules', ruleData) + return data +} + +/** + * Update rule + * @param id - Rule ID + * @param updates - Fields to update + * @returns Updated rule + */ +export async function update(id: number, updates: UpdateRuleRequest): Promise { + const { data } = await apiClient.put(`/admin/error-passthrough-rules/${id}`, updates) + return data +} + +/** + * Delete rule + * @param id - Rule ID + * @returns Success confirmation + */ +export async function deleteRule(id: number): Promise<{ message: string }> { + const { data } = await apiClient.delete<{ message: string }>(`/admin/error-passthrough-rules/${id}`) + return data +} + +/** + * Toggle rule enabled status + * @param id - Rule ID + * @param enabled - New enabled status + * @returns Updated rule + */ +export async function toggleEnabled(id: number, enabled: boolean): Promise { + return update(id, { enabled }) +} + +export const errorPassthroughAPI = { + list, + getById, + create, + update, + delete: deleteRule, + toggleEnabled +} + +export default errorPassthroughAPI diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index a88b02c6..ffb9b179 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -19,6 +19,7 @@ import geminiAPI from './gemini' import antigravityAPI from './antigravity' import userAttributesAPI from './userAttributes' import opsAPI from './ops' +import errorPassthroughAPI from './errorPassthrough' /** * Unified admin API object for convenient access @@ -39,7 +40,8 @@ export const adminAPI = { gemini: geminiAPI, antigravity: antigravityAPI, userAttributes: userAttributesAPI, - ops: opsAPI + ops: opsAPI, + errorPassthrough: errorPassthroughAPI } export { @@ -58,7 +60,12 @@ export { geminiAPI, antigravityAPI, userAttributesAPI, - opsAPI + opsAPI, + errorPassthroughAPI } export default adminAPI + +// Re-export types used by components +export type { BalanceHistoryItem } from './users' +export type { ErrorPassthroughRule, CreateRuleRequest, UpdateRuleRequest } from './errorPassthrough' diff --git a/frontend/src/api/admin/ops.ts b/frontend/src/api/admin/ops.ts index bf2c246c..5b96feda 100644 --- a/frontend/src/api/admin/ops.ts +++ b/frontend/src/api/admin/ops.ts @@ -136,6 +136,7 @@ export interface OpsThroughputTrendPoint { bucket_start: string request_count: number token_consumed: number + switch_count?: number qps: number tps: number } @@ -284,6 +285,7 @@ export interface OpsSystemMetricsSnapshot { goroutine_count?: number | null concurrency_queue_depth?: number | null + account_switch_count?: number | null } export interface OpsJobHeartbeat { @@ -335,6 +337,22 @@ export interface OpsConcurrencyStatsResponse { timestamp?: string } +export interface UserConcurrencyInfo { + user_id: number + user_email: string + username: string + current_in_use: number + max_capacity: number + load_percentage: number + waiting_in_queue: number +} + +export interface OpsUserConcurrencyStatsResponse { + enabled: boolean + user: Record + timestamp?: string +} + export async function getConcurrencyStats(platform?: string, groupId?: number | null): Promise { const params: Record = {} if (platform) { @@ -348,6 +366,11 @@ export async function getConcurrencyStats(platform?: string, groupId?: number | return data } +export async function getUserConcurrencyStats(): Promise { + const { data } = await apiClient.get('/admin/ops/user-concurrency') + return data +} + export interface PlatformAvailability { platform: string total_accounts: number @@ -1169,6 +1192,7 @@ export const opsAPI = { getErrorTrend, getErrorDistribution, getConcurrencyStats, + getUserConcurrencyStats, getAccountAvailabilityStats, getRealtimeTrafficSummary, subscribeQPS, diff --git a/frontend/src/api/admin/proxies.ts b/frontend/src/api/admin/proxies.ts index 1af2ea39..b6aaf595 100644 --- a/frontend/src/api/admin/proxies.ts +++ b/frontend/src/api/admin/proxies.ts @@ -9,7 +9,9 @@ import type { ProxyAccountSummary, CreateProxyRequest, UpdateProxyRequest, - PaginatedResponse + PaginatedResponse, + AdminDataPayload, + AdminDataImportResult } from '@/types' /** @@ -208,6 +210,34 @@ export async function batchDelete(ids: number[]): Promise<{ return data } +export async function exportData(options?: { + ids?: number[] + filters?: { + protocol?: string + status?: 'active' | 'inactive' + search?: string + } +}): Promise { + const params: Record = {} + if (options?.ids && options.ids.length > 0) { + params.ids = options.ids.join(',') + } else if (options?.filters) { + const { protocol, status, search } = options.filters + if (protocol) params.protocol = protocol + if (status) params.status = status + if (search) params.search = search + } + const { data } = await apiClient.get('/admin/proxies/data', { params }) + return data +} + +export async function importData(payload: { + data: AdminDataPayload +}): Promise { + const { data } = await apiClient.post('/admin/proxies/data', payload) + return data +} + export const proxiesAPI = { list, getAll, @@ -221,7 +251,9 @@ export const proxiesAPI = { getStats, getProxyAccounts, batchCreate, - batchDelete + batchDelete, + exportData, + importData } export default proxiesAPI diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts index 734e3ac7..287aef96 100644 --- a/frontend/src/api/admin/users.ts +++ b/frontend/src/api/admin/users.ts @@ -174,6 +174,53 @@ export async function getUserUsageStats( return data } +/** + * Balance history item returned from the API + */ +export interface BalanceHistoryItem { + id: number + code: string + type: string + value: number + status: string + used_by: number | null + used_at: string | null + created_at: string + group_id: number | null + validity_days: number + notes: string + user?: { id: number; email: string } | null + group?: { id: number; name: string } | null +} + +// Balance history response extends pagination with total_recharged summary +export interface BalanceHistoryResponse extends PaginatedResponse { + total_recharged: number +} + +/** + * Get user's balance/concurrency change history + * @param id - User ID + * @param page - Page number + * @param pageSize - Items per page + * @param type - Optional type filter (balance, admin_balance, concurrency, admin_concurrency, subscription) + * @returns Paginated balance history with total_recharged + */ +export async function getUserBalanceHistory( + id: number, + page: number = 1, + pageSize: number = 20, + type?: string +): Promise { + const params: Record = { page, page_size: pageSize } + if (type) params.type = type + const { data } = await apiClient.get( + `/admin/users/${id}/balance-history`, + { params } + ) + return data +} + export const usersAPI = { list, getById, @@ -184,7 +231,8 @@ export const usersAPI = { updateConcurrency, toggleStatus, getUserApiKeys, - getUserUsageStats + getUserUsageStats, + getUserBalanceHistory } export default usersAPI diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 40c9c5a4..e196e234 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -35,6 +35,22 @@ export function setAuthToken(token: string): void { localStorage.setItem('auth_token', token) } +/** + * Store refresh token in localStorage + */ +export function setRefreshToken(token: string): void { + localStorage.setItem('refresh_token', token) +} + +/** + * Store token expiration timestamp in localStorage + * Converts expires_in (seconds) to absolute timestamp (milliseconds) + */ +export function setTokenExpiresAt(expiresIn: number): void { + const expiresAt = Date.now() + expiresIn * 1000 + localStorage.setItem('token_expires_at', String(expiresAt)) +} + /** * Get authentication token from localStorage */ @@ -42,12 +58,29 @@ export function getAuthToken(): string | null { return localStorage.getItem('auth_token') } +/** + * Get refresh token from localStorage + */ +export function getRefreshToken(): string | null { + return localStorage.getItem('refresh_token') +} + +/** + * Get token expiration timestamp from localStorage + */ +export function getTokenExpiresAt(): number | null { + const value = localStorage.getItem('token_expires_at') + return value ? parseInt(value, 10) : null +} + /** * Clear authentication token from localStorage */ export function clearAuthToken(): void { localStorage.removeItem('auth_token') + localStorage.removeItem('refresh_token') localStorage.removeItem('auth_user') + localStorage.removeItem('token_expires_at') } /** @@ -61,6 +94,12 @@ export async function login(credentials: LoginRequest): Promise { // Only store token if 2FA is not required if (!isTotp2FARequired(data)) { setAuthToken(data.access_token) + if (data.refresh_token) { + setRefreshToken(data.refresh_token) + } + if (data.expires_in) { + setTokenExpiresAt(data.expires_in) + } localStorage.setItem('auth_user', JSON.stringify(data.user)) } @@ -77,6 +116,12 @@ export async function login2FA(request: TotpLogin2FARequest): Promise // Store token and user data setAuthToken(data.access_token) + if (data.refresh_token) { + setRefreshToken(data.refresh_token) + } + if (data.expires_in) { + setTokenExpiresAt(data.expires_in) + } localStorage.setItem('auth_user', JSON.stringify(data.user)) return data @@ -108,11 +159,62 @@ export async function getCurrentUser() { /** * User logout * Clears authentication token and user data from localStorage + * Optionally revokes the refresh token on the server */ -export function logout(): void { +export async function logout(): Promise { + const refreshToken = getRefreshToken() + + // Try to revoke the refresh token on the server + if (refreshToken) { + try { + await apiClient.post('/auth/logout', { refresh_token: refreshToken }) + } catch { + // Ignore errors - we still want to clear local state + } + } + clearAuthToken() - // Optionally redirect to login page - // window.location.href = '/login'; +} + +/** + * Refresh token response + */ +export interface RefreshTokenResponse { + access_token: string + refresh_token: string + expires_in: number + token_type: string +} + +/** + * Refresh the access token using the refresh token + * @returns New token pair + */ +export async function refreshToken(): Promise { + const currentRefreshToken = getRefreshToken() + if (!currentRefreshToken) { + throw new Error('No refresh token available') + } + + const { data } = await apiClient.post('/auth/refresh', { + refresh_token: currentRefreshToken + }) + + // Update tokens in localStorage + setAuthToken(data.access_token) + setRefreshToken(data.refresh_token) + setTokenExpiresAt(data.expires_in) + + return data +} + +/** + * Revoke all sessions for the current user + * @returns Response with message + */ +export async function revokeAllSessions(): Promise<{ message: string }> { + const { data } = await apiClient.post<{ message: string }>('/auth/revoke-all-sessions') + return data } /** @@ -242,14 +344,20 @@ export const authAPI = { logout, isAuthenticated, setAuthToken, + setRefreshToken, + setTokenExpiresAt, getAuthToken, + getRefreshToken, + getTokenExpiresAt, clearAuthToken, getPublicSettings, sendVerifyCode, validatePromoCode, validateInvitationCode, forgotPassword, - resetPassword + resetPassword, + refreshToken, + revokeAllSessions } export default authAPI diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 3827498b..22db5a44 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -1,9 +1,9 @@ /** * Axios HTTP Client Configuration - * Base client with interceptors for authentication and error handling + * Base client with interceptors for authentication, token refresh, and error handling */ -import axios, { AxiosInstance, AxiosError, InternalAxiosRequestConfig } from 'axios' +import axios, { AxiosInstance, AxiosError, InternalAxiosRequestConfig, AxiosResponse } from 'axios' import type { ApiResponse } from '@/types' import { getLocale } from '@/i18n' @@ -19,6 +19,28 @@ export const apiClient: AxiosInstance = axios.create({ } }) +// ==================== Token Refresh State ==================== + +// Track if a token refresh is in progress to prevent multiple simultaneous refresh requests +let isRefreshing = false +// Queue of requests waiting for token refresh +let refreshSubscribers: Array<(token: string) => void> = [] + +/** + * Subscribe to token refresh completion + */ +function subscribeTokenRefresh(callback: (token: string) => void): void { + refreshSubscribers.push(callback) +} + +/** + * Notify all subscribers that token has been refreshed + */ +function onTokenRefreshed(token: string): void { + refreshSubscribers.forEach((callback) => callback(token)) + refreshSubscribers = [] +} + // ==================== Request Interceptor ==================== // Get user's timezone @@ -61,7 +83,7 @@ apiClient.interceptors.request.use( // ==================== Response Interceptor ==================== apiClient.interceptors.response.use( - (response) => { + (response: AxiosResponse) => { // Unwrap standard API response format { code, message, data } const apiResponse = response.data as ApiResponse if (apiResponse && typeof apiResponse === 'object' && 'code' in apiResponse) { @@ -79,13 +101,15 @@ apiClient.interceptors.response.use( } return response }, - (error: AxiosError>) => { + async (error: AxiosError>) => { // Request cancellation: keep the original axios cancellation error so callers can ignore it. // Otherwise we'd misclassify it as a generic "network error". if (error.code === 'ERR_CANCELED' || axios.isCancel(error)) { return Promise.reject(error) } + const originalRequest = error.config as InternalAxiosRequestConfig & { _retry?: boolean } + // Handle common errors if (error.response) { const { status, data } = error.response @@ -120,23 +144,116 @@ apiClient.interceptors.response.use( }) } - // 401: Unauthorized - clear token and redirect to login - if (status === 401) { - const hasToken = !!localStorage.getItem('auth_token') - const url = error.config?.url || '' + // 401: Try to refresh the token if we have a refresh token + // This handles TOKEN_EXPIRED, INVALID_TOKEN, TOKEN_REVOKED, etc. + if (status === 401 && !originalRequest._retry) { + const refreshToken = localStorage.getItem('refresh_token') const isAuthEndpoint = url.includes('/auth/login') || url.includes('/auth/register') || url.includes('/auth/refresh') + + // If we have a refresh token and this is not an auth endpoint, try to refresh + if (refreshToken && !isAuthEndpoint) { + if (isRefreshing) { + // Wait for the ongoing refresh to complete + return new Promise((resolve, reject) => { + subscribeTokenRefresh((newToken: string) => { + if (newToken) { + // Mark as retried to prevent infinite loop if retry also returns 401 + originalRequest._retry = true + if (originalRequest.headers) { + originalRequest.headers.Authorization = `Bearer ${newToken}` + } + resolve(apiClient(originalRequest)) + } else { + // Refresh failed, reject with original error + reject({ + status, + code: apiData.code, + message: apiData.message || apiData.detail || error.message + }) + } + }) + }) + } + + originalRequest._retry = true + isRefreshing = true + + try { + // Call refresh endpoint directly to avoid circular dependency + const refreshResponse = await axios.post( + `${API_BASE_URL}/auth/refresh`, + { refresh_token: refreshToken }, + { headers: { 'Content-Type': 'application/json' } } + ) + + const refreshData = refreshResponse.data as ApiResponse<{ + access_token: string + refresh_token: string + expires_in: number + }> + + if (refreshData.code === 0 && refreshData.data) { + const { access_token, refresh_token: newRefreshToken, expires_in } = refreshData.data + + // Update tokens in localStorage (convert expires_in to timestamp) + localStorage.setItem('auth_token', access_token) + localStorage.setItem('refresh_token', newRefreshToken) + localStorage.setItem('token_expires_at', String(Date.now() + expires_in * 1000)) + + // Notify subscribers with new token + onTokenRefreshed(access_token) + + // Retry the original request with new token + if (originalRequest.headers) { + originalRequest.headers.Authorization = `Bearer ${access_token}` + } + + isRefreshing = false + return apiClient(originalRequest) + } + + // Refresh response was not successful, fall through to clear auth + throw new Error('Token refresh failed') + } catch (refreshError) { + // Refresh failed - notify subscribers with empty token + onTokenRefreshed('') + isRefreshing = false + + // Clear tokens and redirect to login + localStorage.removeItem('auth_token') + localStorage.removeItem('refresh_token') + localStorage.removeItem('auth_user') + localStorage.removeItem('token_expires_at') + sessionStorage.setItem('auth_expired', '1') + + if (!window.location.pathname.includes('/login')) { + window.location.href = '/login' + } + + return Promise.reject({ + status: 401, + code: 'TOKEN_REFRESH_FAILED', + message: 'Session expired. Please log in again.' + }) + } + } + + // No refresh token or is auth endpoint - clear auth and redirect + const hasToken = !!localStorage.getItem('auth_token') const headers = error.config?.headers as Record | undefined const authHeader = headers?.Authorization ?? headers?.authorization const sentAuth = typeof authHeader === 'string' ? authHeader.trim() !== '' : Array.isArray(authHeader) - ? authHeader.length > 0 - : !!authHeader + ? authHeader.length > 0 + : !!authHeader localStorage.removeItem('auth_token') + localStorage.removeItem('refresh_token') localStorage.removeItem('auth_user') + localStorage.removeItem('token_expires_at') if ((hasToken || sentAuth) && !isAuthEndpoint) { sessionStorage.setItem('auth_expired', '1') } diff --git a/frontend/src/api/groups.ts b/frontend/src/api/groups.ts index 0f366d51..0963a7a6 100644 --- a/frontend/src/api/groups.ts +++ b/frontend/src/api/groups.ts @@ -18,8 +18,18 @@ export async function getAvailable(): Promise { return data } +/** + * Get current user's custom group rate multipliers + * @returns Map of group_id to custom rate_multiplier + */ +export async function getUserGroupRates(): Promise> { + const { data } = await apiClient.get | null>('/groups/rates') + return data || {} +} + export const userGroupsAPI = { - getAvailable + getAvailable, + getUserGroupRates } export default userGroupsAPI diff --git a/frontend/src/api/keys.ts b/frontend/src/api/keys.ts index cdae1359..c5943789 100644 --- a/frontend/src/api/keys.ts +++ b/frontend/src/api/keys.ts @@ -44,6 +44,8 @@ export async function getById(id: number): Promise { * @param customKey - Optional custom key value * @param ipWhitelist - Optional IP whitelist * @param ipBlacklist - Optional IP blacklist + * @param quota - Optional quota limit in USD (0 = unlimited) + * @param expiresInDays - Optional days until expiry (undefined = never expires) * @returns Created API key */ export async function create( @@ -51,7 +53,9 @@ export async function create( groupId?: number | null, customKey?: string, ipWhitelist?: string[], - ipBlacklist?: string[] + ipBlacklist?: string[], + quota?: number, + expiresInDays?: number ): Promise { const payload: CreateApiKeyRequest = { name } if (groupId !== undefined) { @@ -66,6 +70,12 @@ export async function create( if (ipBlacklist && ipBlacklist.length > 0) { payload.ip_blacklist = ipBlacklist } + if (quota !== undefined && quota > 0) { + payload.quota = quota + } + if (expiresInDays !== undefined && expiresInDays > 0) { + payload.expires_in_days = expiresInDays + } const { data } = await apiClient.post('/keys', payload) return data diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index 8e525fa3..3474da44 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -90,6 +90,26 @@ class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 -translate-x-1/2 whitespace-nowrap rounded bg-gray-900 px-2 py-1 text-xs text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700" > {{ t('admin.accounts.status.scopeRateLimitedUntil', { scope: formatScopeName(item.scope), time: formatTime(item.reset_at) }) }} +
+
+
+ + + +