密码重置请求
+您已请求重置密码。请点击下方按钮设置新密码:
+ 重置密码 +此链接将在 30 分钟后失效。
+如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。
+如果按钮无法点击,请复制以下链接到浏览器中打开:
+%s
+diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml
index 3ea8860a..e5624f86 100644
--- a/.github/workflows/backend-ci.yml
+++ b/.github/workflows/backend-ci.yml
@@ -19,7 +19,7 @@ jobs:
cache: true
- name: Verify Go version
run: |
- go version | grep -q 'go1.25.5'
+ go version | grep -q 'go1.25.6'
- name: Unit tests
working-directory: backend
run: make test-unit
@@ -38,7 +38,7 @@ jobs:
cache: true
- name: Verify Go version
run: |
- go version | grep -q 'go1.25.5'
+ go version | grep -q 'go1.25.6'
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 73ca35d9..f45c1a0b 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -115,7 +115,7 @@ jobs:
- name: Verify Go version
run: |
- go version | grep -q 'go1.25.5'
+ go version | grep -q 'go1.25.6'
# Docker setup for GoReleaser
- name: Set up QEMU
@@ -222,8 +222,9 @@ jobs:
REPO="${{ github.repository }}"
GHCR_IMAGE="ghcr.io/${REPO,,}" # ${,,} converts to lowercase
- # 获取 tag message 内容
+ # 获取 tag message 内容并转义 Markdown 特殊字符
TAG_MESSAGE='${{ steps.tag_message.outputs.message }}'
+ TAG_MESSAGE=$(echo "$TAG_MESSAGE" | sed 's/\([_*`\[]\)/\\\1/g')
# 限制消息长度(Telegram 消息限制 4096 字符,预留空间给头尾固定内容)
if [ ${#TAG_MESSAGE} -gt 3500 ]; then
diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml
index 160a0df9..dfb8e37e 100644
--- a/.github/workflows/security-scan.yml
+++ b/.github/workflows/security-scan.yml
@@ -22,7 +22,7 @@ jobs:
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
- go version | grep -q 'go1.25.5'
+ go version | grep -q 'go1.25.6'
- name: Run govulncheck
working-directory: backend
run: |
diff --git a/Dockerfile b/Dockerfile
index b3320300..3d4b5094 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -7,7 +7,7 @@
# =============================================================================
ARG NODE_IMAGE=node:24-alpine
-ARG GOLANG_IMAGE=golang:1.25.5-alpine
+ARG GOLANG_IMAGE=golang:1.25.6-alpine
ARG ALPINE_IMAGE=alpine:3.20
ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn
diff --git a/README.md b/README.md
index fa965e6f..14656332 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@ English | [中文](README_CN.md)
## Demo
-Try Sub2API online: **https://v2.pincc.ai/**
+Try Sub2API online: **https://demo.sub2api.org/**
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
@@ -128,7 +128,7 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
---
-### Method 2: Docker Compose
+### Method 2: Docker Compose (Recommended)
Deploy with Docker Compose, including PostgreSQL and Redis containers.
@@ -137,87 +137,157 @@ Deploy with Docker Compose, including PostgreSQL and Redis containers.
- Docker 20.10+
- Docker Compose v2+
-#### Installation Steps
+#### Quick Start (One-Click Deployment)
+
+Use the automated deployment script for easy setup:
+
+```bash
+# Create deployment directory
+mkdir -p sub2api-deploy && cd sub2api-deploy
+
+# Download and run deployment preparation script
+curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
+
+# Start services
+docker-compose -f docker-compose.local.yml up -d
+
+# View logs
+docker-compose -f docker-compose.local.yml logs -f sub2api
+```
+
+**What the script does:**
+- Downloads `docker-compose.local.yml` and `.env.example`
+- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
+- Creates `.env` file with auto-generated secrets
+- Creates data directories (uses local directories for easy backup/migration)
+- Displays generated credentials for your reference
+
+#### Manual Deployment
+
+If you prefer manual setup:
```bash
# 1. Clone the repository
git clone https://github.com/Wei-Shaw/sub2api.git
-cd sub2api
+cd sub2api/deploy
-# 2. Enter the deploy directory
-cd deploy
-
-# 3. Copy environment configuration
+# 2. Copy environment configuration
cp .env.example .env
-# 4. Edit configuration (set your passwords)
+# 3. Edit configuration (generate secure passwords)
nano .env
```
**Required configuration in `.env`:**
```bash
-# PostgreSQL password (REQUIRED - change this!)
+# PostgreSQL password (REQUIRED)
POSTGRES_PASSWORD=your_secure_password_here
+# JWT Secret (RECOMMENDED - keeps users logged in after restart)
+JWT_SECRET=your_jwt_secret_here
+
+# TOTP Encryption Key (RECOMMENDED - preserves 2FA after restart)
+TOTP_ENCRYPTION_KEY=your_totp_key_here
+
# Optional: Admin account
ADMIN_EMAIL=admin@example.com
ADMIN_PASSWORD=your_admin_password
# Optional: Custom port
SERVER_PORT=8080
+```
-# Optional: Security configuration
-# Enable URL allowlist validation (false to skip allowlist checks, only basic format validation)
-SECURITY_URL_ALLOWLIST_ENABLED=false
+**Generate secure secrets:**
+```bash
+# Generate JWT_SECRET
+openssl rand -hex 32
-# Allow insecure HTTP URLs when allowlist is disabled (default: false, requires https)
-# ⚠️ WARNING: Enabling this allows HTTP (plaintext) URLs which can expose API keys
-# Only recommended for:
-# - Development/testing environments
-# - Internal networks with trusted endpoints
-# - When using local test servers (http://localhost)
-# PRODUCTION: Keep this false or use HTTPS URLs only
-SECURITY_URL_ALLOWLIST_ALLOW_INSECURE_HTTP=false
+# Generate TOTP_ENCRYPTION_KEY
+openssl rand -hex 32
-# Allow private IP addresses for upstream/pricing/CRS (for internal deployments)
-SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=false
+# Generate POSTGRES_PASSWORD
+openssl rand -hex 32
```
```bash
+# 4. Create data directories (for local version)
+mkdir -p data postgres_data redis_data
+
# 5. Start all services
+# Option A: Local directory version (recommended - easy migration)
+docker-compose -f docker-compose.local.yml up -d
+
+# Option B: Named volumes version (simple setup)
docker-compose up -d
# 6. Check status
-docker-compose ps
+docker-compose -f docker-compose.local.yml ps
# 7. View logs
-docker-compose logs -f sub2api
+docker-compose -f docker-compose.local.yml logs -f sub2api
```
+#### Deployment Versions
+
+| Version | Data Storage | Migration | Best For |
+|---------|-------------|-----------|----------|
+| **docker-compose.local.yml** | Local directories | ✅ Easy (tar entire directory) | Production, frequent backups |
+| **docker-compose.yml** | Named volumes | ⚠️ Requires docker commands | Simple setup |
+
+**Recommendation:** Use `docker-compose.local.yml` (deployed by script) for easier data management.
+
#### Access
Open `http://YOUR_SERVER_IP:8080` in your browser.
+If admin password was auto-generated, find it in logs:
+```bash
+docker-compose -f docker-compose.local.yml logs sub2api | grep "admin password"
+```
+
#### Upgrade
```bash
# Pull latest image and recreate container
-docker-compose pull
-docker-compose up -d
+docker-compose -f docker-compose.local.yml pull
+docker-compose -f docker-compose.local.yml up -d
+```
+
+#### Easy Migration (Local Directory Version)
+
+When using `docker-compose.local.yml`, migrate to a new server easily:
+
+```bash
+# On source server
+docker-compose -f docker-compose.local.yml down
+cd ..
+tar czf sub2api-complete.tar.gz sub2api-deploy/
+
+# Transfer to new server
+scp sub2api-complete.tar.gz user@new-server:/path/
+
+# On new server
+tar xzf sub2api-complete.tar.gz
+cd sub2api-deploy/
+docker-compose -f docker-compose.local.yml up -d
```
#### Useful Commands
```bash
# Stop all services
-docker-compose down
+docker-compose -f docker-compose.local.yml down
# Restart
-docker-compose restart
+docker-compose -f docker-compose.local.yml restart
# View all logs
-docker-compose logs -f
+docker-compose -f docker-compose.local.yml logs -f
+
+# Remove all data (caution!)
+docker-compose -f docker-compose.local.yml down
+rm -rf data/ postgres_data/ redis_data/
```
---
diff --git a/README_CN.md b/README_CN.md
index 707f0201..a16adc72 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -135,7 +135,7 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
---
-### 方式二:Docker Compose
+### 方式二:Docker Compose(推荐)
使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。
@@ -144,87 +144,157 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
- Docker 20.10+
- Docker Compose v2+
-#### 安装步骤
+#### 快速开始(一键部署)
+
+使用自动化部署脚本快速搭建:
+
+```bash
+# 创建部署目录
+mkdir -p sub2api-deploy && cd sub2api-deploy
+
+# 下载并运行部署准备脚本
+curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
+
+# 启动服务
+docker-compose -f docker-compose.local.yml up -d
+
+# 查看日志
+docker-compose -f docker-compose.local.yml logs -f sub2api
+```
+
+**脚本功能:**
+- 下载 `docker-compose.local.yml` 和 `.env.example`
+- 自动生成安全凭证(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)
+- 创建 `.env` 文件并填充自动生成的密钥
+- 创建数据目录(使用本地目录,便于备份和迁移)
+- 显示生成的凭证供你记录
+
+#### 手动部署
+
+如果你希望手动配置:
```bash
# 1. 克隆仓库
git clone https://github.com/Wei-Shaw/sub2api.git
-cd sub2api
+cd sub2api/deploy
-# 2. 进入 deploy 目录
-cd deploy
-
-# 3. 复制环境配置文件
+# 2. 复制环境配置文件
cp .env.example .env
-# 4. 编辑配置(设置密码等)
+# 3. 编辑配置(生成安全密码)
nano .env
```
**`.env` 必须配置项:**
```bash
-# PostgreSQL 密码(必须修改!)
+# PostgreSQL 密码(必需)
POSTGRES_PASSWORD=your_secure_password_here
+# JWT 密钥(推荐 - 重启后保持用户登录状态)
+JWT_SECRET=your_jwt_secret_here
+
+# TOTP 加密密钥(推荐 - 重启后保留双因素认证)
+TOTP_ENCRYPTION_KEY=your_totp_key_here
+
# 可选:管理员账号
ADMIN_EMAIL=admin@example.com
ADMIN_PASSWORD=your_admin_password
# 可选:自定义端口
SERVER_PORT=8080
+```
-# 可选:安全配置
-# 启用 URL 白名单验证(false 则跳过白名单检查,仅做基本格式校验)
-SECURITY_URL_ALLOWLIST_ENABLED=false
+**生成安全密钥:**
+```bash
+# 生成 JWT_SECRET
+openssl rand -hex 32
-# 关闭白名单时,是否允许 http:// URL(默认 false,只允许 https://)
-# ⚠️ 警告:允许 HTTP 会暴露 API 密钥(明文传输)
-# 仅建议在以下场景使用:
-# - 开发/测试环境
-# - 内部可信网络
-# - 本地测试服务器(http://localhost)
-# 生产环境:保持 false 或仅使用 HTTPS URL
-SECURITY_URL_ALLOWLIST_ALLOW_INSECURE_HTTP=false
+# 生成 TOTP_ENCRYPTION_KEY
+openssl rand -hex 32
-# 是否允许私有 IP 地址用于上游/定价/CRS(内网部署时使用)
-SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=false
+# 生成 POSTGRES_PASSWORD
+openssl rand -hex 32
```
```bash
+# 4. 创建数据目录(本地版)
+mkdir -p data postgres_data redis_data
+
# 5. 启动所有服务
+# 选项 A:本地目录版(推荐 - 易于迁移)
+docker-compose -f docker-compose.local.yml up -d
+
+# 选项 B:命名卷版(简单设置)
docker-compose up -d
# 6. 查看状态
-docker-compose ps
+docker-compose -f docker-compose.local.yml ps
# 7. 查看日志
-docker-compose logs -f sub2api
+docker-compose -f docker-compose.local.yml logs -f sub2api
```
+#### 部署版本对比
+
+| 版本 | 数据存储 | 迁移便利性 | 适用场景 |
+|------|---------|-----------|---------|
+| **docker-compose.local.yml** | 本地目录 | ✅ 简单(打包整个目录) | 生产环境、频繁备份 |
+| **docker-compose.yml** | 命名卷 | ⚠️ 需要 docker 命令 | 简单设置 |
+
+**推荐:** 使用 `docker-compose.local.yml`(脚本部署)以便更轻松地管理数据。
+
#### 访问
在浏览器中打开 `http://你的服务器IP:8080`
+如果管理员密码是自动生成的,在日志中查找:
+```bash
+docker-compose -f docker-compose.local.yml logs sub2api | grep "admin password"
+```
+
#### 升级
```bash
# 拉取最新镜像并重建容器
-docker-compose pull
-docker-compose up -d
+docker-compose -f docker-compose.local.yml pull
+docker-compose -f docker-compose.local.yml up -d
+```
+
+#### 轻松迁移(本地目录版)
+
+使用 `docker-compose.local.yml` 时,可以轻松迁移到新服务器:
+
+```bash
+# 源服务器
+docker-compose -f docker-compose.local.yml down
+cd ..
+tar czf sub2api-complete.tar.gz sub2api-deploy/
+
+# 传输到新服务器
+scp sub2api-complete.tar.gz user@new-server:/path/
+
+# 新服务器
+tar xzf sub2api-complete.tar.gz
+cd sub2api-deploy/
+docker-compose -f docker-compose.local.yml up -d
```
#### 常用命令
```bash
# 停止所有服务
-docker-compose down
+docker-compose -f docker-compose.local.yml down
# 重启
-docker-compose restart
+docker-compose -f docker-compose.local.yml restart
# 查看所有日志
-docker-compose logs -f
+docker-compose -f docker-compose.local.yml logs -f
+
+# 删除所有数据(谨慎!)
+docker-compose -f docker-compose.local.yml down
+rm -rf data/ postgres_data/ redis_data/
```
---
diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go
index c461198b..139a3a39 100644
--- a/backend/cmd/jwtgen/main.go
+++ b/backend/cmd/jwtgen/main.go
@@ -33,7 +33,7 @@ func main() {
}()
userRepo := repository.NewUserRepository(client, sqlDB)
- authService := service.NewAuthService(userRepo, cfg, nil, nil, nil, nil, nil)
+ authService := service.NewAuthService(userRepo, nil, cfg, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index 79e0dd8a..a2d633db 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.46
+0.1.61
diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go
index 1e9e440e..c55ea844 100644
--- a/backend/cmd/server/wire.go
+++ b/backend/cmd/server/wire.go
@@ -71,6 +71,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
+ subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
@@ -145,6 +146,10 @@ func provideCleanup(
accountExpiry.Stop()
return nil
}},
+ {"SubscriptionExpiryService", func() error {
+ subscriptionExpiry.Stop()
+ return nil
+ }},
{"PricingService", func() error {
pricing.Stop()
return nil
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index dd0eb0d9..6422ea20 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -43,6 +43,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
return nil, err
}
userRepository := repository.NewUserRepository(client, db)
+ redeemCodeRepository := repository.NewRedeemCodeRepository(client)
settingRepository := repository.NewSettingRepository(client)
settingService := service.NewSettingService(settingRepository, configConfig)
redisClient := repository.ProvideRedis(configConfig)
@@ -61,20 +62,29 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
- authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
+ authService := service.NewAuthService(userRepository, redeemCodeRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
- authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
+ subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
+ redeemCache := repository.NewRedeemCache(redisClient)
+ redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
+ secretEncryptor, err := repository.NewAESEncryptor(configConfig)
+ if err != nil {
+ return nil, err
+ }
+ totpCache := repository.NewTotpCache(redisClient)
+ totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
+ authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
- redeemCodeRepository := repository.NewRedeemCodeRepository(client)
- subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
- redeemCache := repository.NewRedeemCache(redisClient)
- redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
+ announcementRepository := repository.NewAnnouncementRepository(client)
+ announcementReadRepository := repository.NewAnnouncementReadRepository(client)
+ announcementService := service.NewAnnouncementService(announcementRepository, announcementReadRepository, userRepository, userSubscriptionRepository)
+ announcementHandler := handler.NewAnnouncementHandler(announcementService)
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
@@ -123,6 +133,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
+ adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
@@ -162,15 +173,16 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
- gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
+ gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
- handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler)
+ totpHandler := handler.NewTotpHandler(totpService)
+ handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -182,9 +194,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
- tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
+ tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
- v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
+ subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
+ v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -218,6 +231,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
+ subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
@@ -291,6 +305,10 @@ func provideCleanup(
accountExpiry.Stop()
return nil
}},
+ {"SubscriptionExpiryService", func() error {
+ subscriptionExpiry.Stop()
+ return nil
+ }},
{"PricingService", func() error {
pricing.Stop()
return nil
diff --git a/backend/ent/announcement.go b/backend/ent/announcement.go
new file mode 100644
index 00000000..93d7a375
--- /dev/null
+++ b/backend/ent/announcement.go
@@ -0,0 +1,249 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
+)
+
+// Announcement is the model entity for the Announcement schema.
+type Announcement struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // 公告标题
+ Title string `json:"title,omitempty"`
+ // 公告内容(支持 Markdown)
+ Content string `json:"content,omitempty"`
+ // 状态: draft, active, archived
+ Status string `json:"status,omitempty"`
+ // 展示条件(JSON 规则)
+ Targeting domain.AnnouncementTargeting `json:"targeting,omitempty"`
+ // 开始展示时间(为空表示立即生效)
+ StartsAt *time.Time `json:"starts_at,omitempty"`
+ // 结束展示时间(为空表示永久生效)
+ EndsAt *time.Time `json:"ends_at,omitempty"`
+ // 创建人用户ID(管理员)
+ CreatedBy *int64 `json:"created_by,omitempty"`
+ // 更新人用户ID(管理员)
+ UpdatedBy *int64 `json:"updated_by,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AnnouncementQuery when eager-loading is set.
+ Edges AnnouncementEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AnnouncementEdges holds the relations/edges for other nodes in the graph.
+type AnnouncementEdges struct {
+ // Reads holds the value of the reads edge.
+ Reads []*AnnouncementRead `json:"reads,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// ReadsOrErr returns the Reads value or an error if the edge
+// was not loaded in eager-loading.
+func (e AnnouncementEdges) ReadsOrErr() ([]*AnnouncementRead, error) {
+ if e.loadedTypes[0] {
+ return e.Reads, nil
+ }
+ return nil, &NotLoadedError{edge: "reads"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*Announcement) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case announcement.FieldTargeting:
+ values[i] = new([]byte)
+ case announcement.FieldID, announcement.FieldCreatedBy, announcement.FieldUpdatedBy:
+ values[i] = new(sql.NullInt64)
+ case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus:
+ values[i] = new(sql.NullString)
+ case announcement.FieldStartsAt, announcement.FieldEndsAt, announcement.FieldCreatedAt, announcement.FieldUpdatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the Announcement fields.
+func (_m *Announcement) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case announcement.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case announcement.FieldTitle:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field title", values[i])
+ } else if value.Valid {
+ _m.Title = value.String
+ }
+ case announcement.FieldContent:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field content", values[i])
+ } else if value.Valid {
+ _m.Content = value.String
+ }
+ case announcement.FieldStatus:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field status", values[i])
+ } else if value.Valid {
+ _m.Status = value.String
+ }
+ case announcement.FieldTargeting:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field targeting", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.Targeting); err != nil {
+ return fmt.Errorf("unmarshal field targeting: %w", err)
+ }
+ }
+ case announcement.FieldStartsAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field starts_at", values[i])
+ } else if value.Valid {
+ _m.StartsAt = new(time.Time)
+ *_m.StartsAt = value.Time
+ }
+ case announcement.FieldEndsAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field ends_at", values[i])
+ } else if value.Valid {
+ _m.EndsAt = new(time.Time)
+ *_m.EndsAt = value.Time
+ }
+ case announcement.FieldCreatedBy:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field created_by", values[i])
+ } else if value.Valid {
+ _m.CreatedBy = new(int64)
+ *_m.CreatedBy = value.Int64
+ }
+ case announcement.FieldUpdatedBy:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_by", values[i])
+ } else if value.Valid {
+ _m.UpdatedBy = new(int64)
+ *_m.UpdatedBy = value.Int64
+ }
+ case announcement.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case announcement.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the Announcement.
+// This includes values selected through modifiers, order, etc.
+func (_m *Announcement) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryReads queries the "reads" edge of the Announcement entity.
+func (_m *Announcement) QueryReads() *AnnouncementReadQuery {
+ return NewAnnouncementClient(_m.config).QueryReads(_m)
+}
+
+// Update returns a builder for updating this Announcement.
+// Note that you need to call Announcement.Unwrap() before calling this method if this Announcement
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *Announcement) Update() *AnnouncementUpdateOne {
+ return NewAnnouncementClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the Announcement entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *Announcement) Unwrap() *Announcement {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: Announcement is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *Announcement) String() string {
+ var builder strings.Builder
+ builder.WriteString("Announcement(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("title=")
+ builder.WriteString(_m.Title)
+ builder.WriteString(", ")
+ builder.WriteString("content=")
+ builder.WriteString(_m.Content)
+ builder.WriteString(", ")
+ builder.WriteString("status=")
+ builder.WriteString(_m.Status)
+ builder.WriteString(", ")
+ builder.WriteString("targeting=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Targeting))
+ builder.WriteString(", ")
+ if v := _m.StartsAt; v != nil {
+ builder.WriteString("starts_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.EndsAt; v != nil {
+ builder.WriteString("ends_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.CreatedBy; v != nil {
+ builder.WriteString("created_by=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ if v := _m.UpdatedBy; v != nil {
+ builder.WriteString("updated_by=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// Announcements is a parsable slice of Announcement.
+type Announcements []*Announcement
diff --git a/backend/ent/announcement/announcement.go b/backend/ent/announcement/announcement.go
new file mode 100644
index 00000000..4f34ee05
--- /dev/null
+++ b/backend/ent/announcement/announcement.go
@@ -0,0 +1,164 @@
+// Code generated by ent, DO NOT EDIT.
+
+package announcement
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the announcement type in the database.
+ Label = "announcement"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldTitle holds the string denoting the title field in the database.
+ FieldTitle = "title"
+ // FieldContent holds the string denoting the content field in the database.
+ FieldContent = "content"
+ // FieldStatus holds the string denoting the status field in the database.
+ FieldStatus = "status"
+ // FieldTargeting holds the string denoting the targeting field in the database.
+ FieldTargeting = "targeting"
+ // FieldStartsAt holds the string denoting the starts_at field in the database.
+ FieldStartsAt = "starts_at"
+ // FieldEndsAt holds the string denoting the ends_at field in the database.
+ FieldEndsAt = "ends_at"
+ // FieldCreatedBy holds the string denoting the created_by field in the database.
+ FieldCreatedBy = "created_by"
+ // FieldUpdatedBy holds the string denoting the updated_by field in the database.
+ FieldUpdatedBy = "updated_by"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // EdgeReads holds the string denoting the reads edge name in mutations.
+ EdgeReads = "reads"
+ // Table holds the table name of the announcement in the database.
+ Table = "announcements"
+ // ReadsTable is the table that holds the reads relation/edge.
+ ReadsTable = "announcement_reads"
+ // ReadsInverseTable is the table name for the AnnouncementRead entity.
+ // It exists in this package in order to avoid circular dependency with the "announcementread" package.
+ ReadsInverseTable = "announcement_reads"
+ // ReadsColumn is the table column denoting the reads relation/edge.
+ ReadsColumn = "announcement_id"
+)
+
+// Columns holds all SQL columns for announcement fields.
+var Columns = []string{
+ FieldID,
+ FieldTitle,
+ FieldContent,
+ FieldStatus,
+ FieldTargeting,
+ FieldStartsAt,
+ FieldEndsAt,
+ FieldCreatedBy,
+ FieldUpdatedBy,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // TitleValidator is a validator for the "title" field. It is called by the builders before save.
+ TitleValidator func(string) error
+ // ContentValidator is a validator for the "content" field. It is called by the builders before save.
+ ContentValidator func(string) error
+ // DefaultStatus holds the default value on creation for the "status" field.
+ DefaultStatus string
+ // StatusValidator is a validator for the "status" field. It is called by the builders before save.
+ StatusValidator func(string) error
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the Announcement queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByTitle orders the results by the title field.
+func ByTitle(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTitle, opts...).ToFunc()
+}
+
+// ByContent orders the results by the content field.
+func ByContent(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldContent, opts...).ToFunc()
+}
+
+// ByStatus orders the results by the status field.
+func ByStatus(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldStatus, opts...).ToFunc()
+}
+
+// ByStartsAt orders the results by the starts_at field.
+func ByStartsAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldStartsAt, opts...).ToFunc()
+}
+
+// ByEndsAt orders the results by the ends_at field.
+func ByEndsAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldEndsAt, opts...).ToFunc()
+}
+
+// ByCreatedBy orders the results by the created_by field.
+func ByCreatedBy(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedBy, opts...).ToFunc()
+}
+
+// ByUpdatedBy orders the results by the updated_by field.
+func ByUpdatedBy(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedBy, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByReadsCount orders the results by reads count.
+func ByReadsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newReadsStep(), opts...)
+ }
+}
+
+// ByReads orders the results by reads terms.
+func ByReads(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newReadsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+func newReadsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(ReadsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, ReadsTable, ReadsColumn),
+ )
+}
diff --git a/backend/ent/announcement/where.go b/backend/ent/announcement/where.go
new file mode 100644
index 00000000..d3cad2a5
--- /dev/null
+++ b/backend/ent/announcement/where.go
@@ -0,0 +1,624 @@
+// Code generated by ent, DO NOT EDIT.
+
+package announcement
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLTE(FieldID, id))
+}
+
+// Title applies equality check predicate on the "title" field. It's identical to TitleEQ.
+func Title(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldTitle, v))
+}
+
+// Content applies equality check predicate on the "content" field. It's identical to ContentEQ.
+func Content(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldContent, v))
+}
+
+// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
+func Status(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldStatus, v))
+}
+
+// StartsAt applies equality check predicate on the "starts_at" field. It's identical to StartsAtEQ.
+func StartsAt(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldStartsAt, v))
+}
+
+// EndsAt applies equality check predicate on the "ends_at" field. It's identical to EndsAtEQ.
+func EndsAt(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldEndsAt, v))
+}
+
+// CreatedBy applies equality check predicate on the "created_by" field. It's identical to CreatedByEQ.
+func CreatedBy(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldCreatedBy, v))
+}
+
+// UpdatedBy applies equality check predicate on the "updated_by" field. It's identical to UpdatedByEQ.
+func UpdatedBy(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldUpdatedBy, v))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// TitleEQ applies the EQ predicate on the "title" field.
+func TitleEQ(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldTitle, v))
+}
+
+// TitleNEQ applies the NEQ predicate on the "title" field.
+func TitleNEQ(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNEQ(FieldTitle, v))
+}
+
+// TitleIn applies the In predicate on the "title" field.
+func TitleIn(vs ...string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldIn(FieldTitle, vs...))
+}
+
+// TitleNotIn applies the NotIn predicate on the "title" field.
+func TitleNotIn(vs ...string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotIn(FieldTitle, vs...))
+}
+
+// TitleGT applies the GT predicate on the "title" field.
+func TitleGT(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGT(FieldTitle, v))
+}
+
+// TitleGTE applies the GTE predicate on the "title" field.
+func TitleGTE(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGTE(FieldTitle, v))
+}
+
+// TitleLT applies the LT predicate on the "title" field.
+func TitleLT(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLT(FieldTitle, v))
+}
+
+// TitleLTE applies the LTE predicate on the "title" field.
+func TitleLTE(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLTE(FieldTitle, v))
+}
+
+// TitleContains applies the Contains predicate on the "title" field.
+func TitleContains(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldContains(FieldTitle, v))
+}
+
+// TitleHasPrefix applies the HasPrefix predicate on the "title" field.
+func TitleHasPrefix(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldHasPrefix(FieldTitle, v))
+}
+
+// TitleHasSuffix applies the HasSuffix predicate on the "title" field.
+func TitleHasSuffix(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldHasSuffix(FieldTitle, v))
+}
+
+// TitleEqualFold applies the EqualFold predicate on the "title" field.
+func TitleEqualFold(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEqualFold(FieldTitle, v))
+}
+
+// TitleContainsFold applies the ContainsFold predicate on the "title" field.
+func TitleContainsFold(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldContainsFold(FieldTitle, v))
+}
+
+// ContentEQ applies the EQ predicate on the "content" field.
+func ContentEQ(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldContent, v))
+}
+
+// ContentNEQ applies the NEQ predicate on the "content" field.
+func ContentNEQ(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNEQ(FieldContent, v))
+}
+
+// ContentIn applies the In predicate on the "content" field.
+func ContentIn(vs ...string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldIn(FieldContent, vs...))
+}
+
+// ContentNotIn applies the NotIn predicate on the "content" field.
+func ContentNotIn(vs ...string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotIn(FieldContent, vs...))
+}
+
+// ContentGT applies the GT predicate on the "content" field.
+func ContentGT(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGT(FieldContent, v))
+}
+
+// ContentGTE applies the GTE predicate on the "content" field.
+func ContentGTE(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGTE(FieldContent, v))
+}
+
+// ContentLT applies the LT predicate on the "content" field.
+func ContentLT(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLT(FieldContent, v))
+}
+
+// ContentLTE applies the LTE predicate on the "content" field.
+func ContentLTE(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLTE(FieldContent, v))
+}
+
+// ContentContains applies the Contains predicate on the "content" field.
+func ContentContains(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldContains(FieldContent, v))
+}
+
+// ContentHasPrefix applies the HasPrefix predicate on the "content" field.
+func ContentHasPrefix(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldHasPrefix(FieldContent, v))
+}
+
+// ContentHasSuffix applies the HasSuffix predicate on the "content" field.
+func ContentHasSuffix(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldHasSuffix(FieldContent, v))
+}
+
+// ContentEqualFold applies the EqualFold predicate on the "content" field.
+func ContentEqualFold(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEqualFold(FieldContent, v))
+}
+
+// ContentContainsFold applies the ContainsFold predicate on the "content" field.
+func ContentContainsFold(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldContainsFold(FieldContent, v))
+}
+
+// StatusEQ applies the EQ predicate on the "status" field.
+func StatusEQ(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldStatus, v))
+}
+
+// StatusNEQ applies the NEQ predicate on the "status" field.
+func StatusNEQ(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNEQ(FieldStatus, v))
+}
+
+// StatusIn applies the In predicate on the "status" field.
+func StatusIn(vs ...string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldIn(FieldStatus, vs...))
+}
+
+// StatusNotIn applies the NotIn predicate on the "status" field.
+func StatusNotIn(vs ...string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotIn(FieldStatus, vs...))
+}
+
+// StatusGT applies the GT predicate on the "status" field.
+func StatusGT(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGT(FieldStatus, v))
+}
+
+// StatusGTE applies the GTE predicate on the "status" field.
+func StatusGTE(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGTE(FieldStatus, v))
+}
+
+// StatusLT applies the LT predicate on the "status" field.
+func StatusLT(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLT(FieldStatus, v))
+}
+
+// StatusLTE applies the LTE predicate on the "status" field.
+func StatusLTE(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLTE(FieldStatus, v))
+}
+
+// StatusContains applies the Contains predicate on the "status" field.
+func StatusContains(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldContains(FieldStatus, v))
+}
+
+// StatusHasPrefix applies the HasPrefix predicate on the "status" field.
+func StatusHasPrefix(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldHasPrefix(FieldStatus, v))
+}
+
+// StatusHasSuffix applies the HasSuffix predicate on the "status" field.
+func StatusHasSuffix(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldHasSuffix(FieldStatus, v))
+}
+
+// StatusEqualFold applies the EqualFold predicate on the "status" field.
+func StatusEqualFold(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEqualFold(FieldStatus, v))
+}
+
+// StatusContainsFold applies the ContainsFold predicate on the "status" field.
+func StatusContainsFold(v string) predicate.Announcement {
+ return predicate.Announcement(sql.FieldContainsFold(FieldStatus, v))
+}
+
+// TargetingIsNil applies the IsNil predicate on the "targeting" field.
+func TargetingIsNil() predicate.Announcement {
+ return predicate.Announcement(sql.FieldIsNull(FieldTargeting))
+}
+
+// TargetingNotNil applies the NotNil predicate on the "targeting" field.
+func TargetingNotNil() predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotNull(FieldTargeting))
+}
+
+// StartsAtEQ applies the EQ predicate on the "starts_at" field.
+func StartsAtEQ(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldStartsAt, v))
+}
+
+// StartsAtNEQ applies the NEQ predicate on the "starts_at" field.
+func StartsAtNEQ(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNEQ(FieldStartsAt, v))
+}
+
+// StartsAtIn applies the In predicate on the "starts_at" field.
+func StartsAtIn(vs ...time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldIn(FieldStartsAt, vs...))
+}
+
+// StartsAtNotIn applies the NotIn predicate on the "starts_at" field.
+func StartsAtNotIn(vs ...time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotIn(FieldStartsAt, vs...))
+}
+
+// StartsAtGT applies the GT predicate on the "starts_at" field.
+func StartsAtGT(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGT(FieldStartsAt, v))
+}
+
+// StartsAtGTE applies the GTE predicate on the "starts_at" field.
+func StartsAtGTE(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGTE(FieldStartsAt, v))
+}
+
+// StartsAtLT applies the LT predicate on the "starts_at" field.
+func StartsAtLT(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLT(FieldStartsAt, v))
+}
+
+// StartsAtLTE applies the LTE predicate on the "starts_at" field.
+func StartsAtLTE(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLTE(FieldStartsAt, v))
+}
+
+// StartsAtIsNil applies the IsNil predicate on the "starts_at" field.
+func StartsAtIsNil() predicate.Announcement {
+ return predicate.Announcement(sql.FieldIsNull(FieldStartsAt))
+}
+
+// StartsAtNotNil applies the NotNil predicate on the "starts_at" field.
+func StartsAtNotNil() predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotNull(FieldStartsAt))
+}
+
+// EndsAtEQ applies the EQ predicate on the "ends_at" field.
+func EndsAtEQ(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldEndsAt, v))
+}
+
+// EndsAtNEQ applies the NEQ predicate on the "ends_at" field.
+func EndsAtNEQ(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNEQ(FieldEndsAt, v))
+}
+
+// EndsAtIn applies the In predicate on the "ends_at" field.
+func EndsAtIn(vs ...time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldIn(FieldEndsAt, vs...))
+}
+
+// EndsAtNotIn applies the NotIn predicate on the "ends_at" field.
+func EndsAtNotIn(vs ...time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotIn(FieldEndsAt, vs...))
+}
+
+// EndsAtGT applies the GT predicate on the "ends_at" field.
+func EndsAtGT(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGT(FieldEndsAt, v))
+}
+
+// EndsAtGTE applies the GTE predicate on the "ends_at" field.
+func EndsAtGTE(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGTE(FieldEndsAt, v))
+}
+
+// EndsAtLT applies the LT predicate on the "ends_at" field.
+func EndsAtLT(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLT(FieldEndsAt, v))
+}
+
+// EndsAtLTE applies the LTE predicate on the "ends_at" field.
+func EndsAtLTE(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLTE(FieldEndsAt, v))
+}
+
+// EndsAtIsNil applies the IsNil predicate on the "ends_at" field.
+func EndsAtIsNil() predicate.Announcement {
+ return predicate.Announcement(sql.FieldIsNull(FieldEndsAt))
+}
+
+// EndsAtNotNil applies the NotNil predicate on the "ends_at" field.
+func EndsAtNotNil() predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotNull(FieldEndsAt))
+}
+
+// CreatedByEQ applies the EQ predicate on the "created_by" field.
+func CreatedByEQ(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldCreatedBy, v))
+}
+
+// CreatedByNEQ applies the NEQ predicate on the "created_by" field.
+func CreatedByNEQ(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNEQ(FieldCreatedBy, v))
+}
+
+// CreatedByIn applies the In predicate on the "created_by" field.
+func CreatedByIn(vs ...int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldIn(FieldCreatedBy, vs...))
+}
+
+// CreatedByNotIn applies the NotIn predicate on the "created_by" field.
+func CreatedByNotIn(vs ...int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotIn(FieldCreatedBy, vs...))
+}
+
+// CreatedByGT applies the GT predicate on the "created_by" field.
+func CreatedByGT(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGT(FieldCreatedBy, v))
+}
+
+// CreatedByGTE applies the GTE predicate on the "created_by" field.
+func CreatedByGTE(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGTE(FieldCreatedBy, v))
+}
+
+// CreatedByLT applies the LT predicate on the "created_by" field.
+func CreatedByLT(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLT(FieldCreatedBy, v))
+}
+
+// CreatedByLTE applies the LTE predicate on the "created_by" field.
+func CreatedByLTE(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLTE(FieldCreatedBy, v))
+}
+
+// CreatedByIsNil applies the IsNil predicate on the "created_by" field.
+func CreatedByIsNil() predicate.Announcement {
+ return predicate.Announcement(sql.FieldIsNull(FieldCreatedBy))
+}
+
+// CreatedByNotNil applies the NotNil predicate on the "created_by" field.
+func CreatedByNotNil() predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotNull(FieldCreatedBy))
+}
+
+// UpdatedByEQ applies the EQ predicate on the "updated_by" field.
+func UpdatedByEQ(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldUpdatedBy, v))
+}
+
+// UpdatedByNEQ applies the NEQ predicate on the "updated_by" field.
+func UpdatedByNEQ(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNEQ(FieldUpdatedBy, v))
+}
+
+// UpdatedByIn applies the In predicate on the "updated_by" field.
+func UpdatedByIn(vs ...int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldIn(FieldUpdatedBy, vs...))
+}
+
+// UpdatedByNotIn applies the NotIn predicate on the "updated_by" field.
+func UpdatedByNotIn(vs ...int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotIn(FieldUpdatedBy, vs...))
+}
+
+// UpdatedByGT applies the GT predicate on the "updated_by" field.
+func UpdatedByGT(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGT(FieldUpdatedBy, v))
+}
+
+// UpdatedByGTE applies the GTE predicate on the "updated_by" field.
+func UpdatedByGTE(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGTE(FieldUpdatedBy, v))
+}
+
+// UpdatedByLT applies the LT predicate on the "updated_by" field.
+func UpdatedByLT(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLT(FieldUpdatedBy, v))
+}
+
+// UpdatedByLTE applies the LTE predicate on the "updated_by" field.
+func UpdatedByLTE(v int64) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLTE(FieldUpdatedBy, v))
+}
+
+// UpdatedByIsNil applies the IsNil predicate on the "updated_by" field.
+func UpdatedByIsNil() predicate.Announcement {
+ return predicate.Announcement(sql.FieldIsNull(FieldUpdatedBy))
+}
+
+// UpdatedByNotNil applies the NotNil predicate on the "updated_by" field.
+func UpdatedByNotNil() predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotNull(FieldUpdatedBy))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.Announcement {
+ return predicate.Announcement(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// HasReads applies the HasEdge predicate on the "reads" edge.
+func HasReads() predicate.Announcement {
+ return predicate.Announcement(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, ReadsTable, ReadsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasReadsWith applies the HasEdge predicate on the "reads" edge with a given conditions (other predicates).
+func HasReadsWith(preds ...predicate.AnnouncementRead) predicate.Announcement {
+ return predicate.Announcement(func(s *sql.Selector) {
+ step := newReadsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.Announcement) predicate.Announcement {
+ return predicate.Announcement(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.Announcement) predicate.Announcement {
+ return predicate.Announcement(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.Announcement) predicate.Announcement {
+ return predicate.Announcement(sql.NotPredicates(p))
+}
diff --git a/backend/ent/announcement_create.go b/backend/ent/announcement_create.go
new file mode 100644
index 00000000..151d4c11
--- /dev/null
+++ b/backend/ent/announcement_create.go
@@ -0,0 +1,1159 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
+)
+
+// AnnouncementCreate is the builder for creating a Announcement entity.
+type AnnouncementCreate struct {
+ config
+ mutation *AnnouncementMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetTitle sets the "title" field.
+func (_c *AnnouncementCreate) SetTitle(v string) *AnnouncementCreate {
+ _c.mutation.SetTitle(v)
+ return _c
+}
+
+// SetContent sets the "content" field.
+func (_c *AnnouncementCreate) SetContent(v string) *AnnouncementCreate {
+ _c.mutation.SetContent(v)
+ return _c
+}
+
+// SetStatus sets the "status" field.
+func (_c *AnnouncementCreate) SetStatus(v string) *AnnouncementCreate {
+ _c.mutation.SetStatus(v)
+ return _c
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_c *AnnouncementCreate) SetNillableStatus(v *string) *AnnouncementCreate {
+ if v != nil {
+ _c.SetStatus(*v)
+ }
+ return _c
+}
+
+// SetTargeting sets the "targeting" field.
+func (_c *AnnouncementCreate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementCreate {
+ _c.mutation.SetTargeting(v)
+ return _c
+}
+
+// SetNillableTargeting sets the "targeting" field if the given value is not nil.
+func (_c *AnnouncementCreate) SetNillableTargeting(v *domain.AnnouncementTargeting) *AnnouncementCreate {
+ if v != nil {
+ _c.SetTargeting(*v)
+ }
+ return _c
+}
+
+// SetStartsAt sets the "starts_at" field.
+func (_c *AnnouncementCreate) SetStartsAt(v time.Time) *AnnouncementCreate {
+ _c.mutation.SetStartsAt(v)
+ return _c
+}
+
+// SetNillableStartsAt sets the "starts_at" field if the given value is not nil.
+func (_c *AnnouncementCreate) SetNillableStartsAt(v *time.Time) *AnnouncementCreate {
+ if v != nil {
+ _c.SetStartsAt(*v)
+ }
+ return _c
+}
+
+// SetEndsAt sets the "ends_at" field.
+func (_c *AnnouncementCreate) SetEndsAt(v time.Time) *AnnouncementCreate {
+ _c.mutation.SetEndsAt(v)
+ return _c
+}
+
+// SetNillableEndsAt sets the "ends_at" field if the given value is not nil.
+func (_c *AnnouncementCreate) SetNillableEndsAt(v *time.Time) *AnnouncementCreate {
+ if v != nil {
+ _c.SetEndsAt(*v)
+ }
+ return _c
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (_c *AnnouncementCreate) SetCreatedBy(v int64) *AnnouncementCreate {
+ _c.mutation.SetCreatedBy(v)
+ return _c
+}
+
+// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
+func (_c *AnnouncementCreate) SetNillableCreatedBy(v *int64) *AnnouncementCreate {
+ if v != nil {
+ _c.SetCreatedBy(*v)
+ }
+ return _c
+}
+
+// SetUpdatedBy sets the "updated_by" field.
+func (_c *AnnouncementCreate) SetUpdatedBy(v int64) *AnnouncementCreate {
+ _c.mutation.SetUpdatedBy(v)
+ return _c
+}
+
+// SetNillableUpdatedBy sets the "updated_by" field if the given value is not nil.
+func (_c *AnnouncementCreate) SetNillableUpdatedBy(v *int64) *AnnouncementCreate {
+ if v != nil {
+ _c.SetUpdatedBy(*v)
+ }
+ return _c
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *AnnouncementCreate) SetCreatedAt(v time.Time) *AnnouncementCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *AnnouncementCreate) SetNillableCreatedAt(v *time.Time) *AnnouncementCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *AnnouncementCreate) SetUpdatedAt(v time.Time) *AnnouncementCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *AnnouncementCreate) SetNillableUpdatedAt(v *time.Time) *AnnouncementCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// AddReadIDs adds the "reads" edge to the AnnouncementRead entity by IDs.
+func (_c *AnnouncementCreate) AddReadIDs(ids ...int64) *AnnouncementCreate {
+ _c.mutation.AddReadIDs(ids...)
+ return _c
+}
+
+// AddReads adds the "reads" edges to the AnnouncementRead entity.
+func (_c *AnnouncementCreate) AddReads(v ...*AnnouncementRead) *AnnouncementCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddReadIDs(ids...)
+}
+
+// Mutation returns the AnnouncementMutation object of the builder.
+func (_c *AnnouncementCreate) Mutation() *AnnouncementMutation {
+ return _c.mutation
+}
+
+// Save creates the Announcement in the database.
+func (_c *AnnouncementCreate) Save(ctx context.Context) (*Announcement, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *AnnouncementCreate) SaveX(ctx context.Context) *Announcement {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AnnouncementCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AnnouncementCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *AnnouncementCreate) defaults() {
+ if _, ok := _c.mutation.Status(); !ok {
+ v := announcement.DefaultStatus
+ _c.mutation.SetStatus(v)
+ }
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := announcement.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := announcement.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *AnnouncementCreate) check() error {
+ if _, ok := _c.mutation.Title(); !ok {
+ return &ValidationError{Name: "title", err: errors.New(`ent: missing required field "Announcement.title"`)}
+ }
+ if v, ok := _c.mutation.Title(); ok {
+ if err := announcement.TitleValidator(v); err != nil {
+ return &ValidationError{Name: "title", err: fmt.Errorf(`ent: validator failed for field "Announcement.title": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Content(); !ok {
+ return &ValidationError{Name: "content", err: errors.New(`ent: missing required field "Announcement.content"`)}
+ }
+ if v, ok := _c.mutation.Content(); ok {
+ if err := announcement.ContentValidator(v); err != nil {
+ return &ValidationError{Name: "content", err: fmt.Errorf(`ent: validator failed for field "Announcement.content": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Status(); !ok {
+ return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Announcement.status"`)}
+ }
+ if v, ok := _c.mutation.Status(); ok {
+ if err := announcement.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Announcement.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Announcement.updated_at"`)}
+ }
+ return nil
+}
+
+func (_c *AnnouncementCreate) sqlSave(ctx context.Context) (*Announcement, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *AnnouncementCreate) createSpec() (*Announcement, *sqlgraph.CreateSpec) {
+ var (
+ _node = &Announcement{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(announcement.Table, sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.Title(); ok {
+ _spec.SetField(announcement.FieldTitle, field.TypeString, value)
+ _node.Title = value
+ }
+ if value, ok := _c.mutation.Content(); ok {
+ _spec.SetField(announcement.FieldContent, field.TypeString, value)
+ _node.Content = value
+ }
+ if value, ok := _c.mutation.Status(); ok {
+ _spec.SetField(announcement.FieldStatus, field.TypeString, value)
+ _node.Status = value
+ }
+ if value, ok := _c.mutation.Targeting(); ok {
+ _spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
+ _node.Targeting = value
+ }
+ if value, ok := _c.mutation.StartsAt(); ok {
+ _spec.SetField(announcement.FieldStartsAt, field.TypeTime, value)
+ _node.StartsAt = &value
+ }
+ if value, ok := _c.mutation.EndsAt(); ok {
+ _spec.SetField(announcement.FieldEndsAt, field.TypeTime, value)
+ _node.EndsAt = &value
+ }
+ if value, ok := _c.mutation.CreatedBy(); ok {
+ _spec.SetField(announcement.FieldCreatedBy, field.TypeInt64, value)
+ _node.CreatedBy = &value
+ }
+ if value, ok := _c.mutation.UpdatedBy(); ok {
+ _spec.SetField(announcement.FieldUpdatedBy, field.TypeInt64, value)
+ _node.UpdatedBy = &value
+ }
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(announcement.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(announcement.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if nodes := _c.mutation.ReadsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: announcement.ReadsTable,
+ Columns: []string{announcement.ReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.Announcement.Create().
+// SetTitle(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AnnouncementUpsert) {
+// SetTitle(v+v).
+// }).
+// Exec(ctx)
+func (_c *AnnouncementCreate) OnConflict(opts ...sql.ConflictOption) *AnnouncementUpsertOne {
+ _c.conflict = opts
+ return &AnnouncementUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.Announcement.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AnnouncementCreate) OnConflictColumns(columns ...string) *AnnouncementUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AnnouncementUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // AnnouncementUpsertOne is the builder for "upsert"-ing
+ // one Announcement node.
+ AnnouncementUpsertOne struct {
+ create *AnnouncementCreate
+ }
+
+ // AnnouncementUpsert is the "OnConflict" setter.
+ AnnouncementUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetTitle sets the "title" field.
+func (u *AnnouncementUpsert) SetTitle(v string) *AnnouncementUpsert {
+ u.Set(announcement.FieldTitle, v)
+ return u
+}
+
+// UpdateTitle sets the "title" field to the value that was provided on create.
+func (u *AnnouncementUpsert) UpdateTitle() *AnnouncementUpsert {
+ u.SetExcluded(announcement.FieldTitle)
+ return u
+}
+
+// SetContent sets the "content" field.
+func (u *AnnouncementUpsert) SetContent(v string) *AnnouncementUpsert {
+ u.Set(announcement.FieldContent, v)
+ return u
+}
+
+// UpdateContent sets the "content" field to the value that was provided on create.
+func (u *AnnouncementUpsert) UpdateContent() *AnnouncementUpsert {
+ u.SetExcluded(announcement.FieldContent)
+ return u
+}
+
+// SetStatus sets the "status" field.
+func (u *AnnouncementUpsert) SetStatus(v string) *AnnouncementUpsert {
+ u.Set(announcement.FieldStatus, v)
+ return u
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *AnnouncementUpsert) UpdateStatus() *AnnouncementUpsert {
+ u.SetExcluded(announcement.FieldStatus)
+ return u
+}
+
+// SetTargeting sets the "targeting" field.
+func (u *AnnouncementUpsert) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsert {
+ u.Set(announcement.FieldTargeting, v)
+ return u
+}
+
+// UpdateTargeting sets the "targeting" field to the value that was provided on create.
+func (u *AnnouncementUpsert) UpdateTargeting() *AnnouncementUpsert {
+ u.SetExcluded(announcement.FieldTargeting)
+ return u
+}
+
+// ClearTargeting clears the value of the "targeting" field.
+func (u *AnnouncementUpsert) ClearTargeting() *AnnouncementUpsert {
+ u.SetNull(announcement.FieldTargeting)
+ return u
+}
+
+// SetStartsAt sets the "starts_at" field.
+func (u *AnnouncementUpsert) SetStartsAt(v time.Time) *AnnouncementUpsert {
+ u.Set(announcement.FieldStartsAt, v)
+ return u
+}
+
+// UpdateStartsAt sets the "starts_at" field to the value that was provided on create.
+func (u *AnnouncementUpsert) UpdateStartsAt() *AnnouncementUpsert {
+ u.SetExcluded(announcement.FieldStartsAt)
+ return u
+}
+
+// ClearStartsAt clears the value of the "starts_at" field.
+func (u *AnnouncementUpsert) ClearStartsAt() *AnnouncementUpsert {
+ u.SetNull(announcement.FieldStartsAt)
+ return u
+}
+
+// SetEndsAt sets the "ends_at" field.
+func (u *AnnouncementUpsert) SetEndsAt(v time.Time) *AnnouncementUpsert {
+ u.Set(announcement.FieldEndsAt, v)
+ return u
+}
+
+// UpdateEndsAt sets the "ends_at" field to the value that was provided on create.
+func (u *AnnouncementUpsert) UpdateEndsAt() *AnnouncementUpsert {
+ u.SetExcluded(announcement.FieldEndsAt)
+ return u
+}
+
+// ClearEndsAt clears the value of the "ends_at" field.
+func (u *AnnouncementUpsert) ClearEndsAt() *AnnouncementUpsert {
+ u.SetNull(announcement.FieldEndsAt)
+ return u
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (u *AnnouncementUpsert) SetCreatedBy(v int64) *AnnouncementUpsert {
+ u.Set(announcement.FieldCreatedBy, v)
+ return u
+}
+
+// UpdateCreatedBy sets the "created_by" field to the value that was provided on create.
+func (u *AnnouncementUpsert) UpdateCreatedBy() *AnnouncementUpsert {
+ u.SetExcluded(announcement.FieldCreatedBy)
+ return u
+}
+
+// AddCreatedBy adds v to the "created_by" field.
+func (u *AnnouncementUpsert) AddCreatedBy(v int64) *AnnouncementUpsert {
+ u.Add(announcement.FieldCreatedBy, v)
+ return u
+}
+
+// ClearCreatedBy clears the value of the "created_by" field.
+func (u *AnnouncementUpsert) ClearCreatedBy() *AnnouncementUpsert {
+ u.SetNull(announcement.FieldCreatedBy)
+ return u
+}
+
+// SetUpdatedBy sets the "updated_by" field.
+func (u *AnnouncementUpsert) SetUpdatedBy(v int64) *AnnouncementUpsert {
+ u.Set(announcement.FieldUpdatedBy, v)
+ return u
+}
+
+// UpdateUpdatedBy sets the "updated_by" field to the value that was provided on create.
+func (u *AnnouncementUpsert) UpdateUpdatedBy() *AnnouncementUpsert {
+ u.SetExcluded(announcement.FieldUpdatedBy)
+ return u
+}
+
+// AddUpdatedBy adds v to the "updated_by" field.
+func (u *AnnouncementUpsert) AddUpdatedBy(v int64) *AnnouncementUpsert {
+ u.Add(announcement.FieldUpdatedBy, v)
+ return u
+}
+
+// ClearUpdatedBy clears the value of the "updated_by" field.
+func (u *AnnouncementUpsert) ClearUpdatedBy() *AnnouncementUpsert {
+ u.SetNull(announcement.FieldUpdatedBy)
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AnnouncementUpsert) SetUpdatedAt(v time.Time) *AnnouncementUpsert {
+ u.Set(announcement.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AnnouncementUpsert) UpdateUpdatedAt() *AnnouncementUpsert {
+ u.SetExcluded(announcement.FieldUpdatedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.Announcement.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AnnouncementUpsertOne) UpdateNewValues() *AnnouncementUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(announcement.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.Announcement.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AnnouncementUpsertOne) Ignore() *AnnouncementUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AnnouncementUpsertOne) DoNothing() *AnnouncementUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AnnouncementCreate.OnConflict
+// documentation for more info.
+func (u *AnnouncementUpsertOne) Update(set func(*AnnouncementUpsert)) *AnnouncementUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AnnouncementUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetTitle sets the "title" field.
+func (u *AnnouncementUpsertOne) SetTitle(v string) *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetTitle(v)
+ })
+}
+
+// UpdateTitle sets the "title" field to the value that was provided on create.
+func (u *AnnouncementUpsertOne) UpdateTitle() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateTitle()
+ })
+}
+
+// SetContent sets the "content" field.
+func (u *AnnouncementUpsertOne) SetContent(v string) *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetContent(v)
+ })
+}
+
+// UpdateContent sets the "content" field to the value that was provided on create.
+func (u *AnnouncementUpsertOne) UpdateContent() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateContent()
+ })
+}
+
+// SetStatus sets the "status" field.
+func (u *AnnouncementUpsertOne) SetStatus(v string) *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetStatus(v)
+ })
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *AnnouncementUpsertOne) UpdateStatus() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateStatus()
+ })
+}
+
+// SetTargeting sets the "targeting" field.
+func (u *AnnouncementUpsertOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetTargeting(v)
+ })
+}
+
+// UpdateTargeting sets the "targeting" field to the value that was provided on create.
+func (u *AnnouncementUpsertOne) UpdateTargeting() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateTargeting()
+ })
+}
+
+// ClearTargeting clears the value of the "targeting" field.
+func (u *AnnouncementUpsertOne) ClearTargeting() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.ClearTargeting()
+ })
+}
+
+// SetStartsAt sets the "starts_at" field.
+func (u *AnnouncementUpsertOne) SetStartsAt(v time.Time) *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetStartsAt(v)
+ })
+}
+
+// UpdateStartsAt sets the "starts_at" field to the value that was provided on create.
+func (u *AnnouncementUpsertOne) UpdateStartsAt() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateStartsAt()
+ })
+}
+
+// ClearStartsAt clears the value of the "starts_at" field.
+func (u *AnnouncementUpsertOne) ClearStartsAt() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.ClearStartsAt()
+ })
+}
+
+// SetEndsAt sets the "ends_at" field.
+func (u *AnnouncementUpsertOne) SetEndsAt(v time.Time) *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetEndsAt(v)
+ })
+}
+
+// UpdateEndsAt sets the "ends_at" field to the value that was provided on create.
+func (u *AnnouncementUpsertOne) UpdateEndsAt() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateEndsAt()
+ })
+}
+
+// ClearEndsAt clears the value of the "ends_at" field.
+func (u *AnnouncementUpsertOne) ClearEndsAt() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.ClearEndsAt()
+ })
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (u *AnnouncementUpsertOne) SetCreatedBy(v int64) *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetCreatedBy(v)
+ })
+}
+
+// AddCreatedBy adds v to the "created_by" field.
+func (u *AnnouncementUpsertOne) AddCreatedBy(v int64) *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.AddCreatedBy(v)
+ })
+}
+
+// UpdateCreatedBy sets the "created_by" field to the value that was provided on create.
+func (u *AnnouncementUpsertOne) UpdateCreatedBy() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateCreatedBy()
+ })
+}
+
+// ClearCreatedBy clears the value of the "created_by" field.
+func (u *AnnouncementUpsertOne) ClearCreatedBy() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.ClearCreatedBy()
+ })
+}
+
+// SetUpdatedBy sets the "updated_by" field.
+func (u *AnnouncementUpsertOne) SetUpdatedBy(v int64) *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetUpdatedBy(v)
+ })
+}
+
+// AddUpdatedBy adds v to the "updated_by" field.
+func (u *AnnouncementUpsertOne) AddUpdatedBy(v int64) *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.AddUpdatedBy(v)
+ })
+}
+
+// UpdateUpdatedBy sets the "updated_by" field to the value that was provided on create.
+func (u *AnnouncementUpsertOne) UpdateUpdatedBy() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateUpdatedBy()
+ })
+}
+
+// ClearUpdatedBy clears the value of the "updated_by" field.
+func (u *AnnouncementUpsertOne) ClearUpdatedBy() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.ClearUpdatedBy()
+ })
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AnnouncementUpsertOne) SetUpdatedAt(v time.Time) *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AnnouncementUpsertOne) UpdateUpdatedAt() *AnnouncementUpsertOne {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *AnnouncementUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AnnouncementCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AnnouncementUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *AnnouncementUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *AnnouncementUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// AnnouncementCreateBulk is the builder for creating many Announcement entities in bulk.
+type AnnouncementCreateBulk struct {
+ config
+ err error
+ builders []*AnnouncementCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the Announcement entities in the database.
+func (_c *AnnouncementCreateBulk) Save(ctx context.Context) ([]*Announcement, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*Announcement, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AnnouncementMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *AnnouncementCreateBulk) SaveX(ctx context.Context) []*Announcement {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AnnouncementCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AnnouncementCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.Announcement.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AnnouncementUpsert) {
+// SetTitle(v+v).
+// }).
+// Exec(ctx)
+func (_c *AnnouncementCreateBulk) OnConflict(opts ...sql.ConflictOption) *AnnouncementUpsertBulk {
+ _c.conflict = opts
+ return &AnnouncementUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.Announcement.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AnnouncementCreateBulk) OnConflictColumns(columns ...string) *AnnouncementUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AnnouncementUpsertBulk{
+ create: _c,
+ }
+}
+
+// AnnouncementUpsertBulk is the builder for "upsert"-ing
+// a bulk of Announcement nodes.
+type AnnouncementUpsertBulk struct {
+ create *AnnouncementCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.Announcement.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AnnouncementUpsertBulk) UpdateNewValues() *AnnouncementUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(announcement.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.Announcement.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AnnouncementUpsertBulk) Ignore() *AnnouncementUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AnnouncementUpsertBulk) DoNothing() *AnnouncementUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AnnouncementCreateBulk.OnConflict
+// documentation for more info.
+func (u *AnnouncementUpsertBulk) Update(set func(*AnnouncementUpsert)) *AnnouncementUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AnnouncementUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetTitle sets the "title" field.
+func (u *AnnouncementUpsertBulk) SetTitle(v string) *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetTitle(v)
+ })
+}
+
+// UpdateTitle sets the "title" field to the value that was provided on create.
+func (u *AnnouncementUpsertBulk) UpdateTitle() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateTitle()
+ })
+}
+
+// SetContent sets the "content" field.
+func (u *AnnouncementUpsertBulk) SetContent(v string) *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetContent(v)
+ })
+}
+
+// UpdateContent sets the "content" field to the value that was provided on create.
+func (u *AnnouncementUpsertBulk) UpdateContent() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateContent()
+ })
+}
+
+// SetStatus sets the "status" field.
+func (u *AnnouncementUpsertBulk) SetStatus(v string) *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetStatus(v)
+ })
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *AnnouncementUpsertBulk) UpdateStatus() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateStatus()
+ })
+}
+
+// SetTargeting sets the "targeting" field.
+func (u *AnnouncementUpsertBulk) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetTargeting(v)
+ })
+}
+
+// UpdateTargeting sets the "targeting" field to the value that was provided on create.
+func (u *AnnouncementUpsertBulk) UpdateTargeting() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateTargeting()
+ })
+}
+
+// ClearTargeting clears the value of the "targeting" field.
+func (u *AnnouncementUpsertBulk) ClearTargeting() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.ClearTargeting()
+ })
+}
+
+// SetStartsAt sets the "starts_at" field.
+func (u *AnnouncementUpsertBulk) SetStartsAt(v time.Time) *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetStartsAt(v)
+ })
+}
+
+// UpdateStartsAt sets the "starts_at" field to the value that was provided on create.
+func (u *AnnouncementUpsertBulk) UpdateStartsAt() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateStartsAt()
+ })
+}
+
+// ClearStartsAt clears the value of the "starts_at" field.
+func (u *AnnouncementUpsertBulk) ClearStartsAt() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.ClearStartsAt()
+ })
+}
+
+// SetEndsAt sets the "ends_at" field.
+func (u *AnnouncementUpsertBulk) SetEndsAt(v time.Time) *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetEndsAt(v)
+ })
+}
+
+// UpdateEndsAt sets the "ends_at" field to the value that was provided on create.
+func (u *AnnouncementUpsertBulk) UpdateEndsAt() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateEndsAt()
+ })
+}
+
+// ClearEndsAt clears the value of the "ends_at" field.
+func (u *AnnouncementUpsertBulk) ClearEndsAt() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.ClearEndsAt()
+ })
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (u *AnnouncementUpsertBulk) SetCreatedBy(v int64) *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetCreatedBy(v)
+ })
+}
+
+// AddCreatedBy adds v to the "created_by" field.
+func (u *AnnouncementUpsertBulk) AddCreatedBy(v int64) *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.AddCreatedBy(v)
+ })
+}
+
+// UpdateCreatedBy sets the "created_by" field to the value that was provided on create.
+func (u *AnnouncementUpsertBulk) UpdateCreatedBy() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateCreatedBy()
+ })
+}
+
+// ClearCreatedBy clears the value of the "created_by" field.
+func (u *AnnouncementUpsertBulk) ClearCreatedBy() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.ClearCreatedBy()
+ })
+}
+
+// SetUpdatedBy sets the "updated_by" field.
+func (u *AnnouncementUpsertBulk) SetUpdatedBy(v int64) *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetUpdatedBy(v)
+ })
+}
+
+// AddUpdatedBy adds v to the "updated_by" field.
+func (u *AnnouncementUpsertBulk) AddUpdatedBy(v int64) *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.AddUpdatedBy(v)
+ })
+}
+
+// UpdateUpdatedBy sets the "updated_by" field to the value that was provided on create.
+func (u *AnnouncementUpsertBulk) UpdateUpdatedBy() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateUpdatedBy()
+ })
+}
+
+// ClearUpdatedBy clears the value of the "updated_by" field.
+func (u *AnnouncementUpsertBulk) ClearUpdatedBy() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.ClearUpdatedBy()
+ })
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AnnouncementUpsertBulk) SetUpdatedAt(v time.Time) *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AnnouncementUpsertBulk) UpdateUpdatedAt() *AnnouncementUpsertBulk {
+ return u.Update(func(s *AnnouncementUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *AnnouncementUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AnnouncementCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AnnouncementCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AnnouncementUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/announcement_delete.go b/backend/ent/announcement_delete.go
new file mode 100644
index 00000000..d185e9f7
--- /dev/null
+++ b/backend/ent/announcement_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AnnouncementDelete is the builder for deleting a Announcement entity.
+type AnnouncementDelete struct {
+ config
+ hooks []Hook
+ mutation *AnnouncementMutation
+}
+
+// Where appends a list predicates to the AnnouncementDelete builder.
+func (_d *AnnouncementDelete) Where(ps ...predicate.Announcement) *AnnouncementDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *AnnouncementDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AnnouncementDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *AnnouncementDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(announcement.Table, sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// AnnouncementDeleteOne is the builder for deleting a single Announcement entity.
+type AnnouncementDeleteOne struct {
+ _d *AnnouncementDelete
+}
+
+// Where appends a list predicates to the AnnouncementDelete builder.
+func (_d *AnnouncementDeleteOne) Where(ps ...predicate.Announcement) *AnnouncementDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *AnnouncementDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{announcement.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AnnouncementDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/announcement_query.go b/backend/ent/announcement_query.go
new file mode 100644
index 00000000..a27d50fa
--- /dev/null
+++ b/backend/ent/announcement_query.go
@@ -0,0 +1,643 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AnnouncementQuery is the builder for querying Announcement entities.
+type AnnouncementQuery struct {
+ config
+ ctx *QueryContext
+ order []announcement.OrderOption
+ inters []Interceptor
+ predicates []predicate.Announcement
+ withReads *AnnouncementReadQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AnnouncementQuery builder.
+func (_q *AnnouncementQuery) Where(ps ...predicate.Announcement) *AnnouncementQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *AnnouncementQuery) Limit(limit int) *AnnouncementQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *AnnouncementQuery) Offset(offset int) *AnnouncementQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *AnnouncementQuery) Unique(unique bool) *AnnouncementQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *AnnouncementQuery) Order(o ...announcement.OrderOption) *AnnouncementQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryReads chains the current query on the "reads" edge.
+func (_q *AnnouncementQuery) QueryReads() *AnnouncementReadQuery {
+ query := (&AnnouncementReadClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(announcement.Table, announcement.FieldID, selector),
+ sqlgraph.To(announcementread.Table, announcementread.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, announcement.ReadsTable, announcement.ReadsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first Announcement entity from the query.
+// Returns a *NotFoundError when no Announcement was found.
+func (_q *AnnouncementQuery) First(ctx context.Context) (*Announcement, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{announcement.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *AnnouncementQuery) FirstX(ctx context.Context) *Announcement {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first Announcement ID from the query.
+// Returns a *NotFoundError when no Announcement ID was found.
+func (_q *AnnouncementQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{announcement.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *AnnouncementQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single Announcement entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one Announcement entity is found.
+// Returns a *NotFoundError when no Announcement entities are found.
+func (_q *AnnouncementQuery) Only(ctx context.Context) (*Announcement, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{announcement.Label}
+ default:
+ return nil, &NotSingularError{announcement.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *AnnouncementQuery) OnlyX(ctx context.Context) *Announcement {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only Announcement ID in the query.
+// Returns a *NotSingularError when more than one Announcement ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *AnnouncementQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{announcement.Label}
+ default:
+ err = &NotSingularError{announcement.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *AnnouncementQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of Announcements.
+func (_q *AnnouncementQuery) All(ctx context.Context) ([]*Announcement, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*Announcement, *AnnouncementQuery]()
+ return withInterceptors[[]*Announcement](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *AnnouncementQuery) AllX(ctx context.Context) []*Announcement {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of Announcement IDs.
+func (_q *AnnouncementQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(announcement.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *AnnouncementQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *AnnouncementQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*AnnouncementQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *AnnouncementQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *AnnouncementQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *AnnouncementQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AnnouncementQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *AnnouncementQuery) Clone() *AnnouncementQuery {
+ if _q == nil {
+ return nil
+ }
+ return &AnnouncementQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]announcement.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.Announcement{}, _q.predicates...),
+ withReads: _q.withReads.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithReads tells the query-builder to eager-load the nodes that are connected to
+// the "reads" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AnnouncementQuery) WithReads(opts ...func(*AnnouncementReadQuery)) *AnnouncementQuery {
+ query := (&AnnouncementReadClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withReads = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// Title string `json:"title,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.Announcement.Query().
+// GroupBy(announcement.FieldTitle).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *AnnouncementQuery) GroupBy(field string, fields ...string) *AnnouncementGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AnnouncementGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = announcement.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// Title string `json:"title,omitempty"`
+// }
+//
+// client.Announcement.Query().
+// Select(announcement.FieldTitle).
+// Scan(ctx, &v)
+func (_q *AnnouncementQuery) Select(fields ...string) *AnnouncementSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &AnnouncementSelect{AnnouncementQuery: _q}
+ sbuild.label = announcement.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AnnouncementSelect configured with the given aggregations.
+func (_q *AnnouncementQuery) Aggregate(fns ...AggregateFunc) *AnnouncementSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *AnnouncementQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !announcement.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *AnnouncementQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Announcement, error) {
+ var (
+ nodes = []*Announcement{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withReads != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*Announcement).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &Announcement{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withReads; query != nil {
+ if err := _q.loadReads(ctx, query, nodes,
+ func(n *Announcement) { n.Edges.Reads = []*AnnouncementRead{} },
+ func(n *Announcement, e *AnnouncementRead) { n.Edges.Reads = append(n.Edges.Reads, e) }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *AnnouncementQuery) loadReads(ctx context.Context, query *AnnouncementReadQuery, nodes []*Announcement, init func(*Announcement), assign func(*Announcement, *AnnouncementRead)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*Announcement)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(announcementread.FieldAnnouncementID)
+ }
+ query.Where(predicate.AnnouncementRead(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(announcement.ReadsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.AnnouncementID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "announcement_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *AnnouncementQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *AnnouncementQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(announcement.Table, announcement.Columns, sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, announcement.FieldID)
+ for i := range fields {
+ if fields[i] != announcement.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *AnnouncementQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(announcement.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = announcement.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *AnnouncementQuery) ForUpdate(opts ...sql.LockOption) *AnnouncementQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *AnnouncementQuery) ForShare(opts ...sql.LockOption) *AnnouncementQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// AnnouncementGroupBy is the group-by builder for Announcement entities.
+type AnnouncementGroupBy struct {
+ selector
+ build *AnnouncementQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *AnnouncementGroupBy) Aggregate(fns ...AggregateFunc) *AnnouncementGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *AnnouncementGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AnnouncementQuery, *AnnouncementGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *AnnouncementGroupBy) sqlScan(ctx context.Context, root *AnnouncementQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AnnouncementSelect is the builder for selecting fields of Announcement entities.
+type AnnouncementSelect struct {
+ *AnnouncementQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *AnnouncementSelect) Aggregate(fns ...AggregateFunc) *AnnouncementSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *AnnouncementSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AnnouncementQuery, *AnnouncementSelect](ctx, _s.AnnouncementQuery, _s, _s.inters, v)
+}
+
+func (_s *AnnouncementSelect) sqlScan(ctx context.Context, root *AnnouncementQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/announcement_update.go b/backend/ent/announcement_update.go
new file mode 100644
index 00000000..702d0817
--- /dev/null
+++ b/backend/ent/announcement_update.go
@@ -0,0 +1,824 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
+)
+
+// AnnouncementUpdate is the builder for updating Announcement entities.
+type AnnouncementUpdate struct {
+ config
+ hooks []Hook
+ mutation *AnnouncementMutation
+}
+
+// Where appends a list predicates to the AnnouncementUpdate builder.
+func (_u *AnnouncementUpdate) Where(ps ...predicate.Announcement) *AnnouncementUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetTitle sets the "title" field.
+func (_u *AnnouncementUpdate) SetTitle(v string) *AnnouncementUpdate {
+ _u.mutation.SetTitle(v)
+ return _u
+}
+
+// SetNillableTitle sets the "title" field if the given value is not nil.
+func (_u *AnnouncementUpdate) SetNillableTitle(v *string) *AnnouncementUpdate {
+ if v != nil {
+ _u.SetTitle(*v)
+ }
+ return _u
+}
+
+// SetContent sets the "content" field.
+func (_u *AnnouncementUpdate) SetContent(v string) *AnnouncementUpdate {
+ _u.mutation.SetContent(v)
+ return _u
+}
+
+// SetNillableContent sets the "content" field if the given value is not nil.
+func (_u *AnnouncementUpdate) SetNillableContent(v *string) *AnnouncementUpdate {
+ if v != nil {
+ _u.SetContent(*v)
+ }
+ return _u
+}
+
+// SetStatus sets the "status" field.
+func (_u *AnnouncementUpdate) SetStatus(v string) *AnnouncementUpdate {
+ _u.mutation.SetStatus(v)
+ return _u
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_u *AnnouncementUpdate) SetNillableStatus(v *string) *AnnouncementUpdate {
+ if v != nil {
+ _u.SetStatus(*v)
+ }
+ return _u
+}
+
+// SetTargeting sets the "targeting" field.
+func (_u *AnnouncementUpdate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdate {
+ _u.mutation.SetTargeting(v)
+ return _u
+}
+
+// SetNillableTargeting sets the "targeting" field if the given value is not nil.
+func (_u *AnnouncementUpdate) SetNillableTargeting(v *domain.AnnouncementTargeting) *AnnouncementUpdate {
+ if v != nil {
+ _u.SetTargeting(*v)
+ }
+ return _u
+}
+
+// ClearTargeting clears the value of the "targeting" field.
+func (_u *AnnouncementUpdate) ClearTargeting() *AnnouncementUpdate {
+ _u.mutation.ClearTargeting()
+ return _u
+}
+
+// SetStartsAt sets the "starts_at" field.
+func (_u *AnnouncementUpdate) SetStartsAt(v time.Time) *AnnouncementUpdate {
+ _u.mutation.SetStartsAt(v)
+ return _u
+}
+
+// SetNillableStartsAt sets the "starts_at" field if the given value is not nil.
+func (_u *AnnouncementUpdate) SetNillableStartsAt(v *time.Time) *AnnouncementUpdate {
+ if v != nil {
+ _u.SetStartsAt(*v)
+ }
+ return _u
+}
+
+// ClearStartsAt clears the value of the "starts_at" field.
+func (_u *AnnouncementUpdate) ClearStartsAt() *AnnouncementUpdate {
+ _u.mutation.ClearStartsAt()
+ return _u
+}
+
+// SetEndsAt sets the "ends_at" field.
+func (_u *AnnouncementUpdate) SetEndsAt(v time.Time) *AnnouncementUpdate {
+ _u.mutation.SetEndsAt(v)
+ return _u
+}
+
+// SetNillableEndsAt sets the "ends_at" field if the given value is not nil.
+func (_u *AnnouncementUpdate) SetNillableEndsAt(v *time.Time) *AnnouncementUpdate {
+ if v != nil {
+ _u.SetEndsAt(*v)
+ }
+ return _u
+}
+
+// ClearEndsAt clears the value of the "ends_at" field.
+func (_u *AnnouncementUpdate) ClearEndsAt() *AnnouncementUpdate {
+ _u.mutation.ClearEndsAt()
+ return _u
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (_u *AnnouncementUpdate) SetCreatedBy(v int64) *AnnouncementUpdate {
+ _u.mutation.ResetCreatedBy()
+ _u.mutation.SetCreatedBy(v)
+ return _u
+}
+
+// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
+func (_u *AnnouncementUpdate) SetNillableCreatedBy(v *int64) *AnnouncementUpdate {
+ if v != nil {
+ _u.SetCreatedBy(*v)
+ }
+ return _u
+}
+
+// AddCreatedBy adds value to the "created_by" field.
+func (_u *AnnouncementUpdate) AddCreatedBy(v int64) *AnnouncementUpdate {
+ _u.mutation.AddCreatedBy(v)
+ return _u
+}
+
+// ClearCreatedBy clears the value of the "created_by" field.
+func (_u *AnnouncementUpdate) ClearCreatedBy() *AnnouncementUpdate {
+ _u.mutation.ClearCreatedBy()
+ return _u
+}
+
+// SetUpdatedBy sets the "updated_by" field.
+func (_u *AnnouncementUpdate) SetUpdatedBy(v int64) *AnnouncementUpdate {
+ _u.mutation.ResetUpdatedBy()
+ _u.mutation.SetUpdatedBy(v)
+ return _u
+}
+
+// SetNillableUpdatedBy sets the "updated_by" field if the given value is not nil.
+func (_u *AnnouncementUpdate) SetNillableUpdatedBy(v *int64) *AnnouncementUpdate {
+ if v != nil {
+ _u.SetUpdatedBy(*v)
+ }
+ return _u
+}
+
+// AddUpdatedBy adds value to the "updated_by" field.
+func (_u *AnnouncementUpdate) AddUpdatedBy(v int64) *AnnouncementUpdate {
+ _u.mutation.AddUpdatedBy(v)
+ return _u
+}
+
+// ClearUpdatedBy clears the value of the "updated_by" field.
+func (_u *AnnouncementUpdate) ClearUpdatedBy() *AnnouncementUpdate {
+ _u.mutation.ClearUpdatedBy()
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AnnouncementUpdate) SetUpdatedAt(v time.Time) *AnnouncementUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// AddReadIDs adds the "reads" edge to the AnnouncementRead entity by IDs.
+func (_u *AnnouncementUpdate) AddReadIDs(ids ...int64) *AnnouncementUpdate {
+ _u.mutation.AddReadIDs(ids...)
+ return _u
+}
+
+// AddReads adds the "reads" edges to the AnnouncementRead entity.
+func (_u *AnnouncementUpdate) AddReads(v ...*AnnouncementRead) *AnnouncementUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddReadIDs(ids...)
+}
+
+// Mutation returns the AnnouncementMutation object of the builder.
+func (_u *AnnouncementUpdate) Mutation() *AnnouncementMutation {
+ return _u.mutation
+}
+
+// ClearReads clears all "reads" edges to the AnnouncementRead entity.
+func (_u *AnnouncementUpdate) ClearReads() *AnnouncementUpdate {
+ _u.mutation.ClearReads()
+ return _u
+}
+
+// RemoveReadIDs removes the "reads" edge to AnnouncementRead entities by IDs.
+func (_u *AnnouncementUpdate) RemoveReadIDs(ids ...int64) *AnnouncementUpdate {
+ _u.mutation.RemoveReadIDs(ids...)
+ return _u
+}
+
+// RemoveReads removes "reads" edges to AnnouncementRead entities.
+func (_u *AnnouncementUpdate) RemoveReads(v ...*AnnouncementRead) *AnnouncementUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveReadIDs(ids...)
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *AnnouncementUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AnnouncementUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *AnnouncementUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AnnouncementUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AnnouncementUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := announcement.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AnnouncementUpdate) check() error {
+ if v, ok := _u.mutation.Title(); ok {
+ if err := announcement.TitleValidator(v); err != nil {
+ return &ValidationError{Name: "title", err: fmt.Errorf(`ent: validator failed for field "Announcement.title": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Content(); ok {
+ if err := announcement.ContentValidator(v); err != nil {
+ return &ValidationError{Name: "content", err: fmt.Errorf(`ent: validator failed for field "Announcement.content": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Status(); ok {
+ if err := announcement.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *AnnouncementUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(announcement.Table, announcement.Columns, sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.Title(); ok {
+ _spec.SetField(announcement.FieldTitle, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Content(); ok {
+ _spec.SetField(announcement.FieldContent, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Status(); ok {
+ _spec.SetField(announcement.FieldStatus, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Targeting(); ok {
+ _spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
+ }
+ if _u.mutation.TargetingCleared() {
+ _spec.ClearField(announcement.FieldTargeting, field.TypeJSON)
+ }
+ if value, ok := _u.mutation.StartsAt(); ok {
+ _spec.SetField(announcement.FieldStartsAt, field.TypeTime, value)
+ }
+ if _u.mutation.StartsAtCleared() {
+ _spec.ClearField(announcement.FieldStartsAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.EndsAt(); ok {
+ _spec.SetField(announcement.FieldEndsAt, field.TypeTime, value)
+ }
+ if _u.mutation.EndsAtCleared() {
+ _spec.ClearField(announcement.FieldEndsAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.CreatedBy(); ok {
+ _spec.SetField(announcement.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedCreatedBy(); ok {
+ _spec.AddField(announcement.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if _u.mutation.CreatedByCleared() {
+ _spec.ClearField(announcement.FieldCreatedBy, field.TypeInt64)
+ }
+ if value, ok := _u.mutation.UpdatedBy(); ok {
+ _spec.SetField(announcement.FieldUpdatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedUpdatedBy(); ok {
+ _spec.AddField(announcement.FieldUpdatedBy, field.TypeInt64, value)
+ }
+ if _u.mutation.UpdatedByCleared() {
+ _spec.ClearField(announcement.FieldUpdatedBy, field.TypeInt64)
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(announcement.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if _u.mutation.ReadsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: announcement.ReadsTable,
+ Columns: []string{announcement.ReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedReadsIDs(); len(nodes) > 0 && !_u.mutation.ReadsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: announcement.ReadsTable,
+ Columns: []string{announcement.ReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.ReadsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: announcement.ReadsTable,
+ Columns: []string{announcement.ReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{announcement.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// AnnouncementUpdateOne is the builder for updating a single Announcement entity.
+type AnnouncementUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AnnouncementMutation
+}
+
+// SetTitle sets the "title" field.
+func (_u *AnnouncementUpdateOne) SetTitle(v string) *AnnouncementUpdateOne {
+ _u.mutation.SetTitle(v)
+ return _u
+}
+
+// SetNillableTitle sets the "title" field if the given value is not nil.
+func (_u *AnnouncementUpdateOne) SetNillableTitle(v *string) *AnnouncementUpdateOne {
+ if v != nil {
+ _u.SetTitle(*v)
+ }
+ return _u
+}
+
+// SetContent sets the "content" field.
+func (_u *AnnouncementUpdateOne) SetContent(v string) *AnnouncementUpdateOne {
+ _u.mutation.SetContent(v)
+ return _u
+}
+
+// SetNillableContent sets the "content" field if the given value is not nil.
+func (_u *AnnouncementUpdateOne) SetNillableContent(v *string) *AnnouncementUpdateOne {
+ if v != nil {
+ _u.SetContent(*v)
+ }
+ return _u
+}
+
+// SetStatus sets the "status" field.
+func (_u *AnnouncementUpdateOne) SetStatus(v string) *AnnouncementUpdateOne {
+ _u.mutation.SetStatus(v)
+ return _u
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_u *AnnouncementUpdateOne) SetNillableStatus(v *string) *AnnouncementUpdateOne {
+ if v != nil {
+ _u.SetStatus(*v)
+ }
+ return _u
+}
+
+// SetTargeting sets the "targeting" field.
+func (_u *AnnouncementUpdateOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdateOne {
+ _u.mutation.SetTargeting(v)
+ return _u
+}
+
+// SetNillableTargeting sets the "targeting" field if the given value is not nil.
+func (_u *AnnouncementUpdateOne) SetNillableTargeting(v *domain.AnnouncementTargeting) *AnnouncementUpdateOne {
+ if v != nil {
+ _u.SetTargeting(*v)
+ }
+ return _u
+}
+
+// ClearTargeting clears the value of the "targeting" field.
+func (_u *AnnouncementUpdateOne) ClearTargeting() *AnnouncementUpdateOne {
+ _u.mutation.ClearTargeting()
+ return _u
+}
+
+// SetStartsAt sets the "starts_at" field.
+func (_u *AnnouncementUpdateOne) SetStartsAt(v time.Time) *AnnouncementUpdateOne {
+ _u.mutation.SetStartsAt(v)
+ return _u
+}
+
+// SetNillableStartsAt sets the "starts_at" field if the given value is not nil.
+func (_u *AnnouncementUpdateOne) SetNillableStartsAt(v *time.Time) *AnnouncementUpdateOne {
+ if v != nil {
+ _u.SetStartsAt(*v)
+ }
+ return _u
+}
+
+// ClearStartsAt clears the value of the "starts_at" field.
+func (_u *AnnouncementUpdateOne) ClearStartsAt() *AnnouncementUpdateOne {
+ _u.mutation.ClearStartsAt()
+ return _u
+}
+
+// SetEndsAt sets the "ends_at" field.
+func (_u *AnnouncementUpdateOne) SetEndsAt(v time.Time) *AnnouncementUpdateOne {
+ _u.mutation.SetEndsAt(v)
+ return _u
+}
+
+// SetNillableEndsAt sets the "ends_at" field if the given value is not nil.
+func (_u *AnnouncementUpdateOne) SetNillableEndsAt(v *time.Time) *AnnouncementUpdateOne {
+ if v != nil {
+ _u.SetEndsAt(*v)
+ }
+ return _u
+}
+
+// ClearEndsAt clears the value of the "ends_at" field.
+func (_u *AnnouncementUpdateOne) ClearEndsAt() *AnnouncementUpdateOne {
+ _u.mutation.ClearEndsAt()
+ return _u
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (_u *AnnouncementUpdateOne) SetCreatedBy(v int64) *AnnouncementUpdateOne {
+ _u.mutation.ResetCreatedBy()
+ _u.mutation.SetCreatedBy(v)
+ return _u
+}
+
+// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
+func (_u *AnnouncementUpdateOne) SetNillableCreatedBy(v *int64) *AnnouncementUpdateOne {
+ if v != nil {
+ _u.SetCreatedBy(*v)
+ }
+ return _u
+}
+
+// AddCreatedBy adds value to the "created_by" field.
+func (_u *AnnouncementUpdateOne) AddCreatedBy(v int64) *AnnouncementUpdateOne {
+ _u.mutation.AddCreatedBy(v)
+ return _u
+}
+
+// ClearCreatedBy clears the value of the "created_by" field.
+func (_u *AnnouncementUpdateOne) ClearCreatedBy() *AnnouncementUpdateOne {
+ _u.mutation.ClearCreatedBy()
+ return _u
+}
+
+// SetUpdatedBy sets the "updated_by" field.
+func (_u *AnnouncementUpdateOne) SetUpdatedBy(v int64) *AnnouncementUpdateOne {
+ _u.mutation.ResetUpdatedBy()
+ _u.mutation.SetUpdatedBy(v)
+ return _u
+}
+
+// SetNillableUpdatedBy sets the "updated_by" field if the given value is not nil.
+func (_u *AnnouncementUpdateOne) SetNillableUpdatedBy(v *int64) *AnnouncementUpdateOne {
+ if v != nil {
+ _u.SetUpdatedBy(*v)
+ }
+ return _u
+}
+
+// AddUpdatedBy adds value to the "updated_by" field.
+func (_u *AnnouncementUpdateOne) AddUpdatedBy(v int64) *AnnouncementUpdateOne {
+ _u.mutation.AddUpdatedBy(v)
+ return _u
+}
+
+// ClearUpdatedBy clears the value of the "updated_by" field.
+func (_u *AnnouncementUpdateOne) ClearUpdatedBy() *AnnouncementUpdateOne {
+ _u.mutation.ClearUpdatedBy()
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AnnouncementUpdateOne) SetUpdatedAt(v time.Time) *AnnouncementUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// AddReadIDs adds the "reads" edge to the AnnouncementRead entity by IDs.
+func (_u *AnnouncementUpdateOne) AddReadIDs(ids ...int64) *AnnouncementUpdateOne {
+ _u.mutation.AddReadIDs(ids...)
+ return _u
+}
+
+// AddReads adds the "reads" edges to the AnnouncementRead entity.
+func (_u *AnnouncementUpdateOne) AddReads(v ...*AnnouncementRead) *AnnouncementUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddReadIDs(ids...)
+}
+
+// Mutation returns the AnnouncementMutation object of the builder.
+func (_u *AnnouncementUpdateOne) Mutation() *AnnouncementMutation {
+ return _u.mutation
+}
+
+// ClearReads clears all "reads" edges to the AnnouncementRead entity.
+func (_u *AnnouncementUpdateOne) ClearReads() *AnnouncementUpdateOne {
+ _u.mutation.ClearReads()
+ return _u
+}
+
+// RemoveReadIDs removes the "reads" edge to AnnouncementRead entities by IDs.
+func (_u *AnnouncementUpdateOne) RemoveReadIDs(ids ...int64) *AnnouncementUpdateOne {
+ _u.mutation.RemoveReadIDs(ids...)
+ return _u
+}
+
+// RemoveReads removes "reads" edges to AnnouncementRead entities.
+func (_u *AnnouncementUpdateOne) RemoveReads(v ...*AnnouncementRead) *AnnouncementUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveReadIDs(ids...)
+}
+
+// Where appends a list predicates to the AnnouncementUpdate builder.
+func (_u *AnnouncementUpdateOne) Where(ps ...predicate.Announcement) *AnnouncementUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *AnnouncementUpdateOne) Select(field string, fields ...string) *AnnouncementUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated Announcement entity.
+func (_u *AnnouncementUpdateOne) Save(ctx context.Context) (*Announcement, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AnnouncementUpdateOne) SaveX(ctx context.Context) *Announcement {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *AnnouncementUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AnnouncementUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AnnouncementUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := announcement.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AnnouncementUpdateOne) check() error {
+ if v, ok := _u.mutation.Title(); ok {
+ if err := announcement.TitleValidator(v); err != nil {
+ return &ValidationError{Name: "title", err: fmt.Errorf(`ent: validator failed for field "Announcement.title": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Content(); ok {
+ if err := announcement.ContentValidator(v); err != nil {
+ return &ValidationError{Name: "content", err: fmt.Errorf(`ent: validator failed for field "Announcement.content": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Status(); ok {
+ if err := announcement.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *AnnouncementUpdateOne) sqlSave(ctx context.Context) (_node *Announcement, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(announcement.Table, announcement.Columns, sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Announcement.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, announcement.FieldID)
+ for _, f := range fields {
+ if !announcement.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != announcement.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.Title(); ok {
+ _spec.SetField(announcement.FieldTitle, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Content(); ok {
+ _spec.SetField(announcement.FieldContent, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Status(); ok {
+ _spec.SetField(announcement.FieldStatus, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Targeting(); ok {
+ _spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
+ }
+ if _u.mutation.TargetingCleared() {
+ _spec.ClearField(announcement.FieldTargeting, field.TypeJSON)
+ }
+ if value, ok := _u.mutation.StartsAt(); ok {
+ _spec.SetField(announcement.FieldStartsAt, field.TypeTime, value)
+ }
+ if _u.mutation.StartsAtCleared() {
+ _spec.ClearField(announcement.FieldStartsAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.EndsAt(); ok {
+ _spec.SetField(announcement.FieldEndsAt, field.TypeTime, value)
+ }
+ if _u.mutation.EndsAtCleared() {
+ _spec.ClearField(announcement.FieldEndsAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.CreatedBy(); ok {
+ _spec.SetField(announcement.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedCreatedBy(); ok {
+ _spec.AddField(announcement.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if _u.mutation.CreatedByCleared() {
+ _spec.ClearField(announcement.FieldCreatedBy, field.TypeInt64)
+ }
+ if value, ok := _u.mutation.UpdatedBy(); ok {
+ _spec.SetField(announcement.FieldUpdatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedUpdatedBy(); ok {
+ _spec.AddField(announcement.FieldUpdatedBy, field.TypeInt64, value)
+ }
+ if _u.mutation.UpdatedByCleared() {
+ _spec.ClearField(announcement.FieldUpdatedBy, field.TypeInt64)
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(announcement.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if _u.mutation.ReadsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: announcement.ReadsTable,
+ Columns: []string{announcement.ReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedReadsIDs(); len(nodes) > 0 && !_u.mutation.ReadsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: announcement.ReadsTable,
+ Columns: []string{announcement.ReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.ReadsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: announcement.ReadsTable,
+ Columns: []string{announcement.ReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &Announcement{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{announcement.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/announcementread.go b/backend/ent/announcementread.go
new file mode 100644
index 00000000..7bba04f2
--- /dev/null
+++ b/backend/ent/announcementread.go
@@ -0,0 +1,185 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AnnouncementRead is the model entity for the AnnouncementRead schema.
+type AnnouncementRead struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // AnnouncementID holds the value of the "announcement_id" field.
+ AnnouncementID int64 `json:"announcement_id,omitempty"`
+ // UserID holds the value of the "user_id" field.
+ UserID int64 `json:"user_id,omitempty"`
+ // 用户首次已读时间
+ ReadAt time.Time `json:"read_at,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AnnouncementReadQuery when eager-loading is set.
+ Edges AnnouncementReadEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AnnouncementReadEdges holds the relations/edges for other nodes in the graph.
+type AnnouncementReadEdges struct {
+ // Announcement holds the value of the announcement edge.
+ Announcement *Announcement `json:"announcement,omitempty"`
+ // User holds the value of the user edge.
+ User *User `json:"user,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [2]bool
+}
+
+// AnnouncementOrErr returns the Announcement value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e AnnouncementReadEdges) AnnouncementOrErr() (*Announcement, error) {
+ if e.Announcement != nil {
+ return e.Announcement, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: announcement.Label}
+ }
+ return nil, &NotLoadedError{edge: "announcement"}
+}
+
+// UserOrErr returns the User value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e AnnouncementReadEdges) UserOrErr() (*User, error) {
+ if e.User != nil {
+ return e.User, nil
+ } else if e.loadedTypes[1] {
+ return nil, &NotFoundError{label: user.Label}
+ }
+ return nil, &NotLoadedError{edge: "user"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*AnnouncementRead) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case announcementread.FieldID, announcementread.FieldAnnouncementID, announcementread.FieldUserID:
+ values[i] = new(sql.NullInt64)
+ case announcementread.FieldReadAt, announcementread.FieldCreatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the AnnouncementRead fields.
+func (_m *AnnouncementRead) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case announcementread.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case announcementread.FieldAnnouncementID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field announcement_id", values[i])
+ } else if value.Valid {
+ _m.AnnouncementID = value.Int64
+ }
+ case announcementread.FieldUserID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field user_id", values[i])
+ } else if value.Valid {
+ _m.UserID = value.Int64
+ }
+ case announcementread.FieldReadAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field read_at", values[i])
+ } else if value.Valid {
+ _m.ReadAt = value.Time
+ }
+ case announcementread.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the AnnouncementRead.
+// This includes values selected through modifiers, order, etc.
+func (_m *AnnouncementRead) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryAnnouncement queries the "announcement" edge of the AnnouncementRead entity.
+func (_m *AnnouncementRead) QueryAnnouncement() *AnnouncementQuery {
+ return NewAnnouncementReadClient(_m.config).QueryAnnouncement(_m)
+}
+
+// QueryUser queries the "user" edge of the AnnouncementRead entity.
+func (_m *AnnouncementRead) QueryUser() *UserQuery {
+ return NewAnnouncementReadClient(_m.config).QueryUser(_m)
+}
+
+// Update returns a builder for updating this AnnouncementRead.
+// Note that you need to call AnnouncementRead.Unwrap() before calling this method if this AnnouncementRead
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *AnnouncementRead) Update() *AnnouncementReadUpdateOne {
+ return NewAnnouncementReadClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the AnnouncementRead entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *AnnouncementRead) Unwrap() *AnnouncementRead {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: AnnouncementRead is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *AnnouncementRead) String() string {
+ var builder strings.Builder
+ builder.WriteString("AnnouncementRead(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("announcement_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AnnouncementID))
+ builder.WriteString(", ")
+ builder.WriteString("user_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.UserID))
+ builder.WriteString(", ")
+ builder.WriteString("read_at=")
+ builder.WriteString(_m.ReadAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// AnnouncementReads is a parsable slice of AnnouncementRead.
+type AnnouncementReads []*AnnouncementRead
diff --git a/backend/ent/announcementread/announcementread.go b/backend/ent/announcementread/announcementread.go
new file mode 100644
index 00000000..cf5fe458
--- /dev/null
+++ b/backend/ent/announcementread/announcementread.go
@@ -0,0 +1,127 @@
+// Code generated by ent, DO NOT EDIT.
+
+package announcementread
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the announcementread type in the database.
+ Label = "announcement_read"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldAnnouncementID holds the string denoting the announcement_id field in the database.
+ FieldAnnouncementID = "announcement_id"
+ // FieldUserID holds the string denoting the user_id field in the database.
+ FieldUserID = "user_id"
+ // FieldReadAt holds the string denoting the read_at field in the database.
+ FieldReadAt = "read_at"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // EdgeAnnouncement holds the string denoting the announcement edge name in mutations.
+ EdgeAnnouncement = "announcement"
+ // EdgeUser holds the string denoting the user edge name in mutations.
+ EdgeUser = "user"
+ // Table holds the table name of the announcementread in the database.
+ Table = "announcement_reads"
+ // AnnouncementTable is the table that holds the announcement relation/edge.
+ AnnouncementTable = "announcement_reads"
+ // AnnouncementInverseTable is the table name for the Announcement entity.
+ // It exists in this package in order to avoid circular dependency with the "announcement" package.
+ AnnouncementInverseTable = "announcements"
+ // AnnouncementColumn is the table column denoting the announcement relation/edge.
+ AnnouncementColumn = "announcement_id"
+ // UserTable is the table that holds the user relation/edge.
+ UserTable = "announcement_reads"
+ // UserInverseTable is the table name for the User entity.
+ // It exists in this package in order to avoid circular dependency with the "user" package.
+ UserInverseTable = "users"
+ // UserColumn is the table column denoting the user relation/edge.
+ UserColumn = "user_id"
+)
+
+// Columns holds all SQL columns for announcementread fields.
+var Columns = []string{
+ FieldID,
+ FieldAnnouncementID,
+ FieldUserID,
+ FieldReadAt,
+ FieldCreatedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultReadAt holds the default value on creation for the "read_at" field.
+ DefaultReadAt func() time.Time
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the AnnouncementRead queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByAnnouncementID orders the results by the announcement_id field.
+func ByAnnouncementID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAnnouncementID, opts...).ToFunc()
+}
+
+// ByUserID orders the results by the user_id field.
+func ByUserID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUserID, opts...).ToFunc()
+}
+
+// ByReadAt orders the results by the read_at field.
+func ByReadAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldReadAt, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByAnnouncementField orders the results by announcement field.
+func ByAnnouncementField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAnnouncementStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByUserField orders the results by user field.
+func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newAnnouncementStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AnnouncementInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, AnnouncementTable, AnnouncementColumn),
+ )
+}
+func newUserStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(UserInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+}
diff --git a/backend/ent/announcementread/where.go b/backend/ent/announcementread/where.go
new file mode 100644
index 00000000..1a4305e8
--- /dev/null
+++ b/backend/ent/announcementread/where.go
@@ -0,0 +1,257 @@
+// Code generated by ent, DO NOT EDIT.
+
+package announcementread
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldLTE(FieldID, id))
+}
+
+// AnnouncementID applies equality check predicate on the "announcement_id" field. It's identical to AnnouncementIDEQ.
+func AnnouncementID(v int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldEQ(FieldAnnouncementID, v))
+}
+
+// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
+func UserID(v int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldEQ(FieldUserID, v))
+}
+
+// ReadAt applies equality check predicate on the "read_at" field. It's identical to ReadAtEQ.
+func ReadAt(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldEQ(FieldReadAt, v))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// AnnouncementIDEQ applies the EQ predicate on the "announcement_id" field.
+func AnnouncementIDEQ(v int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldEQ(FieldAnnouncementID, v))
+}
+
+// AnnouncementIDNEQ applies the NEQ predicate on the "announcement_id" field.
+func AnnouncementIDNEQ(v int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldNEQ(FieldAnnouncementID, v))
+}
+
+// AnnouncementIDIn applies the In predicate on the "announcement_id" field.
+func AnnouncementIDIn(vs ...int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldIn(FieldAnnouncementID, vs...))
+}
+
+// AnnouncementIDNotIn applies the NotIn predicate on the "announcement_id" field.
+func AnnouncementIDNotIn(vs ...int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldNotIn(FieldAnnouncementID, vs...))
+}
+
+// UserIDEQ applies the EQ predicate on the "user_id" field.
+func UserIDEQ(v int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldEQ(FieldUserID, v))
+}
+
+// UserIDNEQ applies the NEQ predicate on the "user_id" field.
+func UserIDNEQ(v int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldNEQ(FieldUserID, v))
+}
+
+// UserIDIn applies the In predicate on the "user_id" field.
+func UserIDIn(vs ...int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldIn(FieldUserID, vs...))
+}
+
+// UserIDNotIn applies the NotIn predicate on the "user_id" field.
+func UserIDNotIn(vs ...int64) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldNotIn(FieldUserID, vs...))
+}
+
+// ReadAtEQ applies the EQ predicate on the "read_at" field.
+func ReadAtEQ(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldEQ(FieldReadAt, v))
+}
+
+// ReadAtNEQ applies the NEQ predicate on the "read_at" field.
+func ReadAtNEQ(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldNEQ(FieldReadAt, v))
+}
+
+// ReadAtIn applies the In predicate on the "read_at" field.
+func ReadAtIn(vs ...time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldIn(FieldReadAt, vs...))
+}
+
+// ReadAtNotIn applies the NotIn predicate on the "read_at" field.
+func ReadAtNotIn(vs ...time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldNotIn(FieldReadAt, vs...))
+}
+
+// ReadAtGT applies the GT predicate on the "read_at" field.
+func ReadAtGT(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldGT(FieldReadAt, v))
+}
+
+// ReadAtGTE applies the GTE predicate on the "read_at" field.
+func ReadAtGTE(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldGTE(FieldReadAt, v))
+}
+
+// ReadAtLT applies the LT predicate on the "read_at" field.
+func ReadAtLT(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldLT(FieldReadAt, v))
+}
+
+// ReadAtLTE applies the LTE predicate on the "read_at" field.
+func ReadAtLTE(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldLTE(FieldReadAt, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// HasAnnouncement applies the HasEdge predicate on the "announcement" edge.
+func HasAnnouncement() predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, AnnouncementTable, AnnouncementColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAnnouncementWith applies the HasEdge predicate on the "announcement" edge with a given conditions (other predicates).
+func HasAnnouncementWith(preds ...predicate.Announcement) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(func(s *sql.Selector) {
+ step := newAnnouncementStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasUser applies the HasEdge predicate on the "user" edge.
+func HasUser() predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
+func HasUserWith(preds ...predicate.User) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(func(s *sql.Selector) {
+ step := newUserStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.AnnouncementRead) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.AnnouncementRead) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.AnnouncementRead) predicate.AnnouncementRead {
+ return predicate.AnnouncementRead(sql.NotPredicates(p))
+}
diff --git a/backend/ent/announcementread_create.go b/backend/ent/announcementread_create.go
new file mode 100644
index 00000000..c8c211ff
--- /dev/null
+++ b/backend/ent/announcementread_create.go
@@ -0,0 +1,660 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AnnouncementReadCreate is the builder for creating a AnnouncementRead entity.
+type AnnouncementReadCreate struct {
+ config
+ mutation *AnnouncementReadMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetAnnouncementID sets the "announcement_id" field.
+func (_c *AnnouncementReadCreate) SetAnnouncementID(v int64) *AnnouncementReadCreate {
+ _c.mutation.SetAnnouncementID(v)
+ return _c
+}
+
+// SetUserID sets the "user_id" field.
+func (_c *AnnouncementReadCreate) SetUserID(v int64) *AnnouncementReadCreate {
+ _c.mutation.SetUserID(v)
+ return _c
+}
+
+// SetReadAt sets the "read_at" field.
+func (_c *AnnouncementReadCreate) SetReadAt(v time.Time) *AnnouncementReadCreate {
+ _c.mutation.SetReadAt(v)
+ return _c
+}
+
+// SetNillableReadAt sets the "read_at" field if the given value is not nil.
+func (_c *AnnouncementReadCreate) SetNillableReadAt(v *time.Time) *AnnouncementReadCreate {
+ if v != nil {
+ _c.SetReadAt(*v)
+ }
+ return _c
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *AnnouncementReadCreate) SetCreatedAt(v time.Time) *AnnouncementReadCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *AnnouncementReadCreate) SetNillableCreatedAt(v *time.Time) *AnnouncementReadCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetAnnouncement sets the "announcement" edge to the Announcement entity.
+func (_c *AnnouncementReadCreate) SetAnnouncement(v *Announcement) *AnnouncementReadCreate {
+ return _c.SetAnnouncementID(v.ID)
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_c *AnnouncementReadCreate) SetUser(v *User) *AnnouncementReadCreate {
+ return _c.SetUserID(v.ID)
+}
+
+// Mutation returns the AnnouncementReadMutation object of the builder.
+func (_c *AnnouncementReadCreate) Mutation() *AnnouncementReadMutation {
+ return _c.mutation
+}
+
+// Save creates the AnnouncementRead in the database.
+func (_c *AnnouncementReadCreate) Save(ctx context.Context) (*AnnouncementRead, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *AnnouncementReadCreate) SaveX(ctx context.Context) *AnnouncementRead {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AnnouncementReadCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AnnouncementReadCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *AnnouncementReadCreate) defaults() {
+ if _, ok := _c.mutation.ReadAt(); !ok {
+ v := announcementread.DefaultReadAt()
+ _c.mutation.SetReadAt(v)
+ }
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := announcementread.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *AnnouncementReadCreate) check() error {
+ if _, ok := _c.mutation.AnnouncementID(); !ok {
+ return &ValidationError{Name: "announcement_id", err: errors.New(`ent: missing required field "AnnouncementRead.announcement_id"`)}
+ }
+ if _, ok := _c.mutation.UserID(); !ok {
+ return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "AnnouncementRead.user_id"`)}
+ }
+ if _, ok := _c.mutation.ReadAt(); !ok {
+ return &ValidationError{Name: "read_at", err: errors.New(`ent: missing required field "AnnouncementRead.read_at"`)}
+ }
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AnnouncementRead.created_at"`)}
+ }
+ if len(_c.mutation.AnnouncementIDs()) == 0 {
+ return &ValidationError{Name: "announcement", err: errors.New(`ent: missing required edge "AnnouncementRead.announcement"`)}
+ }
+ if len(_c.mutation.UserIDs()) == 0 {
+ return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "AnnouncementRead.user"`)}
+ }
+ return nil
+}
+
+func (_c *AnnouncementReadCreate) sqlSave(ctx context.Context) (*AnnouncementRead, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *AnnouncementReadCreate) createSpec() (*AnnouncementRead, *sqlgraph.CreateSpec) {
+ var (
+ _node = &AnnouncementRead{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(announcementread.Table, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.ReadAt(); ok {
+ _spec.SetField(announcementread.FieldReadAt, field.TypeTime, value)
+ _node.ReadAt = value
+ }
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(announcementread.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if nodes := _c.mutation.AnnouncementIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: announcementread.AnnouncementTable,
+ Columns: []string{announcementread.AnnouncementColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.AnnouncementID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: announcementread.UserTable,
+ Columns: []string{announcementread.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.UserID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AnnouncementRead.Create().
+// SetAnnouncementID(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AnnouncementReadUpsert) {
+// SetAnnouncementID(v+v).
+// }).
+// Exec(ctx)
+func (_c *AnnouncementReadCreate) OnConflict(opts ...sql.ConflictOption) *AnnouncementReadUpsertOne {
+ _c.conflict = opts
+ return &AnnouncementReadUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AnnouncementRead.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AnnouncementReadCreate) OnConflictColumns(columns ...string) *AnnouncementReadUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AnnouncementReadUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // AnnouncementReadUpsertOne is the builder for "upsert"-ing
+ // one AnnouncementRead node.
+ AnnouncementReadUpsertOne struct {
+ create *AnnouncementReadCreate
+ }
+
+ // AnnouncementReadUpsert is the "OnConflict" setter.
+ AnnouncementReadUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetAnnouncementID sets the "announcement_id" field.
+func (u *AnnouncementReadUpsert) SetAnnouncementID(v int64) *AnnouncementReadUpsert {
+ u.Set(announcementread.FieldAnnouncementID, v)
+ return u
+}
+
+// UpdateAnnouncementID sets the "announcement_id" field to the value that was provided on create.
+func (u *AnnouncementReadUpsert) UpdateAnnouncementID() *AnnouncementReadUpsert {
+ u.SetExcluded(announcementread.FieldAnnouncementID)
+ return u
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AnnouncementReadUpsert) SetUserID(v int64) *AnnouncementReadUpsert {
+ u.Set(announcementread.FieldUserID, v)
+ return u
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AnnouncementReadUpsert) UpdateUserID() *AnnouncementReadUpsert {
+ u.SetExcluded(announcementread.FieldUserID)
+ return u
+}
+
+// SetReadAt sets the "read_at" field.
+func (u *AnnouncementReadUpsert) SetReadAt(v time.Time) *AnnouncementReadUpsert {
+ u.Set(announcementread.FieldReadAt, v)
+ return u
+}
+
+// UpdateReadAt sets the "read_at" field to the value that was provided on create.
+func (u *AnnouncementReadUpsert) UpdateReadAt() *AnnouncementReadUpsert {
+ u.SetExcluded(announcementread.FieldReadAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.AnnouncementRead.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AnnouncementReadUpsertOne) UpdateNewValues() *AnnouncementReadUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(announcementread.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AnnouncementRead.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AnnouncementReadUpsertOne) Ignore() *AnnouncementReadUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AnnouncementReadUpsertOne) DoNothing() *AnnouncementReadUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AnnouncementReadCreate.OnConflict
+// documentation for more info.
+func (u *AnnouncementReadUpsertOne) Update(set func(*AnnouncementReadUpsert)) *AnnouncementReadUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AnnouncementReadUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetAnnouncementID sets the "announcement_id" field.
+func (u *AnnouncementReadUpsertOne) SetAnnouncementID(v int64) *AnnouncementReadUpsertOne {
+ return u.Update(func(s *AnnouncementReadUpsert) {
+ s.SetAnnouncementID(v)
+ })
+}
+
+// UpdateAnnouncementID sets the "announcement_id" field to the value that was provided on create.
+func (u *AnnouncementReadUpsertOne) UpdateAnnouncementID() *AnnouncementReadUpsertOne {
+ return u.Update(func(s *AnnouncementReadUpsert) {
+ s.UpdateAnnouncementID()
+ })
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AnnouncementReadUpsertOne) SetUserID(v int64) *AnnouncementReadUpsertOne {
+ return u.Update(func(s *AnnouncementReadUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AnnouncementReadUpsertOne) UpdateUserID() *AnnouncementReadUpsertOne {
+ return u.Update(func(s *AnnouncementReadUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetReadAt sets the "read_at" field.
+func (u *AnnouncementReadUpsertOne) SetReadAt(v time.Time) *AnnouncementReadUpsertOne {
+ return u.Update(func(s *AnnouncementReadUpsert) {
+ s.SetReadAt(v)
+ })
+}
+
+// UpdateReadAt sets the "read_at" field to the value that was provided on create.
+func (u *AnnouncementReadUpsertOne) UpdateReadAt() *AnnouncementReadUpsertOne {
+ return u.Update(func(s *AnnouncementReadUpsert) {
+ s.UpdateReadAt()
+ })
+}
+
+// Exec executes the query.
+func (u *AnnouncementReadUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AnnouncementReadCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AnnouncementReadUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *AnnouncementReadUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *AnnouncementReadUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// AnnouncementReadCreateBulk is the builder for creating many AnnouncementRead entities in bulk.
+type AnnouncementReadCreateBulk struct {
+ config
+ err error
+ builders []*AnnouncementReadCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the AnnouncementRead entities in the database.
+func (_c *AnnouncementReadCreateBulk) Save(ctx context.Context) ([]*AnnouncementRead, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*AnnouncementRead, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AnnouncementReadMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *AnnouncementReadCreateBulk) SaveX(ctx context.Context) []*AnnouncementRead {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AnnouncementReadCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AnnouncementReadCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AnnouncementRead.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AnnouncementReadUpsert) {
+// SetAnnouncementID(v+v).
+// }).
+// Exec(ctx)
+func (_c *AnnouncementReadCreateBulk) OnConflict(opts ...sql.ConflictOption) *AnnouncementReadUpsertBulk {
+ _c.conflict = opts
+ return &AnnouncementReadUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AnnouncementRead.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AnnouncementReadCreateBulk) OnConflictColumns(columns ...string) *AnnouncementReadUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AnnouncementReadUpsertBulk{
+ create: _c,
+ }
+}
+
+// AnnouncementReadUpsertBulk is the builder for "upsert"-ing
+// a bulk of AnnouncementRead nodes.
+type AnnouncementReadUpsertBulk struct {
+ create *AnnouncementReadCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.AnnouncementRead.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AnnouncementReadUpsertBulk) UpdateNewValues() *AnnouncementReadUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(announcementread.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AnnouncementRead.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AnnouncementReadUpsertBulk) Ignore() *AnnouncementReadUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AnnouncementReadUpsertBulk) DoNothing() *AnnouncementReadUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AnnouncementReadCreateBulk.OnConflict
+// documentation for more info.
+func (u *AnnouncementReadUpsertBulk) Update(set func(*AnnouncementReadUpsert)) *AnnouncementReadUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AnnouncementReadUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetAnnouncementID sets the "announcement_id" field.
+func (u *AnnouncementReadUpsertBulk) SetAnnouncementID(v int64) *AnnouncementReadUpsertBulk {
+ return u.Update(func(s *AnnouncementReadUpsert) {
+ s.SetAnnouncementID(v)
+ })
+}
+
+// UpdateAnnouncementID sets the "announcement_id" field to the value that was provided on create.
+func (u *AnnouncementReadUpsertBulk) UpdateAnnouncementID() *AnnouncementReadUpsertBulk {
+ return u.Update(func(s *AnnouncementReadUpsert) {
+ s.UpdateAnnouncementID()
+ })
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AnnouncementReadUpsertBulk) SetUserID(v int64) *AnnouncementReadUpsertBulk {
+ return u.Update(func(s *AnnouncementReadUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AnnouncementReadUpsertBulk) UpdateUserID() *AnnouncementReadUpsertBulk {
+ return u.Update(func(s *AnnouncementReadUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetReadAt sets the "read_at" field.
+func (u *AnnouncementReadUpsertBulk) SetReadAt(v time.Time) *AnnouncementReadUpsertBulk {
+ return u.Update(func(s *AnnouncementReadUpsert) {
+ s.SetReadAt(v)
+ })
+}
+
+// UpdateReadAt sets the "read_at" field to the value that was provided on create.
+func (u *AnnouncementReadUpsertBulk) UpdateReadAt() *AnnouncementReadUpsertBulk {
+ return u.Update(func(s *AnnouncementReadUpsert) {
+ s.UpdateReadAt()
+ })
+}
+
+// Exec executes the query.
+func (u *AnnouncementReadUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AnnouncementReadCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AnnouncementReadCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AnnouncementReadUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/announcementread_delete.go b/backend/ent/announcementread_delete.go
new file mode 100644
index 00000000..a4da0821
--- /dev/null
+++ b/backend/ent/announcementread_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AnnouncementReadDelete is the builder for deleting a AnnouncementRead entity.
+type AnnouncementReadDelete struct {
+ config
+ hooks []Hook
+ mutation *AnnouncementReadMutation
+}
+
+// Where appends a list predicates to the AnnouncementReadDelete builder.
+func (_d *AnnouncementReadDelete) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *AnnouncementReadDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AnnouncementReadDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *AnnouncementReadDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(announcementread.Table, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// AnnouncementReadDeleteOne is the builder for deleting a single AnnouncementRead entity.
+type AnnouncementReadDeleteOne struct {
+ _d *AnnouncementReadDelete
+}
+
+// Where appends a list predicates to the AnnouncementReadDelete builder.
+func (_d *AnnouncementReadDeleteOne) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *AnnouncementReadDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{announcementread.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AnnouncementReadDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/announcementread_query.go b/backend/ent/announcementread_query.go
new file mode 100644
index 00000000..108299fd
--- /dev/null
+++ b/backend/ent/announcementread_query.go
@@ -0,0 +1,718 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AnnouncementReadQuery is the builder for querying AnnouncementRead entities.
+type AnnouncementReadQuery struct {
+ config
+ ctx *QueryContext
+ order []announcementread.OrderOption
+ inters []Interceptor
+ predicates []predicate.AnnouncementRead
+ withAnnouncement *AnnouncementQuery
+ withUser *UserQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AnnouncementReadQuery builder.
+func (_q *AnnouncementReadQuery) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *AnnouncementReadQuery) Limit(limit int) *AnnouncementReadQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *AnnouncementReadQuery) Offset(offset int) *AnnouncementReadQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *AnnouncementReadQuery) Unique(unique bool) *AnnouncementReadQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *AnnouncementReadQuery) Order(o ...announcementread.OrderOption) *AnnouncementReadQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryAnnouncement chains the current query on the "announcement" edge.
+func (_q *AnnouncementReadQuery) QueryAnnouncement() *AnnouncementQuery {
+ query := (&AnnouncementClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(announcementread.Table, announcementread.FieldID, selector),
+ sqlgraph.To(announcement.Table, announcement.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, announcementread.AnnouncementTable, announcementread.AnnouncementColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryUser chains the current query on the "user" edge.
+func (_q *AnnouncementReadQuery) QueryUser() *UserQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(announcementread.Table, announcementread.FieldID, selector),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, announcementread.UserTable, announcementread.UserColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first AnnouncementRead entity from the query.
+// Returns a *NotFoundError when no AnnouncementRead was found.
+func (_q *AnnouncementReadQuery) First(ctx context.Context) (*AnnouncementRead, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{announcementread.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *AnnouncementReadQuery) FirstX(ctx context.Context) *AnnouncementRead {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first AnnouncementRead ID from the query.
+// Returns a *NotFoundError when no AnnouncementRead ID was found.
+func (_q *AnnouncementReadQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{announcementread.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *AnnouncementReadQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single AnnouncementRead entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one AnnouncementRead entity is found.
+// Returns a *NotFoundError when no AnnouncementRead entities are found.
+func (_q *AnnouncementReadQuery) Only(ctx context.Context) (*AnnouncementRead, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{announcementread.Label}
+ default:
+ return nil, &NotSingularError{announcementread.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *AnnouncementReadQuery) OnlyX(ctx context.Context) *AnnouncementRead {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only AnnouncementRead ID in the query.
+// Returns a *NotSingularError when more than one AnnouncementRead ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *AnnouncementReadQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{announcementread.Label}
+ default:
+ err = &NotSingularError{announcementread.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *AnnouncementReadQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of AnnouncementReads.
+func (_q *AnnouncementReadQuery) All(ctx context.Context) ([]*AnnouncementRead, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*AnnouncementRead, *AnnouncementReadQuery]()
+ return withInterceptors[[]*AnnouncementRead](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *AnnouncementReadQuery) AllX(ctx context.Context) []*AnnouncementRead {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of AnnouncementRead IDs.
+func (_q *AnnouncementReadQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(announcementread.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *AnnouncementReadQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *AnnouncementReadQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*AnnouncementReadQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *AnnouncementReadQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *AnnouncementReadQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *AnnouncementReadQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AnnouncementReadQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *AnnouncementReadQuery) Clone() *AnnouncementReadQuery {
+ if _q == nil {
+ return nil
+ }
+ return &AnnouncementReadQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]announcementread.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.AnnouncementRead{}, _q.predicates...),
+ withAnnouncement: _q.withAnnouncement.Clone(),
+ withUser: _q.withUser.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithAnnouncement tells the query-builder to eager-load the nodes that are connected to
+// the "announcement" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AnnouncementReadQuery) WithAnnouncement(opts ...func(*AnnouncementQuery)) *AnnouncementReadQuery {
+ query := (&AnnouncementClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAnnouncement = query
+ return _q
+}
+
+// WithUser tells the query-builder to eager-load the nodes that are connected to
+// the "user" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AnnouncementReadQuery) WithUser(opts ...func(*UserQuery)) *AnnouncementReadQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withUser = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// AnnouncementID int64 `json:"announcement_id,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.AnnouncementRead.Query().
+// GroupBy(announcementread.FieldAnnouncementID).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *AnnouncementReadQuery) GroupBy(field string, fields ...string) *AnnouncementReadGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AnnouncementReadGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = announcementread.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// AnnouncementID int64 `json:"announcement_id,omitempty"`
+// }
+//
+// client.AnnouncementRead.Query().
+// Select(announcementread.FieldAnnouncementID).
+// Scan(ctx, &v)
+func (_q *AnnouncementReadQuery) Select(fields ...string) *AnnouncementReadSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &AnnouncementReadSelect{AnnouncementReadQuery: _q}
+ sbuild.label = announcementread.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AnnouncementReadSelect configured with the given aggregations.
+func (_q *AnnouncementReadQuery) Aggregate(fns ...AggregateFunc) *AnnouncementReadSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *AnnouncementReadQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !announcementread.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *AnnouncementReadQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AnnouncementRead, error) {
+ var (
+ nodes = []*AnnouncementRead{}
+ _spec = _q.querySpec()
+ loadedTypes = [2]bool{
+ _q.withAnnouncement != nil,
+ _q.withUser != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*AnnouncementRead).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &AnnouncementRead{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withAnnouncement; query != nil {
+ if err := _q.loadAnnouncement(ctx, query, nodes, nil,
+ func(n *AnnouncementRead, e *Announcement) { n.Edges.Announcement = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withUser; query != nil {
+ if err := _q.loadUser(ctx, query, nodes, nil,
+ func(n *AnnouncementRead, e *User) { n.Edges.User = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *AnnouncementReadQuery) loadAnnouncement(ctx context.Context, query *AnnouncementQuery, nodes []*AnnouncementRead, init func(*AnnouncementRead), assign func(*AnnouncementRead, *Announcement)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*AnnouncementRead)
+ for i := range nodes {
+ fk := nodes[i].AnnouncementID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(announcement.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "announcement_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *AnnouncementReadQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AnnouncementRead, init func(*AnnouncementRead), assign func(*AnnouncementRead, *User)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*AnnouncementRead)
+ for i := range nodes {
+ fk := nodes[i].UserID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(user.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *AnnouncementReadQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *AnnouncementReadQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(announcementread.Table, announcementread.Columns, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, announcementread.FieldID)
+ for i := range fields {
+ if fields[i] != announcementread.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withAnnouncement != nil {
+ _spec.Node.AddColumnOnce(announcementread.FieldAnnouncementID)
+ }
+ if _q.withUser != nil {
+ _spec.Node.AddColumnOnce(announcementread.FieldUserID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *AnnouncementReadQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(announcementread.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = announcementread.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *AnnouncementReadQuery) ForUpdate(opts ...sql.LockOption) *AnnouncementReadQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *AnnouncementReadQuery) ForShare(opts ...sql.LockOption) *AnnouncementReadQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// AnnouncementReadGroupBy is the group-by builder for AnnouncementRead entities.
+type AnnouncementReadGroupBy struct {
+ selector
+ build *AnnouncementReadQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *AnnouncementReadGroupBy) Aggregate(fns ...AggregateFunc) *AnnouncementReadGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *AnnouncementReadGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AnnouncementReadQuery, *AnnouncementReadGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *AnnouncementReadGroupBy) sqlScan(ctx context.Context, root *AnnouncementReadQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AnnouncementReadSelect is the builder for selecting fields of AnnouncementRead entities.
+type AnnouncementReadSelect struct {
+ *AnnouncementReadQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *AnnouncementReadSelect) Aggregate(fns ...AggregateFunc) *AnnouncementReadSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *AnnouncementReadSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AnnouncementReadQuery, *AnnouncementReadSelect](ctx, _s.AnnouncementReadQuery, _s, _s.inters, v)
+}
+
+func (_s *AnnouncementReadSelect) sqlScan(ctx context.Context, root *AnnouncementReadQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/announcementread_update.go b/backend/ent/announcementread_update.go
new file mode 100644
index 00000000..55a4eef8
--- /dev/null
+++ b/backend/ent/announcementread_update.go
@@ -0,0 +1,456 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AnnouncementReadUpdate is the builder for updating AnnouncementRead entities.
+type AnnouncementReadUpdate struct {
+ config
+ hooks []Hook
+ mutation *AnnouncementReadMutation
+}
+
+// Where appends a list predicates to the AnnouncementReadUpdate builder.
+func (_u *AnnouncementReadUpdate) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetAnnouncementID sets the "announcement_id" field.
+func (_u *AnnouncementReadUpdate) SetAnnouncementID(v int64) *AnnouncementReadUpdate {
+ _u.mutation.SetAnnouncementID(v)
+ return _u
+}
+
+// SetNillableAnnouncementID sets the "announcement_id" field if the given value is not nil.
+func (_u *AnnouncementReadUpdate) SetNillableAnnouncementID(v *int64) *AnnouncementReadUpdate {
+ if v != nil {
+ _u.SetAnnouncementID(*v)
+ }
+ return _u
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *AnnouncementReadUpdate) SetUserID(v int64) *AnnouncementReadUpdate {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *AnnouncementReadUpdate) SetNillableUserID(v *int64) *AnnouncementReadUpdate {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetReadAt sets the "read_at" field.
+func (_u *AnnouncementReadUpdate) SetReadAt(v time.Time) *AnnouncementReadUpdate {
+ _u.mutation.SetReadAt(v)
+ return _u
+}
+
+// SetNillableReadAt sets the "read_at" field if the given value is not nil.
+func (_u *AnnouncementReadUpdate) SetNillableReadAt(v *time.Time) *AnnouncementReadUpdate {
+ if v != nil {
+ _u.SetReadAt(*v)
+ }
+ return _u
+}
+
+// SetAnnouncement sets the "announcement" edge to the Announcement entity.
+func (_u *AnnouncementReadUpdate) SetAnnouncement(v *Announcement) *AnnouncementReadUpdate {
+ return _u.SetAnnouncementID(v.ID)
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *AnnouncementReadUpdate) SetUser(v *User) *AnnouncementReadUpdate {
+ return _u.SetUserID(v.ID)
+}
+
+// Mutation returns the AnnouncementReadMutation object of the builder.
+func (_u *AnnouncementReadUpdate) Mutation() *AnnouncementReadMutation {
+ return _u.mutation
+}
+
+// ClearAnnouncement clears the "announcement" edge to the Announcement entity.
+func (_u *AnnouncementReadUpdate) ClearAnnouncement() *AnnouncementReadUpdate {
+ _u.mutation.ClearAnnouncement()
+ return _u
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *AnnouncementReadUpdate) ClearUser() *AnnouncementReadUpdate {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *AnnouncementReadUpdate) Save(ctx context.Context) (int, error) {
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AnnouncementReadUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *AnnouncementReadUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AnnouncementReadUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AnnouncementReadUpdate) check() error {
+ if _u.mutation.AnnouncementCleared() && len(_u.mutation.AnnouncementIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AnnouncementRead.announcement"`)
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AnnouncementRead.user"`)
+ }
+ return nil
+}
+
+func (_u *AnnouncementReadUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(announcementread.Table, announcementread.Columns, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.ReadAt(); ok {
+ _spec.SetField(announcementread.FieldReadAt, field.TypeTime, value)
+ }
+ if _u.mutation.AnnouncementCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: announcementread.AnnouncementTable,
+ Columns: []string{announcementread.AnnouncementColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AnnouncementIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: announcementread.AnnouncementTable,
+ Columns: []string{announcementread.AnnouncementColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: announcementread.UserTable,
+ Columns: []string{announcementread.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: announcementread.UserTable,
+ Columns: []string{announcementread.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{announcementread.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// AnnouncementReadUpdateOne is the builder for updating a single AnnouncementRead entity.
+type AnnouncementReadUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AnnouncementReadMutation
+}
+
+// SetAnnouncementID sets the "announcement_id" field.
+func (_u *AnnouncementReadUpdateOne) SetAnnouncementID(v int64) *AnnouncementReadUpdateOne {
+ _u.mutation.SetAnnouncementID(v)
+ return _u
+}
+
+// SetNillableAnnouncementID sets the "announcement_id" field if the given value is not nil.
+func (_u *AnnouncementReadUpdateOne) SetNillableAnnouncementID(v *int64) *AnnouncementReadUpdateOne {
+ if v != nil {
+ _u.SetAnnouncementID(*v)
+ }
+ return _u
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *AnnouncementReadUpdateOne) SetUserID(v int64) *AnnouncementReadUpdateOne {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *AnnouncementReadUpdateOne) SetNillableUserID(v *int64) *AnnouncementReadUpdateOne {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetReadAt sets the "read_at" field.
+func (_u *AnnouncementReadUpdateOne) SetReadAt(v time.Time) *AnnouncementReadUpdateOne {
+ _u.mutation.SetReadAt(v)
+ return _u
+}
+
+// SetNillableReadAt sets the "read_at" field if the given value is not nil.
+func (_u *AnnouncementReadUpdateOne) SetNillableReadAt(v *time.Time) *AnnouncementReadUpdateOne {
+ if v != nil {
+ _u.SetReadAt(*v)
+ }
+ return _u
+}
+
+// SetAnnouncement sets the "announcement" edge to the Announcement entity.
+func (_u *AnnouncementReadUpdateOne) SetAnnouncement(v *Announcement) *AnnouncementReadUpdateOne {
+ return _u.SetAnnouncementID(v.ID)
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *AnnouncementReadUpdateOne) SetUser(v *User) *AnnouncementReadUpdateOne {
+ return _u.SetUserID(v.ID)
+}
+
+// Mutation returns the AnnouncementReadMutation object of the builder.
+func (_u *AnnouncementReadUpdateOne) Mutation() *AnnouncementReadMutation {
+ return _u.mutation
+}
+
+// ClearAnnouncement clears the "announcement" edge to the Announcement entity.
+func (_u *AnnouncementReadUpdateOne) ClearAnnouncement() *AnnouncementReadUpdateOne {
+ _u.mutation.ClearAnnouncement()
+ return _u
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *AnnouncementReadUpdateOne) ClearUser() *AnnouncementReadUpdateOne {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// Where appends a list predicates to the AnnouncementReadUpdate builder.
+func (_u *AnnouncementReadUpdateOne) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *AnnouncementReadUpdateOne) Select(field string, fields ...string) *AnnouncementReadUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated AnnouncementRead entity.
+func (_u *AnnouncementReadUpdateOne) Save(ctx context.Context) (*AnnouncementRead, error) {
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AnnouncementReadUpdateOne) SaveX(ctx context.Context) *AnnouncementRead {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *AnnouncementReadUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AnnouncementReadUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AnnouncementReadUpdateOne) check() error {
+ if _u.mutation.AnnouncementCleared() && len(_u.mutation.AnnouncementIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AnnouncementRead.announcement"`)
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AnnouncementRead.user"`)
+ }
+ return nil
+}
+
+func (_u *AnnouncementReadUpdateOne) sqlSave(ctx context.Context) (_node *AnnouncementRead, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(announcementread.Table, announcementread.Columns, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AnnouncementRead.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, announcementread.FieldID)
+ for _, f := range fields {
+ if !announcementread.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != announcementread.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.ReadAt(); ok {
+ _spec.SetField(announcementread.FieldReadAt, field.TypeTime, value)
+ }
+ if _u.mutation.AnnouncementCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: announcementread.AnnouncementTable,
+ Columns: []string{announcementread.AnnouncementColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AnnouncementIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: announcementread.AnnouncementTable,
+ Columns: []string{announcementread.AnnouncementColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: announcementread.UserTable,
+ Columns: []string{announcementread.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: announcementread.UserTable,
+ Columns: []string{announcementread.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &AnnouncementRead{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{announcementread.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/client.go b/backend/ent/client.go
index f6c13e84..a17721da 100644
--- a/backend/ent/client.go
+++ b/backend/ent/client.go
@@ -17,6 +17,8 @@ import (
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocode"
@@ -46,6 +48,10 @@ type Client struct {
Account *AccountClient
// AccountGroup is the client for interacting with the AccountGroup builders.
AccountGroup *AccountGroupClient
+ // Announcement is the client for interacting with the Announcement builders.
+ Announcement *AnnouncementClient
+ // AnnouncementRead is the client for interacting with the AnnouncementRead builders.
+ AnnouncementRead *AnnouncementReadClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// PromoCode is the client for interacting with the PromoCode builders.
@@ -86,6 +92,8 @@ func (c *Client) init() {
c.APIKey = NewAPIKeyClient(c.config)
c.Account = NewAccountClient(c.config)
c.AccountGroup = NewAccountGroupClient(c.config)
+ c.Announcement = NewAnnouncementClient(c.config)
+ c.AnnouncementRead = NewAnnouncementReadClient(c.config)
c.Group = NewGroupClient(c.config)
c.PromoCode = NewPromoCodeClient(c.config)
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
@@ -194,6 +202,8 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
APIKey: NewAPIKeyClient(cfg),
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
Group: NewGroupClient(cfg),
PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
@@ -229,6 +239,8 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
APIKey: NewAPIKeyClient(cfg),
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
Group: NewGroupClient(cfg),
PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
@@ -271,10 +283,10 @@ func (c *Client) Close() error {
// In order to add hooks to a specific client, call: `client.Node.Use(...)`.
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
- c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
- c.UserSubscription,
+ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
+ c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting,
+ c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
+ c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
} {
n.Use(hooks...)
}
@@ -284,10 +296,10 @@ func (c *Client) Use(hooks ...Hook) {
// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`.
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
- c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
- c.UserSubscription,
+ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
+ c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting,
+ c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
+ c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
} {
n.Intercept(interceptors...)
}
@@ -302,6 +314,10 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.Account.mutate(ctx, m)
case *AccountGroupMutation:
return c.AccountGroup.mutate(ctx, m)
+ case *AnnouncementMutation:
+ return c.Announcement.mutate(ctx, m)
+ case *AnnouncementReadMutation:
+ return c.AnnouncementRead.mutate(ctx, m)
case *GroupMutation:
return c.Group.mutate(ctx, m)
case *PromoCodeMutation:
@@ -831,6 +847,320 @@ func (c *AccountGroupClient) mutate(ctx context.Context, m *AccountGroupMutation
}
}
+// AnnouncementClient is a client for the Announcement schema.
+type AnnouncementClient struct {
+ config
+}
+
+// NewAnnouncementClient returns a client for the Announcement from the given config.
+func NewAnnouncementClient(c config) *AnnouncementClient {
+ return &AnnouncementClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `announcement.Hooks(f(g(h())))`.
+func (c *AnnouncementClient) Use(hooks ...Hook) {
+ c.hooks.Announcement = append(c.hooks.Announcement, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `announcement.Intercept(f(g(h())))`.
+func (c *AnnouncementClient) Intercept(interceptors ...Interceptor) {
+ c.inters.Announcement = append(c.inters.Announcement, interceptors...)
+}
+
+// Create returns a builder for creating a Announcement entity.
+func (c *AnnouncementClient) Create() *AnnouncementCreate {
+ mutation := newAnnouncementMutation(c.config, OpCreate)
+ return &AnnouncementCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of Announcement entities.
+func (c *AnnouncementClient) CreateBulk(builders ...*AnnouncementCreate) *AnnouncementCreateBulk {
+ return &AnnouncementCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AnnouncementClient) MapCreateBulk(slice any, setFunc func(*AnnouncementCreate, int)) *AnnouncementCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AnnouncementCreateBulk{err: fmt.Errorf("calling to AnnouncementClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AnnouncementCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AnnouncementCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for Announcement.
+func (c *AnnouncementClient) Update() *AnnouncementUpdate {
+ mutation := newAnnouncementMutation(c.config, OpUpdate)
+ return &AnnouncementUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AnnouncementClient) UpdateOne(_m *Announcement) *AnnouncementUpdateOne {
+ mutation := newAnnouncementMutation(c.config, OpUpdateOne, withAnnouncement(_m))
+ return &AnnouncementUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AnnouncementClient) UpdateOneID(id int64) *AnnouncementUpdateOne {
+ mutation := newAnnouncementMutation(c.config, OpUpdateOne, withAnnouncementID(id))
+ return &AnnouncementUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for Announcement.
+func (c *AnnouncementClient) Delete() *AnnouncementDelete {
+ mutation := newAnnouncementMutation(c.config, OpDelete)
+ return &AnnouncementDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AnnouncementClient) DeleteOne(_m *Announcement) *AnnouncementDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AnnouncementClient) DeleteOneID(id int64) *AnnouncementDeleteOne {
+ builder := c.Delete().Where(announcement.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AnnouncementDeleteOne{builder}
+}
+
+// Query returns a query builder for Announcement.
+func (c *AnnouncementClient) Query() *AnnouncementQuery {
+ return &AnnouncementQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAnnouncement},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a Announcement entity by its id.
+func (c *AnnouncementClient) Get(ctx context.Context, id int64) (*Announcement, error) {
+ return c.Query().Where(announcement.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AnnouncementClient) GetX(ctx context.Context, id int64) *Announcement {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryReads queries the reads edge of a Announcement.
+func (c *AnnouncementClient) QueryReads(_m *Announcement) *AnnouncementReadQuery {
+ query := (&AnnouncementReadClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(announcement.Table, announcement.FieldID, id),
+ sqlgraph.To(announcementread.Table, announcementread.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, announcement.ReadsTable, announcement.ReadsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AnnouncementClient) Hooks() []Hook {
+ return c.hooks.Announcement
+}
+
+// Interceptors returns the client interceptors.
+func (c *AnnouncementClient) Interceptors() []Interceptor {
+ return c.inters.Announcement
+}
+
+func (c *AnnouncementClient) mutate(ctx context.Context, m *AnnouncementMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AnnouncementCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AnnouncementUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AnnouncementUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AnnouncementDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown Announcement mutation op: %q", m.Op())
+ }
+}
+
+// AnnouncementReadClient is a client for the AnnouncementRead schema.
+type AnnouncementReadClient struct {
+ config
+}
+
+// NewAnnouncementReadClient returns a client for the AnnouncementRead from the given config.
+func NewAnnouncementReadClient(c config) *AnnouncementReadClient {
+ return &AnnouncementReadClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `announcementread.Hooks(f(g(h())))`.
+func (c *AnnouncementReadClient) Use(hooks ...Hook) {
+ c.hooks.AnnouncementRead = append(c.hooks.AnnouncementRead, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `announcementread.Intercept(f(g(h())))`.
+func (c *AnnouncementReadClient) Intercept(interceptors ...Interceptor) {
+ c.inters.AnnouncementRead = append(c.inters.AnnouncementRead, interceptors...)
+}
+
+// Create returns a builder for creating a AnnouncementRead entity.
+func (c *AnnouncementReadClient) Create() *AnnouncementReadCreate {
+ mutation := newAnnouncementReadMutation(c.config, OpCreate)
+ return &AnnouncementReadCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of AnnouncementRead entities.
+func (c *AnnouncementReadClient) CreateBulk(builders ...*AnnouncementReadCreate) *AnnouncementReadCreateBulk {
+ return &AnnouncementReadCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AnnouncementReadClient) MapCreateBulk(slice any, setFunc func(*AnnouncementReadCreate, int)) *AnnouncementReadCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AnnouncementReadCreateBulk{err: fmt.Errorf("calling to AnnouncementReadClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AnnouncementReadCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AnnouncementReadCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for AnnouncementRead.
+func (c *AnnouncementReadClient) Update() *AnnouncementReadUpdate {
+ mutation := newAnnouncementReadMutation(c.config, OpUpdate)
+ return &AnnouncementReadUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AnnouncementReadClient) UpdateOne(_m *AnnouncementRead) *AnnouncementReadUpdateOne {
+ mutation := newAnnouncementReadMutation(c.config, OpUpdateOne, withAnnouncementRead(_m))
+ return &AnnouncementReadUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AnnouncementReadClient) UpdateOneID(id int64) *AnnouncementReadUpdateOne {
+ mutation := newAnnouncementReadMutation(c.config, OpUpdateOne, withAnnouncementReadID(id))
+ return &AnnouncementReadUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for AnnouncementRead.
+func (c *AnnouncementReadClient) Delete() *AnnouncementReadDelete {
+ mutation := newAnnouncementReadMutation(c.config, OpDelete)
+ return &AnnouncementReadDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AnnouncementReadClient) DeleteOne(_m *AnnouncementRead) *AnnouncementReadDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AnnouncementReadClient) DeleteOneID(id int64) *AnnouncementReadDeleteOne {
+ builder := c.Delete().Where(announcementread.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AnnouncementReadDeleteOne{builder}
+}
+
+// Query returns a query builder for AnnouncementRead.
+func (c *AnnouncementReadClient) Query() *AnnouncementReadQuery {
+ return &AnnouncementReadQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAnnouncementRead},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a AnnouncementRead entity by its id.
+func (c *AnnouncementReadClient) Get(ctx context.Context, id int64) (*AnnouncementRead, error) {
+ return c.Query().Where(announcementread.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AnnouncementReadClient) GetX(ctx context.Context, id int64) *AnnouncementRead {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryAnnouncement queries the announcement edge of a AnnouncementRead.
+func (c *AnnouncementReadClient) QueryAnnouncement(_m *AnnouncementRead) *AnnouncementQuery {
+ query := (&AnnouncementClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(announcementread.Table, announcementread.FieldID, id),
+ sqlgraph.To(announcement.Table, announcement.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, announcementread.AnnouncementTable, announcementread.AnnouncementColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryUser queries the user edge of a AnnouncementRead.
+func (c *AnnouncementReadClient) QueryUser(_m *AnnouncementRead) *UserQuery {
+ query := (&UserClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(announcementread.Table, announcementread.FieldID, id),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, announcementread.UserTable, announcementread.UserColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AnnouncementReadClient) Hooks() []Hook {
+ return c.hooks.AnnouncementRead
+}
+
+// Interceptors returns the client interceptors.
+func (c *AnnouncementReadClient) Interceptors() []Interceptor {
+ return c.inters.AnnouncementRead
+}
+
+func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementReadMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AnnouncementReadCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AnnouncementReadUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AnnouncementReadUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AnnouncementReadDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown AnnouncementRead mutation op: %q", m.Op())
+ }
+}
+
// GroupClient is a client for the Group schema.
type GroupClient struct {
config
@@ -2375,6 +2705,22 @@ func (c *UserClient) QueryAssignedSubscriptions(_m *User) *UserSubscriptionQuery
return query
}
+// QueryAnnouncementReads queries the announcement_reads edge of a User.
+func (c *UserClient) QueryAnnouncementReads(_m *User) *AnnouncementReadQuery {
+ query := (&AnnouncementReadClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, id),
+ sqlgraph.To(announcementread.Table, announcementread.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.AnnouncementReadsTable, user.AnnouncementReadsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
// QueryAllowedGroups queries the allowed_groups edge of a User.
func (c *UserClient) QueryAllowedGroups(_m *User) *GroupQuery {
query := (&GroupClient{config: c.config}).Query()
@@ -3116,14 +3462,16 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
// hooks and interceptors per client, for fast access.
type (
hooks struct {
- APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
- RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
- UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
+ APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode,
+ PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User,
+ UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
+ UserSubscription []ent.Hook
}
inters struct {
- APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
- RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
- UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
+ APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode,
+ PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User,
+ UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
+ UserSubscription []ent.Interceptor
}
)
diff --git a/backend/ent/ent.go b/backend/ent/ent.go
index 4bcc2642..05e30ba7 100644
--- a/backend/ent/ent.go
+++ b/backend/ent/ent.go
@@ -14,6 +14,8 @@ import (
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocode"
@@ -91,6 +93,8 @@ func checkColumn(t, c string) error {
apikey.Table: apikey.ValidColumn,
account.Table: account.ValidColumn,
accountgroup.Table: accountgroup.ValidColumn,
+ announcement.Table: announcement.ValidColumn,
+ announcementread.Table: announcementread.ValidColumn,
group.Table: group.ValidColumn,
promocode.Table: promocode.ValidColumn,
promocodeusage.Table: promocodeusage.ValidColumn,
diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go
index edd84f5e..1e653c77 100644
--- a/backend/ent/hook/hook.go
+++ b/backend/ent/hook/hook.go
@@ -45,6 +45,30 @@ func (f AccountGroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AccountGroupMutation", m)
}
+// The AnnouncementFunc type is an adapter to allow the use of ordinary
+// function as Announcement mutator.
+type AnnouncementFunc func(context.Context, *ent.AnnouncementMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AnnouncementFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AnnouncementMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementMutation", m)
+}
+
+// The AnnouncementReadFunc type is an adapter to allow the use of ordinary
+// function as AnnouncementRead mutator.
+type AnnouncementReadFunc func(context.Context, *ent.AnnouncementReadMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AnnouncementReadMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m)
+}
+
// The GroupFunc type is an adapter to allow the use of ordinary
// function as Group mutator.
type GroupFunc func(context.Context, *ent.GroupMutation) (ent.Value, error)
diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go
index f18c0624..a37be48f 100644
--- a/backend/ent/intercept/intercept.go
+++ b/backend/ent/intercept/intercept.go
@@ -10,6 +10,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
@@ -164,6 +166,60 @@ func (f TraverseAccountGroup) Traverse(ctx context.Context, q ent.Query) error {
return fmt.Errorf("unexpected query type %T. expect *ent.AccountGroupQuery", q)
}
+// The AnnouncementFunc type is an adapter to allow the use of ordinary function as a Querier.
+type AnnouncementFunc func(context.Context, *ent.AnnouncementQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f AnnouncementFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.AnnouncementQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementQuery", q)
+}
+
+// The TraverseAnnouncement type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseAnnouncement func(context.Context, *ent.AnnouncementQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseAnnouncement) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseAnnouncement) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.AnnouncementQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementQuery", q)
+}
+
+// The AnnouncementReadFunc type is an adapter to allow the use of ordinary function as a Querier.
+type AnnouncementReadFunc func(context.Context, *ent.AnnouncementReadQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f AnnouncementReadFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.AnnouncementReadQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
+}
+
+// The TraverseAnnouncementRead type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseAnnouncementRead func(context.Context, *ent.AnnouncementReadQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseAnnouncementRead) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.AnnouncementReadQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
+}
+
// The GroupFunc type is an adapter to allow the use of ordinary function as a Querier.
type GroupFunc func(context.Context, *ent.GroupQuery) (ent.Value, error)
@@ -524,6 +580,10 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.AccountQuery, predicate.Account, account.OrderOption]{typ: ent.TypeAccount, tq: q}, nil
case *ent.AccountGroupQuery:
return &query[*ent.AccountGroupQuery, predicate.AccountGroup, accountgroup.OrderOption]{typ: ent.TypeAccountGroup, tq: q}, nil
+ case *ent.AnnouncementQuery:
+ return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil
+ case *ent.AnnouncementReadQuery:
+ return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil
case *ent.GroupQuery:
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
case *ent.PromoCodeQuery:
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index fe1f80a8..8df0cdb3 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -204,6 +204,98 @@ var (
},
},
}
+ // AnnouncementsColumns holds the columns for the "announcements" table.
+ AnnouncementsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "title", Type: field.TypeString, Size: 200},
+ {Name: "content", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "status", Type: field.TypeString, Size: 20, Default: "draft"},
+ {Name: "targeting", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "starts_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "ends_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "created_by", Type: field.TypeInt64, Nullable: true},
+ {Name: "updated_by", Type: field.TypeInt64, Nullable: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ }
+ // AnnouncementsTable holds the schema information for the "announcements" table.
+ AnnouncementsTable = &schema.Table{
+ Name: "announcements",
+ Columns: AnnouncementsColumns,
+ PrimaryKey: []*schema.Column{AnnouncementsColumns[0]},
+ Indexes: []*schema.Index{
+ {
+ Name: "announcement_status",
+ Unique: false,
+ Columns: []*schema.Column{AnnouncementsColumns[3]},
+ },
+ {
+ Name: "announcement_created_at",
+ Unique: false,
+ Columns: []*schema.Column{AnnouncementsColumns[9]},
+ },
+ {
+ Name: "announcement_starts_at",
+ Unique: false,
+ Columns: []*schema.Column{AnnouncementsColumns[5]},
+ },
+ {
+ Name: "announcement_ends_at",
+ Unique: false,
+ Columns: []*schema.Column{AnnouncementsColumns[6]},
+ },
+ },
+ }
+ // AnnouncementReadsColumns holds the columns for the "announcement_reads" table.
+ AnnouncementReadsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "read_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "announcement_id", Type: field.TypeInt64},
+ {Name: "user_id", Type: field.TypeInt64},
+ }
+ // AnnouncementReadsTable holds the schema information for the "announcement_reads" table.
+ AnnouncementReadsTable = &schema.Table{
+ Name: "announcement_reads",
+ Columns: AnnouncementReadsColumns,
+ PrimaryKey: []*schema.Column{AnnouncementReadsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "announcement_reads_announcements_reads",
+ Columns: []*schema.Column{AnnouncementReadsColumns[3]},
+ RefColumns: []*schema.Column{AnnouncementsColumns[0]},
+ OnDelete: schema.NoAction,
+ },
+ {
+ Symbol: "announcement_reads_users_announcement_reads",
+ Columns: []*schema.Column{AnnouncementReadsColumns[4]},
+ RefColumns: []*schema.Column{UsersColumns[0]},
+ OnDelete: schema.NoAction,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "announcementread_announcement_id",
+ Unique: false,
+ Columns: []*schema.Column{AnnouncementReadsColumns[3]},
+ },
+ {
+ Name: "announcementread_user_id",
+ Unique: false,
+ Columns: []*schema.Column{AnnouncementReadsColumns[4]},
+ },
+ {
+ Name: "announcementread_read_at",
+ Unique: false,
+ Columns: []*schema.Column{AnnouncementReadsColumns[1]},
+ },
+ {
+ Name: "announcementread_announcement_id_user_id",
+ Unique: true,
+ Columns: []*schema.Column{AnnouncementReadsColumns[3], AnnouncementReadsColumns[4]},
+ },
+ },
+ }
// GroupsColumns holds the columns for the "groups" table.
GroupsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -615,6 +707,9 @@ var (
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "username", Type: field.TypeString, Size: 100, Default: ""},
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "totp_enabled", Type: field.TypeBool, Default: false},
+ {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{
@@ -842,6 +937,8 @@ var (
APIKeysTable,
AccountsTable,
AccountGroupsTable,
+ AnnouncementsTable,
+ AnnouncementReadsTable,
GroupsTable,
PromoCodesTable,
PromoCodeUsagesTable,
@@ -873,6 +970,14 @@ func init() {
AccountGroupsTable.Annotation = &entsql.Annotation{
Table: "account_groups",
}
+ AnnouncementsTable.Annotation = &entsql.Annotation{
+ Table: "announcements",
+ }
+ AnnouncementReadsTable.ForeignKeys[0].RefTable = AnnouncementsTable
+ AnnouncementReadsTable.ForeignKeys[1].RefTable = UsersTable
+ AnnouncementReadsTable.Annotation = &entsql.Annotation{
+ Table: "announcement_reads",
+ }
GroupsTable.Annotation = &entsql.Annotation{
Table: "groups",
}
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index b3d1e410..f12ccb4f 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -14,6 +14,8 @@ import (
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
@@ -29,6 +31,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
)
const (
@@ -43,6 +46,8 @@ const (
TypeAPIKey = "APIKey"
TypeAccount = "Account"
TypeAccountGroup = "AccountGroup"
+ TypeAnnouncement = "Announcement"
+ TypeAnnouncementRead = "AnnouncementRead"
TypeGroup = "Group"
TypePromoCode = "PromoCode"
TypePromoCodeUsage = "PromoCodeUsage"
@@ -3833,6 +3838,1671 @@ func (m *AccountGroupMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown AccountGroup edge %s", name)
}
+// AnnouncementMutation represents an operation that mutates the Announcement nodes in the graph.
+type AnnouncementMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ title *string
+ content *string
+ status *string
+ targeting *domain.AnnouncementTargeting
+ starts_at *time.Time
+ ends_at *time.Time
+ created_by *int64
+ addcreated_by *int64
+ updated_by *int64
+ addupdated_by *int64
+ created_at *time.Time
+ updated_at *time.Time
+ clearedFields map[string]struct{}
+ reads map[int64]struct{}
+ removedreads map[int64]struct{}
+ clearedreads bool
+ done bool
+ oldValue func(context.Context) (*Announcement, error)
+ predicates []predicate.Announcement
+}
+
+var _ ent.Mutation = (*AnnouncementMutation)(nil)
+
+// announcementOption allows management of the mutation configuration using functional options.
+type announcementOption func(*AnnouncementMutation)
+
+// newAnnouncementMutation creates new mutation for the Announcement entity.
+func newAnnouncementMutation(c config, op Op, opts ...announcementOption) *AnnouncementMutation {
+ m := &AnnouncementMutation{
+ config: c,
+ op: op,
+ typ: TypeAnnouncement,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withAnnouncementID sets the ID field of the mutation.
+func withAnnouncementID(id int64) announcementOption {
+ return func(m *AnnouncementMutation) {
+ var (
+ err error
+ once sync.Once
+ value *Announcement
+ )
+ m.oldValue = func(ctx context.Context) (*Announcement, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().Announcement.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withAnnouncement sets the old Announcement of the mutation.
+func withAnnouncement(node *Announcement) announcementOption {
+ return func(m *AnnouncementMutation) {
+ m.oldValue = func(context.Context) (*Announcement, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m AnnouncementMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m AnnouncementMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *AnnouncementMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *AnnouncementMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().Announcement.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetTitle sets the "title" field.
+func (m *AnnouncementMutation) SetTitle(s string) {
+ m.title = &s
+}
+
+// Title returns the value of the "title" field in the mutation.
+func (m *AnnouncementMutation) Title() (r string, exists bool) {
+ v := m.title
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTitle returns the old "title" field's value of the Announcement entity.
+// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementMutation) OldTitle(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTitle is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTitle requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTitle: %w", err)
+ }
+ return oldValue.Title, nil
+}
+
+// ResetTitle resets all changes to the "title" field.
+func (m *AnnouncementMutation) ResetTitle() {
+ m.title = nil
+}
+
+// SetContent sets the "content" field.
+func (m *AnnouncementMutation) SetContent(s string) {
+ m.content = &s
+}
+
+// Content returns the value of the "content" field in the mutation.
+func (m *AnnouncementMutation) Content() (r string, exists bool) {
+ v := m.content
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldContent returns the old "content" field's value of the Announcement entity.
+// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementMutation) OldContent(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldContent is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldContent requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldContent: %w", err)
+ }
+ return oldValue.Content, nil
+}
+
+// ResetContent resets all changes to the "content" field.
+func (m *AnnouncementMutation) ResetContent() {
+ m.content = nil
+}
+
+// SetStatus sets the "status" field.
+func (m *AnnouncementMutation) SetStatus(s string) {
+ m.status = &s
+}
+
+// Status returns the value of the "status" field in the mutation.
+func (m *AnnouncementMutation) Status() (r string, exists bool) {
+ v := m.status
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldStatus returns the old "status" field's value of the Announcement entity.
+// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementMutation) OldStatus(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStatus is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStatus requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStatus: %w", err)
+ }
+ return oldValue.Status, nil
+}
+
+// ResetStatus resets all changes to the "status" field.
+func (m *AnnouncementMutation) ResetStatus() {
+ m.status = nil
+}
+
+// SetTargeting sets the "targeting" field.
+func (m *AnnouncementMutation) SetTargeting(dt domain.AnnouncementTargeting) {
+ m.targeting = &dt
+}
+
+// Targeting returns the value of the "targeting" field in the mutation.
+func (m *AnnouncementMutation) Targeting() (r domain.AnnouncementTargeting, exists bool) {
+ v := m.targeting
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTargeting returns the old "targeting" field's value of the Announcement entity.
+// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementMutation) OldTargeting(ctx context.Context) (v domain.AnnouncementTargeting, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTargeting is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTargeting requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTargeting: %w", err)
+ }
+ return oldValue.Targeting, nil
+}
+
+// ClearTargeting clears the value of the "targeting" field.
+func (m *AnnouncementMutation) ClearTargeting() {
+ m.targeting = nil
+ m.clearedFields[announcement.FieldTargeting] = struct{}{}
+}
+
+// TargetingCleared returns if the "targeting" field was cleared in this mutation.
+func (m *AnnouncementMutation) TargetingCleared() bool {
+ _, ok := m.clearedFields[announcement.FieldTargeting]
+ return ok
+}
+
+// ResetTargeting resets all changes to the "targeting" field.
+func (m *AnnouncementMutation) ResetTargeting() {
+ m.targeting = nil
+ delete(m.clearedFields, announcement.FieldTargeting)
+}
+
+// SetStartsAt sets the "starts_at" field.
+func (m *AnnouncementMutation) SetStartsAt(t time.Time) {
+ m.starts_at = &t
+}
+
+// StartsAt returns the value of the "starts_at" field in the mutation.
+func (m *AnnouncementMutation) StartsAt() (r time.Time, exists bool) {
+ v := m.starts_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldStartsAt returns the old "starts_at" field's value of the Announcement entity.
+// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementMutation) OldStartsAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStartsAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStartsAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStartsAt: %w", err)
+ }
+ return oldValue.StartsAt, nil
+}
+
+// ClearStartsAt clears the value of the "starts_at" field.
+func (m *AnnouncementMutation) ClearStartsAt() {
+ m.starts_at = nil
+ m.clearedFields[announcement.FieldStartsAt] = struct{}{}
+}
+
+// StartsAtCleared returns if the "starts_at" field was cleared in this mutation.
+func (m *AnnouncementMutation) StartsAtCleared() bool {
+ _, ok := m.clearedFields[announcement.FieldStartsAt]
+ return ok
+}
+
+// ResetStartsAt resets all changes to the "starts_at" field.
+func (m *AnnouncementMutation) ResetStartsAt() {
+ m.starts_at = nil
+ delete(m.clearedFields, announcement.FieldStartsAt)
+}
+
+// SetEndsAt sets the "ends_at" field.
+func (m *AnnouncementMutation) SetEndsAt(t time.Time) {
+ m.ends_at = &t
+}
+
+// EndsAt returns the value of the "ends_at" field in the mutation.
+func (m *AnnouncementMutation) EndsAt() (r time.Time, exists bool) {
+ v := m.ends_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldEndsAt returns the old "ends_at" field's value of the Announcement entity.
+// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementMutation) OldEndsAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldEndsAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldEndsAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldEndsAt: %w", err)
+ }
+ return oldValue.EndsAt, nil
+}
+
+// ClearEndsAt clears the value of the "ends_at" field.
+func (m *AnnouncementMutation) ClearEndsAt() {
+ m.ends_at = nil
+ m.clearedFields[announcement.FieldEndsAt] = struct{}{}
+}
+
+// EndsAtCleared returns if the "ends_at" field was cleared in this mutation.
+func (m *AnnouncementMutation) EndsAtCleared() bool {
+ _, ok := m.clearedFields[announcement.FieldEndsAt]
+ return ok
+}
+
+// ResetEndsAt resets all changes to the "ends_at" field.
+func (m *AnnouncementMutation) ResetEndsAt() {
+ m.ends_at = nil
+ delete(m.clearedFields, announcement.FieldEndsAt)
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (m *AnnouncementMutation) SetCreatedBy(i int64) {
+ m.created_by = &i
+ m.addcreated_by = nil
+}
+
+// CreatedBy returns the value of the "created_by" field in the mutation.
+func (m *AnnouncementMutation) CreatedBy() (r int64, exists bool) {
+ v := m.created_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedBy returns the old "created_by" field's value of the Announcement entity.
+// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementMutation) OldCreatedBy(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedBy is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedBy requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedBy: %w", err)
+ }
+ return oldValue.CreatedBy, nil
+}
+
+// AddCreatedBy adds i to the "created_by" field.
+func (m *AnnouncementMutation) AddCreatedBy(i int64) {
+ if m.addcreated_by != nil {
+ *m.addcreated_by += i
+ } else {
+ m.addcreated_by = &i
+ }
+}
+
+// AddedCreatedBy returns the value that was added to the "created_by" field in this mutation.
+func (m *AnnouncementMutation) AddedCreatedBy() (r int64, exists bool) {
+ v := m.addcreated_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearCreatedBy clears the value of the "created_by" field.
+func (m *AnnouncementMutation) ClearCreatedBy() {
+ m.created_by = nil
+ m.addcreated_by = nil
+ m.clearedFields[announcement.FieldCreatedBy] = struct{}{}
+}
+
+// CreatedByCleared returns if the "created_by" field was cleared in this mutation.
+func (m *AnnouncementMutation) CreatedByCleared() bool {
+ _, ok := m.clearedFields[announcement.FieldCreatedBy]
+ return ok
+}
+
+// ResetCreatedBy resets all changes to the "created_by" field.
+func (m *AnnouncementMutation) ResetCreatedBy() {
+ m.created_by = nil
+ m.addcreated_by = nil
+ delete(m.clearedFields, announcement.FieldCreatedBy)
+}
+
+// SetUpdatedBy sets the "updated_by" field.
+func (m *AnnouncementMutation) SetUpdatedBy(i int64) {
+ m.updated_by = &i
+ m.addupdated_by = nil
+}
+
+// UpdatedBy returns the value of the "updated_by" field in the mutation.
+func (m *AnnouncementMutation) UpdatedBy() (r int64, exists bool) {
+ v := m.updated_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedBy returns the old "updated_by" field's value of the Announcement entity.
+// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementMutation) OldUpdatedBy(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedBy is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedBy requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedBy: %w", err)
+ }
+ return oldValue.UpdatedBy, nil
+}
+
+// AddUpdatedBy adds i to the "updated_by" field.
+func (m *AnnouncementMutation) AddUpdatedBy(i int64) {
+ if m.addupdated_by != nil {
+ *m.addupdated_by += i
+ } else {
+ m.addupdated_by = &i
+ }
+}
+
+// AddedUpdatedBy returns the value that was added to the "updated_by" field in this mutation.
+func (m *AnnouncementMutation) AddedUpdatedBy() (r int64, exists bool) {
+ v := m.addupdated_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearUpdatedBy clears the value of the "updated_by" field.
+func (m *AnnouncementMutation) ClearUpdatedBy() {
+ m.updated_by = nil
+ m.addupdated_by = nil
+ m.clearedFields[announcement.FieldUpdatedBy] = struct{}{}
+}
+
+// UpdatedByCleared returns if the "updated_by" field was cleared in this mutation.
+func (m *AnnouncementMutation) UpdatedByCleared() bool {
+ _, ok := m.clearedFields[announcement.FieldUpdatedBy]
+ return ok
+}
+
+// ResetUpdatedBy resets all changes to the "updated_by" field.
+func (m *AnnouncementMutation) ResetUpdatedBy() {
+ m.updated_by = nil
+ m.addupdated_by = nil
+ delete(m.clearedFields, announcement.FieldUpdatedBy)
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *AnnouncementMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *AnnouncementMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the Announcement entity.
+// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *AnnouncementMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *AnnouncementMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *AnnouncementMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the Announcement entity.
+// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *AnnouncementMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// AddReadIDs adds the "reads" edge to the AnnouncementRead entity by ids.
+func (m *AnnouncementMutation) AddReadIDs(ids ...int64) {
+ if m.reads == nil {
+ m.reads = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.reads[ids[i]] = struct{}{}
+ }
+}
+
+// ClearReads clears the "reads" edge to the AnnouncementRead entity.
+func (m *AnnouncementMutation) ClearReads() {
+ m.clearedreads = true
+}
+
+// ReadsCleared reports if the "reads" edge to the AnnouncementRead entity was cleared.
+func (m *AnnouncementMutation) ReadsCleared() bool {
+ return m.clearedreads
+}
+
+// RemoveReadIDs removes the "reads" edge to the AnnouncementRead entity by IDs.
+func (m *AnnouncementMutation) RemoveReadIDs(ids ...int64) {
+ if m.removedreads == nil {
+ m.removedreads = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.reads, ids[i])
+ m.removedreads[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedReads returns the removed IDs of the "reads" edge to the AnnouncementRead entity.
+func (m *AnnouncementMutation) RemovedReadsIDs() (ids []int64) {
+ for id := range m.removedreads {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ReadsIDs returns the "reads" edge IDs in the mutation.
+func (m *AnnouncementMutation) ReadsIDs() (ids []int64) {
+ for id := range m.reads {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetReads resets all changes to the "reads" edge.
+func (m *AnnouncementMutation) ResetReads() {
+ m.reads = nil
+ m.clearedreads = false
+ m.removedreads = nil
+}
+
+// Where appends a list predicates to the AnnouncementMutation builder.
+func (m *AnnouncementMutation) Where(ps ...predicate.Announcement) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the AnnouncementMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AnnouncementMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.Announcement, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *AnnouncementMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *AnnouncementMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (Announcement).
+func (m *AnnouncementMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AnnouncementMutation) Fields() []string {
+ fields := make([]string, 0, 10)
+ if m.title != nil {
+ fields = append(fields, announcement.FieldTitle)
+ }
+ if m.content != nil {
+ fields = append(fields, announcement.FieldContent)
+ }
+ if m.status != nil {
+ fields = append(fields, announcement.FieldStatus)
+ }
+ if m.targeting != nil {
+ fields = append(fields, announcement.FieldTargeting)
+ }
+ if m.starts_at != nil {
+ fields = append(fields, announcement.FieldStartsAt)
+ }
+ if m.ends_at != nil {
+ fields = append(fields, announcement.FieldEndsAt)
+ }
+ if m.created_by != nil {
+ fields = append(fields, announcement.FieldCreatedBy)
+ }
+ if m.updated_by != nil {
+ fields = append(fields, announcement.FieldUpdatedBy)
+ }
+ if m.created_at != nil {
+ fields = append(fields, announcement.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, announcement.FieldUpdatedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AnnouncementMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case announcement.FieldTitle:
+ return m.Title()
+ case announcement.FieldContent:
+ return m.Content()
+ case announcement.FieldStatus:
+ return m.Status()
+ case announcement.FieldTargeting:
+ return m.Targeting()
+ case announcement.FieldStartsAt:
+ return m.StartsAt()
+ case announcement.FieldEndsAt:
+ return m.EndsAt()
+ case announcement.FieldCreatedBy:
+ return m.CreatedBy()
+ case announcement.FieldUpdatedBy:
+ return m.UpdatedBy()
+ case announcement.FieldCreatedAt:
+ return m.CreatedAt()
+ case announcement.FieldUpdatedAt:
+ return m.UpdatedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AnnouncementMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case announcement.FieldTitle:
+ return m.OldTitle(ctx)
+ case announcement.FieldContent:
+ return m.OldContent(ctx)
+ case announcement.FieldStatus:
+ return m.OldStatus(ctx)
+ case announcement.FieldTargeting:
+ return m.OldTargeting(ctx)
+ case announcement.FieldStartsAt:
+ return m.OldStartsAt(ctx)
+ case announcement.FieldEndsAt:
+ return m.OldEndsAt(ctx)
+ case announcement.FieldCreatedBy:
+ return m.OldCreatedBy(ctx)
+ case announcement.FieldUpdatedBy:
+ return m.OldUpdatedBy(ctx)
+ case announcement.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case announcement.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown Announcement field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AnnouncementMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case announcement.FieldTitle:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTitle(v)
+ return nil
+ case announcement.FieldContent:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetContent(v)
+ return nil
+ case announcement.FieldStatus:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetStatus(v)
+ return nil
+ case announcement.FieldTargeting:
+ v, ok := value.(domain.AnnouncementTargeting)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTargeting(v)
+ return nil
+ case announcement.FieldStartsAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetStartsAt(v)
+ return nil
+ case announcement.FieldEndsAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEndsAt(v)
+ return nil
+ case announcement.FieldCreatedBy:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedBy(v)
+ return nil
+ case announcement.FieldUpdatedBy:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedBy(v)
+ return nil
+ case announcement.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case announcement.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown Announcement field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *AnnouncementMutation) AddedFields() []string {
+ var fields []string
+ if m.addcreated_by != nil {
+ fields = append(fields, announcement.FieldCreatedBy)
+ }
+ if m.addupdated_by != nil {
+ fields = append(fields, announcement.FieldUpdatedBy)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *AnnouncementMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case announcement.FieldCreatedBy:
+ return m.AddedCreatedBy()
+ case announcement.FieldUpdatedBy:
+ return m.AddedUpdatedBy()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AnnouncementMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case announcement.FieldCreatedBy:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddCreatedBy(v)
+ return nil
+ case announcement.FieldUpdatedBy:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddUpdatedBy(v)
+ return nil
+ }
+ return fmt.Errorf("unknown Announcement numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *AnnouncementMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(announcement.FieldTargeting) {
+ fields = append(fields, announcement.FieldTargeting)
+ }
+ if m.FieldCleared(announcement.FieldStartsAt) {
+ fields = append(fields, announcement.FieldStartsAt)
+ }
+ if m.FieldCleared(announcement.FieldEndsAt) {
+ fields = append(fields, announcement.FieldEndsAt)
+ }
+ if m.FieldCleared(announcement.FieldCreatedBy) {
+ fields = append(fields, announcement.FieldCreatedBy)
+ }
+ if m.FieldCleared(announcement.FieldUpdatedBy) {
+ fields = append(fields, announcement.FieldUpdatedBy)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *AnnouncementMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *AnnouncementMutation) ClearField(name string) error {
+ switch name {
+ case announcement.FieldTargeting:
+ m.ClearTargeting()
+ return nil
+ case announcement.FieldStartsAt:
+ m.ClearStartsAt()
+ return nil
+ case announcement.FieldEndsAt:
+ m.ClearEndsAt()
+ return nil
+ case announcement.FieldCreatedBy:
+ m.ClearCreatedBy()
+ return nil
+ case announcement.FieldUpdatedBy:
+ m.ClearUpdatedBy()
+ return nil
+ }
+ return fmt.Errorf("unknown Announcement nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *AnnouncementMutation) ResetField(name string) error {
+ switch name {
+ case announcement.FieldTitle:
+ m.ResetTitle()
+ return nil
+ case announcement.FieldContent:
+ m.ResetContent()
+ return nil
+ case announcement.FieldStatus:
+ m.ResetStatus()
+ return nil
+ case announcement.FieldTargeting:
+ m.ResetTargeting()
+ return nil
+ case announcement.FieldStartsAt:
+ m.ResetStartsAt()
+ return nil
+ case announcement.FieldEndsAt:
+ m.ResetEndsAt()
+ return nil
+ case announcement.FieldCreatedBy:
+ m.ResetCreatedBy()
+ return nil
+ case announcement.FieldUpdatedBy:
+ m.ResetUpdatedBy()
+ return nil
+ case announcement.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case announcement.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown Announcement field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *AnnouncementMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.reads != nil {
+ edges = append(edges, announcement.EdgeReads)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *AnnouncementMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case announcement.EdgeReads:
+ ids := make([]ent.Value, 0, len(m.reads))
+ for id := range m.reads {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *AnnouncementMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.removedreads != nil {
+ edges = append(edges, announcement.EdgeReads)
+ }
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *AnnouncementMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case announcement.EdgeReads:
+ ids := make([]ent.Value, 0, len(m.removedreads))
+ for id := range m.removedreads {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *AnnouncementMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedreads {
+ edges = append(edges, announcement.EdgeReads)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *AnnouncementMutation) EdgeCleared(name string) bool {
+ switch name {
+ case announcement.EdgeReads:
+ return m.clearedreads
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *AnnouncementMutation) ClearEdge(name string) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown Announcement unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *AnnouncementMutation) ResetEdge(name string) error {
+ switch name {
+ case announcement.EdgeReads:
+ m.ResetReads()
+ return nil
+ }
+ return fmt.Errorf("unknown Announcement edge %s", name)
+}
+
+// AnnouncementReadMutation represents an operation that mutates the AnnouncementRead nodes in the graph.
+type AnnouncementReadMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ read_at *time.Time
+ created_at *time.Time
+ clearedFields map[string]struct{}
+ announcement *int64
+ clearedannouncement bool
+ user *int64
+ cleareduser bool
+ done bool
+ oldValue func(context.Context) (*AnnouncementRead, error)
+ predicates []predicate.AnnouncementRead
+}
+
+var _ ent.Mutation = (*AnnouncementReadMutation)(nil)
+
+// announcementreadOption allows management of the mutation configuration using functional options.
+type announcementreadOption func(*AnnouncementReadMutation)
+
+// newAnnouncementReadMutation creates new mutation for the AnnouncementRead entity.
+func newAnnouncementReadMutation(c config, op Op, opts ...announcementreadOption) *AnnouncementReadMutation {
+ m := &AnnouncementReadMutation{
+ config: c,
+ op: op,
+ typ: TypeAnnouncementRead,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withAnnouncementReadID sets the ID field of the mutation.
+func withAnnouncementReadID(id int64) announcementreadOption {
+ return func(m *AnnouncementReadMutation) {
+ var (
+ err error
+ once sync.Once
+ value *AnnouncementRead
+ )
+ m.oldValue = func(ctx context.Context) (*AnnouncementRead, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().AnnouncementRead.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withAnnouncementRead sets the old AnnouncementRead of the mutation.
+func withAnnouncementRead(node *AnnouncementRead) announcementreadOption {
+ return func(m *AnnouncementReadMutation) {
+ m.oldValue = func(context.Context) (*AnnouncementRead, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m AnnouncementReadMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m AnnouncementReadMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *AnnouncementReadMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *AnnouncementReadMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().AnnouncementRead.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetAnnouncementID sets the "announcement_id" field.
+func (m *AnnouncementReadMutation) SetAnnouncementID(i int64) {
+ m.announcement = &i
+}
+
+// AnnouncementID returns the value of the "announcement_id" field in the mutation.
+func (m *AnnouncementReadMutation) AnnouncementID() (r int64, exists bool) {
+ v := m.announcement
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAnnouncementID returns the old "announcement_id" field's value of the AnnouncementRead entity.
+// If the AnnouncementRead object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementReadMutation) OldAnnouncementID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAnnouncementID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAnnouncementID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAnnouncementID: %w", err)
+ }
+ return oldValue.AnnouncementID, nil
+}
+
+// ResetAnnouncementID resets all changes to the "announcement_id" field.
+func (m *AnnouncementReadMutation) ResetAnnouncementID() {
+ m.announcement = nil
+}
+
+// SetUserID sets the "user_id" field.
+func (m *AnnouncementReadMutation) SetUserID(i int64) {
+ m.user = &i
+}
+
+// UserID returns the value of the "user_id" field in the mutation.
+func (m *AnnouncementReadMutation) UserID() (r int64, exists bool) {
+ v := m.user
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUserID returns the old "user_id" field's value of the AnnouncementRead entity.
+// If the AnnouncementRead object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementReadMutation) OldUserID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUserID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUserID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUserID: %w", err)
+ }
+ return oldValue.UserID, nil
+}
+
+// ResetUserID resets all changes to the "user_id" field.
+func (m *AnnouncementReadMutation) ResetUserID() {
+ m.user = nil
+}
+
+// SetReadAt sets the "read_at" field.
+func (m *AnnouncementReadMutation) SetReadAt(t time.Time) {
+ m.read_at = &t
+}
+
+// ReadAt returns the value of the "read_at" field in the mutation.
+func (m *AnnouncementReadMutation) ReadAt() (r time.Time, exists bool) {
+ v := m.read_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldReadAt returns the old "read_at" field's value of the AnnouncementRead entity.
+// If the AnnouncementRead object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementReadMutation) OldReadAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldReadAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldReadAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldReadAt: %w", err)
+ }
+ return oldValue.ReadAt, nil
+}
+
+// ResetReadAt resets all changes to the "read_at" field.
+func (m *AnnouncementReadMutation) ResetReadAt() {
+ m.read_at = nil
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *AnnouncementReadMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *AnnouncementReadMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the AnnouncementRead entity.
+// If the AnnouncementRead object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AnnouncementReadMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *AnnouncementReadMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// ClearAnnouncement clears the "announcement" edge to the Announcement entity.
+func (m *AnnouncementReadMutation) ClearAnnouncement() {
+ m.clearedannouncement = true
+ m.clearedFields[announcementread.FieldAnnouncementID] = struct{}{}
+}
+
+// AnnouncementCleared reports if the "announcement" edge to the Announcement entity was cleared.
+func (m *AnnouncementReadMutation) AnnouncementCleared() bool {
+ return m.clearedannouncement
+}
+
+// AnnouncementIDs returns the "announcement" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// AnnouncementID instead. It exists only for internal usage by the builders.
+func (m *AnnouncementReadMutation) AnnouncementIDs() (ids []int64) {
+ if id := m.announcement; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetAnnouncement resets all changes to the "announcement" edge.
+func (m *AnnouncementReadMutation) ResetAnnouncement() {
+ m.announcement = nil
+ m.clearedannouncement = false
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (m *AnnouncementReadMutation) ClearUser() {
+ m.cleareduser = true
+ m.clearedFields[announcementread.FieldUserID] = struct{}{}
+}
+
+// UserCleared reports if the "user" edge to the User entity was cleared.
+func (m *AnnouncementReadMutation) UserCleared() bool {
+ return m.cleareduser
+}
+
+// UserIDs returns the "user" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// UserID instead. It exists only for internal usage by the builders.
+func (m *AnnouncementReadMutation) UserIDs() (ids []int64) {
+ if id := m.user; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetUser resets all changes to the "user" edge.
+func (m *AnnouncementReadMutation) ResetUser() {
+ m.user = nil
+ m.cleareduser = false
+}
+
+// Where appends a list predicates to the AnnouncementReadMutation builder.
+func (m *AnnouncementReadMutation) Where(ps ...predicate.AnnouncementRead) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the AnnouncementReadMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AnnouncementReadMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.AnnouncementRead, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *AnnouncementReadMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *AnnouncementReadMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (AnnouncementRead).
+func (m *AnnouncementReadMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AnnouncementReadMutation) Fields() []string {
+ fields := make([]string, 0, 4)
+ if m.announcement != nil {
+ fields = append(fields, announcementread.FieldAnnouncementID)
+ }
+ if m.user != nil {
+ fields = append(fields, announcementread.FieldUserID)
+ }
+ if m.read_at != nil {
+ fields = append(fields, announcementread.FieldReadAt)
+ }
+ if m.created_at != nil {
+ fields = append(fields, announcementread.FieldCreatedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AnnouncementReadMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case announcementread.FieldAnnouncementID:
+ return m.AnnouncementID()
+ case announcementread.FieldUserID:
+ return m.UserID()
+ case announcementread.FieldReadAt:
+ return m.ReadAt()
+ case announcementread.FieldCreatedAt:
+ return m.CreatedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AnnouncementReadMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case announcementread.FieldAnnouncementID:
+ return m.OldAnnouncementID(ctx)
+ case announcementread.FieldUserID:
+ return m.OldUserID(ctx)
+ case announcementread.FieldReadAt:
+ return m.OldReadAt(ctx)
+ case announcementread.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown AnnouncementRead field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AnnouncementReadMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case announcementread.FieldAnnouncementID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAnnouncementID(v)
+ return nil
+ case announcementread.FieldUserID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserID(v)
+ return nil
+ case announcementread.FieldReadAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetReadAt(v)
+ return nil
+ case announcementread.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown AnnouncementRead field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *AnnouncementReadMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *AnnouncementReadMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AnnouncementReadMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown AnnouncementRead numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *AnnouncementReadMutation) ClearedFields() []string {
+ return nil
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *AnnouncementReadMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *AnnouncementReadMutation) ClearField(name string) error {
+ return fmt.Errorf("unknown AnnouncementRead nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *AnnouncementReadMutation) ResetField(name string) error {
+ switch name {
+ case announcementread.FieldAnnouncementID:
+ m.ResetAnnouncementID()
+ return nil
+ case announcementread.FieldUserID:
+ m.ResetUserID()
+ return nil
+ case announcementread.FieldReadAt:
+ m.ResetReadAt()
+ return nil
+ case announcementread.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown AnnouncementRead field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *AnnouncementReadMutation) AddedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.announcement != nil {
+ edges = append(edges, announcementread.EdgeAnnouncement)
+ }
+ if m.user != nil {
+ edges = append(edges, announcementread.EdgeUser)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *AnnouncementReadMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case announcementread.EdgeAnnouncement:
+ if id := m.announcement; id != nil {
+ return []ent.Value{*id}
+ }
+ case announcementread.EdgeUser:
+ if id := m.user; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *AnnouncementReadMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 2)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *AnnouncementReadMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *AnnouncementReadMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.clearedannouncement {
+ edges = append(edges, announcementread.EdgeAnnouncement)
+ }
+ if m.cleareduser {
+ edges = append(edges, announcementread.EdgeUser)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *AnnouncementReadMutation) EdgeCleared(name string) bool {
+ switch name {
+ case announcementread.EdgeAnnouncement:
+ return m.clearedannouncement
+ case announcementread.EdgeUser:
+ return m.cleareduser
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *AnnouncementReadMutation) ClearEdge(name string) error {
+ switch name {
+ case announcementread.EdgeAnnouncement:
+ m.ClearAnnouncement()
+ return nil
+ case announcementread.EdgeUser:
+ m.ClearUser()
+ return nil
+ }
+ return fmt.Errorf("unknown AnnouncementRead unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *AnnouncementReadMutation) ResetEdge(name string) error {
+ switch name {
+ case announcementread.EdgeAnnouncement:
+ m.ResetAnnouncement()
+ return nil
+ case announcementread.EdgeUser:
+ m.ResetUser()
+ return nil
+ }
+ return fmt.Errorf("unknown AnnouncementRead edge %s", name)
+}
+
// GroupMutation represents an operation that mutates the Group nodes in the graph.
type GroupMutation struct {
config
@@ -14861,6 +16531,9 @@ type UserMutation struct {
status *string
username *string
notes *string
+ totp_secret_encrypted *string
+ totp_enabled *bool
+ totp_enabled_at *time.Time
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -14874,6 +16547,9 @@ type UserMutation struct {
assigned_subscriptions map[int64]struct{}
removedassigned_subscriptions map[int64]struct{}
clearedassigned_subscriptions bool
+ announcement_reads map[int64]struct{}
+ removedannouncement_reads map[int64]struct{}
+ clearedannouncement_reads bool
allowed_groups map[int64]struct{}
removedallowed_groups map[int64]struct{}
clearedallowed_groups bool
@@ -15438,6 +17114,140 @@ func (m *UserMutation) ResetNotes() {
m.notes = nil
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (m *UserMutation) SetTotpSecretEncrypted(s string) {
+ m.totp_secret_encrypted = &s
+}
+
+// TotpSecretEncrypted returns the value of the "totp_secret_encrypted" field in the mutation.
+func (m *UserMutation) TotpSecretEncrypted() (r string, exists bool) {
+ v := m.totp_secret_encrypted
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotpSecretEncrypted returns the old "totp_secret_encrypted" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldTotpSecretEncrypted(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotpSecretEncrypted is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotpSecretEncrypted requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotpSecretEncrypted: %w", err)
+ }
+ return oldValue.TotpSecretEncrypted, nil
+}
+
+// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
+func (m *UserMutation) ClearTotpSecretEncrypted() {
+ m.totp_secret_encrypted = nil
+ m.clearedFields[user.FieldTotpSecretEncrypted] = struct{}{}
+}
+
+// TotpSecretEncryptedCleared returns if the "totp_secret_encrypted" field was cleared in this mutation.
+func (m *UserMutation) TotpSecretEncryptedCleared() bool {
+ _, ok := m.clearedFields[user.FieldTotpSecretEncrypted]
+ return ok
+}
+
+// ResetTotpSecretEncrypted resets all changes to the "totp_secret_encrypted" field.
+func (m *UserMutation) ResetTotpSecretEncrypted() {
+ m.totp_secret_encrypted = nil
+ delete(m.clearedFields, user.FieldTotpSecretEncrypted)
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (m *UserMutation) SetTotpEnabled(b bool) {
+ m.totp_enabled = &b
+}
+
+// TotpEnabled returns the value of the "totp_enabled" field in the mutation.
+func (m *UserMutation) TotpEnabled() (r bool, exists bool) {
+ v := m.totp_enabled
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotpEnabled returns the old "totp_enabled" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldTotpEnabled(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotpEnabled is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotpEnabled requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotpEnabled: %w", err)
+ }
+ return oldValue.TotpEnabled, nil
+}
+
+// ResetTotpEnabled resets all changes to the "totp_enabled" field.
+func (m *UserMutation) ResetTotpEnabled() {
+ m.totp_enabled = nil
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (m *UserMutation) SetTotpEnabledAt(t time.Time) {
+ m.totp_enabled_at = &t
+}
+
+// TotpEnabledAt returns the value of the "totp_enabled_at" field in the mutation.
+func (m *UserMutation) TotpEnabledAt() (r time.Time, exists bool) {
+ v := m.totp_enabled_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotpEnabledAt returns the old "totp_enabled_at" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldTotpEnabledAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotpEnabledAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotpEnabledAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotpEnabledAt: %w", err)
+ }
+ return oldValue.TotpEnabledAt, nil
+}
+
+// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
+func (m *UserMutation) ClearTotpEnabledAt() {
+ m.totp_enabled_at = nil
+ m.clearedFields[user.FieldTotpEnabledAt] = struct{}{}
+}
+
+// TotpEnabledAtCleared returns if the "totp_enabled_at" field was cleared in this mutation.
+func (m *UserMutation) TotpEnabledAtCleared() bool {
+ _, ok := m.clearedFields[user.FieldTotpEnabledAt]
+ return ok
+}
+
+// ResetTotpEnabledAt resets all changes to the "totp_enabled_at" field.
+func (m *UserMutation) ResetTotpEnabledAt() {
+ m.totp_enabled_at = nil
+ delete(m.clearedFields, user.FieldTotpEnabledAt)
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -15654,6 +17464,60 @@ func (m *UserMutation) ResetAssignedSubscriptions() {
m.removedassigned_subscriptions = nil
}
+// AddAnnouncementReadIDs adds the "announcement_reads" edge to the AnnouncementRead entity by ids.
+func (m *UserMutation) AddAnnouncementReadIDs(ids ...int64) {
+ if m.announcement_reads == nil {
+ m.announcement_reads = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.announcement_reads[ids[i]] = struct{}{}
+ }
+}
+
+// ClearAnnouncementReads clears the "announcement_reads" edge to the AnnouncementRead entity.
+func (m *UserMutation) ClearAnnouncementReads() {
+ m.clearedannouncement_reads = true
+}
+
+// AnnouncementReadsCleared reports if the "announcement_reads" edge to the AnnouncementRead entity was cleared.
+func (m *UserMutation) AnnouncementReadsCleared() bool {
+ return m.clearedannouncement_reads
+}
+
+// RemoveAnnouncementReadIDs removes the "announcement_reads" edge to the AnnouncementRead entity by IDs.
+func (m *UserMutation) RemoveAnnouncementReadIDs(ids ...int64) {
+ if m.removedannouncement_reads == nil {
+ m.removedannouncement_reads = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.announcement_reads, ids[i])
+ m.removedannouncement_reads[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedAnnouncementReads returns the removed IDs of the "announcement_reads" edge to the AnnouncementRead entity.
+func (m *UserMutation) RemovedAnnouncementReadsIDs() (ids []int64) {
+ for id := range m.removedannouncement_reads {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// AnnouncementReadsIDs returns the "announcement_reads" edge IDs in the mutation.
+func (m *UserMutation) AnnouncementReadsIDs() (ids []int64) {
+ for id := range m.announcement_reads {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetAnnouncementReads resets all changes to the "announcement_reads" edge.
+func (m *UserMutation) ResetAnnouncementReads() {
+ m.announcement_reads = nil
+ m.clearedannouncement_reads = false
+ m.removedannouncement_reads = nil
+}
+
// AddAllowedGroupIDs adds the "allowed_groups" edge to the Group entity by ids.
func (m *UserMutation) AddAllowedGroupIDs(ids ...int64) {
if m.allowed_groups == nil {
@@ -15904,7 +17768,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
- fields := make([]string, 0, 11)
+ fields := make([]string, 0, 14)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -15938,6 +17802,15 @@ func (m *UserMutation) Fields() []string {
if m.notes != nil {
fields = append(fields, user.FieldNotes)
}
+ if m.totp_secret_encrypted != nil {
+ fields = append(fields, user.FieldTotpSecretEncrypted)
+ }
+ if m.totp_enabled != nil {
+ fields = append(fields, user.FieldTotpEnabled)
+ }
+ if m.totp_enabled_at != nil {
+ fields = append(fields, user.FieldTotpEnabledAt)
+ }
return fields
}
@@ -15968,6 +17841,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.Username()
case user.FieldNotes:
return m.Notes()
+ case user.FieldTotpSecretEncrypted:
+ return m.TotpSecretEncrypted()
+ case user.FieldTotpEnabled:
+ return m.TotpEnabled()
+ case user.FieldTotpEnabledAt:
+ return m.TotpEnabledAt()
}
return nil, false
}
@@ -15999,6 +17878,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldUsername(ctx)
case user.FieldNotes:
return m.OldNotes(ctx)
+ case user.FieldTotpSecretEncrypted:
+ return m.OldTotpSecretEncrypted(ctx)
+ case user.FieldTotpEnabled:
+ return m.OldTotpEnabled(ctx)
+ case user.FieldTotpEnabledAt:
+ return m.OldTotpEnabledAt(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -16085,6 +17970,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetNotes(v)
return nil
+ case user.FieldTotpSecretEncrypted:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotpSecretEncrypted(v)
+ return nil
+ case user.FieldTotpEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotpEnabled(v)
+ return nil
+ case user.FieldTotpEnabledAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotpEnabledAt(v)
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -16145,6 +18051,12 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldDeletedAt) {
fields = append(fields, user.FieldDeletedAt)
}
+ if m.FieldCleared(user.FieldTotpSecretEncrypted) {
+ fields = append(fields, user.FieldTotpSecretEncrypted)
+ }
+ if m.FieldCleared(user.FieldTotpEnabledAt) {
+ fields = append(fields, user.FieldTotpEnabledAt)
+ }
return fields
}
@@ -16162,6 +18074,12 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldDeletedAt:
m.ClearDeletedAt()
return nil
+ case user.FieldTotpSecretEncrypted:
+ m.ClearTotpSecretEncrypted()
+ return nil
+ case user.FieldTotpEnabledAt:
+ m.ClearTotpEnabledAt()
+ return nil
}
return fmt.Errorf("unknown User nullable field %s", name)
}
@@ -16203,13 +18121,22 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldNotes:
m.ResetNotes()
return nil
+ case user.FieldTotpSecretEncrypted:
+ m.ResetTotpSecretEncrypted()
+ return nil
+ case user.FieldTotpEnabled:
+ m.ResetTotpEnabled()
+ return nil
+ case user.FieldTotpEnabledAt:
+ m.ResetTotpEnabledAt()
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
// AddedEdges returns all edge names that were set/added in this mutation.
func (m *UserMutation) AddedEdges() []string {
- edges := make([]string, 0, 8)
+ edges := make([]string, 0, 9)
if m.api_keys != nil {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -16222,6 +18149,9 @@ func (m *UserMutation) AddedEdges() []string {
if m.assigned_subscriptions != nil {
edges = append(edges, user.EdgeAssignedSubscriptions)
}
+ if m.announcement_reads != nil {
+ edges = append(edges, user.EdgeAnnouncementReads)
+ }
if m.allowed_groups != nil {
edges = append(edges, user.EdgeAllowedGroups)
}
@@ -16265,6 +18195,12 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case user.EdgeAnnouncementReads:
+ ids := make([]ent.Value, 0, len(m.announcement_reads))
+ for id := range m.announcement_reads {
+ ids = append(ids, id)
+ }
+ return ids
case user.EdgeAllowedGroups:
ids := make([]ent.Value, 0, len(m.allowed_groups))
for id := range m.allowed_groups {
@@ -16295,7 +18231,7 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value {
// RemovedEdges returns all edge names that were removed in this mutation.
func (m *UserMutation) RemovedEdges() []string {
- edges := make([]string, 0, 8)
+ edges := make([]string, 0, 9)
if m.removedapi_keys != nil {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -16308,6 +18244,9 @@ func (m *UserMutation) RemovedEdges() []string {
if m.removedassigned_subscriptions != nil {
edges = append(edges, user.EdgeAssignedSubscriptions)
}
+ if m.removedannouncement_reads != nil {
+ edges = append(edges, user.EdgeAnnouncementReads)
+ }
if m.removedallowed_groups != nil {
edges = append(edges, user.EdgeAllowedGroups)
}
@@ -16351,6 +18290,12 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case user.EdgeAnnouncementReads:
+ ids := make([]ent.Value, 0, len(m.removedannouncement_reads))
+ for id := range m.removedannouncement_reads {
+ ids = append(ids, id)
+ }
+ return ids
case user.EdgeAllowedGroups:
ids := make([]ent.Value, 0, len(m.removedallowed_groups))
for id := range m.removedallowed_groups {
@@ -16381,7 +18326,7 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value {
// ClearedEdges returns all edge names that were cleared in this mutation.
func (m *UserMutation) ClearedEdges() []string {
- edges := make([]string, 0, 8)
+ edges := make([]string, 0, 9)
if m.clearedapi_keys {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -16394,6 +18339,9 @@ func (m *UserMutation) ClearedEdges() []string {
if m.clearedassigned_subscriptions {
edges = append(edges, user.EdgeAssignedSubscriptions)
}
+ if m.clearedannouncement_reads {
+ edges = append(edges, user.EdgeAnnouncementReads)
+ }
if m.clearedallowed_groups {
edges = append(edges, user.EdgeAllowedGroups)
}
@@ -16421,6 +18369,8 @@ func (m *UserMutation) EdgeCleared(name string) bool {
return m.clearedsubscriptions
case user.EdgeAssignedSubscriptions:
return m.clearedassigned_subscriptions
+ case user.EdgeAnnouncementReads:
+ return m.clearedannouncement_reads
case user.EdgeAllowedGroups:
return m.clearedallowed_groups
case user.EdgeUsageLogs:
@@ -16457,6 +18407,9 @@ func (m *UserMutation) ResetEdge(name string) error {
case user.EdgeAssignedSubscriptions:
m.ResetAssignedSubscriptions()
return nil
+ case user.EdgeAnnouncementReads:
+ m.ResetAnnouncementReads()
+ return nil
case user.EdgeAllowedGroups:
m.ResetAllowedGroups()
return nil
diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go
index 785cb4e6..613c5913 100644
--- a/backend/ent/predicate/predicate.go
+++ b/backend/ent/predicate/predicate.go
@@ -15,6 +15,12 @@ type Account func(*sql.Selector)
// AccountGroup is the predicate function for accountgroup builders.
type AccountGroup func(*sql.Selector)
+// Announcement is the predicate function for announcement builders.
+type Announcement func(*sql.Selector)
+
+// AnnouncementRead is the predicate function for announcementread builders.
+type AnnouncementRead func(*sql.Selector)
+
// Group is the predicate function for group builders.
type Group func(*sql.Selector)
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index 15b02ad1..aeced47a 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -7,6 +7,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocode"
@@ -210,6 +212,56 @@ func init() {
accountgroupDescCreatedAt := accountgroupFields[3].Descriptor()
// accountgroup.DefaultCreatedAt holds the default value on creation for the created_at field.
accountgroup.DefaultCreatedAt = accountgroupDescCreatedAt.Default.(func() time.Time)
+ announcementFields := schema.Announcement{}.Fields()
+ _ = announcementFields
+ // announcementDescTitle is the schema descriptor for title field.
+ announcementDescTitle := announcementFields[0].Descriptor()
+ // announcement.TitleValidator is a validator for the "title" field. It is called by the builders before save.
+ announcement.TitleValidator = func() func(string) error {
+ validators := announcementDescTitle.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(title string) error {
+ for _, fn := range fns {
+ if err := fn(title); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // announcementDescContent is the schema descriptor for content field.
+ announcementDescContent := announcementFields[1].Descriptor()
+ // announcement.ContentValidator is a validator for the "content" field. It is called by the builders before save.
+ announcement.ContentValidator = announcementDescContent.Validators[0].(func(string) error)
+ // announcementDescStatus is the schema descriptor for status field.
+ announcementDescStatus := announcementFields[2].Descriptor()
+ // announcement.DefaultStatus holds the default value on creation for the status field.
+ announcement.DefaultStatus = announcementDescStatus.Default.(string)
+ // announcement.StatusValidator is a validator for the "status" field. It is called by the builders before save.
+ announcement.StatusValidator = announcementDescStatus.Validators[0].(func(string) error)
+ // announcementDescCreatedAt is the schema descriptor for created_at field.
+ announcementDescCreatedAt := announcementFields[8].Descriptor()
+ // announcement.DefaultCreatedAt holds the default value on creation for the created_at field.
+ announcement.DefaultCreatedAt = announcementDescCreatedAt.Default.(func() time.Time)
+ // announcementDescUpdatedAt is the schema descriptor for updated_at field.
+ announcementDescUpdatedAt := announcementFields[9].Descriptor()
+ // announcement.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ announcement.DefaultUpdatedAt = announcementDescUpdatedAt.Default.(func() time.Time)
+ // announcement.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ announcement.UpdateDefaultUpdatedAt = announcementDescUpdatedAt.UpdateDefault.(func() time.Time)
+ announcementreadFields := schema.AnnouncementRead{}.Fields()
+ _ = announcementreadFields
+ // announcementreadDescReadAt is the schema descriptor for read_at field.
+ announcementreadDescReadAt := announcementreadFields[2].Descriptor()
+ // announcementread.DefaultReadAt holds the default value on creation for the read_at field.
+ announcementread.DefaultReadAt = announcementreadDescReadAt.Default.(func() time.Time)
+ // announcementreadDescCreatedAt is the schema descriptor for created_at field.
+ announcementreadDescCreatedAt := announcementreadFields[3].Descriptor()
+ // announcementread.DefaultCreatedAt holds the default value on creation for the created_at field.
+ announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time)
groupMixin := schema.Group{}.Mixin()
groupMixinHooks1 := groupMixin[1].Hooks()
group.Hooks[0] = groupMixinHooks1[0]
@@ -740,6 +792,10 @@ func init() {
userDescNotes := userFields[7].Descriptor()
// user.DefaultNotes holds the default value on creation for the notes field.
user.DefaultNotes = userDescNotes.Default.(string)
+ // userDescTotpEnabled is the schema descriptor for totp_enabled field.
+ userDescTotpEnabled := userFields[9].Descriptor()
+ // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
+ user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go
index dd79ba96..1cfecc2d 100644
--- a/backend/ent/schema/account.go
+++ b/backend/ent/schema/account.go
@@ -4,7 +4,7 @@ package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
- "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -111,7 +111,7 @@ func (Account) Fields() []ent.Field {
// status: 账户状态,如 "active", "error", "disabled"
field.String("status").
MaxLen(20).
- Default(service.StatusActive),
+ Default(domain.StatusActive),
// error_message: 错误信息,记录账户异常时的详细信息
field.String("error_message").
diff --git a/backend/ent/schema/announcement.go b/backend/ent/schema/announcement.go
new file mode 100644
index 00000000..1568778f
--- /dev/null
+++ b/backend/ent/schema/announcement.go
@@ -0,0 +1,90 @@
+package schema
+
+import (
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/domain"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// Announcement holds the schema definition for the Announcement entity.
+//
+// 删除策略:硬删除(已读记录通过外键级联删除)
+type Announcement struct {
+ ent.Schema
+}
+
+func (Announcement) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "announcements"},
+ }
+}
+
+func (Announcement) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("title").
+ MaxLen(200).
+ NotEmpty().
+ Comment("公告标题"),
+ field.String("content").
+ SchemaType(map[string]string{dialect.Postgres: "text"}).
+ NotEmpty().
+ Comment("公告内容(支持 Markdown)"),
+ field.String("status").
+ MaxLen(20).
+ Default(domain.AnnouncementStatusDraft).
+ Comment("状态: draft, active, archived"),
+ field.JSON("targeting", domain.AnnouncementTargeting{}).
+ Optional().
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
+ Comment("展示条件(JSON 规则)"),
+ field.Time("starts_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}).
+ Comment("开始展示时间(为空表示立即生效)"),
+ field.Time("ends_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}).
+ Comment("结束展示时间(为空表示永久生效)"),
+ field.Int64("created_by").
+ Optional().
+ Nillable().
+ Comment("创建人用户ID(管理员)"),
+ field.Int64("updated_by").
+ Optional().
+ Nillable().
+ Comment("更新人用户ID(管理员)"),
+ field.Time("created_at").
+ Immutable().
+ Default(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("updated_at").
+ Default(time.Now).
+ UpdateDefault(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (Announcement) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.To("reads", AnnouncementRead.Type),
+ }
+}
+
+func (Announcement) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("status"),
+ index.Fields("created_at"),
+ index.Fields("starts_at"),
+ index.Fields("ends_at"),
+ }
+}
diff --git a/backend/ent/schema/announcement_read.go b/backend/ent/schema/announcement_read.go
new file mode 100644
index 00000000..e0b50777
--- /dev/null
+++ b/backend/ent/schema/announcement_read.go
@@ -0,0 +1,65 @@
+package schema
+
+import (
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// AnnouncementRead holds the schema definition for the AnnouncementRead entity.
+//
+// 记录用户对公告的已读状态(首次已读时间)。
+type AnnouncementRead struct {
+ ent.Schema
+}
+
+func (AnnouncementRead) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "announcement_reads"},
+ }
+}
+
+func (AnnouncementRead) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("announcement_id"),
+ field.Int64("user_id"),
+ field.Time("read_at").
+ Default(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}).
+ Comment("用户首次已读时间"),
+ field.Time("created_at").
+ Immutable().
+ Default(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (AnnouncementRead) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("announcement", Announcement.Type).
+ Ref("reads").
+ Field("announcement_id").
+ Unique().
+ Required(),
+ edge.From("user", User.Type).
+ Ref("announcement_reads").
+ Field("user_id").
+ Unique().
+ Required(),
+ }
+}
+
+func (AnnouncementRead) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("announcement_id"),
+ index.Fields("user_id"),
+ index.Fields("read_at"),
+ index.Fields("announcement_id", "user_id").Unique(),
+ }
+}
diff --git a/backend/ent/schema/api_key.go b/backend/ent/schema/api_key.go
index 1b206089..1c2d4bd4 100644
--- a/backend/ent/schema/api_key.go
+++ b/backend/ent/schema/api_key.go
@@ -2,7 +2,7 @@ package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
- "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect/entsql"
@@ -45,7 +45,7 @@ func (APIKey) Fields() []ent.Field {
Nillable(),
field.String("status").
MaxLen(20).
- Default(service.StatusActive),
+ Default(domain.StatusActive),
field.JSON("ip_whitelist", []string{}).
Optional().
Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"),
diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go
index 7fa04b8a..65b57754 100644
--- a/backend/ent/schema/group.go
+++ b/backend/ent/schema/group.go
@@ -2,7 +2,7 @@ package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
- "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -49,15 +49,15 @@ func (Group) Fields() []ent.Field {
Default(false),
field.String("status").
MaxLen(20).
- Default(service.StatusActive),
+ Default(domain.StatusActive),
// Subscription-related fields (added by migration 003)
field.String("platform").
MaxLen(50).
- Default(service.PlatformAnthropic),
+ Default(domain.PlatformAnthropic),
field.String("subscription_type").
MaxLen(20).
- Default(service.SubscriptionTypeStandard),
+ Default(domain.SubscriptionTypeStandard),
field.Float("daily_limit_usd").
Optional().
Nillable().
diff --git a/backend/ent/schema/promo_code.go b/backend/ent/schema/promo_code.go
index c3bb824b..3dd08c0e 100644
--- a/backend/ent/schema/promo_code.go
+++ b/backend/ent/schema/promo_code.go
@@ -3,7 +3,7 @@ package schema
import (
"time"
- "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -49,7 +49,7 @@ func (PromoCode) Fields() []ent.Field {
Comment("已使用次数"),
field.String("status").
MaxLen(20).
- Default(service.PromoCodeStatusActive).
+ Default(domain.PromoCodeStatusActive).
Comment("状态: active, disabled"),
field.Time("expires_at").
Optional().
diff --git a/backend/ent/schema/redeem_code.go b/backend/ent/schema/redeem_code.go
index b4664e06..6fb86148 100644
--- a/backend/ent/schema/redeem_code.go
+++ b/backend/ent/schema/redeem_code.go
@@ -3,7 +3,7 @@ package schema
import (
"time"
- "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -41,13 +41,13 @@ func (RedeemCode) Fields() []ent.Field {
Unique(),
field.String("type").
MaxLen(20).
- Default(service.RedeemTypeBalance),
+ Default(domain.RedeemTypeBalance),
field.Float("value").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0),
field.String("status").
MaxLen(20).
- Default(service.StatusUnused),
+ Default(domain.StatusUnused),
field.Int64("used_by").
Optional().
Nillable(),
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index 79dc2286..d443ef45 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -2,7 +2,7 @@ package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
- "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -43,7 +43,7 @@ func (User) Fields() []ent.Field {
NotEmpty(),
field.String("role").
MaxLen(20).
- Default(service.RoleUser),
+ Default(domain.RoleUser),
field.Float("balance").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0),
@@ -51,7 +51,7 @@ func (User) Fields() []ent.Field {
Default(5),
field.String("status").
MaxLen(20).
- Default(service.StatusActive),
+ Default(domain.StatusActive),
// Optional profile fields (added later; default '' in DB migration)
field.String("username").
@@ -61,6 +61,17 @@ func (User) Fields() []ent.Field {
field.String("notes").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Default(""),
+
+ // TOTP 双因素认证字段
+ field.String("totp_secret_encrypted").
+ SchemaType(map[string]string{dialect.Postgres: "text"}).
+ Optional().
+ Nillable(),
+ field.Bool("totp_enabled").
+ Default(false),
+ field.Time("totp_enabled_at").
+ Optional().
+ Nillable(),
}
}
@@ -70,6 +81,7 @@ func (User) Edges() []ent.Edge {
edge.To("redeem_codes", RedeemCode.Type),
edge.To("subscriptions", UserSubscription.Type),
edge.To("assigned_subscriptions", UserSubscription.Type),
+ edge.To("announcement_reads", AnnouncementRead.Type),
edge.To("allowed_groups", Group.Type).
Through("user_allowed_groups", UserAllowedGroup.Type),
edge.To("usage_logs", UsageLog.Type),
diff --git a/backend/ent/schema/user_subscription.go b/backend/ent/schema/user_subscription.go
index b21f4083..fa13612b 100644
--- a/backend/ent/schema/user_subscription.go
+++ b/backend/ent/schema/user_subscription.go
@@ -4,7 +4,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
- "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -44,7 +44,7 @@ func (UserSubscription) Fields() []ent.Field {
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.String("status").
MaxLen(20).
- Default(service.SubscriptionStatusActive),
+ Default(domain.SubscriptionStatusActive),
field.Time("daily_window_start").
Optional().
diff --git a/backend/ent/tx.go b/backend/ent/tx.go
index 7ff16ec8..702bdf90 100644
--- a/backend/ent/tx.go
+++ b/backend/ent/tx.go
@@ -20,6 +20,10 @@ type Tx struct {
Account *AccountClient
// AccountGroup is the client for interacting with the AccountGroup builders.
AccountGroup *AccountGroupClient
+ // Announcement is the client for interacting with the Announcement builders.
+ Announcement *AnnouncementClient
+ // AnnouncementRead is the client for interacting with the AnnouncementRead builders.
+ AnnouncementRead *AnnouncementReadClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// PromoCode is the client for interacting with the PromoCode builders.
@@ -180,6 +184,8 @@ func (tx *Tx) init() {
tx.APIKey = NewAPIKeyClient(tx.config)
tx.Account = NewAccountClient(tx.config)
tx.AccountGroup = NewAccountGroupClient(tx.config)
+ tx.Announcement = NewAnnouncementClient(tx.config)
+ tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
diff --git a/backend/ent/user.go b/backend/ent/user.go
index 0b9a48cc..2435aa1b 100644
--- a/backend/ent/user.go
+++ b/backend/ent/user.go
@@ -39,6 +39,12 @@ type User struct {
Username string `json:"username,omitempty"`
// Notes holds the value of the "notes" field.
Notes string `json:"notes,omitempty"`
+ // TotpSecretEncrypted holds the value of the "totp_secret_encrypted" field.
+ TotpSecretEncrypted *string `json:"totp_secret_encrypted,omitempty"`
+ // TotpEnabled holds the value of the "totp_enabled" field.
+ TotpEnabled bool `json:"totp_enabled,omitempty"`
+ // TotpEnabledAt holds the value of the "totp_enabled_at" field.
+ TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -55,6 +61,8 @@ type UserEdges struct {
Subscriptions []*UserSubscription `json:"subscriptions,omitempty"`
// AssignedSubscriptions holds the value of the assigned_subscriptions edge.
AssignedSubscriptions []*UserSubscription `json:"assigned_subscriptions,omitempty"`
+ // AnnouncementReads holds the value of the announcement_reads edge.
+ AnnouncementReads []*AnnouncementRead `json:"announcement_reads,omitempty"`
// AllowedGroups holds the value of the allowed_groups edge.
AllowedGroups []*Group `json:"allowed_groups,omitempty"`
// UsageLogs holds the value of the usage_logs edge.
@@ -67,7 +75,7 @@ type UserEdges struct {
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
- loadedTypes [9]bool
+ loadedTypes [10]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@@ -106,10 +114,19 @@ func (e UserEdges) AssignedSubscriptionsOrErr() ([]*UserSubscription, error) {
return nil, &NotLoadedError{edge: "assigned_subscriptions"}
}
+// AnnouncementReadsOrErr returns the AnnouncementReads value or an error if the edge
+// was not loaded in eager-loading.
+func (e UserEdges) AnnouncementReadsOrErr() ([]*AnnouncementRead, error) {
+ if e.loadedTypes[4] {
+ return e.AnnouncementReads, nil
+ }
+ return nil, &NotLoadedError{edge: "announcement_reads"}
+}
+
// AllowedGroupsOrErr returns the AllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) AllowedGroupsOrErr() ([]*Group, error) {
- if e.loadedTypes[4] {
+ if e.loadedTypes[5] {
return e.AllowedGroups, nil
}
return nil, &NotLoadedError{edge: "allowed_groups"}
@@ -118,7 +135,7 @@ func (e UserEdges) AllowedGroupsOrErr() ([]*Group, error) {
// UsageLogsOrErr returns the UsageLogs value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UsageLogsOrErr() ([]*UsageLog, error) {
- if e.loadedTypes[5] {
+ if e.loadedTypes[6] {
return e.UsageLogs, nil
}
return nil, &NotLoadedError{edge: "usage_logs"}
@@ -127,7 +144,7 @@ func (e UserEdges) UsageLogsOrErr() ([]*UsageLog, error) {
// AttributeValuesOrErr returns the AttributeValues value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) AttributeValuesOrErr() ([]*UserAttributeValue, error) {
- if e.loadedTypes[6] {
+ if e.loadedTypes[7] {
return e.AttributeValues, nil
}
return nil, &NotLoadedError{edge: "attribute_values"}
@@ -136,7 +153,7 @@ func (e UserEdges) AttributeValuesOrErr() ([]*UserAttributeValue, error) {
// PromoCodeUsagesOrErr returns the PromoCodeUsages value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) PromoCodeUsagesOrErr() ([]*PromoCodeUsage, error) {
- if e.loadedTypes[7] {
+ if e.loadedTypes[8] {
return e.PromoCodeUsages, nil
}
return nil, &NotLoadedError{edge: "promo_code_usages"}
@@ -145,7 +162,7 @@ func (e UserEdges) PromoCodeUsagesOrErr() ([]*PromoCodeUsage, error) {
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
- if e.loadedTypes[8] {
+ if e.loadedTypes[9] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@@ -156,13 +173,15 @@ func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
+ case user.FieldTotpEnabled:
+ values[i] = new(sql.NullBool)
case user.FieldBalance:
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
- case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes:
+ case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
values[i] = new(sql.NullString)
- case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt:
+ case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -252,6 +271,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Notes = value.String
}
+ case user.FieldTotpSecretEncrypted:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field totp_secret_encrypted", values[i])
+ } else if value.Valid {
+ _m.TotpSecretEncrypted = new(string)
+ *_m.TotpSecretEncrypted = value.String
+ }
+ case user.FieldTotpEnabled:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field totp_enabled", values[i])
+ } else if value.Valid {
+ _m.TotpEnabled = value.Bool
+ }
+ case user.FieldTotpEnabledAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field totp_enabled_at", values[i])
+ } else if value.Valid {
+ _m.TotpEnabledAt = new(time.Time)
+ *_m.TotpEnabledAt = value.Time
+ }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -285,6 +324,11 @@ func (_m *User) QueryAssignedSubscriptions() *UserSubscriptionQuery {
return NewUserClient(_m.config).QueryAssignedSubscriptions(_m)
}
+// QueryAnnouncementReads queries the "announcement_reads" edge of the User entity.
+func (_m *User) QueryAnnouncementReads() *AnnouncementReadQuery {
+ return NewUserClient(_m.config).QueryAnnouncementReads(_m)
+}
+
// QueryAllowedGroups queries the "allowed_groups" edge of the User entity.
func (_m *User) QueryAllowedGroups() *GroupQuery {
return NewUserClient(_m.config).QueryAllowedGroups(_m)
@@ -367,6 +411,19 @@ func (_m *User) String() string {
builder.WriteString(", ")
builder.WriteString("notes=")
builder.WriteString(_m.Notes)
+ builder.WriteString(", ")
+ if v := _m.TotpSecretEncrypted; v != nil {
+ builder.WriteString("totp_secret_encrypted=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("totp_enabled=")
+ builder.WriteString(fmt.Sprintf("%v", _m.TotpEnabled))
+ builder.WriteString(", ")
+ if v := _m.TotpEnabledAt; v != nil {
+ builder.WriteString("totp_enabled_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go
index 1be1d871..ae9418ff 100644
--- a/backend/ent/user/user.go
+++ b/backend/ent/user/user.go
@@ -37,6 +37,12 @@ const (
FieldUsername = "username"
// FieldNotes holds the string denoting the notes field in the database.
FieldNotes = "notes"
+ // FieldTotpSecretEncrypted holds the string denoting the totp_secret_encrypted field in the database.
+ FieldTotpSecretEncrypted = "totp_secret_encrypted"
+ // FieldTotpEnabled holds the string denoting the totp_enabled field in the database.
+ FieldTotpEnabled = "totp_enabled"
+ // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
+ FieldTotpEnabledAt = "totp_enabled_at"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -45,6 +51,8 @@ const (
EdgeSubscriptions = "subscriptions"
// EdgeAssignedSubscriptions holds the string denoting the assigned_subscriptions edge name in mutations.
EdgeAssignedSubscriptions = "assigned_subscriptions"
+ // EdgeAnnouncementReads holds the string denoting the announcement_reads edge name in mutations.
+ EdgeAnnouncementReads = "announcement_reads"
// EdgeAllowedGroups holds the string denoting the allowed_groups edge name in mutations.
EdgeAllowedGroups = "allowed_groups"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
@@ -85,6 +93,13 @@ const (
AssignedSubscriptionsInverseTable = "user_subscriptions"
// AssignedSubscriptionsColumn is the table column denoting the assigned_subscriptions relation/edge.
AssignedSubscriptionsColumn = "assigned_by"
+ // AnnouncementReadsTable is the table that holds the announcement_reads relation/edge.
+ AnnouncementReadsTable = "announcement_reads"
+ // AnnouncementReadsInverseTable is the table name for the AnnouncementRead entity.
+ // It exists in this package in order to avoid circular dependency with the "announcementread" package.
+ AnnouncementReadsInverseTable = "announcement_reads"
+ // AnnouncementReadsColumn is the table column denoting the announcement_reads relation/edge.
+ AnnouncementReadsColumn = "user_id"
// AllowedGroupsTable is the table that holds the allowed_groups relation/edge. The primary key declared below.
AllowedGroupsTable = "user_allowed_groups"
// AllowedGroupsInverseTable is the table name for the Group entity.
@@ -134,6 +149,9 @@ var Columns = []string{
FieldStatus,
FieldUsername,
FieldNotes,
+ FieldTotpSecretEncrypted,
+ FieldTotpEnabled,
+ FieldTotpEnabledAt,
}
var (
@@ -188,6 +206,8 @@ var (
UsernameValidator func(string) error
// DefaultNotes holds the default value on creation for the "notes" field.
DefaultNotes string
+ // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
+ DefaultTotpEnabled bool
)
// OrderOption defines the ordering options for the User queries.
@@ -253,6 +273,21 @@ func ByNotes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldNotes, opts...).ToFunc()
}
+// ByTotpSecretEncrypted orders the results by the totp_secret_encrypted field.
+func ByTotpSecretEncrypted(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotpSecretEncrypted, opts...).ToFunc()
+}
+
+// ByTotpEnabled orders the results by the totp_enabled field.
+func ByTotpEnabled(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotpEnabled, opts...).ToFunc()
+}
+
+// ByTotpEnabledAt orders the results by the totp_enabled_at field.
+func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
+}
+
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -309,6 +344,20 @@ func ByAssignedSubscriptions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOp
}
}
+// ByAnnouncementReadsCount orders the results by announcement_reads count.
+func ByAnnouncementReadsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newAnnouncementReadsStep(), opts...)
+ }
+}
+
+// ByAnnouncementReads orders the results by announcement_reads terms.
+func ByAnnouncementReads(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAnnouncementReadsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
// ByAllowedGroupsCount orders the results by allowed_groups count.
func ByAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -406,6 +455,13 @@ func newAssignedSubscriptionsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, AssignedSubscriptionsTable, AssignedSubscriptionsColumn),
)
}
+func newAnnouncementReadsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AnnouncementReadsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AnnouncementReadsTable, AnnouncementReadsColumn),
+ )
+}
func newAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go
index 6a460f10..1de61037 100644
--- a/backend/ent/user/where.go
+++ b/backend/ent/user/where.go
@@ -110,6 +110,21 @@ func Notes(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldNotes, v))
}
+// TotpSecretEncrypted applies equality check predicate on the "totp_secret_encrypted" field. It's identical to TotpSecretEncryptedEQ.
+func TotpSecretEncrypted(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
+}
+
+// TotpEnabled applies equality check predicate on the "totp_enabled" field. It's identical to TotpEnabledEQ.
+func TotpEnabled(v bool) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
+}
+
+// TotpEnabledAt applies equality check predicate on the "totp_enabled_at" field. It's identical to TotpEnabledAtEQ.
+func TotpEnabledAt(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -710,6 +725,141 @@ func NotesContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldNotes, v))
}
+// TotpSecretEncryptedEQ applies the EQ predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedEQ(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedNEQ applies the NEQ predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedNEQ(v string) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedIn applies the In predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldIn(FieldTotpSecretEncrypted, vs...))
+}
+
+// TotpSecretEncryptedNotIn applies the NotIn predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedNotIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldTotpSecretEncrypted, vs...))
+}
+
+// TotpSecretEncryptedGT applies the GT predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedGT(v string) predicate.User {
+ return predicate.User(sql.FieldGT(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedGTE applies the GTE predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedGTE(v string) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedLT applies the LT predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedLT(v string) predicate.User {
+ return predicate.User(sql.FieldLT(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedLTE applies the LTE predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedLTE(v string) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedContains applies the Contains predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedContains(v string) predicate.User {
+ return predicate.User(sql.FieldContains(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedHasPrefix applies the HasPrefix predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedHasPrefix(v string) predicate.User {
+ return predicate.User(sql.FieldHasPrefix(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedHasSuffix applies the HasSuffix predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedHasSuffix(v string) predicate.User {
+ return predicate.User(sql.FieldHasSuffix(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedIsNil applies the IsNil predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldTotpSecretEncrypted))
+}
+
+// TotpSecretEncryptedNotNil applies the NotNil predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldTotpSecretEncrypted))
+}
+
+// TotpSecretEncryptedEqualFold applies the EqualFold predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedEqualFold(v string) predicate.User {
+ return predicate.User(sql.FieldEqualFold(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedContainsFold applies the ContainsFold predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedContainsFold(v string) predicate.User {
+ return predicate.User(sql.FieldContainsFold(FieldTotpSecretEncrypted, v))
+}
+
+// TotpEnabledEQ applies the EQ predicate on the "totp_enabled" field.
+func TotpEnabledEQ(v bool) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
+}
+
+// TotpEnabledNEQ applies the NEQ predicate on the "totp_enabled" field.
+func TotpEnabledNEQ(v bool) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldTotpEnabled, v))
+}
+
+// TotpEnabledAtEQ applies the EQ predicate on the "totp_enabled_at" field.
+func TotpEnabledAtEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
+}
+
+// TotpEnabledAtNEQ applies the NEQ predicate on the "totp_enabled_at" field.
+func TotpEnabledAtNEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldTotpEnabledAt, v))
+}
+
+// TotpEnabledAtIn applies the In predicate on the "totp_enabled_at" field.
+func TotpEnabledAtIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldIn(FieldTotpEnabledAt, vs...))
+}
+
+// TotpEnabledAtNotIn applies the NotIn predicate on the "totp_enabled_at" field.
+func TotpEnabledAtNotIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldTotpEnabledAt, vs...))
+}
+
+// TotpEnabledAtGT applies the GT predicate on the "totp_enabled_at" field.
+func TotpEnabledAtGT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGT(FieldTotpEnabledAt, v))
+}
+
+// TotpEnabledAtGTE applies the GTE predicate on the "totp_enabled_at" field.
+func TotpEnabledAtGTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldTotpEnabledAt, v))
+}
+
+// TotpEnabledAtLT applies the LT predicate on the "totp_enabled_at" field.
+func TotpEnabledAtLT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLT(FieldTotpEnabledAt, v))
+}
+
+// TotpEnabledAtLTE applies the LTE predicate on the "totp_enabled_at" field.
+func TotpEnabledAtLTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldTotpEnabledAt, v))
+}
+
+// TotpEnabledAtIsNil applies the IsNil predicate on the "totp_enabled_at" field.
+func TotpEnabledAtIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldTotpEnabledAt))
+}
+
+// TotpEnabledAtNotNil applies the NotNil predicate on the "totp_enabled_at" field.
+func TotpEnabledAtNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
+}
+
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {
@@ -802,6 +952,29 @@ func HasAssignedSubscriptionsWith(preds ...predicate.UserSubscription) predicate
})
}
+// HasAnnouncementReads applies the HasEdge predicate on the "announcement_reads" edge.
+func HasAnnouncementReads() predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AnnouncementReadsTable, AnnouncementReadsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAnnouncementReadsWith applies the HasEdge predicate on the "announcement_reads" edge with a given conditions (other predicates).
+func HasAnnouncementReadsWith(preds ...predicate.AnnouncementRead) predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := newAnnouncementReadsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
// HasAllowedGroups applies the HasEdge predicate on the "allowed_groups" edge.
func HasAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {
diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go
index e12e476c..f862a580 100644
--- a/backend/ent/user_create.go
+++ b/backend/ent/user_create.go
@@ -11,6 +11,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -167,6 +168,48 @@ func (_c *UserCreate) SetNillableNotes(v *string) *UserCreate {
return _c
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (_c *UserCreate) SetTotpSecretEncrypted(v string) *UserCreate {
+ _c.mutation.SetTotpSecretEncrypted(v)
+ return _c
+}
+
+// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
+func (_c *UserCreate) SetNillableTotpSecretEncrypted(v *string) *UserCreate {
+ if v != nil {
+ _c.SetTotpSecretEncrypted(*v)
+ }
+ return _c
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (_c *UserCreate) SetTotpEnabled(v bool) *UserCreate {
+ _c.mutation.SetTotpEnabled(v)
+ return _c
+}
+
+// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
+func (_c *UserCreate) SetNillableTotpEnabled(v *bool) *UserCreate {
+ if v != nil {
+ _c.SetTotpEnabled(*v)
+ }
+ return _c
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (_c *UserCreate) SetTotpEnabledAt(v time.Time) *UserCreate {
+ _c.mutation.SetTotpEnabledAt(v)
+ return _c
+}
+
+// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
+func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
+ if v != nil {
+ _c.SetTotpEnabledAt(*v)
+ }
+ return _c
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -227,6 +270,21 @@ func (_c *UserCreate) AddAssignedSubscriptions(v ...*UserSubscription) *UserCrea
return _c.AddAssignedSubscriptionIDs(ids...)
}
+// AddAnnouncementReadIDs adds the "announcement_reads" edge to the AnnouncementRead entity by IDs.
+func (_c *UserCreate) AddAnnouncementReadIDs(ids ...int64) *UserCreate {
+ _c.mutation.AddAnnouncementReadIDs(ids...)
+ return _c
+}
+
+// AddAnnouncementReads adds the "announcement_reads" edges to the AnnouncementRead entity.
+func (_c *UserCreate) AddAnnouncementReads(v ...*AnnouncementRead) *UserCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddAnnouncementReadIDs(ids...)
+}
+
// AddAllowedGroupIDs adds the "allowed_groups" edge to the Group entity by IDs.
func (_c *UserCreate) AddAllowedGroupIDs(ids ...int64) *UserCreate {
_c.mutation.AddAllowedGroupIDs(ids...)
@@ -362,6 +420,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultNotes
_c.mutation.SetNotes(v)
}
+ if _, ok := _c.mutation.TotpEnabled(); !ok {
+ v := user.DefaultTotpEnabled
+ _c.mutation.SetTotpEnabled(v)
+ }
return nil
}
@@ -422,6 +484,9 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.Notes(); !ok {
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
}
+ if _, ok := _c.mutation.TotpEnabled(); !ok {
+ return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
+ }
return nil
}
@@ -493,6 +558,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldNotes, field.TypeString, value)
_node.Notes = value
}
+ if value, ok := _c.mutation.TotpSecretEncrypted(); ok {
+ _spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
+ _node.TotpSecretEncrypted = &value
+ }
+ if value, ok := _c.mutation.TotpEnabled(); ok {
+ _spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
+ _node.TotpEnabled = value
+ }
+ if value, ok := _c.mutation.TotpEnabledAt(); ok {
+ _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
+ _node.TotpEnabledAt = &value
+ }
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -557,6 +634,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
+ if nodes := _c.mutation.AnnouncementReadsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AnnouncementReadsTable,
+ Columns: []string{user.AnnouncementReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
if nodes := _c.mutation.AllowedGroupsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,
@@ -815,6 +908,54 @@ func (u *UserUpsert) UpdateNotes() *UserUpsert {
return u
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (u *UserUpsert) SetTotpSecretEncrypted(v string) *UserUpsert {
+ u.Set(user.FieldTotpSecretEncrypted, v)
+ return u
+}
+
+// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
+func (u *UserUpsert) UpdateTotpSecretEncrypted() *UserUpsert {
+ u.SetExcluded(user.FieldTotpSecretEncrypted)
+ return u
+}
+
+// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
+func (u *UserUpsert) ClearTotpSecretEncrypted() *UserUpsert {
+ u.SetNull(user.FieldTotpSecretEncrypted)
+ return u
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (u *UserUpsert) SetTotpEnabled(v bool) *UserUpsert {
+ u.Set(user.FieldTotpEnabled, v)
+ return u
+}
+
+// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
+func (u *UserUpsert) UpdateTotpEnabled() *UserUpsert {
+ u.SetExcluded(user.FieldTotpEnabled)
+ return u
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (u *UserUpsert) SetTotpEnabledAt(v time.Time) *UserUpsert {
+ u.Set(user.FieldTotpEnabledAt, v)
+ return u
+}
+
+// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
+func (u *UserUpsert) UpdateTotpEnabledAt() *UserUpsert {
+ u.SetExcluded(user.FieldTotpEnabledAt)
+ return u
+}
+
+// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
+func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
+ u.SetNull(user.FieldTotpEnabledAt)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1021,6 +1162,62 @@ func (u *UserUpsertOne) UpdateNotes() *UserUpsertOne {
})
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (u *UserUpsertOne) SetTotpSecretEncrypted(v string) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotpSecretEncrypted(v)
+ })
+}
+
+// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateTotpSecretEncrypted() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotpSecretEncrypted()
+ })
+}
+
+// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
+func (u *UserUpsertOne) ClearTotpSecretEncrypted() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearTotpSecretEncrypted()
+ })
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (u *UserUpsertOne) SetTotpEnabled(v bool) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotpEnabled(v)
+ })
+}
+
+// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateTotpEnabled() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotpEnabled()
+ })
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (u *UserUpsertOne) SetTotpEnabledAt(v time.Time) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotpEnabledAt(v)
+ })
+}
+
+// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateTotpEnabledAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotpEnabledAt()
+ })
+}
+
+// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
+func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearTotpEnabledAt()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1393,6 +1590,62 @@ func (u *UserUpsertBulk) UpdateNotes() *UserUpsertBulk {
})
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (u *UserUpsertBulk) SetTotpSecretEncrypted(v string) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotpSecretEncrypted(v)
+ })
+}
+
+// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateTotpSecretEncrypted() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotpSecretEncrypted()
+ })
+}
+
+// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
+func (u *UserUpsertBulk) ClearTotpSecretEncrypted() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearTotpSecretEncrypted()
+ })
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (u *UserUpsertBulk) SetTotpEnabled(v bool) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotpEnabled(v)
+ })
+}
+
+// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateTotpEnabled() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotpEnabled()
+ })
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (u *UserUpsertBulk) SetTotpEnabledAt(v time.Time) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotpEnabledAt(v)
+ })
+}
+
+// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateTotpEnabledAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotpEnabledAt()
+ })
+}
+
+// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
+func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearTotpEnabledAt()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go
index e66e2dc8..4b56e16f 100644
--- a/backend/ent/user_query.go
+++ b/backend/ent/user_query.go
@@ -13,6 +13,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
@@ -36,6 +37,7 @@ type UserQuery struct {
withRedeemCodes *RedeemCodeQuery
withSubscriptions *UserSubscriptionQuery
withAssignedSubscriptions *UserSubscriptionQuery
+ withAnnouncementReads *AnnouncementReadQuery
withAllowedGroups *GroupQuery
withUsageLogs *UsageLogQuery
withAttributeValues *UserAttributeValueQuery
@@ -166,6 +168,28 @@ func (_q *UserQuery) QueryAssignedSubscriptions() *UserSubscriptionQuery {
return query
}
+// QueryAnnouncementReads chains the current query on the "announcement_reads" edge.
+func (_q *UserQuery) QueryAnnouncementReads() *AnnouncementReadQuery {
+ query := (&AnnouncementReadClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, selector),
+ sqlgraph.To(announcementread.Table, announcementread.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.AnnouncementReadsTable, user.AnnouncementReadsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
// QueryAllowedGroups chains the current query on the "allowed_groups" edge.
func (_q *UserQuery) QueryAllowedGroups() *GroupQuery {
query := (&GroupClient{config: _q.config}).Query()
@@ -472,6 +496,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withRedeemCodes: _q.withRedeemCodes.Clone(),
withSubscriptions: _q.withSubscriptions.Clone(),
withAssignedSubscriptions: _q.withAssignedSubscriptions.Clone(),
+ withAnnouncementReads: _q.withAnnouncementReads.Clone(),
withAllowedGroups: _q.withAllowedGroups.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
withAttributeValues: _q.withAttributeValues.Clone(),
@@ -527,6 +552,17 @@ func (_q *UserQuery) WithAssignedSubscriptions(opts ...func(*UserSubscriptionQue
return _q
}
+// WithAnnouncementReads tells the query-builder to eager-load the nodes that are connected to
+// the "announcement_reads" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *UserQuery) WithAnnouncementReads(opts ...func(*AnnouncementReadQuery)) *UserQuery {
+ query := (&AnnouncementReadClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAnnouncementReads = query
+ return _q
+}
+
// WithAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithAllowedGroups(opts ...func(*GroupQuery)) *UserQuery {
@@ -660,11 +696,12 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
- loadedTypes = [9]bool{
+ loadedTypes = [10]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
_q.withAssignedSubscriptions != nil,
+ _q.withAnnouncementReads != nil,
_q.withAllowedGroups != nil,
_q.withUsageLogs != nil,
_q.withAttributeValues != nil,
@@ -723,6 +760,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
+ if query := _q.withAnnouncementReads; query != nil {
+ if err := _q.loadAnnouncementReads(ctx, query, nodes,
+ func(n *User) { n.Edges.AnnouncementReads = []*AnnouncementRead{} },
+ func(n *User, e *AnnouncementRead) { n.Edges.AnnouncementReads = append(n.Edges.AnnouncementReads, e) }); err != nil {
+ return nil, err
+ }
+ }
if query := _q.withAllowedGroups; query != nil {
if err := _q.loadAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.AllowedGroups = []*Group{} },
@@ -887,6 +931,36 @@ func (_q *UserQuery) loadAssignedSubscriptions(ctx context.Context, query *UserS
}
return nil
}
+func (_q *UserQuery) loadAnnouncementReads(ctx context.Context, query *AnnouncementReadQuery, nodes []*User, init func(*User), assign func(*User, *AnnouncementRead)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*User)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(announcementread.FieldUserID)
+ }
+ query.Where(predicate.AnnouncementRead(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(user.AnnouncementReadsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.UserID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
func (_q *UserQuery) loadAllowedGroups(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error {
edgeIDs := make([]driver.Value, len(nodes))
byID := make(map[int64]*User)
diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go
index cf189fea..80222c92 100644
--- a/backend/ent/user_update.go
+++ b/backend/ent/user_update.go
@@ -11,6 +11,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
@@ -187,6 +188,60 @@ func (_u *UserUpdate) SetNillableNotes(v *string) *UserUpdate {
return _u
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (_u *UserUpdate) SetTotpSecretEncrypted(v string) *UserUpdate {
+ _u.mutation.SetTotpSecretEncrypted(v)
+ return _u
+}
+
+// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableTotpSecretEncrypted(v *string) *UserUpdate {
+ if v != nil {
+ _u.SetTotpSecretEncrypted(*v)
+ }
+ return _u
+}
+
+// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
+func (_u *UserUpdate) ClearTotpSecretEncrypted() *UserUpdate {
+ _u.mutation.ClearTotpSecretEncrypted()
+ return _u
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (_u *UserUpdate) SetTotpEnabled(v bool) *UserUpdate {
+ _u.mutation.SetTotpEnabled(v)
+ return _u
+}
+
+// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableTotpEnabled(v *bool) *UserUpdate {
+ if v != nil {
+ _u.SetTotpEnabled(*v)
+ }
+ return _u
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (_u *UserUpdate) SetTotpEnabledAt(v time.Time) *UserUpdate {
+ _u.mutation.SetTotpEnabledAt(v)
+ return _u
+}
+
+// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableTotpEnabledAt(v *time.Time) *UserUpdate {
+ if v != nil {
+ _u.SetTotpEnabledAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
+func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
+ _u.mutation.ClearTotpEnabledAt()
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -247,6 +302,21 @@ func (_u *UserUpdate) AddAssignedSubscriptions(v ...*UserSubscription) *UserUpda
return _u.AddAssignedSubscriptionIDs(ids...)
}
+// AddAnnouncementReadIDs adds the "announcement_reads" edge to the AnnouncementRead entity by IDs.
+func (_u *UserUpdate) AddAnnouncementReadIDs(ids ...int64) *UserUpdate {
+ _u.mutation.AddAnnouncementReadIDs(ids...)
+ return _u
+}
+
+// AddAnnouncementReads adds the "announcement_reads" edges to the AnnouncementRead entity.
+func (_u *UserUpdate) AddAnnouncementReads(v ...*AnnouncementRead) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAnnouncementReadIDs(ids...)
+}
+
// AddAllowedGroupIDs adds the "allowed_groups" edge to the Group entity by IDs.
func (_u *UserUpdate) AddAllowedGroupIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAllowedGroupIDs(ids...)
@@ -396,6 +466,27 @@ func (_u *UserUpdate) RemoveAssignedSubscriptions(v ...*UserSubscription) *UserU
return _u.RemoveAssignedSubscriptionIDs(ids...)
}
+// ClearAnnouncementReads clears all "announcement_reads" edges to the AnnouncementRead entity.
+func (_u *UserUpdate) ClearAnnouncementReads() *UserUpdate {
+ _u.mutation.ClearAnnouncementReads()
+ return _u
+}
+
+// RemoveAnnouncementReadIDs removes the "announcement_reads" edge to AnnouncementRead entities by IDs.
+func (_u *UserUpdate) RemoveAnnouncementReadIDs(ids ...int64) *UserUpdate {
+ _u.mutation.RemoveAnnouncementReadIDs(ids...)
+ return _u
+}
+
+// RemoveAnnouncementReads removes "announcement_reads" edges to AnnouncementRead entities.
+func (_u *UserUpdate) RemoveAnnouncementReads(v ...*AnnouncementRead) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAnnouncementReadIDs(ids...)
+}
+
// ClearAllowedGroups clears all "allowed_groups" edges to the Group entity.
func (_u *UserUpdate) ClearAllowedGroups() *UserUpdate {
_u.mutation.ClearAllowedGroups()
@@ -603,6 +694,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
+ if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
+ _spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
+ }
+ if _u.mutation.TotpSecretEncryptedCleared() {
+ _spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
+ }
+ if value, ok := _u.mutation.TotpEnabled(); ok {
+ _spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.TotpEnabledAt(); ok {
+ _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpEnabledAtCleared() {
+ _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -783,6 +889,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.AnnouncementReadsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AnnouncementReadsTable,
+ Columns: []string{user.AnnouncementReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAnnouncementReadsIDs(); len(nodes) > 0 && !_u.mutation.AnnouncementReadsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AnnouncementReadsTable,
+ Columns: []string{user.AnnouncementReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AnnouncementReadsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AnnouncementReadsTable,
+ Columns: []string{user.AnnouncementReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
if _u.mutation.AllowedGroupsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,
@@ -1147,6 +1298,60 @@ func (_u *UserUpdateOne) SetNillableNotes(v *string) *UserUpdateOne {
return _u
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (_u *UserUpdateOne) SetTotpSecretEncrypted(v string) *UserUpdateOne {
+ _u.mutation.SetTotpSecretEncrypted(v)
+ return _u
+}
+
+// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableTotpSecretEncrypted(v *string) *UserUpdateOne {
+ if v != nil {
+ _u.SetTotpSecretEncrypted(*v)
+ }
+ return _u
+}
+
+// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
+func (_u *UserUpdateOne) ClearTotpSecretEncrypted() *UserUpdateOne {
+ _u.mutation.ClearTotpSecretEncrypted()
+ return _u
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (_u *UserUpdateOne) SetTotpEnabled(v bool) *UserUpdateOne {
+ _u.mutation.SetTotpEnabled(v)
+ return _u
+}
+
+// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableTotpEnabled(v *bool) *UserUpdateOne {
+ if v != nil {
+ _u.SetTotpEnabled(*v)
+ }
+ return _u
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (_u *UserUpdateOne) SetTotpEnabledAt(v time.Time) *UserUpdateOne {
+ _u.mutation.SetTotpEnabledAt(v)
+ return _u
+}
+
+// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableTotpEnabledAt(v *time.Time) *UserUpdateOne {
+ if v != nil {
+ _u.SetTotpEnabledAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
+func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
+ _u.mutation.ClearTotpEnabledAt()
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1207,6 +1412,21 @@ func (_u *UserUpdateOne) AddAssignedSubscriptions(v ...*UserSubscription) *UserU
return _u.AddAssignedSubscriptionIDs(ids...)
}
+// AddAnnouncementReadIDs adds the "announcement_reads" edge to the AnnouncementRead entity by IDs.
+func (_u *UserUpdateOne) AddAnnouncementReadIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.AddAnnouncementReadIDs(ids...)
+ return _u
+}
+
+// AddAnnouncementReads adds the "announcement_reads" edges to the AnnouncementRead entity.
+func (_u *UserUpdateOne) AddAnnouncementReads(v ...*AnnouncementRead) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAnnouncementReadIDs(ids...)
+}
+
// AddAllowedGroupIDs adds the "allowed_groups" edge to the Group entity by IDs.
func (_u *UserUpdateOne) AddAllowedGroupIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAllowedGroupIDs(ids...)
@@ -1356,6 +1576,27 @@ func (_u *UserUpdateOne) RemoveAssignedSubscriptions(v ...*UserSubscription) *Us
return _u.RemoveAssignedSubscriptionIDs(ids...)
}
+// ClearAnnouncementReads clears all "announcement_reads" edges to the AnnouncementRead entity.
+func (_u *UserUpdateOne) ClearAnnouncementReads() *UserUpdateOne {
+ _u.mutation.ClearAnnouncementReads()
+ return _u
+}
+
+// RemoveAnnouncementReadIDs removes the "announcement_reads" edge to AnnouncementRead entities by IDs.
+func (_u *UserUpdateOne) RemoveAnnouncementReadIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.RemoveAnnouncementReadIDs(ids...)
+ return _u
+}
+
+// RemoveAnnouncementReads removes "announcement_reads" edges to AnnouncementRead entities.
+func (_u *UserUpdateOne) RemoveAnnouncementReads(v ...*AnnouncementRead) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAnnouncementReadIDs(ids...)
+}
+
// ClearAllowedGroups clears all "allowed_groups" edges to the Group entity.
func (_u *UserUpdateOne) ClearAllowedGroups() *UserUpdateOne {
_u.mutation.ClearAllowedGroups()
@@ -1593,6 +1834,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
+ if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
+ _spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
+ }
+ if _u.mutation.TotpSecretEncryptedCleared() {
+ _spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
+ }
+ if value, ok := _u.mutation.TotpEnabled(); ok {
+ _spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.TotpEnabledAt(); ok {
+ _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpEnabledAtCleared() {
+ _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1773,6 +2029,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.AnnouncementReadsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AnnouncementReadsTable,
+ Columns: []string{user.AnnouncementReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAnnouncementReadsIDs(); len(nodes) > 0 && !_u.mutation.AnnouncementReadsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AnnouncementReadsTable,
+ Columns: []string{user.AnnouncementReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AnnouncementReadsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AnnouncementReadsTable,
+ Columns: []string{user.AnnouncementReadsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
if _u.mutation.AllowedGroupsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,
diff --git a/backend/go.mod b/backend/go.mod
index fd429b07..4c3e6246 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -1,6 +1,6 @@
module github.com/Wei-Shaw/sub2api
-go 1.25.5
+go 1.25.6
require (
entgo.io/ent v0.14.5
@@ -37,6 +37,7 @@ require (
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
+ github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -106,6 +107,7 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
+ github.com/pquerna/otp v1.5.0 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.57.1 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
diff --git a/backend/go.sum b/backend/go.sum
index aa10718c..0addb5bb 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -20,6 +20,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
+github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
+github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -217,6 +219,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
+github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
+github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 7d1b10e8..1cab3111 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -47,6 +47,7 @@ type Config struct {
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
+ Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
@@ -479,6 +480,8 @@ type RedisConfig struct {
PoolSize int `mapstructure:"pool_size"`
// MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟
MinIdleConns int `mapstructure:"min_idle_conns"`
+ // EnableTLS: 是否启用 TLS/SSL 连接
+ EnableTLS bool `mapstructure:"enable_tls"`
}
func (r *RedisConfig) Address() string {
@@ -531,6 +534,16 @@ type JWTConfig struct {
ExpireHour int `mapstructure:"expire_hour"`
}
+// TotpConfig TOTP 双因素认证配置
+type TotpConfig struct {
+ // EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥(32 字节 hex 编码)
+ // 如果为空,将自动生成一个随机密钥(仅适用于开发环境)
+ EncryptionKey string `mapstructure:"encryption_key"`
+ // EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成)
+ // 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
+ EncryptionKeyConfigured bool `mapstructure:"-"`
+}
+
type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
@@ -691,6 +704,20 @@ func Load() (*Config, error) {
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
}
+ // Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
+ cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
+ if cfg.Totp.EncryptionKey == "" {
+ key, err := generateJWTSecret(32) // Reuse the same random generation function
+ if err != nil {
+ return nil, fmt.Errorf("generate totp encryption key error: %w", err)
+ }
+ cfg.Totp.EncryptionKey = key
+ cfg.Totp.EncryptionKeyConfigured = false
+ log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
+ } else {
+ cfg.Totp.EncryptionKeyConfigured = true
+ }
+
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
@@ -802,6 +829,7 @@ func setDefaults() {
viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128)
viper.SetDefault("redis.min_idle_conns", 10)
+ viper.SetDefault("redis.enable_tls", false)
// Ops (vNext)
viper.SetDefault("ops.enabled", true)
@@ -821,6 +849,9 @@ func setDefaults() {
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
+ // TOTP
+ viper.SetDefault("totp.encryption_key", "")
+
// Default
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.
diff --git a/backend/internal/domain/announcement.go b/backend/internal/domain/announcement.go
new file mode 100644
index 00000000..7dc9a9cc
--- /dev/null
+++ b/backend/internal/domain/announcement.go
@@ -0,0 +1,226 @@
+package domain
+
+import (
+ "strings"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+const (
+ AnnouncementStatusDraft = "draft"
+ AnnouncementStatusActive = "active"
+ AnnouncementStatusArchived = "archived"
+)
+
+const (
+ AnnouncementConditionTypeSubscription = "subscription"
+ AnnouncementConditionTypeBalance = "balance"
+)
+
+const (
+ AnnouncementOperatorIn = "in"
+ AnnouncementOperatorGT = "gt"
+ AnnouncementOperatorGTE = "gte"
+ AnnouncementOperatorLT = "lt"
+ AnnouncementOperatorLTE = "lte"
+ AnnouncementOperatorEQ = "eq"
+)
+
+var (
+ ErrAnnouncementNotFound = infraerrors.NotFound("ANNOUNCEMENT_NOT_FOUND", "announcement not found")
+ ErrAnnouncementInvalidTarget = infraerrors.BadRequest("ANNOUNCEMENT_INVALID_TARGET", "invalid announcement targeting rules")
+)
+
+type AnnouncementTargeting struct {
+ // AnyOf 表示 OR:任意一个条件组满足即可展示。
+ AnyOf []AnnouncementConditionGroup `json:"any_of,omitempty"`
+}
+
+type AnnouncementConditionGroup struct {
+ // AllOf 表示 AND:组内所有条件都满足才算命中该组。
+ AllOf []AnnouncementCondition `json:"all_of,omitempty"`
+}
+
+type AnnouncementCondition struct {
+ // Type: subscription | balance
+ Type string `json:"type"`
+
+ // Operator:
+ // - subscription: in
+ // - balance: gt/gte/lt/lte/eq
+ Operator string `json:"operator"`
+
+ // subscription 条件:匹配的订阅套餐(group_id)
+ GroupIDs []int64 `json:"group_ids,omitempty"`
+
+ // balance 条件:比较阈值
+ Value float64 `json:"value,omitempty"`
+}
+
+func (t AnnouncementTargeting) Matches(balance float64, activeSubscriptionGroupIDs map[int64]struct{}) bool {
+ // 空规则:展示给所有用户
+ if len(t.AnyOf) == 0 {
+ return true
+ }
+
+ for _, group := range t.AnyOf {
+ if len(group.AllOf) == 0 {
+ // 空条件组不命中(避免 OR 中出现无条件 “全命中”)
+ continue
+ }
+ allMatched := true
+ for _, cond := range group.AllOf {
+ if !cond.Matches(balance, activeSubscriptionGroupIDs) {
+ allMatched = false
+ break
+ }
+ }
+ if allMatched {
+ return true
+ }
+ }
+
+ return false
+}
+
+func (c AnnouncementCondition) Matches(balance float64, activeSubscriptionGroupIDs map[int64]struct{}) bool {
+ switch c.Type {
+ case AnnouncementConditionTypeSubscription:
+ if c.Operator != AnnouncementOperatorIn {
+ return false
+ }
+ if len(c.GroupIDs) == 0 {
+ return false
+ }
+ if len(activeSubscriptionGroupIDs) == 0 {
+ return false
+ }
+ for _, gid := range c.GroupIDs {
+ if _, ok := activeSubscriptionGroupIDs[gid]; ok {
+ return true
+ }
+ }
+ return false
+
+ case AnnouncementConditionTypeBalance:
+ switch c.Operator {
+ case AnnouncementOperatorGT:
+ return balance > c.Value
+ case AnnouncementOperatorGTE:
+ return balance >= c.Value
+ case AnnouncementOperatorLT:
+ return balance < c.Value
+ case AnnouncementOperatorLTE:
+ return balance <= c.Value
+ case AnnouncementOperatorEQ:
+ return balance == c.Value
+ default:
+ return false
+ }
+
+ default:
+ return false
+ }
+}
+
+func (t AnnouncementTargeting) NormalizeAndValidate() (AnnouncementTargeting, error) {
+ normalized := AnnouncementTargeting{AnyOf: make([]AnnouncementConditionGroup, 0, len(t.AnyOf))}
+
+ // 允许空 targeting(展示给所有用户)
+ if len(t.AnyOf) == 0 {
+ return normalized, nil
+ }
+
+ if len(t.AnyOf) > 50 {
+ return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
+ }
+
+ for _, g := range t.AnyOf {
+ if len(g.AllOf) == 0 {
+ return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
+ }
+ if len(g.AllOf) > 50 {
+ return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
+ }
+
+ group := AnnouncementConditionGroup{AllOf: make([]AnnouncementCondition, 0, len(g.AllOf))}
+ for _, c := range g.AllOf {
+ cond := AnnouncementCondition{
+ Type: strings.TrimSpace(c.Type),
+ Operator: strings.TrimSpace(c.Operator),
+ Value: c.Value,
+ }
+ for _, gid := range c.GroupIDs {
+ if gid <= 0 {
+ return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
+ }
+ cond.GroupIDs = append(cond.GroupIDs, gid)
+ }
+
+ if err := cond.validate(); err != nil {
+ return AnnouncementTargeting{}, err
+ }
+ group.AllOf = append(group.AllOf, cond)
+ }
+
+ normalized.AnyOf = append(normalized.AnyOf, group)
+ }
+
+ return normalized, nil
+}
+
+func (c AnnouncementCondition) validate() error {
+ switch c.Type {
+ case AnnouncementConditionTypeSubscription:
+ if c.Operator != AnnouncementOperatorIn {
+ return ErrAnnouncementInvalidTarget
+ }
+ if len(c.GroupIDs) == 0 {
+ return ErrAnnouncementInvalidTarget
+ }
+ return nil
+
+ case AnnouncementConditionTypeBalance:
+ switch c.Operator {
+ case AnnouncementOperatorGT, AnnouncementOperatorGTE, AnnouncementOperatorLT, AnnouncementOperatorLTE, AnnouncementOperatorEQ:
+ return nil
+ default:
+ return ErrAnnouncementInvalidTarget
+ }
+
+ default:
+ return ErrAnnouncementInvalidTarget
+ }
+}
+
+type Announcement struct {
+ ID int64
+ Title string
+ Content string
+ Status string
+ Targeting AnnouncementTargeting
+ StartsAt *time.Time
+ EndsAt *time.Time
+ CreatedBy *int64
+ UpdatedBy *int64
+ CreatedAt time.Time
+ UpdatedAt time.Time
+}
+
+func (a *Announcement) IsActiveAt(now time.Time) bool {
+ if a == nil {
+ return false
+ }
+ if a.Status != AnnouncementStatusActive {
+ return false
+ }
+ if a.StartsAt != nil && now.Before(*a.StartsAt) {
+ return false
+ }
+ if a.EndsAt != nil && !now.Before(*a.EndsAt) {
+ // ends_at 语义:到点即下线
+ return false
+ }
+ return true
+}
diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go
new file mode 100644
index 00000000..1aee7777
--- /dev/null
+++ b/backend/internal/domain/constants.go
@@ -0,0 +1,66 @@
+package domain
+
+// Status constants
+const (
+ StatusActive = "active"
+ StatusDisabled = "disabled"
+ StatusError = "error"
+ StatusUnused = "unused"
+ StatusUsed = "used"
+ StatusExpired = "expired"
+)
+
+// Role constants
+const (
+ RoleAdmin = "admin"
+ RoleUser = "user"
+)
+
+// Platform constants
+const (
+ PlatformAnthropic = "anthropic"
+ PlatformOpenAI = "openai"
+ PlatformGemini = "gemini"
+ PlatformAntigravity = "antigravity"
+ PlatformSora = "sora"
+)
+
+// Account type constants
+const (
+ AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
+ AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
+ AccountTypeAPIKey = "apikey" // API Key类型账号
+)
+
+// Redeem type constants
+const (
+ RedeemTypeBalance = "balance"
+ RedeemTypeConcurrency = "concurrency"
+ RedeemTypeSubscription = "subscription"
+ RedeemTypeInvitation = "invitation"
+)
+
+// PromoCode status constants
+const (
+ PromoCodeStatusActive = "active"
+ PromoCodeStatusDisabled = "disabled"
+)
+
+// Admin adjustment type constants
+const (
+ AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
+ AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
+)
+
+// Group subscription type constants
+const (
+ SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
+ SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
+)
+
+// Subscription status constants
+const (
+ SubscriptionStatusActive = "active"
+ SubscriptionStatusExpired = "expired"
+ SubscriptionStatusSuspended = "suspended"
+)
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 188aa0ec..bbf5d026 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -547,9 +547,18 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
}
- // 如果 project_id 获取失败,先更新凭证,再标记账户为 error
+ // 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
+ // 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
+ if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
+ if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
+ newCredentials["project_id"] = oldProjectID
+ }
+ }
+
+ // 如果 project_id 获取失败,更新凭证但不标记为 error
+ // LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
if tokenInfo.ProjectIDMissing {
- // 先更新凭证
+ // 先更新凭证(token 本身刷新成功了)
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
@@ -557,14 +566,10 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
return
}
- // 标记账户为 error
- if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil {
- response.InternalError(c, "Failed to set account error: "+setErr.Error())
- return
- }
+ // 不标记为 error,只返回警告信息
response.Success(c, gin.H{
- "message": "Token refreshed but project_id is missing, account marked as error",
- "warning": "missing_project_id",
+ "message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
+ "warning": "missing_project_id_temporary",
})
return
}
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index b820a3fb..ea2ea963 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -290,5 +290,9 @@ func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*ser
return &code, nil
}
+func (s *stubAdminService) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]service.RedeemCode, int64, float64, error) {
+ return s.redeems, int64(len(s.redeems)), 100.0, nil
+}
+
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)
diff --git a/backend/internal/handler/admin/announcement_handler.go b/backend/internal/handler/admin/announcement_handler.go
new file mode 100644
index 00000000..0b5d0fbc
--- /dev/null
+++ b/backend/internal/handler/admin/announcement_handler.go
@@ -0,0 +1,246 @@
+package admin
+
+import (
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// AnnouncementHandler handles admin announcement management
+type AnnouncementHandler struct {
+ announcementService *service.AnnouncementService
+}
+
+// NewAnnouncementHandler creates a new admin announcement handler
+func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
+ return &AnnouncementHandler{
+ announcementService: announcementService,
+ }
+}
+
+type CreateAnnouncementRequest struct {
+ Title string `json:"title" binding:"required"`
+ Content string `json:"content" binding:"required"`
+ Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
+ Targeting service.AnnouncementTargeting `json:"targeting"`
+ StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
+ EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
+}
+
+type UpdateAnnouncementRequest struct {
+ Title *string `json:"title"`
+ Content *string `json:"content"`
+ Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
+ Targeting *service.AnnouncementTargeting `json:"targeting"`
+ StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
+ EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
+}
+
+// List handles listing announcements with filters
+// GET /api/v1/admin/announcements
+func (h *AnnouncementHandler) List(c *gin.Context) {
+ page, pageSize := response.ParsePagination(c)
+ status := strings.TrimSpace(c.Query("status"))
+ search := strings.TrimSpace(c.Query("search"))
+ if len(search) > 200 {
+ search = search[:200]
+ }
+
+ params := pagination.PaginationParams{
+ Page: page,
+ PageSize: pageSize,
+ }
+
+ items, paginationResult, err := h.announcementService.List(
+ c.Request.Context(),
+ params,
+ service.AnnouncementListFilters{Status: status, Search: search},
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ out := make([]dto.Announcement, 0, len(items))
+ for i := range items {
+ out = append(out, *dto.AnnouncementFromService(&items[i]))
+ }
+ response.Paginated(c, out, paginationResult.Total, page, pageSize)
+}
+
+// GetByID handles getting an announcement by ID
+// GET /api/v1/admin/announcements/:id
+func (h *AnnouncementHandler) GetByID(c *gin.Context) {
+ announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil || announcementID <= 0 {
+ response.BadRequest(c, "Invalid announcement ID")
+ return
+ }
+
+ item, err := h.announcementService.GetByID(c.Request.Context(), announcementID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, dto.AnnouncementFromService(item))
+}
+
+// Create handles creating a new announcement
+// POST /api/v1/admin/announcements
+func (h *AnnouncementHandler) Create(c *gin.Context) {
+ var req CreateAnnouncementRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not found in context")
+ return
+ }
+
+ input := &service.CreateAnnouncementInput{
+ Title: req.Title,
+ Content: req.Content,
+ Status: req.Status,
+ Targeting: req.Targeting,
+ ActorID: &subject.UserID,
+ }
+
+ if req.StartsAt != nil && *req.StartsAt > 0 {
+ t := time.Unix(*req.StartsAt, 0)
+ input.StartsAt = &t
+ }
+ if req.EndsAt != nil && *req.EndsAt > 0 {
+ t := time.Unix(*req.EndsAt, 0)
+ input.EndsAt = &t
+ }
+
+ created, err := h.announcementService.Create(c.Request.Context(), input)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, dto.AnnouncementFromService(created))
+}
+
+// Update handles updating an announcement
+// PUT /api/v1/admin/announcements/:id
+func (h *AnnouncementHandler) Update(c *gin.Context) {
+ announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil || announcementID <= 0 {
+ response.BadRequest(c, "Invalid announcement ID")
+ return
+ }
+
+ var req UpdateAnnouncementRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not found in context")
+ return
+ }
+
+ input := &service.UpdateAnnouncementInput{
+ Title: req.Title,
+ Content: req.Content,
+ Status: req.Status,
+ Targeting: req.Targeting,
+ ActorID: &subject.UserID,
+ }
+
+ if req.StartsAt != nil {
+ if *req.StartsAt == 0 {
+ var cleared *time.Time = nil
+ input.StartsAt = &cleared
+ } else {
+ t := time.Unix(*req.StartsAt, 0)
+ ptr := &t
+ input.StartsAt = &ptr
+ }
+ }
+
+ if req.EndsAt != nil {
+ if *req.EndsAt == 0 {
+ var cleared *time.Time = nil
+ input.EndsAt = &cleared
+ } else {
+ t := time.Unix(*req.EndsAt, 0)
+ ptr := &t
+ input.EndsAt = &ptr
+ }
+ }
+
+ updated, err := h.announcementService.Update(c.Request.Context(), announcementID, input)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, dto.AnnouncementFromService(updated))
+}
+
+// Delete handles deleting an announcement
+// DELETE /api/v1/admin/announcements/:id
+func (h *AnnouncementHandler) Delete(c *gin.Context) {
+ announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil || announcementID <= 0 {
+ response.BadRequest(c, "Invalid announcement ID")
+ return
+ }
+
+ if err := h.announcementService.Delete(c.Request.Context(), announcementID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "Announcement deleted successfully"})
+}
+
+// ListReadStatus handles listing users read status for an announcement
+// GET /api/v1/admin/announcements/:id/read-status
+func (h *AnnouncementHandler) ListReadStatus(c *gin.Context) {
+ announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil || announcementID <= 0 {
+ response.BadRequest(c, "Invalid announcement ID")
+ return
+ }
+
+ page, pageSize := response.ParsePagination(c)
+ params := pagination.PaginationParams{
+ Page: page,
+ PageSize: pageSize,
+ }
+ search := strings.TrimSpace(c.Query("search"))
+ if len(search) > 200 {
+ search = search[:200]
+ }
+
+ items, paginationResult, err := h.announcementService.ListUserReadStatus(
+ c.Request.Context(),
+ announcementID,
+ params,
+ search,
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Paginated(c, items, paginationResult.Total, page, pageSize)
+}
diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go
index 328c8fce..f7f6c893 100644
--- a/backend/internal/handler/admin/group_handler.go
+++ b/backend/internal/handler/admin/group_handler.go
@@ -47,6 +47,8 @@ type CreateGroupRequest struct {
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
+ // 从指定分组复制账号(创建后自动绑定)
+ CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
// UpdateGroupRequest represents update group request
@@ -74,6 +76,8 @@ type UpdateGroupRequest struct {
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
+ // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
+ CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
// List handles listing all groups with pagination
@@ -183,6 +187,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
+ CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -229,6 +234,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
+ CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
response.ErrorFrom(c, err)
diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go
index f1b68334..e229385f 100644
--- a/backend/internal/handler/admin/redeem_handler.go
+++ b/backend/internal/handler/admin/redeem_handler.go
@@ -29,7 +29,7 @@ func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
// GenerateRedeemCodesRequest represents generate redeem codes request
type GenerateRedeemCodesRequest struct {
Count int `json:"count" binding:"required,min=1,max=100"`
- Type string `json:"type" binding:"required,oneof=balance concurrency subscription"`
+ Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value" binding:"min=0"`
GroupID *int64 `json:"group_id"` // 订阅类型必填
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 0e3e0a2f..1e723ee5 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -48,6 +48,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
+ PasswordResetEnabled: settings.PasswordResetEnabled,
+ InvitationCodeEnabled: settings.InvitationCodeEnabled,
+ TotpEnabled: settings.TotpEnabled,
+ TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
@@ -70,6 +74,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
+ PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
+ PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
@@ -89,9 +95,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
- RegistrationEnabled bool `json:"registration_enabled"`
- EmailVerifyEnabled bool `json:"email_verify_enabled"`
- PromoCodeEnabled bool `json:"promo_code_enabled"`
+ RegistrationEnabled bool `json:"registration_enabled"`
+ EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ PromoCodeEnabled bool `json:"promo_code_enabled"`
+ PasswordResetEnabled bool `json:"password_reset_enabled"`
+ InvitationCodeEnabled bool `json:"invitation_code_enabled"`
+ TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
@@ -114,14 +123,16 @@ type UpdateSettingsRequest struct {
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// OEM设置
- SiteName string `json:"site_name"`
- SiteLogo string `json:"site_logo"`
- SiteSubtitle string `json:"site_subtitle"`
- APIBaseURL string `json:"api_base_url"`
- ContactInfo string `json:"contact_info"`
- DocURL string `json:"doc_url"`
- HomeContent string `json:"home_content"`
- HideCcsImportButton bool `json:"hide_ccs_import_button"`
+ SiteName string `json:"site_name"`
+ SiteLogo string `json:"site_logo"`
+ SiteSubtitle string `json:"site_subtitle"`
+ APIBaseURL string `json:"api_base_url"`
+ ContactInfo string `json:"contact_info"`
+ DocURL string `json:"doc_url"`
+ HomeContent string `json:"home_content"`
+ HideCcsImportButton bool `json:"hide_ccs_import_button"`
+ PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
+ PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
@@ -198,6 +209,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
+ // TOTP 双因素认证参数验证
+ // 只有手动配置了加密密钥才允许启用 TOTP 功能
+ if req.TotpEnabled && !previousSettings.TotpEnabled {
+ // 尝试启用 TOTP,检查加密密钥是否已手动配置
+ if !h.settingService.IsTotpEncryptionKeyConfigured() {
+ response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.")
+ return
+ }
+ }
+
// LinuxDo Connect 参数验证
if req.LinuxDoConnectEnabled {
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
@@ -227,6 +248,34 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
+ // “购买订阅”页面配置验证
+ purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled
+ if req.PurchaseSubscriptionEnabled != nil {
+ purchaseEnabled = *req.PurchaseSubscriptionEnabled
+ }
+ purchaseURL := previousSettings.PurchaseSubscriptionURL
+ if req.PurchaseSubscriptionURL != nil {
+ purchaseURL = strings.TrimSpace(*req.PurchaseSubscriptionURL)
+ }
+
+ // - 启用时要求 URL 合法且非空
+ // - 禁用时允许为空;若提供了 URL 也做基本校验,避免误配置
+ if purchaseEnabled {
+ if purchaseURL == "" {
+ response.BadRequest(c, "Purchase Subscription URL is required when enabled")
+ return
+ }
+ if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
+ response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
+ return
+ }
+ } else if purchaseURL != "" {
+ if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
+ response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
+ return
+ }
+ }
+
// Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds
@@ -240,40 +289,45 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
- RegistrationEnabled: req.RegistrationEnabled,
- EmailVerifyEnabled: req.EmailVerifyEnabled,
- PromoCodeEnabled: req.PromoCodeEnabled,
- SMTPHost: req.SMTPHost,
- SMTPPort: req.SMTPPort,
- SMTPUsername: req.SMTPUsername,
- SMTPPassword: req.SMTPPassword,
- SMTPFrom: req.SMTPFrom,
- SMTPFromName: req.SMTPFromName,
- SMTPUseTLS: req.SMTPUseTLS,
- TurnstileEnabled: req.TurnstileEnabled,
- TurnstileSiteKey: req.TurnstileSiteKey,
- TurnstileSecretKey: req.TurnstileSecretKey,
- LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
- LinuxDoConnectClientID: req.LinuxDoConnectClientID,
- LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
- LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
- SiteName: req.SiteName,
- SiteLogo: req.SiteLogo,
- SiteSubtitle: req.SiteSubtitle,
- APIBaseURL: req.APIBaseURL,
- ContactInfo: req.ContactInfo,
- DocURL: req.DocURL,
- HomeContent: req.HomeContent,
- HideCcsImportButton: req.HideCcsImportButton,
- DefaultConcurrency: req.DefaultConcurrency,
- DefaultBalance: req.DefaultBalance,
- EnableModelFallback: req.EnableModelFallback,
- FallbackModelAnthropic: req.FallbackModelAnthropic,
- FallbackModelOpenAI: req.FallbackModelOpenAI,
- FallbackModelGemini: req.FallbackModelGemini,
- FallbackModelAntigravity: req.FallbackModelAntigravity,
- EnableIdentityPatch: req.EnableIdentityPatch,
- IdentityPatchPrompt: req.IdentityPatchPrompt,
+ RegistrationEnabled: req.RegistrationEnabled,
+ EmailVerifyEnabled: req.EmailVerifyEnabled,
+ PromoCodeEnabled: req.PromoCodeEnabled,
+ PasswordResetEnabled: req.PasswordResetEnabled,
+ InvitationCodeEnabled: req.InvitationCodeEnabled,
+ TotpEnabled: req.TotpEnabled,
+ SMTPHost: req.SMTPHost,
+ SMTPPort: req.SMTPPort,
+ SMTPUsername: req.SMTPUsername,
+ SMTPPassword: req.SMTPPassword,
+ SMTPFrom: req.SMTPFrom,
+ SMTPFromName: req.SMTPFromName,
+ SMTPUseTLS: req.SMTPUseTLS,
+ TurnstileEnabled: req.TurnstileEnabled,
+ TurnstileSiteKey: req.TurnstileSiteKey,
+ TurnstileSecretKey: req.TurnstileSecretKey,
+ LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
+ LinuxDoConnectClientID: req.LinuxDoConnectClientID,
+ LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
+ LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
+ SiteName: req.SiteName,
+ SiteLogo: req.SiteLogo,
+ SiteSubtitle: req.SiteSubtitle,
+ APIBaseURL: req.APIBaseURL,
+ ContactInfo: req.ContactInfo,
+ DocURL: req.DocURL,
+ HomeContent: req.HomeContent,
+ HideCcsImportButton: req.HideCcsImportButton,
+ PurchaseSubscriptionEnabled: purchaseEnabled,
+ PurchaseSubscriptionURL: purchaseURL,
+ DefaultConcurrency: req.DefaultConcurrency,
+ DefaultBalance: req.DefaultBalance,
+ EnableModelFallback: req.EnableModelFallback,
+ FallbackModelAnthropic: req.FallbackModelAnthropic,
+ FallbackModelOpenAI: req.FallbackModelOpenAI,
+ FallbackModelGemini: req.FallbackModelGemini,
+ FallbackModelAntigravity: req.FallbackModelAntigravity,
+ EnableIdentityPatch: req.EnableIdentityPatch,
+ IdentityPatchPrompt: req.IdentityPatchPrompt,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -318,6 +372,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
+ PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
+ InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
+ TotpEnabled: updatedSettings.TotpEnabled,
+ TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
@@ -340,6 +398,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
DocURL: updatedSettings.DocURL,
HomeContent: updatedSettings.HomeContent,
HideCcsImportButton: updatedSettings.HideCcsImportButton,
+ PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
+ PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
@@ -384,6 +444,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
+ if before.PasswordResetEnabled != after.PasswordResetEnabled {
+ changed = append(changed, "password_reset_enabled")
+ }
+ if before.TotpEnabled != after.TotpEnabled {
+ changed = append(changed, "totp_enabled")
+ }
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}
diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go
index a0d1456f..51995ab1 100644
--- a/backend/internal/handler/admin/subscription_handler.go
+++ b/backend/internal/handler/admin/subscription_handler.go
@@ -77,7 +77,11 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
}
status := c.Query("status")
- subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
+ // Parse sorting parameters
+ sortBy := c.DefaultQuery("sort_by", "created_at")
+ sortOrder := c.DefaultQuery("sort_order", "desc")
+
+ subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go
index 9a5a691f..ac76689d 100644
--- a/backend/internal/handler/admin/user_handler.go
+++ b/backend/internal/handler/admin/user_handler.go
@@ -277,3 +277,44 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
response.Success(c, stats)
}
+
+// GetBalanceHistory handles getting user's balance/concurrency change history
+// GET /api/v1/admin/users/:id/balance-history
+// Query params:
+// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription)
+func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
+ userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid user ID")
+ return
+ }
+
+ page, pageSize := response.ParsePagination(c)
+ codeType := c.Query("type")
+
+ codes, total, totalRecharged, err := h.adminService.GetUserBalanceHistory(c.Request.Context(), userID, page, pageSize, codeType)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Convert to admin DTO (includes notes field for admin visibility)
+ out := make([]dto.AdminRedeemCode, 0, len(codes))
+ for i := range codes {
+ out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
+ }
+
+ // Custom response with total_recharged alongside pagination
+ pages := int((total + int64(pageSize) - 1) / int64(pageSize))
+ if pages < 1 {
+ pages = 1
+ }
+ response.Success(c, gin.H{
+ "items": out,
+ "total": total,
+ "page": page,
+ "page_size": pageSize,
+ "pages": pages,
+ "total_recharged": totalRecharged,
+ })
+}
diff --git a/backend/internal/handler/announcement_handler.go b/backend/internal/handler/announcement_handler.go
new file mode 100644
index 00000000..72823eaf
--- /dev/null
+++ b/backend/internal/handler/announcement_handler.go
@@ -0,0 +1,81 @@
+package handler
+
+import (
+ "strconv"
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// AnnouncementHandler handles user announcement operations
+type AnnouncementHandler struct {
+ announcementService *service.AnnouncementService
+}
+
+// NewAnnouncementHandler creates a new user announcement handler
+func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
+ return &AnnouncementHandler{
+ announcementService: announcementService,
+ }
+}
+
+// List handles listing announcements visible to current user
+// GET /api/v1/announcements
+func (h *AnnouncementHandler) List(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not found in context")
+ return
+ }
+
+ unreadOnly := parseBoolQuery(c.Query("unread_only"))
+
+ items, err := h.announcementService.ListForUser(c.Request.Context(), subject.UserID, unreadOnly)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ out := make([]dto.UserAnnouncement, 0, len(items))
+ for i := range items {
+ out = append(out, *dto.UserAnnouncementFromService(&items[i]))
+ }
+ response.Success(c, out)
+}
+
+// MarkRead marks an announcement as read for current user
+// POST /api/v1/announcements/:id/read
+func (h *AnnouncementHandler) MarkRead(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not found in context")
+ return
+ }
+
+ announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil || announcementID <= 0 {
+ response.BadRequest(c, "Invalid announcement ID")
+ return
+ }
+
+ if err := h.announcementService.MarkRead(c.Request.Context(), subject.UserID, announcementID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "ok"})
+}
+
+func parseBoolQuery(v string) bool {
+ switch strings.TrimSpace(strings.ToLower(v)) {
+ case "1", "true", "yes", "y", "on":
+ return true
+ default:
+ return false
+ }
+}
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index 89f34aae..75ea9f08 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -1,6 +1,8 @@
package handler
import (
+ "log/slog"
+
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
@@ -13,21 +15,25 @@ import (
// AuthHandler handles authentication-related requests
type AuthHandler struct {
- cfg *config.Config
- authService *service.AuthService
- userService *service.UserService
- settingSvc *service.SettingService
- promoService *service.PromoService
+ cfg *config.Config
+ authService *service.AuthService
+ userService *service.UserService
+ settingSvc *service.SettingService
+ promoService *service.PromoService
+ redeemService *service.RedeemService
+ totpService *service.TotpService
}
// NewAuthHandler creates a new AuthHandler
-func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
+func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler {
return &AuthHandler{
- cfg: cfg,
- authService: authService,
- userService: userService,
- settingSvc: settingService,
- promoService: promoService,
+ cfg: cfg,
+ authService: authService,
+ userService: userService,
+ settingSvc: settingService,
+ promoService: promoService,
+ redeemService: redeemService,
+ totpService: totpService,
}
}
@@ -37,7 +43,8 @@ type RegisterRequest struct {
Password string `json:"password" binding:"required,min=6"`
VerifyCode string `json:"verify_code"`
TurnstileToken string `json:"turnstile_token"`
- PromoCode string `json:"promo_code"` // 注册优惠码
+ PromoCode string `json:"promo_code"` // 注册优惠码
+ InvitationCode string `json:"invitation_code"` // 邀请码
}
// SendVerifyCodeRequest 发送验证码请求
@@ -83,7 +90,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
}
}
- token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode)
+ token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -144,6 +151,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
+ // Check if TOTP 2FA is enabled for this user
+ if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
+ // Create a temporary login session for 2FA
+ tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
+ if err != nil {
+ response.InternalError(c, "Failed to create 2FA session")
+ return
+ }
+
+ response.Success(c, TotpLoginResponse{
+ Requires2FA: true,
+ TempToken: tempToken,
+ UserEmailMasked: service.MaskEmail(user.Email),
+ })
+ return
+ }
+
+ response.Success(c, AuthResponse{
+ AccessToken: token,
+ TokenType: "Bearer",
+ User: dto.UserFromService(user),
+ })
+}
+
+// TotpLoginResponse represents the response when 2FA is required
+type TotpLoginResponse struct {
+ Requires2FA bool `json:"requires_2fa"`
+ TempToken string `json:"temp_token,omitempty"`
+ UserEmailMasked string `json:"user_email_masked,omitempty"`
+}
+
+// Login2FARequest represents the 2FA login request
+type Login2FARequest struct {
+ TempToken string `json:"temp_token" binding:"required"`
+ TotpCode string `json:"totp_code" binding:"required,len=6"`
+}
+
+// Login2FA completes the login with 2FA verification
+// POST /api/v1/auth/login/2fa
+func (h *AuthHandler) Login2FA(c *gin.Context) {
+ var req Login2FARequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ slog.Debug("login_2fa_request",
+ "temp_token_len", len(req.TempToken),
+ "totp_code_len", len(req.TotpCode))
+
+ // Get the login session
+ session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
+ if err != nil || session == nil {
+ tokenPrefix := ""
+ if len(req.TempToken) >= 8 {
+ tokenPrefix = req.TempToken[:8]
+ }
+ slog.Debug("login_2fa_session_invalid",
+ "temp_token_prefix", tokenPrefix,
+ "error", err)
+ response.BadRequest(c, "Invalid or expired 2FA session")
+ return
+ }
+
+ slog.Debug("login_2fa_session_found",
+ "user_id", session.UserID,
+ "email", session.Email)
+
+ // Verify the TOTP code
+ if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
+ slog.Debug("login_2fa_verify_failed",
+ "user_id", session.UserID,
+ "error", err)
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Delete the login session
+ _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
+
+ // Get the user
+ user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Generate the JWT token
+ token, err := h.authService.GenerateToken(user)
+ if err != nil {
+ response.InternalError(c, "Failed to generate token")
+ return
+ }
+
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
@@ -247,3 +348,146 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
BonusAmount: promoCode.BonusAmount,
})
}
+
+// ValidateInvitationCodeRequest 验证邀请码请求
+type ValidateInvitationCodeRequest struct {
+ Code string `json:"code" binding:"required"`
+}
+
+// ValidateInvitationCodeResponse 验证邀请码响应
+type ValidateInvitationCodeResponse struct {
+ Valid bool `json:"valid"`
+ ErrorCode string `json:"error_code,omitempty"`
+}
+
+// ValidateInvitationCode 验证邀请码(公开接口,注册前调用)
+// POST /api/v1/auth/validate-invitation-code
+func (h *AuthHandler) ValidateInvitationCode(c *gin.Context) {
+ // 检查邀请码功能是否启用
+ if h.settingSvc == nil || !h.settingSvc.IsInvitationCodeEnabled(c.Request.Context()) {
+ response.Success(c, ValidateInvitationCodeResponse{
+ Valid: false,
+ ErrorCode: "INVITATION_CODE_DISABLED",
+ })
+ return
+ }
+
+ var req ValidateInvitationCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ // 验证邀请码
+ redeemCode, err := h.redeemService.GetByCode(c.Request.Context(), req.Code)
+ if err != nil {
+ response.Success(c, ValidateInvitationCodeResponse{
+ Valid: false,
+ ErrorCode: "INVITATION_CODE_NOT_FOUND",
+ })
+ return
+ }
+
+ // 检查类型和状态
+ if redeemCode.Type != service.RedeemTypeInvitation {
+ response.Success(c, ValidateInvitationCodeResponse{
+ Valid: false,
+ ErrorCode: "INVITATION_CODE_INVALID",
+ })
+ return
+ }
+
+ if redeemCode.Status != service.StatusUnused {
+ response.Success(c, ValidateInvitationCodeResponse{
+ Valid: false,
+ ErrorCode: "INVITATION_CODE_USED",
+ })
+ return
+ }
+
+ response.Success(c, ValidateInvitationCodeResponse{
+ Valid: true,
+ })
+}
+
+// ForgotPasswordRequest 忘记密码请求
+type ForgotPasswordRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ TurnstileToken string `json:"turnstile_token"`
+}
+
+// ForgotPasswordResponse 忘记密码响应
+type ForgotPasswordResponse struct {
+ Message string `json:"message"`
+}
+
+// ForgotPassword 请求密码重置
+// POST /api/v1/auth/forgot-password
+func (h *AuthHandler) ForgotPassword(c *gin.Context) {
+ var req ForgotPasswordRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ // Turnstile 验证
+ if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Build frontend base URL from request
+ scheme := "https"
+ if c.Request.TLS == nil {
+ // Check X-Forwarded-Proto header (common in reverse proxy setups)
+ if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
+ scheme = proto
+ } else {
+ scheme = "http"
+ }
+ }
+ frontendBaseURL := scheme + "://" + c.Request.Host
+
+ // Request password reset (async)
+ // Note: This returns success even if email doesn't exist (to prevent enumeration)
+ if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, ForgotPasswordResponse{
+ Message: "If your email is registered, you will receive a password reset link shortly.",
+ })
+}
+
+// ResetPasswordRequest 重置密码请求
+type ResetPasswordRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ Token string `json:"token" binding:"required"`
+ NewPassword string `json:"new_password" binding:"required,min=6"`
+}
+
+// ResetPasswordResponse 重置密码响应
+type ResetPasswordResponse struct {
+ Message string `json:"message"`
+}
+
+// ResetPassword 重置密码
+// POST /api/v1/auth/reset-password
+func (h *AuthHandler) ResetPassword(c *gin.Context) {
+ var req ResetPasswordRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ // Reset password
+ if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, ResetPasswordResponse{
+ Message: "Your password has been reset successfully. You can now log in with your new password.",
+ })
+}
diff --git a/backend/internal/handler/dto/announcement.go b/backend/internal/handler/dto/announcement.go
new file mode 100644
index 00000000..bc0db1b2
--- /dev/null
+++ b/backend/internal/handler/dto/announcement.go
@@ -0,0 +1,74 @@
+package dto
+
+import (
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+type Announcement struct {
+ ID int64 `json:"id"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+ Status string `json:"status"`
+
+ Targeting service.AnnouncementTargeting `json:"targeting"`
+
+ StartsAt *time.Time `json:"starts_at,omitempty"`
+ EndsAt *time.Time `json:"ends_at,omitempty"`
+
+ CreatedBy *int64 `json:"created_by,omitempty"`
+ UpdatedBy *int64 `json:"updated_by,omitempty"`
+
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
+
+type UserAnnouncement struct {
+ ID int64 `json:"id"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+
+ StartsAt *time.Time `json:"starts_at,omitempty"`
+ EndsAt *time.Time `json:"ends_at,omitempty"`
+
+ ReadAt *time.Time `json:"read_at,omitempty"`
+
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
+
+func AnnouncementFromService(a *service.Announcement) *Announcement {
+ if a == nil {
+ return nil
+ }
+ return &Announcement{
+ ID: a.ID,
+ Title: a.Title,
+ Content: a.Content,
+ Status: a.Status,
+ Targeting: a.Targeting,
+ StartsAt: a.StartsAt,
+ EndsAt: a.EndsAt,
+ CreatedBy: a.CreatedBy,
+ UpdatedBy: a.UpdatedBy,
+ CreatedAt: a.CreatedAt,
+ UpdatedAt: a.UpdatedAt,
+ }
+}
+
+func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement {
+ if a == nil {
+ return nil
+ }
+ return &UserAnnouncement{
+ ID: a.Announcement.ID,
+ Title: a.Announcement.Title,
+ Content: a.Announcement.Content,
+ StartsAt: a.Announcement.StartsAt,
+ EndsAt: a.Announcement.EndsAt,
+ ReadAt: a.ReadAt,
+ CreatedAt: a.Announcement.CreatedAt,
+ UpdatedAt: a.Announcement.UpdatedAt,
+ }
+}
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index 58a4ad86..04d1385d 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -208,6 +208,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
}
}
+ if scopeLimits := a.GetAntigravityScopeRateLimits(); len(scopeLimits) > 0 {
+ out.ScopeRateLimits = make(map[string]ScopeRateLimitInfo, len(scopeLimits))
+ now := time.Now()
+ for scope, remainingSec := range scopeLimits {
+ out.ScopeRateLimits[scope] = ScopeRateLimitInfo{
+ ResetAt: now.Add(time.Duration(remainingSec) * time.Second),
+ RemainingSec: remainingSec,
+ }
+ }
+ }
+
return out
}
@@ -325,7 +336,7 @@ func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
}
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
- return RedeemCode{
+ out := RedeemCode{
ID: rc.ID,
Code: rc.Code,
Type: rc.Type,
@@ -339,6 +350,14 @@ func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
User: UserFromServiceShallow(rc.User),
Group: GroupFromServiceShallow(rc.Group),
}
+
+ // For admin_balance/admin_concurrency types, include notes so users can see
+ // why they were charged or credited by admin
+ if (rc.Type == "admin_balance" || rc.Type == "admin_concurrency") && rc.Notes != "" {
+ out.Notes = &rc.Notes
+ }
+
+ return out
}
// AccountSummaryFromService returns a minimal AccountSummary for usage log display.
@@ -362,6 +381,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
+ ReasoningEffort: l.ReasoningEffort,
GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens,
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 01f39478..be94bc16 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -2,9 +2,13 @@ package dto
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
- RegistrationEnabled bool `json:"registration_enabled"`
- EmailVerifyEnabled bool `json:"email_verify_enabled"`
- PromoCodeEnabled bool `json:"promo_code_enabled"`
+ RegistrationEnabled bool `json:"registration_enabled"`
+ EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ PromoCodeEnabled bool `json:"promo_code_enabled"`
+ PasswordResetEnabled bool `json:"password_reset_enabled"`
+ InvitationCodeEnabled bool `json:"invitation_code_enabled"`
+ TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
+ TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
@@ -23,14 +27,16 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
- SiteName string `json:"site_name"`
- SiteLogo string `json:"site_logo"`
- SiteSubtitle string `json:"site_subtitle"`
- APIBaseURL string `json:"api_base_url"`
- ContactInfo string `json:"contact_info"`
- DocURL string `json:"doc_url"`
- HomeContent string `json:"home_content"`
- HideCcsImportButton bool `json:"hide_ccs_import_button"`
+ SiteName string `json:"site_name"`
+ SiteLogo string `json:"site_logo"`
+ SiteSubtitle string `json:"site_subtitle"`
+ APIBaseURL string `json:"api_base_url"`
+ ContactInfo string `json:"contact_info"`
+ DocURL string `json:"doc_url"`
+ HomeContent string `json:"home_content"`
+ HideCcsImportButton bool `json:"hide_ccs_import_button"`
+ PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
+ PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
@@ -54,21 +60,26 @@ type SystemSettings struct {
}
type PublicSettings struct {
- RegistrationEnabled bool `json:"registration_enabled"`
- EmailVerifyEnabled bool `json:"email_verify_enabled"`
- PromoCodeEnabled bool `json:"promo_code_enabled"`
- TurnstileEnabled bool `json:"turnstile_enabled"`
- TurnstileSiteKey string `json:"turnstile_site_key"`
- SiteName string `json:"site_name"`
- SiteLogo string `json:"site_logo"`
- SiteSubtitle string `json:"site_subtitle"`
- APIBaseURL string `json:"api_base_url"`
- ContactInfo string `json:"contact_info"`
- DocURL string `json:"doc_url"`
- HomeContent string `json:"home_content"`
- HideCcsImportButton bool `json:"hide_ccs_import_button"`
- LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
- Version string `json:"version"`
+ RegistrationEnabled bool `json:"registration_enabled"`
+ EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ PromoCodeEnabled bool `json:"promo_code_enabled"`
+ PasswordResetEnabled bool `json:"password_reset_enabled"`
+ InvitationCodeEnabled bool `json:"invitation_code_enabled"`
+ TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
+ TurnstileEnabled bool `json:"turnstile_enabled"`
+ TurnstileSiteKey string `json:"turnstile_site_key"`
+ SiteName string `json:"site_name"`
+ SiteLogo string `json:"site_logo"`
+ SiteSubtitle string `json:"site_subtitle"`
+ APIBaseURL string `json:"api_base_url"`
+ ContactInfo string `json:"contact_info"`
+ DocURL string `json:"doc_url"`
+ HomeContent string `json:"home_content"`
+ HideCcsImportButton bool `json:"hide_ccs_import_button"`
+ PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
+ PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
+ LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ Version string `json:"version"`
}
// StreamTimeoutSettings 流超时处理配置 DTO
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 505f9dd4..f2c7f5f1 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -2,6 +2,11 @@ package dto
import "time"
+type ScopeRateLimitInfo struct {
+ ResetAt time.Time `json:"reset_at"`
+ RemainingSec int64 `json:"remaining_sec"`
+}
+
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
@@ -114,6 +119,9 @@ type Account struct {
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"`
+ // Antigravity scope 级限流状态(从 extra 提取)
+ ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"`
+
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
@@ -204,6 +212,10 @@ type RedeemCode struct {
GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
+ // Notes is only populated for admin_balance/admin_concurrency types
+ // so users can see why they were charged or credited
+ Notes *string `json:"notes,omitempty"`
+
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
@@ -224,6 +236,9 @@ type UsageLog struct {
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
+ // ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
+ // nil means not provided / not applicable.
+ ReasoningEffort *string `json:"reasoning_effort,omitempty"`
GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"`
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index a7b98940..673f6369 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -30,6 +30,7 @@ type GatewayHandler struct {
antigravityGatewayService *service.AntigravityGatewayService
userService *service.UserService
billingCacheService *service.BillingCacheService
+ usageService *service.UsageService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
@@ -44,6 +45,7 @@ func NewGatewayHandler(
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
+ usageService *service.UsageService,
cfg *config.Config,
) *GatewayHandler {
pingInterval := time.Duration(0)
@@ -64,6 +66,7 @@ func NewGatewayHandler(
antigravityGatewayService: antigravityGatewayService,
userService: userService,
billingCacheService: billingCacheService,
+ usageService: usageService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
@@ -537,7 +540,7 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
})
}
-// Usage handles getting account balance for CC Switch integration
+// Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
@@ -552,7 +555,40 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return
}
- // 订阅模式:返回订阅限额信息
+ // Best-effort: 获取用量统计,失败不影响基础响应
+ var usageData gin.H
+ if h.usageService != nil {
+ dashStats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
+ if err == nil && dashStats != nil {
+ usageData = gin.H{
+ "today": gin.H{
+ "requests": dashStats.TodayRequests,
+ "input_tokens": dashStats.TodayInputTokens,
+ "output_tokens": dashStats.TodayOutputTokens,
+ "cache_creation_tokens": dashStats.TodayCacheCreationTokens,
+ "cache_read_tokens": dashStats.TodayCacheReadTokens,
+ "total_tokens": dashStats.TodayTokens,
+ "cost": dashStats.TodayCost,
+ "actual_cost": dashStats.TodayActualCost,
+ },
+ "total": gin.H{
+ "requests": dashStats.TotalRequests,
+ "input_tokens": dashStats.TotalInputTokens,
+ "output_tokens": dashStats.TotalOutputTokens,
+ "cache_creation_tokens": dashStats.TotalCacheCreationTokens,
+ "cache_read_tokens": dashStats.TotalCacheReadTokens,
+ "total_tokens": dashStats.TotalTokens,
+ "cost": dashStats.TotalCost,
+ "actual_cost": dashStats.TotalActualCost,
+ },
+ "average_duration_ms": dashStats.AverageDurationMs,
+ "rpm": dashStats.Rpm,
+ "tpm": dashStats.Tpm,
+ }
+ }
+ }
+
+ // 订阅模式:返回订阅限额信息 + 用量统计
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware2.GetSubscriptionFromContext(c)
if !ok {
@@ -561,28 +597,46 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
}
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
- c.JSON(http.StatusOK, gin.H{
+ resp := gin.H{
"isValid": true,
"planName": apiKey.Group.Name,
"remaining": remaining,
"unit": "USD",
- })
+ "subscription": gin.H{
+ "daily_usage_usd": subscription.DailyUsageUSD,
+ "weekly_usage_usd": subscription.WeeklyUsageUSD,
+ "monthly_usage_usd": subscription.MonthlyUsageUSD,
+ "daily_limit_usd": apiKey.Group.DailyLimitUSD,
+ "weekly_limit_usd": apiKey.Group.WeeklyLimitUSD,
+ "monthly_limit_usd": apiKey.Group.MonthlyLimitUSD,
+ "expires_at": subscription.ExpiresAt,
+ },
+ }
+ if usageData != nil {
+ resp["usage"] = usageData
+ }
+ c.JSON(http.StatusOK, resp)
return
}
- // 余额模式:返回钱包余额
+ // 余额模式:返回钱包余额 + 用量统计
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return
}
- c.JSON(http.StatusOK, gin.H{
+ resp := gin.H{
"isValid": true,
"planName": "钱包余额",
"remaining": latestUser.Balance,
"unit": "USD",
- })
+ "balance": latestUser.Balance,
+ }
+ if usageData != nil {
+ resp["usage"] = usageData
+ }
+ c.JSON(http.StatusOK, resp)
}
// calculateSubscriptionRemaining 计算订阅剩余可用额度
@@ -738,6 +792,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
+ // 检查是否为 Claude Code 客户端,设置到 context 中
+ SetClaudeCodeClientContext(c, body)
+
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
diff --git a/backend/internal/handler/gemini_cli_session_test.go b/backend/internal/handler/gemini_cli_session_test.go
new file mode 100644
index 00000000..0b37f5f2
--- /dev/null
+++ b/backend/internal/handler/gemini_cli_session_test.go
@@ -0,0 +1,122 @@
+//go:build unit
+
+package handler
+
+import (
+ "crypto/sha256"
+ "encoding/hex"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestExtractGeminiCLISessionHash(t *testing.T) {
+ tests := []struct {
+ name string
+ body string
+ privilegedUserID string
+ wantEmpty bool
+ wantHash string
+ }{
+ {
+ name: "with privileged-user-id and tmp dir",
+ body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
+ privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
+ wantEmpty: false,
+ wantHash: func() string {
+ combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
+ hash := sha256.Sum256([]byte(combined))
+ return hex.EncodeToString(hash[:])
+ }(),
+ },
+ {
+ name: "without privileged-user-id but with tmp dir",
+ body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
+ privilegedUserID: "",
+ wantEmpty: false,
+ wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
+ },
+ {
+ name: "without tmp dir",
+ body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
+ privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
+ wantEmpty: true,
+ },
+ {
+ name: "empty body",
+ body: "",
+ privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
+ wantEmpty: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // 创建测试上下文
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest("POST", "/test", nil)
+ if tt.privilegedUserID != "" {
+ c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
+ }
+
+ // 调用函数
+ result := extractGeminiCLISessionHash(c, []byte(tt.body))
+
+ // 验证结果
+ if tt.wantEmpty {
+ require.Empty(t, result, "expected empty session hash")
+ } else {
+ require.NotEmpty(t, result, "expected non-empty session hash")
+ require.Equal(t, tt.wantHash, result, "session hash mismatch")
+ }
+ })
+ }
+}
+
+func TestGeminiCLITmpDirRegex(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ wantMatch bool
+ wantHash string
+ }{
+ {
+ name: "valid tmp dir path",
+ input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
+ wantMatch: true,
+ wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
+ },
+ {
+ name: "valid tmp dir path in text",
+ input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
+ wantMatch: true,
+ wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
+ },
+ {
+ name: "invalid hash length",
+ input: "/Users/ianshaw/.gemini/tmp/abc123",
+ wantMatch: false,
+ },
+ {
+ name: "no tmp dir",
+ input: "Hello world",
+ wantMatch: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
+ if tt.wantMatch {
+ require.NotNil(t, match, "expected regex to match")
+ require.Len(t, match, 2, "expected 2 capture groups")
+ require.Equal(t, tt.wantHash, match[1], "hash mismatch")
+ } else {
+ require.Nil(t, match, "expected regex not to match")
+ }
+ })
+ }
+}
diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go
index c7646b38..d1b19ede 100644
--- a/backend/internal/handler/gemini_v1beta_handler.go
+++ b/backend/internal/handler/gemini_v1beta_handler.go
@@ -1,11 +1,15 @@
package handler
import (
+ "bytes"
"context"
+ "crypto/sha256"
+ "encoding/hex"
"errors"
"io"
"log"
"net/http"
+ "regexp"
"strings"
"time"
@@ -19,6 +23,17 @@ import (
"github.com/gin-gonic/gin"
)
+// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
+// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
+var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
+
+func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
+ if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
+ return true
+ }
+ return geminiCLITmpDirRegex.Match(body)
+}
+
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 3) select account (sticky session based on request body)
- parsedReq, _ := service.ParseGatewayRequest(body)
- sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
+ // 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希)
+ sessionHash := extractGeminiCLISessionHash(c, body)
+ if sessionHash == "" {
+ // Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
+ parsedReq, _ := service.ParseGatewayRequest(body)
+ sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
+ }
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
+
+ // 查询粘性会话绑定的账号 ID(用于检测账号切换)
+ var sessionBoundAccountID int64
+ if sessionKey != "" {
+ sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
+ }
+ isCLI := isGeminiCLIRequest(c, body)
+ cleanedForUnknownBinding := false
+
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
+ // 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
+ // 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
+ if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
+ log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
+ body = service.CleanGeminiNativeThoughtSignatures(body)
+ sessionBoundAccountID = account.ID
+ } else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
+ // 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
+ // 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
+ log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
+ body = service.CleanGeminiNativeThoughtSignatures(body)
+ cleanedForUnknownBinding = true
+ sessionBoundAccountID = account.ID
+ } else if sessionBoundAccountID == 0 {
+ // 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
+ sessionBoundAccountID = account.ID
+ }
+
// 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
@@ -319,18 +366,21 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
- // 6) record usage async
+ // 6) record usage async (Gemini 使用长上下文双倍计费)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
- if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
- Result: result,
- APIKey: apiKey,
- User: apiKey.User,
- Account: usedAccount,
- Subscription: subscription,
- UserAgent: ua,
- IPAddress: ip,
+
+ if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
+ Result: result,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: usedAccount,
+ Subscription: subscription,
+ UserAgent: ua,
+ IPAddress: ip,
+ LongContextThreshold: 200000, // Gemini 200K 阈值
+ LongContextMultiplier: 2.0, // 超出部分双倍计费
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
@@ -433,3 +483,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
}
return false
}
+
+// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
+// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
+//
+// 会话标识生成策略:
+// 1. 从请求体中提取 tmp 目录哈希(64位十六进制)
+// 2. 从 header 中提取 privileged-user-id(UUID)
+// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
+//
+// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
+//
+// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
+// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
+func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
+ // 1. 从请求体中提取 tmp 目录哈希
+ match := geminiCLITmpDirRegex.FindSubmatch(body)
+ if len(match) < 2 {
+ return "" // 没有找到 tmp 目录,不使用粘性会话
+ }
+ tmpDirHash := string(match[1])
+
+ // 2. 提取 privileged-user-id
+ privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
+
+ // 3. 组合生成最终的 session hash
+ if privilegedUserID != "" {
+ // 组合两个标识符:privileged-user-id + tmp 目录哈希
+ combined := privilegedUserID + ":" + tmpDirHash
+ hash := sha256.Sum256([]byte(combined))
+ return hex.EncodeToString(hash[:])
+ }
+
+ // 如果没有 privileged-user-id,直接使用 tmp 目录哈希
+ return tmpDirHash
+}
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index d7014a22..ec0fb99d 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -10,6 +10,7 @@ type AdminHandlers struct {
User *admin.UserHandler
Group *admin.GroupHandler
Account *admin.AccountHandler
+ Announcement *admin.AnnouncementHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
@@ -33,11 +34,13 @@ type Handlers struct {
Usage *UsageHandler
Redeem *RedeemHandler
Subscription *SubscriptionHandler
+ Announcement *AnnouncementHandler
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
SoraGateway *SoraGatewayHandler
Setting *SettingHandler
+ Totp *TotpHandler
}
// BuildInfo contains build-time information
diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go
index f62e6b3e..36ffde63 100644
--- a/backend/internal/handler/ops_error_logger.go
+++ b/backend/internal/handler/ops_error_logger.go
@@ -905,7 +905,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
switch strings.TrimSpace(code) {
- case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
+ case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
return true
}
if phase == "billing" || phase == "concurrency" {
@@ -1011,5 +1011,12 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
}
}
+ // Check if invalid/missing API key errors should be ignored (user misconfiguration)
+ if settings.IgnoreInvalidApiKeyErrors {
+ if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
+ return true
+ }
+ }
+
return false
}
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 8723c746..2029f116 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -32,20 +32,25 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
}
response.Success(c, dto.PublicSettings{
- RegistrationEnabled: settings.RegistrationEnabled,
- EmailVerifyEnabled: settings.EmailVerifyEnabled,
- PromoCodeEnabled: settings.PromoCodeEnabled,
- TurnstileEnabled: settings.TurnstileEnabled,
- TurnstileSiteKey: settings.TurnstileSiteKey,
- SiteName: settings.SiteName,
- SiteLogo: settings.SiteLogo,
- SiteSubtitle: settings.SiteSubtitle,
- APIBaseURL: settings.APIBaseURL,
- ContactInfo: settings.ContactInfo,
- DocURL: settings.DocURL,
- HomeContent: settings.HomeContent,
- HideCcsImportButton: settings.HideCcsImportButton,
- LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
- Version: h.version,
+ RegistrationEnabled: settings.RegistrationEnabled,
+ EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ PromoCodeEnabled: settings.PromoCodeEnabled,
+ PasswordResetEnabled: settings.PasswordResetEnabled,
+ InvitationCodeEnabled: settings.InvitationCodeEnabled,
+ TotpEnabled: settings.TotpEnabled,
+ TurnstileEnabled: settings.TurnstileEnabled,
+ TurnstileSiteKey: settings.TurnstileSiteKey,
+ SiteName: settings.SiteName,
+ SiteLogo: settings.SiteLogo,
+ SiteSubtitle: settings.SiteSubtitle,
+ APIBaseURL: settings.APIBaseURL,
+ ContactInfo: settings.ContactInfo,
+ DocURL: settings.DocURL,
+ HomeContent: settings.HomeContent,
+ HideCcsImportButton: settings.HideCcsImportButton,
+ PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
+ PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
+ LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
+ Version: h.version,
})
}
diff --git a/backend/internal/handler/totp_handler.go b/backend/internal/handler/totp_handler.go
new file mode 100644
index 00000000..5c5eb567
--- /dev/null
+++ b/backend/internal/handler/totp_handler.go
@@ -0,0 +1,181 @@
+package handler
+
+import (
+ "github.com/gin-gonic/gin"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+// TotpHandler handles TOTP-related requests
+type TotpHandler struct {
+ totpService *service.TotpService
+}
+
+// NewTotpHandler creates a new TotpHandler
+func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
+ return &TotpHandler{
+ totpService: totpService,
+ }
+}
+
+// TotpStatusResponse represents the TOTP status response
+type TotpStatusResponse struct {
+ Enabled bool `json:"enabled"`
+ EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
+ FeatureEnabled bool `json:"feature_enabled"`
+}
+
+// GetStatus returns the TOTP status for the current user
+// GET /api/v1/user/totp/status
+func (h *TotpHandler) GetStatus(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ resp := TotpStatusResponse{
+ Enabled: status.Enabled,
+ FeatureEnabled: status.FeatureEnabled,
+ }
+
+ if status.EnabledAt != nil {
+ ts := status.EnabledAt.Unix()
+ resp.EnabledAt = &ts
+ }
+
+ response.Success(c, resp)
+}
+
+// TotpSetupRequest represents the request to initiate TOTP setup
+type TotpSetupRequest struct {
+ EmailCode string `json:"email_code"`
+ Password string `json:"password"`
+}
+
+// TotpSetupResponse represents the TOTP setup response
+type TotpSetupResponse struct {
+ Secret string `json:"secret"`
+ QRCodeURL string `json:"qr_code_url"`
+ SetupToken string `json:"setup_token"`
+ Countdown int `json:"countdown"`
+}
+
+// InitiateSetup starts the TOTP setup process
+// POST /api/v1/user/totp/setup
+func (h *TotpHandler) InitiateSetup(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req TotpSetupRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ // Allow empty body (optional params)
+ req = TotpSetupRequest{}
+ }
+
+ result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, TotpSetupResponse{
+ Secret: result.Secret,
+ QRCodeURL: result.QRCodeURL,
+ SetupToken: result.SetupToken,
+ Countdown: result.Countdown,
+ })
+}
+
+// TotpEnableRequest represents the request to enable TOTP
+type TotpEnableRequest struct {
+ TotpCode string `json:"totp_code" binding:"required,len=6"`
+ SetupToken string `json:"setup_token" binding:"required"`
+}
+
+// Enable completes the TOTP setup
+// POST /api/v1/user/totp/enable
+func (h *TotpHandler) Enable(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req TotpEnableRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"success": true})
+}
+
+// TotpDisableRequest represents the request to disable TOTP
+type TotpDisableRequest struct {
+ EmailCode string `json:"email_code"`
+ Password string `json:"password"`
+}
+
+// Disable disables TOTP for the current user
+// POST /api/v1/user/totp/disable
+func (h *TotpHandler) Disable(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req TotpDisableRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"success": true})
+}
+
+// GetVerificationMethod returns the verification method for TOTP operations
+// GET /api/v1/user/totp/verification-method
+func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
+ method := h.totpService.GetVerificationMethod(c.Request.Context())
+ response.Success(c, method)
+}
+
+// SendVerifyCode sends an email verification code for TOTP operations
+// POST /api/v1/user/totp/send-code
+func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"success": true})
+}
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index c20b7fbc..3d268c93 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -13,6 +13,7 @@ func ProvideAdminHandlers(
userHandler *admin.UserHandler,
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
+ announcementHandler *admin.AnnouncementHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
@@ -32,6 +33,7 @@ func ProvideAdminHandlers(
User: userHandler,
Group: groupHandler,
Account: accountHandler,
+ Announcement: announcementHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
@@ -66,11 +68,13 @@ func ProvideHandlers(
usageHandler *UsageHandler,
redeemHandler *RedeemHandler,
subscriptionHandler *SubscriptionHandler,
+ announcementHandler *AnnouncementHandler,
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
soraGatewayHandler *SoraGatewayHandler,
settingHandler *SettingHandler,
+ totpHandler *TotpHandler,
) *Handlers {
return &Handlers{
Auth: authHandler,
@@ -79,11 +83,13 @@ func ProvideHandlers(
Usage: usageHandler,
Redeem: redeemHandler,
Subscription: subscriptionHandler,
+ Announcement: announcementHandler,
Admin: adminHandlers,
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
SoraGateway: soraGatewayHandler,
Setting: settingHandler,
+ Totp: totpHandler,
}
}
@@ -96,9 +102,11 @@ var ProviderSet = wire.NewSet(
NewUsageHandler,
NewRedeemHandler,
NewSubscriptionHandler,
+ NewAnnouncementHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
NewSoraGatewayHandler,
+ NewTotpHandler,
ProvideSettingHandler,
// Admin handlers
@@ -106,6 +114,7 @@ var ProviderSet = wire.NewSet(
admin.NewUserHandler,
admin.NewGroupHandler,
admin.NewAccountHandler,
+ admin.NewAnnouncementHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,
diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go
index ee2a6c1a..c7d657b9 100644
--- a/backend/internal/pkg/antigravity/oauth.go
+++ b/backend/internal/pkg/antigravity/oauth.go
@@ -33,7 +33,7 @@ const (
"https://www.googleapis.com/auth/experimentsandconfigs"
// User-Agent(与 Antigravity-Manager 保持一致)
- UserAgent = "antigravity/1.11.9 windows/amd64"
+ UserAgent = "antigravity/1.15.8 windows/amd64"
// Session 过期时间
SessionTTL = 30 * time.Minute
diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go
index 1b21bd58..63f6ee7c 100644
--- a/backend/internal/pkg/antigravity/request_transformer.go
+++ b/backend/internal/pkg/antigravity/request_transformer.go
@@ -367,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
Text: block.Thinking,
Thought: true,
}
- // 保留原有 signature(Claude 模型需要有效的 signature)
- if block.Signature != "" {
+ // signature 处理:
+ // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
+ // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
+ if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
@@ -407,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
},
}
// tool_use 的 signature 处理:
- // - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验)
- // - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路)
- if allowDummyThought {
- part.ThoughtSignature = dummyThoughtSignature
- } else if block.Signature != "" && block.Signature != dummyThoughtSignature {
+ // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
+ // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
+ if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
+ } else if allowDummyThought {
+ part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go
index 60ee6f63..9d62a4a1 100644
--- a/backend/internal/pkg/antigravity/request_transformer_test.go
+++ b/backend/internal/pkg/antigravity/request_transformer_test.go
@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
]`
- t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
+ t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil {
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
+ if parts[0].ThoughtSignature != "sig_tool_abc" {
+ t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
+ }
+ })
+
+ t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
+ contentNoSig := `[
+ {"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
+ ]`
+ toolIDToName := make(map[string]string)
+ parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
+ if err != nil {
+ t.Fatalf("buildParts() error = %v", err)
+ }
+ if len(parts) != 1 || parts[0].FunctionCall == nil {
+ t.Fatalf("expected 1 functionCall part, got %+v", parts)
+ }
if parts[0].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
}
diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go
index d1a56a84..8b3441dc 100644
--- a/backend/internal/pkg/claude/constants.go
+++ b/backend/internal/pkg/claude/constants.go
@@ -9,11 +9,26 @@ const (
BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
+ BetaTokenCounting = "token-counting-2024-11-01"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
+// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
+//
+// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
+// Claude Code for non-Claude-Code clients, we must include the claude-code beta
+// even if the request doesn't use tools, otherwise upstream may reject the
+// request as a non-Claude-Code API request.
+const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
+
+// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
+const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
+
+// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
+const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
+
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{
- "User-Agent": "claude-cli/2.0.62 (external, cli)",
+ // Keep these in sync with recent Claude CLI traffic to reduce the chance
+ // that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
+ "User-Agent": "claude-cli/2.1.22 (external, cli)",
"X-Stainless-Lang": "js",
- "X-Stainless-Package-Version": "0.52.0",
+ "X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux",
- "X-Stainless-Arch": "x64",
+ "X-Stainless-Arch": "arm64",
"X-Stainless-Runtime": "node",
- "X-Stainless-Runtime-Version": "v22.14.0",
+ "X-Stainless-Runtime-Version": "v24.13.0",
"X-Stainless-Retry-Count": "0",
- "X-Stainless-Timeout": "60",
+ "X-Stainless-Timeout": "600",
"X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true",
}
@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
// DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929"
+
+// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
+var ModelIDOverrides = map[string]string{
+ "claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
+ "claude-opus-4-5": "claude-opus-4-5-20251101",
+ "claude-haiku-4-5": "claude-haiku-4-5-20251001",
+}
+
+// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
+var ModelIDReverseOverrides = map[string]string{
+ "claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
+ "claude-opus-4-5-20251101": "claude-opus-4-5",
+ "claude-haiku-4-5-20251001": "claude-haiku-4-5",
+}
+
+// NormalizeModelID 根据 Claude OAuth 规则映射模型
+func NormalizeModelID(id string) string {
+ if id == "" {
+ return id
+ }
+ if mapped, ok := ModelIDOverrides[id]; ok {
+ return mapped
+ }
+ return id
+}
+
+// DenormalizeModelID 将上游模型 ID 转换为短名
+func DenormalizeModelID(id string) string {
+ if id == "" {
+ return id
+ }
+ if mapped, ok := ModelIDReverseOverrides[id]; ok {
+ return mapped
+ }
+ return id
+}
diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go
index 43fe12d4..c5b41d6e 100644
--- a/backend/internal/pkg/response/response.go
+++ b/backend/internal/pkg/response/response.go
@@ -2,6 +2,7 @@
package response
import (
+ "log"
"math"
"net/http"
@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool {
}
statusCode, status := infraerrors.ToHTTP(err)
+
+ // Log internal errors with full details for debugging
+ if statusCode >= 500 && c.Request != nil {
+ log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
+ }
+
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}
diff --git a/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go b/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
new file mode 100644
index 00000000..eea74fcc
--- /dev/null
+++ b/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
@@ -0,0 +1,278 @@
+//go:build integration
+
+// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
+//
+// Integration tests for verifying TLS fingerprint correctness.
+// These tests make actual network requests to external services and should be run manually.
+//
+// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
+package tlsfingerprint
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
+)
+
+// skipIfExternalServiceUnavailable checks if the external service is available.
+// If not, it skips the test instead of failing.
+func skipIfExternalServiceUnavailable(t *testing.T, err error) {
+ t.Helper()
+ if err != nil {
+ // Check for common network/TLS errors that indicate external service issues
+ errStr := err.Error()
+ if strings.Contains(errStr, "certificate has expired") ||
+ strings.Contains(errStr, "certificate is not yet valid") ||
+ strings.Contains(errStr, "connection refused") ||
+ strings.Contains(errStr, "no such host") ||
+ strings.Contains(errStr, "network is unreachable") ||
+ strings.Contains(errStr, "timeout") {
+ t.Skipf("skipping test: external service unavailable: %v", err)
+ }
+ t.Fatalf("failed to get fingerprint: %v", err)
+ }
+}
+
+// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
+// This test uses tls.peet.ws to verify the fingerprint.
+// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
+// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
+func TestJA3Fingerprint(t *testing.T) {
+ // Skip if network is unavailable or if running in short mode
+ if testing.Short() {
+ t.Skip("skipping integration test in short mode")
+ }
+
+ profile := &Profile{
+ Name: "Claude CLI Test",
+ EnableGREASE: false,
+ }
+ dialer := NewDialer(profile, nil)
+
+ client := &http.Client{
+ Transport: &http.Transport{
+ DialTLSContext: dialer.DialTLSContext,
+ },
+ Timeout: 30 * time.Second,
+ }
+
+ // Use tls.peet.ws fingerprint detection API
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
+ if err != nil {
+ t.Fatalf("failed to create request: %v", err)
+ }
+ req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
+
+ resp, err := client.Do(req)
+ skipIfExternalServiceUnavailable(t, err)
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("failed to read response: %v", err)
+ }
+
+ var fpResp FingerprintResponse
+ if err := json.Unmarshal(body, &fpResp); err != nil {
+ t.Logf("Response body: %s", string(body))
+ t.Fatalf("failed to parse fingerprint response: %v", err)
+ }
+
+ // Log all fingerprint information
+ t.Logf("JA3: %s", fpResp.TLS.JA3)
+ t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
+ t.Logf("JA4: %s", fpResp.TLS.JA4)
+ t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
+ t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
+
+ // Verify JA3 hash matches expected value
+ expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
+ if fpResp.TLS.JA3Hash == expectedJA3Hash {
+ t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
+ } else {
+ t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
+ }
+
+ // Verify JA4 fingerprint
+ // JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
+ // Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
+ // The suffix _a33745022dd6_1f22a2ca17c4 should match
+ expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
+ if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
+ t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
+ } else {
+ t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
+ }
+
+ // Verify JA4 prefix (t13d5911h1 or t13i5911h1)
+ // d = domain (SNI present), i = IP (no SNI)
+ // Since we connect to tls.peet.ws (domain), we expect 'd'
+ expectedJA4Prefix := "t13d5911h1"
+ if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
+ t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
+ } else {
+ // Also accept 'i' variant for IP connections
+ altPrefix := "t13i5911h1"
+ if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
+ t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
+ } else {
+ t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
+ }
+ }
+
+ // Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
+ if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
+ t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
+ } else {
+ t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
+ }
+
+ // Verify extension list (should be 11 extensions including SNI)
+ // Expected: 0-11-10-35-16-22-23-13-43-45-51
+ expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
+ if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
+ t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
+ } else {
+ t.Logf("Warning: JA3 extension list may differ")
+ }
+}
+
+// TestProfileExpectation defines expected fingerprint values for a profile.
+type TestProfileExpectation struct {
+ Profile *Profile
+ ExpectedJA3 string // Expected JA3 hash (empty = don't check)
+ ExpectedJA4 string // Expected full JA4 (empty = don't check)
+ JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
+}
+
+// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
+// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
+func TestAllProfiles(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping integration test in short mode")
+ }
+
+ // Define all profiles to test with their expected fingerprints
+ // These profiles are from config.yaml gateway.tls_fingerprint.profiles
+ profiles := []TestProfileExpectation{
+ {
+ // Linux x64 Node.js v22.17.1
+ // Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
+ // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
+ Profile: &Profile{
+ Name: "linux_x64_node_v22171",
+ EnableGREASE: false,
+ CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
+ Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
+ PointFormats: []uint8{0, 1, 2},
+ },
+ JA4CipherHash: "a33745022dd6", // stable part
+ },
+ {
+ // MacOS arm64 Node.js v22.18.0
+ // Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
+ // Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
+ Profile: &Profile{
+ Name: "macos_arm64_node_v22180",
+ EnableGREASE: false,
+ CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
+ Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
+ PointFormats: []uint8{0, 1, 2},
+ },
+ JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
+ },
+ }
+
+ for _, tc := range profiles {
+ tc := tc // capture range variable
+ t.Run(tc.Profile.Name, func(t *testing.T) {
+ fp := fetchFingerprint(t, tc.Profile)
+ if fp == nil {
+ return // fetchFingerprint already called t.Fatal
+ }
+
+ t.Logf("Profile: %s", tc.Profile.Name)
+ t.Logf(" JA3: %s", fp.JA3)
+ t.Logf(" JA3 Hash: %s", fp.JA3Hash)
+ t.Logf(" JA4: %s", fp.JA4)
+ t.Logf(" PeetPrint: %s", fp.PeetPrint)
+ t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
+
+ // Verify expectations
+ if tc.ExpectedJA3 != "" {
+ if fp.JA3Hash == tc.ExpectedJA3 {
+ t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
+ } else {
+ t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
+ }
+ }
+
+ if tc.ExpectedJA4 != "" {
+ if fp.JA4 == tc.ExpectedJA4 {
+ t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
+ } else {
+ t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
+ }
+ }
+
+ // Check JA4 cipher hash (stable middle part)
+ // JA4 format: prefix_cipherHash_extHash
+ if tc.JA4CipherHash != "" {
+ if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
+ t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
+ } else {
+ t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
+ }
+ }
+ })
+ }
+}
+
+// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
+func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
+ t.Helper()
+
+ dialer := NewDialer(profile, nil)
+ client := &http.Client{
+ Transport: &http.Transport{
+ DialTLSContext: dialer.DialTLSContext,
+ },
+ Timeout: 30 * time.Second,
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
+ if err != nil {
+ t.Fatalf("failed to create request: %v", err)
+ return nil
+ }
+ req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
+
+ resp, err := client.Do(req)
+ skipIfExternalServiceUnavailable(t, err)
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("failed to read response: %v", err)
+ return nil
+ }
+
+ var fpResp FingerprintResponse
+ if err := json.Unmarshal(body, &fpResp); err != nil {
+ t.Logf("Response body: %s", string(body))
+ t.Fatalf("failed to parse fingerprint response: %v", err)
+ return nil
+ }
+
+ return &fpResp.TLS
+}
diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go
index 31a59fc7..345067e5 100644
--- a/backend/internal/pkg/tlsfingerprint/dialer_test.go
+++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go
@@ -1,10 +1,11 @@
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
-// Integration tests for verifying TLS fingerprint correctness.
-// These tests make actual network requests and should be run manually.
+// Unit tests for TLS fingerprint dialer.
+// Integration tests that require external network are in dialer_integration_test.go
+// and require the 'integration' build tag.
//
-// Run with: go test -v ./internal/pkg/tlsfingerprint/...
-// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/...
+// Run unit tests: go test -v ./internal/pkg/tlsfingerprint/...
+// Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index 170e5de9..f38f0cfa 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -809,12 +809,21 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
return err
}
- path := "{antigravity_quota_scopes," + string(scope) + "}"
+ scopeKey := string(scope)
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
- "UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
- path,
+ `UPDATE accounts SET
+ extra = jsonb_set(
+ jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true),
+ ARRAY['antigravity_quota_scopes', $1]::text[],
+ $2::jsonb,
+ true
+ ),
+ updated_at = NOW(),
+ last_used_at = NOW()
+ WHERE id = $3 AND deleted_at IS NULL`,
+ scopeKey,
raw,
id,
)
@@ -829,6 +838,7 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
if affected == 0 {
return service.ErrAccountNotFound
}
+
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
}
@@ -849,12 +859,19 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
return err
}
- path := "{model_rate_limits," + scope + "}"
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
- "UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
- path,
+ `UPDATE accounts SET
+ extra = jsonb_set(
+ jsonb_set(COALESCE(extra, '{}'::jsonb), '{model_rate_limits}'::text[], COALESCE(extra->'model_rate_limits', '{}'::jsonb), true),
+ ARRAY['model_rate_limits', $1]::text[],
+ $2::jsonb,
+ true
+ ),
+ updated_at = NOW()
+ WHERE id = $3 AND deleted_at IS NULL`,
+ scope,
raw,
id,
)
diff --git a/backend/internal/repository/aes_encryptor.go b/backend/internal/repository/aes_encryptor.go
new file mode 100644
index 00000000..924e3698
--- /dev/null
+++ b/backend/internal/repository/aes_encryptor.go
@@ -0,0 +1,95 @@
+package repository
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/hex"
+ "fmt"
+ "io"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+// AESEncryptor implements SecretEncryptor using AES-256-GCM
+type AESEncryptor struct {
+ key []byte
+}
+
+// NewAESEncryptor creates a new AES encryptor
+func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
+ key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
+ if err != nil {
+ return nil, fmt.Errorf("invalid totp encryption key: %w", err)
+ }
+
+ if len(key) != 32 {
+ return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
+ }
+
+ return &AESEncryptor{key: key}, nil
+}
+
+// Encrypt encrypts plaintext using AES-256-GCM
+// Output format: base64(nonce + ciphertext + tag)
+func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
+ block, err := aes.NewCipher(e.key)
+ if err != nil {
+ return "", fmt.Errorf("create cipher: %w", err)
+ }
+
+ gcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return "", fmt.Errorf("create gcm: %w", err)
+ }
+
+ // Generate a random nonce
+ nonce := make([]byte, gcm.NonceSize())
+ if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
+ return "", fmt.Errorf("generate nonce: %w", err)
+ }
+
+ // Encrypt the plaintext
+ // Seal appends the ciphertext and tag to the nonce
+ ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
+
+ // Encode as base64
+ return base64.StdEncoding.EncodeToString(ciphertext), nil
+}
+
+// Decrypt decrypts ciphertext using AES-256-GCM
+func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
+ // Decode from base64
+ data, err := base64.StdEncoding.DecodeString(ciphertext)
+ if err != nil {
+ return "", fmt.Errorf("decode base64: %w", err)
+ }
+
+ block, err := aes.NewCipher(e.key)
+ if err != nil {
+ return "", fmt.Errorf("create cipher: %w", err)
+ }
+
+ gcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return "", fmt.Errorf("create gcm: %w", err)
+ }
+
+ nonceSize := gcm.NonceSize()
+ if len(data) < nonceSize {
+ return "", fmt.Errorf("ciphertext too short")
+ }
+
+ // Extract nonce and ciphertext
+ nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
+
+ // Decrypt
+ plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
+ if err != nil {
+ return "", fmt.Errorf("decrypt: %w", err)
+ }
+
+ return string(plaintext), nil
+}
diff --git a/backend/internal/repository/announcement_read_repo.go b/backend/internal/repository/announcement_read_repo.go
new file mode 100644
index 00000000..2dc346b1
--- /dev/null
+++ b/backend/internal/repository/announcement_read_repo.go
@@ -0,0 +1,83 @@
+package repository
+
+import (
+ "context"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/announcementread"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+type announcementReadRepository struct {
+ client *dbent.Client
+}
+
+func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementReadRepository {
+ return &announcementReadRepository{client: client}
+}
+
+func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
+ client := clientFromContext(ctx, r.client)
+ return client.AnnouncementRead.Create().
+ SetAnnouncementID(announcementID).
+ SetUserID(userID).
+ SetReadAt(readAt).
+ OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
+ DoNothing().
+ Exec(ctx)
+}
+
+func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
+ if len(announcementIDs) == 0 {
+ return map[int64]time.Time{}, nil
+ }
+
+ rows, err := r.client.AnnouncementRead.Query().
+ Where(
+ announcementread.UserIDEQ(userID),
+ announcementread.AnnouncementIDIn(announcementIDs...),
+ ).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ out := make(map[int64]time.Time, len(rows))
+ for i := range rows {
+ out[rows[i].AnnouncementID] = rows[i].ReadAt
+ }
+ return out, nil
+}
+
+func (r *announcementReadRepository) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
+ if len(userIDs) == 0 {
+ return map[int64]time.Time{}, nil
+ }
+
+ rows, err := r.client.AnnouncementRead.Query().
+ Where(
+ announcementread.AnnouncementIDEQ(announcementID),
+ announcementread.UserIDIn(userIDs...),
+ ).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ out := make(map[int64]time.Time, len(rows))
+ for i := range rows {
+ out[rows[i].UserID] = rows[i].ReadAt
+ }
+ return out, nil
+}
+
+func (r *announcementReadRepository) CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error) {
+ count, err := r.client.AnnouncementRead.Query().
+ Where(announcementread.AnnouncementIDEQ(announcementID)).
+ Count(ctx)
+ if err != nil {
+ return 0, err
+ }
+ return int64(count), nil
+}
diff --git a/backend/internal/repository/announcement_repo.go b/backend/internal/repository/announcement_repo.go
new file mode 100644
index 00000000..52029e4e
--- /dev/null
+++ b/backend/internal/repository/announcement_repo.go
@@ -0,0 +1,194 @@
+package repository
+
+import (
+ "context"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/announcement"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+type announcementRepository struct {
+ client *dbent.Client
+}
+
+func NewAnnouncementRepository(client *dbent.Client) service.AnnouncementRepository {
+ return &announcementRepository{client: client}
+}
+
+func (r *announcementRepository) Create(ctx context.Context, a *service.Announcement) error {
+ client := clientFromContext(ctx, r.client)
+ builder := client.Announcement.Create().
+ SetTitle(a.Title).
+ SetContent(a.Content).
+ SetStatus(a.Status).
+ SetTargeting(a.Targeting)
+
+ if a.StartsAt != nil {
+ builder.SetStartsAt(*a.StartsAt)
+ }
+ if a.EndsAt != nil {
+ builder.SetEndsAt(*a.EndsAt)
+ }
+ if a.CreatedBy != nil {
+ builder.SetCreatedBy(*a.CreatedBy)
+ }
+ if a.UpdatedBy != nil {
+ builder.SetUpdatedBy(*a.UpdatedBy)
+ }
+
+ created, err := builder.Save(ctx)
+ if err != nil {
+ return err
+ }
+
+ applyAnnouncementEntityToService(a, created)
+ return nil
+}
+
+func (r *announcementRepository) GetByID(ctx context.Context, id int64) (*service.Announcement, error) {
+ m, err := r.client.Announcement.Query().
+ Where(announcement.IDEQ(id)).
+ Only(ctx)
+ if err != nil {
+ return nil, translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
+ }
+ return announcementEntityToService(m), nil
+}
+
+func (r *announcementRepository) Update(ctx context.Context, a *service.Announcement) error {
+ client := clientFromContext(ctx, r.client)
+ builder := client.Announcement.UpdateOneID(a.ID).
+ SetTitle(a.Title).
+ SetContent(a.Content).
+ SetStatus(a.Status).
+ SetTargeting(a.Targeting)
+
+ if a.StartsAt != nil {
+ builder.SetStartsAt(*a.StartsAt)
+ } else {
+ builder.ClearStartsAt()
+ }
+ if a.EndsAt != nil {
+ builder.SetEndsAt(*a.EndsAt)
+ } else {
+ builder.ClearEndsAt()
+ }
+ if a.CreatedBy != nil {
+ builder.SetCreatedBy(*a.CreatedBy)
+ } else {
+ builder.ClearCreatedBy()
+ }
+ if a.UpdatedBy != nil {
+ builder.SetUpdatedBy(*a.UpdatedBy)
+ } else {
+ builder.ClearUpdatedBy()
+ }
+
+ updated, err := builder.Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
+ }
+
+ a.UpdatedAt = updated.UpdatedAt
+ return nil
+}
+
+func (r *announcementRepository) Delete(ctx context.Context, id int64) error {
+ client := clientFromContext(ctx, r.client)
+ _, err := client.Announcement.Delete().Where(announcement.IDEQ(id)).Exec(ctx)
+ return err
+}
+
+func (r *announcementRepository) List(
+ ctx context.Context,
+ params pagination.PaginationParams,
+ filters service.AnnouncementListFilters,
+) ([]service.Announcement, *pagination.PaginationResult, error) {
+ q := r.client.Announcement.Query()
+
+ if filters.Status != "" {
+ q = q.Where(announcement.StatusEQ(filters.Status))
+ }
+ if filters.Search != "" {
+ q = q.Where(
+ announcement.Or(
+ announcement.TitleContainsFold(filters.Search),
+ announcement.ContentContainsFold(filters.Search),
+ ),
+ )
+ }
+
+ total, err := q.Count(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ items, err := q.
+ Offset(params.Offset()).
+ Limit(params.Limit()).
+ Order(dbent.Desc(announcement.FieldID)).
+ All(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ out := announcementEntitiesToService(items)
+ return out, paginationResultFromTotal(int64(total), params), nil
+}
+
+func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) {
+ q := r.client.Announcement.Query().
+ Where(
+ announcement.StatusEQ(service.AnnouncementStatusActive),
+ announcement.Or(announcement.StartsAtIsNil(), announcement.StartsAtLTE(now)),
+ announcement.Or(announcement.EndsAtIsNil(), announcement.EndsAtGT(now)),
+ ).
+ Order(dbent.Desc(announcement.FieldID))
+
+ items, err := q.All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return announcementEntitiesToService(items), nil
+}
+
+func applyAnnouncementEntityToService(dst *service.Announcement, src *dbent.Announcement) {
+ if dst == nil || src == nil {
+ return
+ }
+ dst.ID = src.ID
+ dst.CreatedAt = src.CreatedAt
+ dst.UpdatedAt = src.UpdatedAt
+}
+
+func announcementEntityToService(m *dbent.Announcement) *service.Announcement {
+ if m == nil {
+ return nil
+ }
+ return &service.Announcement{
+ ID: m.ID,
+ Title: m.Title,
+ Content: m.Content,
+ Status: m.Status,
+ Targeting: m.Targeting,
+ StartsAt: m.StartsAt,
+ EndsAt: m.EndsAt,
+ CreatedBy: m.CreatedBy,
+ UpdatedBy: m.UpdatedBy,
+ CreatedAt: m.CreatedAt,
+ UpdatedAt: m.UpdatedAt,
+ }
+}
+
+func announcementEntitiesToService(models []*dbent.Announcement) []service.Announcement {
+ out := make([]service.Announcement, 0, len(models))
+ for i := range models {
+ if s := announcementEntityToService(models[i]); s != nil {
+ out = append(out, *s)
+ }
+ }
+ return out
+}
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index a020ee2b..25fb88b8 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -391,17 +391,20 @@ func userEntityToService(u *dbent.User) *service.User {
return nil
}
return &service.User{
- ID: u.ID,
- Email: u.Email,
- Username: u.Username,
- Notes: u.Notes,
- PasswordHash: u.PasswordHash,
- Role: u.Role,
- Balance: u.Balance,
- Concurrency: u.Concurrency,
- Status: u.Status,
- CreatedAt: u.CreatedAt,
- UpdatedAt: u.UpdatedAt,
+ ID: u.ID,
+ Email: u.Email,
+ Username: u.Username,
+ Notes: u.Notes,
+ PasswordHash: u.PasswordHash,
+ Role: u.Role,
+ Balance: u.Balance,
+ Concurrency: u.Concurrency,
+ Status: u.Status,
+ TotpSecretEncrypted: u.TotpSecretEncrypted,
+ TotpEnabled: u.TotpEnabled,
+ TotpEnabledAt: u.TotpEnabledAt,
+ CreatedAt: u.CreatedAt,
+ UpdatedAt: u.UpdatedAt,
}
}
diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go
index e00e35dd..8f2b8eca 100644
--- a/backend/internal/repository/email_cache.go
+++ b/backend/internal/repository/email_cache.go
@@ -9,13 +9,27 @@ import (
"github.com/redis/go-redis/v9"
)
-const verifyCodeKeyPrefix = "verify_code:"
+const (
+ verifyCodeKeyPrefix = "verify_code:"
+ passwordResetKeyPrefix = "password_reset:"
+ passwordResetSentAtKeyPrefix = "password_reset_sent:"
+)
// verifyCodeKey generates the Redis key for email verification code.
func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
}
+// passwordResetKey generates the Redis key for password reset token.
+func passwordResetKey(email string) string {
+ return passwordResetKeyPrefix + email
+}
+
+// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
+func passwordResetSentAtKey(email string) string {
+ return passwordResetSentAtKeyPrefix + email
+}
+
type emailCache struct {
rdb *redis.Client
}
@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
key := verifyCodeKey(email)
return c.rdb.Del(ctx, key).Err()
}
+
+// Password reset token methods
+
+func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) {
+ key := passwordResetKey(email)
+ val, err := c.rdb.Get(ctx, key).Result()
+ if err != nil {
+ return nil, err
+ }
+ var data service.PasswordResetTokenData
+ if err := json.Unmarshal([]byte(val), &data); err != nil {
+ return nil, err
+ }
+ return &data, nil
+}
+
+func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error {
+ key := passwordResetKey(email)
+ val, err := json.Marshal(data)
+ if err != nil {
+ return err
+ }
+ return c.rdb.Set(ctx, key, val, ttl).Err()
+}
+
+func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error {
+ key := passwordResetKey(email)
+ return c.rdb.Del(ctx, key).Err()
+}
+
+// Password reset email cooldown methods
+
+func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
+ key := passwordResetSentAtKey(email)
+ exists, err := c.rdb.Exists(ctx, key).Result()
+ return err == nil && exists > 0
+}
+
+func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
+ key := passwordResetSentAtKey(email)
+ return c.rdb.Set(ctx, key, "1", ttl).Err()
+}
diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go
index 75684fc9..14e5cb86 100644
--- a/backend/internal/repository/group_repo.go
+++ b/backend/internal/repository/group_repo.go
@@ -433,3 +433,61 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
return counts, nil
}
+
+// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
+func (r *groupRepository) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
+ if len(groupIDs) == 0 {
+ return nil, nil
+ }
+
+ rows, err := r.sql.QueryContext(
+ ctx,
+ "SELECT DISTINCT account_id FROM account_groups WHERE group_id = ANY($1) ORDER BY account_id",
+ pq.Array(groupIDs),
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ var accountIDs []int64
+ for rows.Next() {
+ var accountID int64
+ if err := rows.Scan(&accountID); err != nil {
+ return nil, err
+ }
+ accountIDs = append(accountIDs, accountID)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return accountIDs, nil
+}
+
+// BindAccountsToGroup 将多个账号绑定到指定分组(批量插入,忽略已存在的绑定)
+func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
+ if len(accountIDs) == 0 {
+ return nil
+ }
+
+ // 使用 INSERT ... ON CONFLICT DO NOTHING 忽略已存在的绑定
+ _, err := r.sql.ExecContext(
+ ctx,
+ `INSERT INTO account_groups (account_id, group_id, priority, created_at)
+ SELECT unnest($1::bigint[]), $2, 50, NOW()
+ ON CONFLICT (account_id, group_id) DO NOTHING`,
+ pq.Array(accountIDs),
+ groupID,
+ )
+ if err != nil {
+ return err
+ }
+
+ // 发送调度器事件
+ if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
+ log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
+ }
+
+ return nil
+}
diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go
index b7f3606f..394d3a1a 100644
--- a/backend/internal/repository/openai_oauth_service.go
+++ b/backend/internal/repository/openai_oauth_service.go
@@ -2,11 +2,11 @@ package repository
import (
"context"
- "fmt"
+ "net/http"
"net/url"
- "strings"
"time"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
@@ -22,7 +22,7 @@ type openaiOAuthService struct {
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
- client := createOpenAIReqClient(s.tokenURL, proxyURL)
+ client := createOpenAIReqClient(proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
@@ -39,23 +39,24 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
resp, err := client.R().
SetContext(ctx).
+ SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
- return nil, fmt.Errorf("request failed: %w", err)
+ return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
- return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
+ return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED", "token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
- client := createOpenAIReqClient(s.tokenURL, proxyURL)
+ client := createOpenAIReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
@@ -67,29 +68,25 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
resp, err := client.R().
SetContext(ctx).
+ SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
- return nil, fmt.Errorf("request failed: %w", err)
+ return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
- return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
+ return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
-func createOpenAIReqClient(tokenURL, proxyURL string) *req.Client {
- forceHTTP2 := false
- if parsedURL, err := url.Parse(tokenURL); err == nil {
- forceHTTP2 = strings.EqualFold(parsedURL.Scheme, "https")
- }
+func createOpenAIReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
- ProxyURL: proxyURL,
- Timeout: 120 * time.Second,
- ForceHTTP2: forceHTTP2,
+ ProxyURL: proxyURL,
+ Timeout: 120 * time.Second,
})
}
diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go
index fb6f405e..513e929c 100644
--- a/backend/internal/repository/proxy_probe_service.go
+++ b/backend/internal/repository/proxy_probe_service.go
@@ -28,7 +28,6 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
}
return &proxyProbeService{
- ipInfoURL: defaultIPInfoURL,
insecureSkipVerify: insecure,
allowPrivateHosts: allowPrivate,
validateResolvedIP: validateResolvedIP,
@@ -36,12 +35,20 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
}
const (
- defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
defaultProxyProbeTimeout = 30 * time.Second
)
+// probeURLs 按优先级排列的探测 URL 列表
+// 某些 AI API 专用代理只允许访问特定域名,因此需要多个备选
+var probeURLs = []struct {
+ url string
+ parser string // "ip-api" or "httpbin"
+}{
+ {"http://ip-api.com/json/?lang=zh-CN", "ip-api"},
+ {"http://httpbin.org/ip", "httpbin"},
+}
+
type proxyProbeService struct {
- ipInfoURL string
insecureSkipVerify bool
allowPrivateHosts bool
validateResolvedIP bool
@@ -60,8 +67,21 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
}
+ var lastErr error
+ for _, probe := range probeURLs {
+ exitInfo, latencyMs, err := s.probeWithURL(ctx, client, probe.url, probe.parser)
+ if err == nil {
+ return exitInfo, latencyMs, nil
+ }
+ lastErr = err
+ }
+
+ return nil, 0, fmt.Errorf("all probe URLs failed, last error: %w", lastErr)
+}
+
+func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Client, url string, parser string) (*service.ProxyExitInfo, int64, error) {
startTime := time.Now()
- req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %w", err)
}
@@ -78,6 +98,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
}
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
+ }
+
+ switch parser {
+ case "ip-api":
+ return s.parseIPAPI(body, latencyMs)
+ case "httpbin":
+ return s.parseHTTPBin(body, latencyMs)
+ default:
+ return nil, latencyMs, fmt.Errorf("unknown parser: %s", parser)
+ }
+}
+
+func (s *proxyProbeService) parseIPAPI(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
var ipInfo struct {
Status string `json:"status"`
Message string `json:"message"`
@@ -89,13 +125,12 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
CountryCode string `json:"countryCode"`
}
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
- }
-
if err := json.Unmarshal(body, &ipInfo); err != nil {
- return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
+ preview := string(body)
+ if len(preview) > 200 {
+ preview = preview[:200] + "..."
+ }
+ return nil, latencyMs, fmt.Errorf("failed to parse response: %w (body: %s)", err, preview)
}
if strings.ToLower(ipInfo.Status) != "success" {
if ipInfo.Message == "" {
@@ -116,3 +151,19 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
CountryCode: ipInfo.CountryCode,
}, latencyMs, nil
}
+
+func (s *proxyProbeService) parseHTTPBin(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
+ // httpbin.org/ip 返回格式: {"origin": "1.2.3.4"}
+ var result struct {
+ Origin string `json:"origin"`
+ }
+ if err := json.Unmarshal(body, &result); err != nil {
+ return nil, latencyMs, fmt.Errorf("failed to parse httpbin response: %w", err)
+ }
+ if result.Origin == "" {
+ return nil, latencyMs, fmt.Errorf("httpbin: no IP found in response")
+ }
+ return &service.ProxyExitInfo{
+ IP: result.Origin,
+ }, latencyMs, nil
+}
diff --git a/backend/internal/repository/proxy_probe_service_test.go b/backend/internal/repository/proxy_probe_service_test.go
index f1cd5721..7450653b 100644
--- a/backend/internal/repository/proxy_probe_service_test.go
+++ b/backend/internal/repository/proxy_probe_service_test.go
@@ -5,6 +5,7 @@ import (
"io"
"net/http"
"net/http/httptest"
+ "strings"
"testing"
"github.com/stretchr/testify/require"
@@ -21,7 +22,6 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{
- ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
allowPrivateHosts: true,
}
}
@@ -49,12 +49,16 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
require.ErrorContains(s.T(), err, "failed to create proxy client")
}
-func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
- seen := make(chan string, 1)
+func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_IPAPI() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- seen <- r.RequestURI
- w.Header().Set("Content-Type", "application/json")
- _, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
+ // 检查是否是 ip-api 请求
+ if strings.Contains(r.RequestURI, "ip-api.com") {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
+ return
+ }
+ // 其他请求返回错误
+ w.WriteHeader(http.StatusServiceUnavailable)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
@@ -65,45 +69,59 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
require.Equal(s.T(), "r", info.Region)
require.Equal(s.T(), "cc", info.Country)
require.Equal(s.T(), "CC", info.CountryCode)
-
- // Verify proxy received the request
- select {
- case uri := <-seen:
- require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
- default:
- require.Fail(s.T(), "expected proxy to receive request")
- }
}
-func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
+func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_HTTPBinFallback() {
+ s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // ip-api 失败
+ if strings.Contains(r.RequestURI, "ip-api.com") {
+ w.WriteHeader(http.StatusServiceUnavailable)
+ return
+ }
+ // httpbin 成功
+ if strings.Contains(r.RequestURI, "httpbin.org") {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = io.WriteString(w, `{"origin": "5.6.7.8"}`)
+ return
+ }
+ w.WriteHeader(http.StatusServiceUnavailable)
+ }))
+
+ info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
+ require.NoError(s.T(), err, "ProbeProxy should fallback to httpbin")
+ require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
+ require.Equal(s.T(), "5.6.7.8", info.IP)
+}
+
+func (s *ProxyProbeServiceSuite) TestProbeProxy_AllFailed() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
- require.ErrorContains(s.T(), err, "status: 503")
+ require.ErrorContains(s.T(), err, "all probe URLs failed")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "application/json")
- _, _ = io.WriteString(w, "not-json")
+ if strings.Contains(r.RequestURI, "ip-api.com") {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = io.WriteString(w, "not-json")
+ return
+ }
+ // httpbin 也返回无效响应
+ if strings.Contains(r.RequestURI, "httpbin.org") {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = io.WriteString(w, "not-json")
+ return
+ }
+ w.WriteHeader(http.StatusServiceUnavailable)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
- require.ErrorContains(s.T(), err, "failed to parse response")
-}
-
-func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
- s.prober.ipInfoURL = "://invalid-url"
- s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- }))
-
- _, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
- require.Error(s.T(), err, "expected error for invalid ipInfoURL")
+ require.ErrorContains(s.T(), err, "all probe URLs failed")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
@@ -114,6 +132,40 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
require.Error(s.T(), err, "expected error when proxy server is closed")
}
+func (s *ProxyProbeServiceSuite) TestParseIPAPI_Success() {
+ body := []byte(`{"status":"success","query":"1.2.3.4","city":"Beijing","regionName":"Beijing","country":"China","countryCode":"CN"}`)
+ info, latencyMs, err := s.prober.parseIPAPI(body, 100)
+ require.NoError(s.T(), err)
+ require.Equal(s.T(), int64(100), latencyMs)
+ require.Equal(s.T(), "1.2.3.4", info.IP)
+ require.Equal(s.T(), "Beijing", info.City)
+ require.Equal(s.T(), "Beijing", info.Region)
+ require.Equal(s.T(), "China", info.Country)
+ require.Equal(s.T(), "CN", info.CountryCode)
+}
+
+func (s *ProxyProbeServiceSuite) TestParseIPAPI_Failure() {
+ body := []byte(`{"status":"fail","message":"rate limited"}`)
+ _, _, err := s.prober.parseIPAPI(body, 100)
+ require.Error(s.T(), err)
+ require.ErrorContains(s.T(), err, "rate limited")
+}
+
+func (s *ProxyProbeServiceSuite) TestParseHTTPBin_Success() {
+ body := []byte(`{"origin": "9.8.7.6"}`)
+ info, latencyMs, err := s.prober.parseHTTPBin(body, 50)
+ require.NoError(s.T(), err)
+ require.Equal(s.T(), int64(50), latencyMs)
+ require.Equal(s.T(), "9.8.7.6", info.IP)
+}
+
+func (s *ProxyProbeServiceSuite) TestParseHTTPBin_NoIP() {
+ body := []byte(`{"origin": ""}`)
+ _, _, err := s.prober.parseHTTPBin(body, 50)
+ require.Error(s.T(), err)
+ require.ErrorContains(s.T(), err, "no IP found")
+}
+
func TestProxyProbeServiceSuite(t *testing.T) {
suite.Run(t, new(ProxyProbeServiceSuite))
}
diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go
index ee8a01b5..a3a048c3 100644
--- a/backend/internal/repository/redeem_code_repo.go
+++ b/backend/internal/repository/redeem_code_repo.go
@@ -202,6 +202,57 @@ func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, lim
return redeemCodeEntitiesToService(codes), nil
}
+// ListByUserPaginated returns paginated balance/concurrency history for a user.
+// Supports optional type filter (e.g. "balance", "admin_balance", "concurrency", "admin_concurrency", "subscription").
+func (r *redeemCodeRepository) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ q := r.client.RedeemCode.Query().
+ Where(redeemcode.UsedByEQ(userID))
+
+ // Optional type filter
+ if codeType != "" {
+ q = q.Where(redeemcode.TypeEQ(codeType))
+ }
+
+ total, err := q.Count(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ codes, err := q.
+ WithGroup().
+ Offset(params.Offset()).
+ Limit(params.Limit()).
+ Order(dbent.Desc(redeemcode.FieldUsedAt)).
+ All(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return redeemCodeEntitiesToService(codes), paginationResultFromTotal(int64(total), params), nil
+}
+
+// SumPositiveBalanceByUser returns total recharged amount (sum of value > 0 where type is balance/admin_balance).
+func (r *redeemCodeRepository) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
+ var result []struct {
+ Sum float64 `json:"sum"`
+ }
+ err := r.client.RedeemCode.Query().
+ Where(
+ redeemcode.UsedByEQ(userID),
+ redeemcode.ValueGT(0),
+ redeemcode.TypeIn("balance", "admin_balance"),
+ ).
+ Aggregate(dbent.As(dbent.Sum(redeemcode.FieldValue), "sum")).
+ Scan(ctx, &result)
+ if err != nil {
+ return 0, err
+ }
+ if len(result) == 0 {
+ return 0, nil
+ }
+ return result[0].Sum, nil
+}
+
func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
if m == nil {
return nil
diff --git a/backend/internal/repository/redis.go b/backend/internal/repository/redis.go
index f3606ad9..2b4ee4e6 100644
--- a/backend/internal/repository/redis.go
+++ b/backend/internal/repository/redis.go
@@ -1,6 +1,7 @@
package repository
import (
+ "crypto/tls"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -26,7 +27,7 @@ func InitRedis(cfg *config.Config) *redis.Client {
// buildRedisOptions 构建 Redis 连接选项
// 从配置文件读取连接池和超时参数,支持生产环境调优
func buildRedisOptions(cfg *config.Config) *redis.Options {
- return &redis.Options{
+ opts := &redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
@@ -36,4 +37,13 @@ func buildRedisOptions(cfg *config.Config) *redis.Options {
PoolSize: cfg.Redis.PoolSize, // 连接池大小
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
}
+
+ if cfg.Redis.EnableTLS {
+ opts.TLSConfig = &tls.Config{
+ MinVersion: tls.VersionTLS12,
+ ServerName: cfg.Redis.Host,
+ }
+ }
+
+ return opts
}
diff --git a/backend/internal/repository/redis_test.go b/backend/internal/repository/redis_test.go
index 756a63dc..7cb31002 100644
--- a/backend/internal/repository/redis_test.go
+++ b/backend/internal/repository/redis_test.go
@@ -32,4 +32,16 @@ func TestBuildRedisOptions(t *testing.T) {
require.Equal(t, 4*time.Second, opts.WriteTimeout)
require.Equal(t, 100, opts.PoolSize)
require.Equal(t, 10, opts.MinIdleConns)
+ require.Nil(t, opts.TLSConfig)
+
+ // Test case with TLS enabled
+ cfgTLS := &config.Config{
+ Redis: config.RedisConfig{
+ Host: "localhost",
+ EnableTLS: true,
+ },
+ }
+ optsTLS := buildRedisOptions(cfgTLS)
+ require.NotNil(t, optsTLS.TLSConfig)
+ require.Equal(t, "localhost", optsTLS.TLSConfig.ServerName)
}
diff --git a/backend/internal/repository/req_client_pool_test.go b/backend/internal/repository/req_client_pool_test.go
index cf7e8bd0..904ed4d6 100644
--- a/backend/internal/repository/req_client_pool_test.go
+++ b/backend/internal/repository/req_client_pool_test.go
@@ -77,21 +77,9 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
}
-func TestCreateOpenAIReqClient_ForceHTTP2Enabled(t *testing.T) {
- sharedReqClients = sync.Map{}
- client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
- require.Equal(t, "2", forceHTTPVersion(t, client))
-}
-
-func TestCreateOpenAIReqClient_ForceHTTP2DisabledForHTTP(t *testing.T) {
- sharedReqClients = sync.Map{}
- client := createOpenAIReqClient("http://localhost/oauth/token", "http://proxy.local:8080")
- require.Equal(t, "", forceHTTPVersion(t, client))
-}
-
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{}
- client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
+ client := createOpenAIReqClient("http://proxy.local:8080")
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
}
diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go
index 13b22107..4f447e4f 100644
--- a/backend/internal/repository/scheduler_cache.go
+++ b/backend/internal/repository/scheduler_cache.go
@@ -58,7 +58,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
return nil, false, err
}
if len(ids) == 0 {
- return []*service.Account{}, true, nil
+ // 空快照视为缓存未命中,触发数据库回退查询
+ // 这解决了新分组创建后立即绑定账号时的竞态条件问题
+ return nil, false, nil
}
keys := make([]string, 0, len(ids))
diff --git a/backend/internal/repository/totp_cache.go b/backend/internal/repository/totp_cache.go
new file mode 100644
index 00000000..2f4a8ab2
--- /dev/null
+++ b/backend/internal/repository/totp_cache.go
@@ -0,0 +1,149 @@
+package repository
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "time"
+
+ "github.com/redis/go-redis/v9"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+const (
+ totpSetupKeyPrefix = "totp:setup:"
+ totpLoginKeyPrefix = "totp:login:"
+ totpAttemptsKeyPrefix = "totp:attempts:"
+ totpAttemptsTTL = 15 * time.Minute
+)
+
+// TotpCache implements service.TotpCache using Redis
+type TotpCache struct {
+ rdb *redis.Client
+}
+
+// NewTotpCache creates a new TOTP cache
+func NewTotpCache(rdb *redis.Client) service.TotpCache {
+ return &TotpCache{rdb: rdb}
+}
+
+// GetSetupSession retrieves a TOTP setup session
+func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) {
+ key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
+ data, err := c.rdb.Get(ctx, key).Bytes()
+ if err != nil {
+ if err == redis.Nil {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("get setup session: %w", err)
+ }
+
+ var session service.TotpSetupSession
+ if err := json.Unmarshal(data, &session); err != nil {
+ return nil, fmt.Errorf("unmarshal setup session: %w", err)
+ }
+
+ return &session, nil
+}
+
+// SetSetupSession stores a TOTP setup session
+func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error {
+ key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
+ data, err := json.Marshal(session)
+ if err != nil {
+ return fmt.Errorf("marshal setup session: %w", err)
+ }
+
+ if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
+ return fmt.Errorf("set setup session: %w", err)
+ }
+
+ return nil
+}
+
+// DeleteSetupSession deletes a TOTP setup session
+func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error {
+ key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
+ return c.rdb.Del(ctx, key).Err()
+}
+
+// GetLoginSession retrieves a TOTP login session
+func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) {
+ key := totpLoginKeyPrefix + tempToken
+ data, err := c.rdb.Get(ctx, key).Bytes()
+ if err != nil {
+ if err == redis.Nil {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("get login session: %w", err)
+ }
+
+ var session service.TotpLoginSession
+ if err := json.Unmarshal(data, &session); err != nil {
+ return nil, fmt.Errorf("unmarshal login session: %w", err)
+ }
+
+ return &session, nil
+}
+
+// SetLoginSession stores a TOTP login session
+func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error {
+ key := totpLoginKeyPrefix + tempToken
+ data, err := json.Marshal(session)
+ if err != nil {
+ return fmt.Errorf("marshal login session: %w", err)
+ }
+
+ if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
+ return fmt.Errorf("set login session: %w", err)
+ }
+
+ return nil
+}
+
+// DeleteLoginSession deletes a TOTP login session
+func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error {
+ key := totpLoginKeyPrefix + tempToken
+ return c.rdb.Del(ctx, key).Err()
+}
+
+// IncrementVerifyAttempts increments the verify attempt counter
+func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) {
+ key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
+
+ // Use pipeline for atomic increment and set TTL
+ pipe := c.rdb.Pipeline()
+ incrCmd := pipe.Incr(ctx, key)
+ pipe.Expire(ctx, key, totpAttemptsTTL)
+
+ if _, err := pipe.Exec(ctx); err != nil {
+ return 0, fmt.Errorf("increment verify attempts: %w", err)
+ }
+
+ count, err := incrCmd.Result()
+ if err != nil {
+ return 0, fmt.Errorf("get increment result: %w", err)
+ }
+
+ return int(count), nil
+}
+
+// GetVerifyAttempts gets the current verify attempt count
+func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) {
+ key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
+ count, err := c.rdb.Get(ctx, key).Int()
+ if err != nil {
+ if err == redis.Nil {
+ return 0, nil
+ }
+ return 0, fmt.Errorf("get verify attempts: %w", err)
+ }
+ return count, nil
+}
+
+// ClearVerifyAttempts clears the verify attempt counter
+func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error {
+ key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
+ return c.rdb.Del(ctx, key).Err()
+}
diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go
index 0696c958..c53b7bad 100644
--- a/backend/internal/repository/usage_log_repo.go
+++ b/backend/internal/repository/usage_log_repo.go
@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
-const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, created_at"
+const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, created_at"
type usageLogRepository struct {
client *dbent.Client
@@ -115,6 +115,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
image_count,
image_size,
media_type,
+ reasoning_effort,
created_at
) VALUES (
$1, $2, $3, $4, $5,
@@ -122,7 +123,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
- $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31
+ $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -136,6 +137,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
mediaType := nullString(log.MediaType)
+ reasoningEffort := nullString(log.ReasoningEffort)
var requestIDArg any
if requestID != "" {
@@ -173,6 +175,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.ImageCount,
imageSize,
mediaType,
+ reasoningEffort,
createdAt,
}
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
@@ -2094,6 +2097,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
imageCount int
imageSize sql.NullString
mediaType sql.NullString
+ reasoningEffort sql.NullString
createdAt time.Time
)
@@ -2129,6 +2133,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&imageCount,
&imageSize,
&mediaType,
+ &reasoningEffort,
&createdAt,
); err != nil {
return nil, err
@@ -2191,6 +2196,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if mediaType.Valid {
log.MediaType = &mediaType.String
}
+ if reasoningEffort.Valid {
+ log.ReasoningEffort = &reasoningEffort.String
+ }
return log, nil
}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 006a5464..654bd16b 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -7,6 +7,7 @@ import (
"fmt"
"sort"
"strings"
+ "time"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
@@ -189,6 +190,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
dbuser.Or(
dbuser.EmailContainsFold(filters.Search),
dbuser.UsernameContainsFold(filters.Search),
+ dbuser.NotesContainsFold(filters.Search),
),
)
}
@@ -466,3 +468,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
+
+// UpdateTotpSecret 更新用户的 TOTP 加密密钥
+func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
+ client := clientFromContext(ctx, r.client)
+ update := client.User.UpdateOneID(userID)
+ if encryptedSecret == nil {
+ update = update.ClearTotpSecretEncrypted()
+ } else {
+ update = update.SetTotpSecretEncrypted(*encryptedSecret)
+ }
+ _, err := update.Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ return nil
+}
+
+// EnableTotp 启用用户的 TOTP 双因素认证
+func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
+ client := clientFromContext(ctx, r.client)
+ _, err := client.User.UpdateOneID(userID).
+ SetTotpEnabled(true).
+ SetTotpEnabledAt(time.Now()).
+ Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ return nil
+}
+
+// DisableTotp 禁用用户的 TOTP 双因素认证
+func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
+ client := clientFromContext(ctx, r.client)
+ _, err := client.User.UpdateOneID(userID).
+ SetTotpEnabled(false).
+ ClearTotpEnabledAt().
+ ClearTotpSecretEncrypted().
+ Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ return nil
+}
diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go
index cd3b9db6..5a649846 100644
--- a/backend/internal/repository/user_subscription_repo.go
+++ b/backend/internal/repository/user_subscription_repo.go
@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
-func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
+func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID))
}
- if status != "" {
+
+ // Status filtering with real-time expiration check
+ now := time.Now()
+ switch status {
+ case service.SubscriptionStatusActive:
+ // Active: status is active AND not yet expired
+ q = q.Where(
+ usersubscription.StatusEQ(service.SubscriptionStatusActive),
+ usersubscription.ExpiresAtGT(now),
+ )
+ case service.SubscriptionStatusExpired:
+ // Expired: status is expired OR (status is active but already expired)
+ q = q.Where(
+ usersubscription.Or(
+ usersubscription.StatusEQ(service.SubscriptionStatusExpired),
+ usersubscription.And(
+ usersubscription.StatusEQ(service.SubscriptionStatusActive),
+ usersubscription.ExpiresAtLTE(now),
+ ),
+ ),
+ )
+ case "":
+ // No filter
+ default:
+ // Other status (e.g., revoked)
q = q.Where(usersubscription.StatusEQ(status))
}
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return nil, nil, err
}
+ // Apply sorting
+ q = q.WithUser().WithGroup().WithAssignedByUser()
+
+ // Determine sort field
+ var field string
+ switch sortBy {
+ case "expires_at":
+ field = usersubscription.FieldExpiresAt
+ case "status":
+ field = usersubscription.FieldStatus
+ default:
+ field = usersubscription.FieldCreatedAt
+ }
+
+ // Determine sort order (default: desc)
+ if sortOrder == "asc" && sortBy != "" {
+ q = q.Order(dbent.Asc(field))
+ } else {
+ q = q.Order(dbent.Desc(field))
+ }
+
subs, err := q.
- WithUser().
- WithGroup().
- WithAssignedByUser().
- Order(dbent.Desc(usersubscription.FieldCreatedAt)).
Offset(params.Offset()).
Limit(params.Limit()).
All(ctx)
diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go
index 2099e5d8..60a5a378 100644
--- a/backend/internal/repository/user_subscription_repo_integration_test.go
+++ b/backend/internal/repository/user_subscription_repo_integration_test.go
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil)
- subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
+ subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
s.Require().NoError(err, "List")
s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total)
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
- subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
+ subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID)
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil)
- subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
+ subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID)
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
- subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
+ subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go
index 929eb22b..8d76f014 100644
--- a/backend/internal/repository/wire.go
+++ b/backend/internal/repository/wire.go
@@ -57,6 +57,8 @@ var ProviderSet = wire.NewSet(
NewProxyRepository,
NewRedeemCodeRepository,
NewPromoCodeRepository,
+ NewAnnouncementRepository,
+ NewAnnouncementReadRepository,
NewUsageLogRepository,
NewUsageCleanupRepository,
NewDashboardAggregationRepository,
@@ -83,6 +85,10 @@ var ProviderSet = wire.NewSet(
NewSchedulerCache,
NewSchedulerOutboxRepository,
NewProxyLatencyCache,
+ NewTotpCache,
+
+ // Encryptors
+ NewAESEncryptor,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 409a7625..73809ee1 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -201,7 +201,7 @@ func TestAPIContracts(t *testing.T) {
UserID: 1,
GroupID: 10,
StartsAt: deps.now,
- ExpiresAt: deps.now.Add(24 * time.Hour),
+ ExpiresAt: time.Date(2099, 1, 2, 3, 4, 5, 0, time.UTC), // 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
Status: service.SubscriptionStatusActive,
DailyUsageUSD: 1.23,
WeeklyUsageUSD: 2.34,
@@ -226,7 +226,7 @@ func TestAPIContracts(t *testing.T) {
"user_id": 1,
"group_id": 10,
"starts_at": "2025-01-02T03:04:05Z",
- "expires_at": "2025-01-03T03:04:05Z",
+ "expires_at": "2099-01-02T03:04:05Z",
"status": "active",
"daily_window_start": null,
"weekly_window_start": null,
@@ -457,6 +457,9 @@ func TestAPIContracts(t *testing.T) {
"registration_enabled": true,
"email_verify_enabled": false,
"promo_code_enabled": true,
+ "password_reset_enabled": false,
+ "totp_enabled": false,
+ "totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
@@ -490,8 +493,11 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": "",
+ "invitation_code_enabled": false,
"home_content": "",
- "hide_ccs_import_button": false
+ "hide_ccs_import_button": false,
+ "purchase_subscription_enabled": false,
+ "purchase_subscription_url": ""
}
}`,
},
@@ -600,7 +606,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
- authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
+ authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
@@ -759,6 +765,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented")
}
+func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
+ return errors.New("not implemented")
+}
+
+func (r *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
+ return errors.New("not implemented")
+}
+
+func (r *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
+ return errors.New("not implemented")
+}
+
type stubApiKeyCache struct{}
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
@@ -868,6 +886,14 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
return 0, errors.New("not implemented")
}
+func (stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
+ return errors.New("not implemented")
+}
+
+func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
+ return nil, errors.New("not implemented")
+}
+
type stubAccountRepo struct {
bulkUpdateIDs []int64
}
@@ -1133,6 +1159,14 @@ func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit
return append([]service.RedeemCode(nil), codes...), nil
}
+func (stubRedeemCodeRepo) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ return nil, nil, errors.New("not implemented")
+}
+
+func (stubRedeemCodeRepo) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
+ return 0, errors.New("not implemented")
+}
+
type stubUserSubscriptionRepo struct {
byUser map[int64][]service.UserSubscription
activeByUser map[int64][]service.UserSubscription
@@ -1185,7 +1219,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
-func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
+func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go
index 84398093..920ff93f 100644
--- a/backend/internal/server/middleware/api_key_auth_test.go
+++ b/backend/internal/server/middleware/api_key_auth_test.go
@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return nil, nil, errors.New("not implemented")
}
-func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
+func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 050e724d..ca9d627e 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -29,6 +29,9 @@ func RegisterAdminRoutes(
// 账号管理
registerAccountRoutes(admin, h)
+ // 公告管理
+ registerAnnouncementRoutes(admin, h)
+
// OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h)
@@ -172,6 +175,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
+ users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
// User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
@@ -229,6 +233,18 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
+func registerAnnouncementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ announcements := admin.Group("/announcements")
+ {
+ announcements.GET("", h.Admin.Announcement.List)
+ announcements.POST("", h.Admin.Announcement.Create)
+ announcements.GET("/:id", h.Admin.Announcement.GetByID)
+ announcements.PUT("/:id", h.Admin.Announcement.Update)
+ announcements.DELETE("/:id", h.Admin.Announcement.Delete)
+ announcements.GET("/:id/read-status", h.Admin.Announcement.ListReadStatus)
+ }
+}
+
func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
openai := admin.Group("/openai")
{
diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go
index aa691eba..24f6d549 100644
--- a/backend/internal/server/routes/auth.go
+++ b/backend/internal/server/routes/auth.go
@@ -26,11 +26,24 @@ func RegisterAuthRoutes(
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
+ auth.POST("/login/2fa", h.Auth.Login2FA)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidatePromoCode)
+ // 邀请码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
+ auth.POST("/validate-invitation-code", rateLimiter.LimitWithOptions("validate-invitation", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }), h.Auth.ValidateInvitationCode)
+ // 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
+ auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }), h.Auth.ForgotPassword)
+ // 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
+ auth.POST("/reset-password", rateLimiter.LimitWithOptions("reset-password", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
}
diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go
index ad2166fe..5581e1e1 100644
--- a/backend/internal/server/routes/user.go
+++ b/backend/internal/server/routes/user.go
@@ -22,6 +22,17 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
+
+ // TOTP 双因素认证
+ totp := user.Group("/totp")
+ {
+ totp.GET("/status", h.Totp.GetStatus)
+ totp.GET("/verification-method", h.Totp.GetVerificationMethod)
+ totp.POST("/send-code", h.Totp.SendVerifyCode)
+ totp.POST("/setup", h.Totp.InitiateSetup)
+ totp.POST("/enable", h.Totp.Enable)
+ totp.POST("/disable", h.Totp.Disable)
+ }
}
// API Key管理
@@ -53,6 +64,13 @@ func RegisterUserRoutes(
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
}
+ // 公告(用户可见)
+ announcements := authenticated.Group("/announcements")
+ {
+ announcements.GET("", h.Announcement.List)
+ announcements.POST("/:id/read", h.Announcement.MarkRead)
+ }
+
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 182e0161..7b958838 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
return ""
}
+func (a *Account) GetClaudeUserID() string {
+ if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
+ return v
+ }
+ if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
+ return v
+ }
+ if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
+ return v
+ }
+ if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
+ return v
+ }
+ return ""
+}
+
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index a76c4d20..acb6eb69 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -124,7 +124,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
"system": []map[string]any{
{
"type": "text",
- "text": "You are Claude Code, Anthropic's official CLI for Claude.",
+ "text": claudeCodeSystemPrompt,
"cache_control": map[string]string{
"type": "ephemeral",
},
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 94b18322..b1b37e11 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -22,6 +22,10 @@ type AdminService interface {
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
+ // GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
+ // codeType is optional - pass empty string to return all types.
+ // Also returns totalRecharged (sum of all positive balance top-ups).
+ GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
@@ -115,6 +119,8 @@ type CreateGroupInput struct {
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由
+ // 从指定分组复制账号(创建分组后在同一事务内绑定)
+ CopyAccountsFromGroupIDs []int64
}
type UpdateGroupInput struct {
@@ -142,6 +148,8 @@ type UpdateGroupInput struct {
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由
+ // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
+ CopyAccountsFromGroupIDs []int64
}
type CreateAccountInput struct {
@@ -535,6 +543,21 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
}, nil
}
+// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
+func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) {
+ params := pagination.PaginationParams{Page: page, PageSize: pageSize}
+ codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType)
+ if err != nil {
+ return nil, 0, 0, err
+ }
+ // Aggregate total recharged amount (only once, regardless of type filter)
+ totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
+ if err != nil {
+ return nil, 0, 0, err
+ }
+ return codes, result.Total, totalRecharged, nil
+}
+
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
@@ -589,6 +612,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
}
}
+ // 如果指定了复制账号的源分组,先获取账号 ID 列表
+ var accountIDsToCopy []int64
+ if len(input.CopyAccountsFromGroupIDs) > 0 {
+ // 去重源分组 IDs
+ seen := make(map[int64]struct{})
+ uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
+ for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
+ if _, exists := seen[srcGroupID]; !exists {
+ seen[srcGroupID] = struct{}{}
+ uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
+ }
+ }
+
+ // 校验源分组的平台是否与新分组一致
+ for _, srcGroupID := range uniqueSourceGroupIDs {
+ srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
+ if err != nil {
+ return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
+ }
+ if srcGroup.Platform != platform {
+ return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform)
+ }
+ }
+
+ // 获取所有源分组的账号(去重)
+ var err error
+ accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
+ }
+ }
+
group := &Group{
Name: input.Name,
Description: input.Description,
@@ -614,6 +669,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
}
+
+ // 如果有需要复制的账号,绑定到新分组
+ if len(accountIDsToCopy) > 0 {
+ if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil {
+ return nil, fmt.Errorf("failed to bind accounts to new group: %w", err)
+ }
+ group.AccountCount = int64(len(accountIDsToCopy))
+ }
+
return group, nil
}
@@ -761,6 +825,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
+
+ // 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
+ if len(input.CopyAccountsFromGroupIDs) > 0 {
+ // 去重源分组 IDs
+ seen := make(map[int64]struct{})
+ uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
+ for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
+ // 校验:源分组不能是自身
+ if srcGroupID == id {
+ return nil, fmt.Errorf("cannot copy accounts from self")
+ }
+ // 去重
+ if _, exists := seen[srcGroupID]; !exists {
+ seen[srcGroupID] = struct{}{}
+ uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
+ }
+ }
+
+ // 校验源分组的平台是否与当前分组一致
+ for _, srcGroupID := range uniqueSourceGroupIDs {
+ srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
+ if err != nil {
+ return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
+ }
+ if srcGroup.Platform != group.Platform {
+ return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform)
+ }
+ }
+
+ // 获取所有源分组的账号(去重)
+ accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
+ }
+
+ // 先清空当前分组的所有账号绑定
+ if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil {
+ return nil, fmt.Errorf("failed to clear existing account bindings: %w", err)
+ }
+
+ // 再绑定源分组的账号
+ if len(accountIDsToCopy) > 0 {
+ if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil {
+ return nil, fmt.Errorf("failed to bind accounts to group: %w", err)
+ }
+ }
+ }
+
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index afa433af..e2aa83d9 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call")
}
+func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
+ panic("unexpected UpdateTotpSecret call")
+}
+
+func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
+ panic("unexpected EnableTotp call")
+}
+
+func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
+ panic("unexpected DisableTotp call")
+}
+
type groupRepoStub struct {
affectedUserIDs []int64
deleteErr error
@@ -152,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
panic("unexpected DeleteAccountGroupsByGroupID call")
}
+func (s *groupRepoStub) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
+ panic("unexpected BindAccountsToGroup call")
+}
+
+func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
+ panic("unexpected GetAccountIDsByGroupIDs call")
+}
+
type proxyRepoStub struct {
deleteErr error
countErr error
@@ -262,6 +282,14 @@ func (s *redeemRepoStub) ListByUser(ctx context.Context, userID int64, limit int
panic("unexpected ListByUser call")
}
+func (s *redeemRepoStub) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListByUserPaginated call")
+}
+
+func (s *redeemRepoStub) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
+ panic("unexpected SumPositiveBalanceByUser call")
+}
+
type subscriptionInvalidateCall struct {
userID int64
groupID int64
diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go
index e0574e2e..1daee89f 100644
--- a/backend/internal/service/admin_service_group_test.go
+++ b/backend/internal/service/admin_service_group_test.go
@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
panic("unexpected DeleteAccountGroupsByGroupID call")
}
+func (s *groupRepoStubForAdmin) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
+ panic("unexpected BindAccountsToGroup call")
+}
+
+func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
+ panic("unexpected GetAccountIDsByGroupIDs call")
+}
+
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{}
@@ -378,3 +386,11 @@ func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
+
+func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
+ panic("unexpected BindAccountsToGroup call")
+}
+
+func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
+ panic("unexpected GetAccountIDsByGroupIDs call")
+}
diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go
index 7506c6db..d661b710 100644
--- a/backend/internal/service/admin_service_search_test.go
+++ b/backend/internal/service/admin_service_search_test.go
@@ -152,6 +152,14 @@ func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params p
return s.listWithFiltersCodes, result, nil
}
+func (s *redeemRepoStubForAdminList) ListByUserPaginated(_ context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListByUserPaginated call")
+}
+
+func (s *redeemRepoStubForAdminList) SumPositiveBalanceByUser(_ context.Context, userID int64) (float64, error) {
+ panic("unexpected SumPositiveBalanceByUser call")
+}
+
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &accountRepoStubForAdminList{
diff --git a/backend/internal/service/announcement.go b/backend/internal/service/announcement.go
new file mode 100644
index 00000000..2ba5af5d
--- /dev/null
+++ b/backend/internal/service/announcement.go
@@ -0,0 +1,64 @@
+package service
+
+import (
+ "context"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/domain"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+)
+
+const (
+ AnnouncementStatusDraft = domain.AnnouncementStatusDraft
+ AnnouncementStatusActive = domain.AnnouncementStatusActive
+ AnnouncementStatusArchived = domain.AnnouncementStatusArchived
+)
+
+const (
+ AnnouncementConditionTypeSubscription = domain.AnnouncementConditionTypeSubscription
+ AnnouncementConditionTypeBalance = domain.AnnouncementConditionTypeBalance
+)
+
+const (
+ AnnouncementOperatorIn = domain.AnnouncementOperatorIn
+ AnnouncementOperatorGT = domain.AnnouncementOperatorGT
+ AnnouncementOperatorGTE = domain.AnnouncementOperatorGTE
+ AnnouncementOperatorLT = domain.AnnouncementOperatorLT
+ AnnouncementOperatorLTE = domain.AnnouncementOperatorLTE
+ AnnouncementOperatorEQ = domain.AnnouncementOperatorEQ
+)
+
+var (
+ ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
+ ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
+)
+
+type AnnouncementTargeting = domain.AnnouncementTargeting
+
+type AnnouncementConditionGroup = domain.AnnouncementConditionGroup
+
+type AnnouncementCondition = domain.AnnouncementCondition
+
+type Announcement = domain.Announcement
+
+type AnnouncementListFilters struct {
+ Status string
+ Search string
+}
+
+type AnnouncementRepository interface {
+ Create(ctx context.Context, a *Announcement) error
+ GetByID(ctx context.Context, id int64) (*Announcement, error)
+ Update(ctx context.Context, a *Announcement) error
+ Delete(ctx context.Context, id int64) error
+
+ List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error)
+ ListActive(ctx context.Context, now time.Time) ([]Announcement, error)
+}
+
+type AnnouncementReadRepository interface {
+ MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error
+ GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error)
+ GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error)
+ CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error)
+}
diff --git a/backend/internal/service/announcement_service.go b/backend/internal/service/announcement_service.go
new file mode 100644
index 00000000..c2588e6c
--- /dev/null
+++ b/backend/internal/service/announcement_service.go
@@ -0,0 +1,378 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "sort"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/domain"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+)
+
+type AnnouncementService struct {
+ announcementRepo AnnouncementRepository
+ readRepo AnnouncementReadRepository
+ userRepo UserRepository
+ userSubRepo UserSubscriptionRepository
+}
+
+func NewAnnouncementService(
+ announcementRepo AnnouncementRepository,
+ readRepo AnnouncementReadRepository,
+ userRepo UserRepository,
+ userSubRepo UserSubscriptionRepository,
+) *AnnouncementService {
+ return &AnnouncementService{
+ announcementRepo: announcementRepo,
+ readRepo: readRepo,
+ userRepo: userRepo,
+ userSubRepo: userSubRepo,
+ }
+}
+
+type CreateAnnouncementInput struct {
+ Title string
+ Content string
+ Status string
+ Targeting AnnouncementTargeting
+ StartsAt *time.Time
+ EndsAt *time.Time
+ ActorID *int64 // 管理员用户ID
+}
+
+type UpdateAnnouncementInput struct {
+ Title *string
+ Content *string
+ Status *string
+ Targeting *AnnouncementTargeting
+ StartsAt **time.Time
+ EndsAt **time.Time
+ ActorID *int64 // 管理员用户ID
+}
+
+type UserAnnouncement struct {
+ Announcement Announcement
+ ReadAt *time.Time
+}
+
+type AnnouncementUserReadStatus struct {
+ UserID int64 `json:"user_id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+ Balance float64 `json:"balance"`
+ Eligible bool `json:"eligible"`
+ ReadAt *time.Time `json:"read_at,omitempty"`
+}
+
+func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) {
+ if input == nil {
+ return nil, fmt.Errorf("create announcement: nil input")
+ }
+
+ title := strings.TrimSpace(input.Title)
+ content := strings.TrimSpace(input.Content)
+ if title == "" || len(title) > 200 {
+ return nil, fmt.Errorf("create announcement: invalid title")
+ }
+ if content == "" {
+ return nil, fmt.Errorf("create announcement: content is required")
+ }
+
+ status := strings.TrimSpace(input.Status)
+ if status == "" {
+ status = AnnouncementStatusDraft
+ }
+ if !isValidAnnouncementStatus(status) {
+ return nil, fmt.Errorf("create announcement: invalid status")
+ }
+
+ targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate()
+ if err != nil {
+ return nil, err
+ }
+
+ if input.StartsAt != nil && input.EndsAt != nil {
+ if !input.StartsAt.Before(*input.EndsAt) {
+ return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
+ }
+ }
+
+ a := &Announcement{
+ Title: title,
+ Content: content,
+ Status: status,
+ Targeting: targeting,
+ StartsAt: input.StartsAt,
+ EndsAt: input.EndsAt,
+ }
+ if input.ActorID != nil && *input.ActorID > 0 {
+ a.CreatedBy = input.ActorID
+ a.UpdatedBy = input.ActorID
+ }
+
+ if err := s.announcementRepo.Create(ctx, a); err != nil {
+ return nil, fmt.Errorf("create announcement: %w", err)
+ }
+ return a, nil
+}
+
+func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) {
+ if input == nil {
+ return nil, fmt.Errorf("update announcement: nil input")
+ }
+
+ a, err := s.announcementRepo.GetByID(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+
+ if input.Title != nil {
+ title := strings.TrimSpace(*input.Title)
+ if title == "" || len(title) > 200 {
+ return nil, fmt.Errorf("update announcement: invalid title")
+ }
+ a.Title = title
+ }
+ if input.Content != nil {
+ content := strings.TrimSpace(*input.Content)
+ if content == "" {
+ return nil, fmt.Errorf("update announcement: content is required")
+ }
+ a.Content = content
+ }
+ if input.Status != nil {
+ status := strings.TrimSpace(*input.Status)
+ if !isValidAnnouncementStatus(status) {
+ return nil, fmt.Errorf("update announcement: invalid status")
+ }
+ a.Status = status
+ }
+
+ if input.Targeting != nil {
+ targeting, err := domain.AnnouncementTargeting(*input.Targeting).NormalizeAndValidate()
+ if err != nil {
+ return nil, err
+ }
+ a.Targeting = targeting
+ }
+
+ if input.StartsAt != nil {
+ a.StartsAt = *input.StartsAt
+ }
+ if input.EndsAt != nil {
+ a.EndsAt = *input.EndsAt
+ }
+
+ if a.StartsAt != nil && a.EndsAt != nil {
+ if !a.StartsAt.Before(*a.EndsAt) {
+ return nil, fmt.Errorf("update announcement: starts_at must be before ends_at")
+ }
+ }
+
+ if input.ActorID != nil && *input.ActorID > 0 {
+ a.UpdatedBy = input.ActorID
+ }
+
+ if err := s.announcementRepo.Update(ctx, a); err != nil {
+ return nil, fmt.Errorf("update announcement: %w", err)
+ }
+ return a, nil
+}
+
+func (s *AnnouncementService) Delete(ctx context.Context, id int64) error {
+ if err := s.announcementRepo.Delete(ctx, id); err != nil {
+ return fmt.Errorf("delete announcement: %w", err)
+ }
+ return nil
+}
+
+func (s *AnnouncementService) GetByID(ctx context.Context, id int64) (*Announcement, error) {
+ return s.announcementRepo.GetByID(ctx, id)
+}
+
+func (s *AnnouncementService) List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) {
+ return s.announcementRepo.List(ctx, params, filters)
+}
+
+func (s *AnnouncementService) ListForUser(ctx context.Context, userID int64, unreadOnly bool) ([]UserAnnouncement, error) {
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return nil, fmt.Errorf("get user: %w", err)
+ }
+
+ activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
+ if err != nil {
+ return nil, fmt.Errorf("list active subscriptions: %w", err)
+ }
+ activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
+ for i := range activeSubs {
+ activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
+ }
+
+ now := time.Now()
+ anns, err := s.announcementRepo.ListActive(ctx, now)
+ if err != nil {
+ return nil, fmt.Errorf("list active announcements: %w", err)
+ }
+
+ visible := make([]Announcement, 0, len(anns))
+ ids := make([]int64, 0, len(anns))
+ for i := range anns {
+ a := anns[i]
+ if !a.IsActiveAt(now) {
+ continue
+ }
+ if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
+ continue
+ }
+ visible = append(visible, a)
+ ids = append(ids, a.ID)
+ }
+
+ if len(visible) == 0 {
+ return []UserAnnouncement{}, nil
+ }
+
+ readMap, err := s.readRepo.GetReadMapByUser(ctx, userID, ids)
+ if err != nil {
+ return nil, fmt.Errorf("get read map: %w", err)
+ }
+
+ out := make([]UserAnnouncement, 0, len(visible))
+ for i := range visible {
+ a := visible[i]
+ readAt, ok := readMap[a.ID]
+ if unreadOnly && ok {
+ continue
+ }
+ var ptr *time.Time
+ if ok {
+ t := readAt
+ ptr = &t
+ }
+ out = append(out, UserAnnouncement{
+ Announcement: a,
+ ReadAt: ptr,
+ })
+ }
+
+ // 未读优先、同状态按创建时间倒序
+ sort.Slice(out, func(i, j int) bool {
+ ai, aj := out[i], out[j]
+ if (ai.ReadAt == nil) != (aj.ReadAt == nil) {
+ return ai.ReadAt == nil
+ }
+ return ai.Announcement.ID > aj.Announcement.ID
+ })
+
+ return out, nil
+}
+
+func (s *AnnouncementService) MarkRead(ctx context.Context, userID, announcementID int64) error {
+ // 安全:仅允许标记当前用户“可见”的公告
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return fmt.Errorf("get user: %w", err)
+ }
+
+ a, err := s.announcementRepo.GetByID(ctx, announcementID)
+ if err != nil {
+ return err
+ }
+
+ now := time.Now()
+ if !a.IsActiveAt(now) {
+ return ErrAnnouncementNotFound
+ }
+
+ activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
+ if err != nil {
+ return fmt.Errorf("list active subscriptions: %w", err)
+ }
+ activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
+ for i := range activeSubs {
+ activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
+ }
+
+ if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
+ return ErrAnnouncementNotFound
+ }
+
+ if err := s.readRepo.MarkRead(ctx, announcementID, userID, now); err != nil {
+ return fmt.Errorf("mark read: %w", err)
+ }
+ return nil
+}
+
+func (s *AnnouncementService) ListUserReadStatus(
+ ctx context.Context,
+ announcementID int64,
+ params pagination.PaginationParams,
+ search string,
+) ([]AnnouncementUserReadStatus, *pagination.PaginationResult, error) {
+ ann, err := s.announcementRepo.GetByID(ctx, announcementID)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ filters := UserListFilters{
+ Search: strings.TrimSpace(search),
+ }
+
+ users, page, err := s.userRepo.ListWithFilters(ctx, params, filters)
+ if err != nil {
+ return nil, nil, fmt.Errorf("list users: %w", err)
+ }
+
+ userIDs := make([]int64, 0, len(users))
+ for i := range users {
+ userIDs = append(userIDs, users[i].ID)
+ }
+
+ readMap, err := s.readRepo.GetReadMapByUsers(ctx, announcementID, userIDs)
+ if err != nil {
+ return nil, nil, fmt.Errorf("get read map: %w", err)
+ }
+
+ out := make([]AnnouncementUserReadStatus, 0, len(users))
+ for i := range users {
+ u := users[i]
+ subs, err := s.userSubRepo.ListActiveByUserID(ctx, u.ID)
+ if err != nil {
+ return nil, nil, fmt.Errorf("list active subscriptions: %w", err)
+ }
+ activeGroupIDs := make(map[int64]struct{}, len(subs))
+ for j := range subs {
+ activeGroupIDs[subs[j].GroupID] = struct{}{}
+ }
+
+ readAt, ok := readMap[u.ID]
+ var ptr *time.Time
+ if ok {
+ t := readAt
+ ptr = &t
+ }
+
+ out = append(out, AnnouncementUserReadStatus{
+ UserID: u.ID,
+ Email: u.Email,
+ Username: u.Username,
+ Balance: u.Balance,
+ Eligible: domain.AnnouncementTargeting(ann.Targeting).Matches(u.Balance, activeGroupIDs),
+ ReadAt: ptr,
+ })
+ }
+
+ return out, page, nil
+}
+
+func isValidAnnouncementStatus(status string) bool {
+ switch status {
+ case AnnouncementStatusDraft, AnnouncementStatusActive, AnnouncementStatusArchived:
+ return true
+ default:
+ return false
+ }
+}
diff --git a/backend/internal/service/announcement_targeting_test.go b/backend/internal/service/announcement_targeting_test.go
new file mode 100644
index 00000000..4d904c7d
--- /dev/null
+++ b/backend/internal/service/announcement_targeting_test.go
@@ -0,0 +1,66 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAnnouncementTargeting_Matches_EmptyMatchesAll(t *testing.T) {
+ var targeting AnnouncementTargeting
+ require.True(t, targeting.Matches(0, nil))
+ require.True(t, targeting.Matches(123.45, map[int64]struct{}{1: {}}))
+}
+
+func TestAnnouncementTargeting_NormalizeAndValidate_RejectsEmptyGroup(t *testing.T) {
+ targeting := AnnouncementTargeting{
+ AnyOf: []AnnouncementConditionGroup{
+ {AllOf: nil},
+ },
+ }
+ _, err := targeting.NormalizeAndValidate()
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
+}
+
+func TestAnnouncementTargeting_NormalizeAndValidate_RejectsInvalidCondition(t *testing.T) {
+ targeting := AnnouncementTargeting{
+ AnyOf: []AnnouncementConditionGroup{
+ {
+ AllOf: []AnnouncementCondition{
+ {Type: "balance", Operator: "between", Value: 10},
+ },
+ },
+ },
+ }
+ _, err := targeting.NormalizeAndValidate()
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
+}
+
+func TestAnnouncementTargeting_Matches_AndOrSemantics(t *testing.T) {
+ targeting := AnnouncementTargeting{
+ AnyOf: []AnnouncementConditionGroup{
+ {
+ AllOf: []AnnouncementCondition{
+ {Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorGTE, Value: 100},
+ {Type: AnnouncementConditionTypeSubscription, Operator: AnnouncementOperatorIn, GroupIDs: []int64{10}},
+ },
+ },
+ {
+ AllOf: []AnnouncementCondition{
+ {Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorLT, Value: 5},
+ },
+ },
+ },
+ }
+
+ // 命中第 2 组(balance < 5)
+ require.True(t, targeting.Matches(4.99, nil))
+ require.False(t, targeting.Matches(5, nil))
+
+ // 命中第 1 组(balance >= 100 AND 订阅 in [10])
+ require.False(t, targeting.Matches(100, map[int64]struct{}{}))
+ require.False(t, targeting.Matches(99.9, map[int64]struct{}{10: {}}))
+ require.True(t, targeting.Matches(100, map[int64]struct{}{10: {}}))
+}
diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go
index 3b847bcb..9b8156e6 100644
--- a/backend/internal/service/antigravity_gateway_service.go
+++ b/backend/internal/service/antigravity_gateway_service.go
@@ -273,13 +273,11 @@ func logPrefix(sessionID, accountName string) string {
}
// Antigravity 直接支持的模型(精确匹配透传)
+// 注意:gemini-2.5 系列已移除,统一映射到 gemini-3 系列
var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true,
"claude-sonnet-4-5-thinking": true,
- "gemini-2.5-flash": true,
- "gemini-2.5-flash-lite": true,
- "gemini-2.5-flash-thinking": true,
"gemini-3-flash": true,
"gemini-3-pro-low": true,
"gemini-3-pro-high": true,
@@ -288,23 +286,32 @@ var antigravitySupportedModels = map[string]bool{
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
+// gemini-2.5 系列统一映射到 gemini-3 系列(Antigravity 上游不再支持 2.5)
var antigravityPrefixMapping = []struct {
prefix string
target string
}{
- // 长前缀优先
- {"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image
- {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
- {"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
- {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
- {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
- {"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
+ // gemini-2.5 → gemini-3 映射(长前缀优先)
+ {"gemini-2.5-flash-thinking", "gemini-3-flash"}, // gemini-2.5-flash-thinking → gemini-3-flash
+ {"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → gemini-3-pro-image
+ {"gemini-2.5-flash-lite", "gemini-3-flash"}, // gemini-2.5-flash-lite → gemini-3-flash
+ {"gemini-2.5-flash", "gemini-3-flash"}, // gemini-2.5-flash → gemini-3-flash
+ {"gemini-2.5-pro-preview", "gemini-3-pro-high"}, // gemini-2.5-pro-preview → gemini-3-pro-high
+ {"gemini-2.5-pro-exp", "gemini-3-pro-high"}, // gemini-2.5-pro-exp → gemini-3-pro-high
+ {"gemini-2.5-pro", "gemini-3-pro-high"}, // gemini-2.5-pro → gemini-3-pro-high
+ // gemini-3 前缀映射
+ {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
+ {"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
+ {"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
+ // Claude 映射
+ {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
+ {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
+ {"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
{"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet
{"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet
{"claude-opus-4", "claude-opus-4-5-thinking"},
- {"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
@@ -1530,7 +1537,11 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
func antigravityUseScopeRateLimit() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv)))
- return v == "1" || v == "true" || v == "yes" || v == "on"
+ // 默认开启按配额域限流,只有明确设置为禁用值时才关闭
+ if v == "0" || v == "false" || v == "no" || v == "off" {
+ return false
+ }
+ return true
}
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go
index 179a3520..e269103a 100644
--- a/backend/internal/service/antigravity_model_mapping_test.go
+++ b/backend/internal/service/antigravity_model_mapping_test.go
@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "claude-sonnet-4-5",
},
- // 3. Gemini 透传
+ // 3. Gemini 2.5 → 3 映射
{
- name: "Gemini透传 - gemini-2.5-flash",
+ name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash",
requestedModel: "gemini-2.5-flash",
accountMapping: nil,
- expected: "gemini-2.5-flash",
+ expected: "gemini-3-flash",
},
{
- name: "Gemini透传 - gemini-2.5-pro",
+ name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high",
requestedModel: "gemini-2.5-pro",
accountMapping: nil,
- expected: "gemini-2.5-pro",
+ expected: "gemini-3-pro-high",
},
{
name: "Gemini透传 - gemini-future-model",
diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go
index 52293cd5..fa8379ed 100644
--- a/backend/internal/service/antigravity_oauth_service.go
+++ b/backend/internal/service/antigravity_oauth_service.go
@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.Email = userInfo.Email
}
- // 获取 project_id(部分账户类型可能没有)
- loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
- if err != nil {
- fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
- } else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
- result.ProjectID = loadResp.CloudAICompanionProject
+ // 获取 project_id(部分账户类型可能没有),失败时重试
+ projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
+ if loadErr != nil {
+ fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
+ result.ProjectIDMissing = true
+ } else {
+ result.ProjectID = projectID
}
return result, nil
@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
tokenInfo.Email = existingEmail
}
- // 每次刷新都调用 LoadCodeAssist 获取 project_id
- client := antigravity.NewClient(proxyURL)
- loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
- if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
- // LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
- existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
+ // 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
+ existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
+ projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
+
+ if loadErr != nil {
+ // LoadCodeAssist 失败,保留原有 project_id
tokenInfo.ProjectID = existingProjectID
- tokenInfo.ProjectIDMissing = true
+ // 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
+ // 如果之前有 project_id,本次只是临时故障,不应标记为错误
+ if existingProjectID == "" {
+ tokenInfo.ProjectIDMissing = true
+ }
} else {
- tokenInfo.ProjectID = loadResp.CloudAICompanionProject
+ tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
}
+// loadProjectIDWithRetry 带重试机制获取 project_id
+// 返回 project_id 和错误,失败时会重试指定次数
+func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
+ var lastErr error
+
+ for attempt := 0; attempt <= maxRetries; attempt++ {
+ if attempt > 0 {
+ // 指数退避:1s, 2s, 4s
+ backoff := time.Duration(1< 密码重置请求 您已请求重置密码。请点击下方按钮设置新密码: 此链接将在 30 分钟后失效。 如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。 如果按钮无法点击,请复制以下链接到浏览器中打开: %s
+ {{ t('profile.totp.loginHint') }}
+
+ {{ userEmailMasked }}
+
+ {{ unreadCount }}
+ {{ t('announcements.unread') }}
+ {{ t('announcements.empty') }} {{ t('announcements.emptyDescription') }}
+ {{ t('admin.groups.copyAccounts.tooltip') }}
+ {{ t('admin.groups.copyAccounts.hint') }} {{ t('admin.groups.platformNotEditable') }}
+ {{ t('admin.groups.copyAccounts.tooltipEdit') }}
+ {{ t('admin.groups.copyAccounts.hintEdit') }}
+ {{ t('admin.redeem.invitationHint') }}
+
- {{ errors.password }}
-
+ {{ errors.password }}
+ %s
+
+ {{ t('profile.totp.loginTitle') }}
+
+
+ {{ t('announcements.title') }}
+
+
+ {{ item.title }}
+
+
+ {{ selectedAnnouncement.title }}
+
+
+
+
+ {{ t("setup.redis.enableTls") }} +
++ {{ t("setup.redis.enableTlsHint") }} +
+