From 1181b332f7800ad80814e22fb91e3120517d6e6e Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Tue, 6 Jan 2026 15:46:36 +0800 Subject: [PATCH 01/46] =?UTF-8?q?fix(=E5=89=8D=E7=AB=AF):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E7=BC=96=E8=BE=91=E8=B4=A6=E5=8F=B7=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=E6=97=A0=E6=B3=95=E6=98=BE=E7=A4=BA=E5=85=B7?= =?UTF-8?q?=E4=BD=93=E5=8E=9F=E5=9B=A0=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 后端 API 返回 message 字段,但前端读取 detail 字段,导致无法显示具体错误信息。 现在优先读取 message 字段,兼容 detail 字段。 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- frontend/src/components/account/EditAccountModal.vue | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 4ac149f2..3f47ee31 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1140,7 +1140,7 @@ const handleSubmit = async () => { emit('updated') handleClose() } catch (error: any) { - appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToUpdate')) + appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate')) } finally { submitting.value = false } From 5a52cb608cc0de8a21066670e210ce39caac85dd Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Tue, 6 Jan 2026 19:20:05 +0800 Subject: [PATCH 02/46] =?UTF-8?q?fix(=E5=89=8D=E7=AB=AF):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E8=B4=A6=E5=8F=B7=E7=AE=A1=E7=90=86=E9=A1=B5=E9=9D=A2?= =?UTF-8?q?=E5=B9=B3=E5=8F=B0=E8=BF=87=E6=BB=A4=E4=B8=8D=E7=94=9F=E6=95=88?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加 @update:filters 事件监听,使过滤器参数能正确同步到数据请求中。 修复了平台、类型、状态三个过滤器全部失效的问题。 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 --- frontend/src/views/admin/AccountsView.vue | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index eb73a5ca..c95b89f3 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -6,6 +6,7 @@ From 66fe484f0df3905c14907e3e1bb6c2e7bf91dad8 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Tue, 6 Jan 2026 20:26:32 +0800 Subject: [PATCH 03/46] =?UTF-8?q?chore:=20=E5=88=A0=E9=99=A4=E4=BE=9D?= =?UTF-8?q?=E8=B5=96=E5=AE=89=E5=85=A8=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/dependency-security.md | 58 ------------------------------------- 1 file changed, 58 deletions(-) delete mode 100644 docs/dependency-security.md diff --git a/docs/dependency-security.md b/docs/dependency-security.md deleted file mode 100644 index 66545011..00000000 --- a/docs/dependency-security.md +++ /dev/null @@ -1,58 +0,0 @@ -# Dependency Security - -This document describes how dependency and toolchain security is managed in this repo. - -## Go Toolchain Policy (Pinned to 1.25.5) - -The Go toolchain is pinned to 1.25.5 to address known security issues. - -Locations that MUST stay aligned: -- `backend/go.mod`: `go 1.25.5` and `toolchain go1.25.5` -- `Dockerfile`: `GOLANG_IMAGE=golang:1.25.5-alpine` -- Workflows: use `go-version-file: backend/go.mod` and verify `go1.25.5` - -Update process: -1. Change `backend/go.mod` (go + toolchain) to the new patch version. -2. Update `Dockerfile` GOLANG_IMAGE to the same patch version. -3. Update workflows if needed and keep the `go version` check in place. -4. Run `govulncheck` and the CI security scan workflow. - -## Security Scans - -Automated scans run via `.github/workflows/security-scan.yml`: -- `govulncheck` for Go dependencies -- `gosec` for static security issues -- `pnpm audit` for frontend production dependencies - -Policy: -- High/Critical findings fail the build unless explicitly exempted. -- Exemptions must include mitigation and an expiry date. - -## Audit Exceptions - -Exception list location: `.github/audit-exceptions.yml` - -Required fields: -- `package` -- `advisory` (GHSA ID or advisory URL from pnpm audit) -- `severity` -- `mitigation` -- `expires_on` (recommended <= 90 days) - -Process: -1. Add an exception with mitigation details and an expiry date. -2. Ensure the exception is reviewed before expiry. -3. Remove the exception when the dependency is upgraded or replaced. - -## Frontend xlsx Mitigation (Plan A) - -Current mitigation: -- Use dynamic import so `xlsx` only loads during export. -- Keep export access restricted and data scope limited. - -## Rollback Guidance - -If a change causes issues: -- Go: revert `backend/go.mod` and `Dockerfile` to the previous version. -- Frontend: revert the dynamic import change if needed. -- CI: remove exception entries and re-run scans to confirm status. From 823497a2afe7ebe1a1418b00657145b483d5682a Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Tue, 6 Jan 2026 20:31:40 +0800 Subject: [PATCH 04/46] =?UTF-8?q?fix(=E5=B9=B6=E5=8F=91):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=20wrapReleaseOnDone=20goroutine=20=E6=B3=84=E9=9C=B2?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 问题描述: - wrapReleaseOnDone 函数创建的 goroutine 会持续等待 ctx.Done() - 即使 release() 已被调用,goroutine 仍不会退出 - 高并发场景下(1000 req/s)会产生 3000+ 个泄露 goroutine 修复方案: - 添加 quit channel 作为退出信号 - 正常释放时 close(quit) 通知 goroutine 立即退出 - 使用 select 监听 ctx.Done() 和 quit 两个信号 - 确保 goroutine 在正常流程中及时退出 测试覆盖: - 新增 5 个单元测试验证修复效果 - 验证 goroutine 不泄露 - 验证并发安全性和多次调用保护 - 性能影响:471.9 ns/op, 208 B/op 影响范围: - gateway_handler.go: 每请求调用 2-4 次 - openai_gateway_handler.go: 每请求调用 2-3 次 - 修复后 goroutine 泄露数量从 3/req 降至 0 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 --- backend/internal/handler/gateway_helper.go | 24 ++- .../internal/handler/gateway_helper_test.go | 141 ++++++++++++++++++ 2 files changed, 160 insertions(+), 5 deletions(-) create mode 100644 backend/internal/handler/gateway_helper_test.go diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 2eb3ac72..5de519c7 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -83,19 +83,33 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo // wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation. // 用于避免客户端断开或上游超时导致的并发槽位泄漏。 +// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露 func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() { if releaseFunc == nil { return nil } var once sync.Once - wrapped := func() { - once.Do(releaseFunc) + quit := make(chan struct{}) + + release := func() { + once.Do(func() { + releaseFunc() + close(quit) // 通知监听 goroutine 退出 + }) } + go func() { - <-ctx.Done() - wrapped() + select { + case <-ctx.Done(): + // Context 取消时释放资源 + release() + case <-quit: + // 正常释放已完成,goroutine 退出 + return + } }() - return wrapped + + return release } // IncrementWaitCount increments the wait count for a user diff --git a/backend/internal/handler/gateway_helper_test.go b/backend/internal/handler/gateway_helper_test.go new file mode 100644 index 00000000..664258f8 --- /dev/null +++ b/backend/internal/handler/gateway_helper_test.go @@ -0,0 +1,141 @@ +package handler + +import ( + "context" + "runtime" + "sync/atomic" + "testing" + "time" +) + +// TestWrapReleaseOnDone_NoGoroutineLeak 验证 wrapReleaseOnDone 修复后不会泄露 goroutine +func TestWrapReleaseOnDone_NoGoroutineLeak(t *testing.T) { + // 记录测试开始时的 goroutine 数量 + runtime.GC() + time.Sleep(100 * time.Millisecond) + initialGoroutines := runtime.NumGoroutine() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var releaseCount int32 + release := wrapReleaseOnDone(ctx, func() { + atomic.AddInt32(&releaseCount, 1) + }) + + // 正常释放 + release() + + // 等待足够时间确保 goroutine 退出 + time.Sleep(200 * time.Millisecond) + + // 验证只释放一次 + if count := atomic.LoadInt32(&releaseCount); count != 1 { + t.Errorf("expected release count to be 1, got %d", count) + } + + // 强制 GC,清理已退出的 goroutine + runtime.GC() + time.Sleep(100 * time.Millisecond) + + // 验证 goroutine 数量没有增加(允许±2的误差,考虑到测试框架本身可能创建的 goroutine) + finalGoroutines := runtime.NumGoroutine() + if finalGoroutines > initialGoroutines+2 { + t.Errorf("goroutine leak detected: initial=%d, final=%d, leaked=%d", + initialGoroutines, finalGoroutines, finalGoroutines-initialGoroutines) + } +} + +// TestWrapReleaseOnDone_ContextCancellation 验证 context 取消时也能正确释放 +func TestWrapReleaseOnDone_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + var releaseCount int32 + _ = wrapReleaseOnDone(ctx, func() { + atomic.AddInt32(&releaseCount, 1) + }) + + // 取消 context,应该触发释放 + cancel() + + // 等待释放完成 + time.Sleep(100 * time.Millisecond) + + // 验证释放被调用 + if count := atomic.LoadInt32(&releaseCount); count != 1 { + t.Errorf("expected release count to be 1, got %d", count) + } +} + +// TestWrapReleaseOnDone_MultipleCallsOnlyReleaseOnce 验证多次调用 release 只释放一次 +func TestWrapReleaseOnDone_MultipleCallsOnlyReleaseOnce(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var releaseCount int32 + release := wrapReleaseOnDone(ctx, func() { + atomic.AddInt32(&releaseCount, 1) + }) + + // 调用多次 + release() + release() + release() + + // 等待执行完成 + time.Sleep(100 * time.Millisecond) + + // 验证只释放一次 + if count := atomic.LoadInt32(&releaseCount); count != 1 { + t.Errorf("expected release count to be 1, got %d", count) + } +} + +// TestWrapReleaseOnDone_NilReleaseFunc 验证 nil releaseFunc 不会 panic +func TestWrapReleaseOnDone_NilReleaseFunc(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + release := wrapReleaseOnDone(ctx, nil) + + if release != nil { + t.Error("expected nil release function when releaseFunc is nil") + } +} + +// TestWrapReleaseOnDone_ConcurrentCalls 验证并发调用的安全性 +func TestWrapReleaseOnDone_ConcurrentCalls(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var releaseCount int32 + release := wrapReleaseOnDone(ctx, func() { + atomic.AddInt32(&releaseCount, 1) + }) + + // 并发调用 release + const numGoroutines = 10 + for i := 0; i < numGoroutines; i++ { + go release() + } + + // 等待所有 goroutine 完成 + time.Sleep(200 * time.Millisecond) + + // 验证只释放一次 + if count := atomic.LoadInt32(&releaseCount); count != 1 { + t.Errorf("expected release count to be 1, got %d", count) + } +} + +// BenchmarkWrapReleaseOnDone 性能基准测试 +func BenchmarkWrapReleaseOnDone(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + release := wrapReleaseOnDone(ctx, func() {}) + release() + } +} From 015974a27e76240abef8bbf5ea17b272703ab1ee Mon Sep 17 00:00:00 2001 From: shaw Date: Tue, 6 Jan 2026 22:19:07 +0800 Subject: [PATCH 05/46] =?UTF-8?q?feat(admin/usage):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E5=91=98=E7=94=A8=E9=87=8F=E9=A1=B5=E9=9D=A2?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=92=8C=E4=BD=93=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 后端改进: - 新增 GetStatsWithFilters 方法支持完整筛选条件 - Stats 端点支持 account_id, group_id, model, stream, billing_type 参数 - 统一使用 filters 结构体,移除冗余的分支逻辑 前端改进: - 统计卡片添加"所选范围内"文字提示 - 优化总消费显示格式,清晰展示实际费用和标准计费 - Token 和费用列添加问号图标 tooltip 显示详细信息 - API Key 搜索框体验优化:点击即显示下拉选项 - 选择用户后自动加载该用户的所有 API Key --- .../internal/handler/admin/usage_handler.go | 79 +++++--- backend/internal/repository/usage_log_repo.go | 75 ++++++++ .../internal/service/account_usage_service.go | 1 + backend/internal/service/usage_service.go | 9 + frontend/src/api/admin/usage.ts | 8 +- .../components/admin/usage/UsageFilters.vue | 28 ++- .../admin/usage/UsageStatsCards.vue | 23 ++- .../src/components/admin/usage/UsageTable.vue | 170 ++++++++++++++++-- 8 files changed, 341 insertions(+), 52 deletions(-) diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index 9d14afd2..ad336b3e 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -152,8 +152,8 @@ func (h *UsageHandler) List(c *gin.Context) { // Stats handles getting usage statistics with filters // GET /api/v1/admin/usage/stats func (h *UsageHandler) Stats(c *gin.Context) { - // Parse filters - var userID, apiKeyID int64 + // Parse filters - same as List endpoint + var userID, apiKeyID, accountID, groupID int64 if userIDStr := c.Query("user_id"); userIDStr != "" { id, err := strconv.ParseInt(userIDStr, 10, 64) if err != nil { @@ -172,8 +172,49 @@ func (h *UsageHandler) Stats(c *gin.Context) { apiKeyID = id } + if accountIDStr := c.Query("account_id"); accountIDStr != "" { + id, err := strconv.ParseInt(accountIDStr, 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account_id") + return + } + accountID = id + } + + if groupIDStr := c.Query("group_id"); groupIDStr != "" { + id, err := strconv.ParseInt(groupIDStr, 10, 64) + if err != nil { + response.BadRequest(c, "Invalid group_id") + return + } + groupID = id + } + + model := c.Query("model") + + var stream *bool + if streamStr := c.Query("stream"); streamStr != "" { + val, err := strconv.ParseBool(streamStr) + if err != nil { + response.BadRequest(c, "Invalid stream value, use true or false") + return + } + stream = &val + } + + var billingType *int8 + if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { + val, err := strconv.ParseInt(billingTypeStr, 10, 8) + if err != nil { + response.BadRequest(c, "Invalid billing_type") + return + } + bt := int8(val) + billingType = &bt + } + // Parse date range - userTZ := c.Query("timezone") // Get user's timezone from request + userTZ := c.Query("timezone") now := timezone.NowInUserLocation(userTZ) var startTime, endTime time.Time @@ -208,28 +249,20 @@ func (h *UsageHandler) Stats(c *gin.Context) { endTime = now } - if apiKeyID > 0 { - stats, err := h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, stats) - return + // Build filters and call GetStatsWithFilters + filters := usagestats.UsageLogFilters{ + UserID: userID, + APIKeyID: apiKeyID, + AccountID: accountID, + GroupID: groupID, + Model: model, + Stream: stream, + BillingType: billingType, + StartTime: &startTime, + EndTime: &endTime, } - if userID > 0 { - stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, stats) - return - } - - // Get global stats - stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime) + stats, err := h.usageService.GetStatsWithFilters(c.Request.Context(), filters) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 82d5e833..4df10b23 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -1388,6 +1388,81 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT return stats, nil } +// GetStatsWithFilters gets usage statistics with optional filters +func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters UsageLogFilters) (*UsageStats, error) { + conditions := make([]string, 0, 9) + args := make([]any, 0, 9) + + if filters.UserID > 0 { + conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1)) + args = append(args, filters.UserID) + } + if filters.APIKeyID > 0 { + conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1)) + args = append(args, filters.APIKeyID) + } + if filters.AccountID > 0 { + conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1)) + args = append(args, filters.AccountID) + } + if filters.GroupID > 0 { + conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1)) + args = append(args, filters.GroupID) + } + if filters.Model != "" { + conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) + args = append(args, filters.Model) + } + if filters.Stream != nil { + conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1)) + args = append(args, *filters.Stream) + } + if filters.BillingType != nil { + conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) + args = append(args, int16(*filters.BillingType)) + } + if filters.StartTime != nil { + conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1)) + args = append(args, *filters.StartTime) + } + if filters.EndTime != nil { + conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1)) + args = append(args, *filters.EndTime) + } + + query := fmt.Sprintf(` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(duration_ms), 0) as avg_duration_ms + FROM usage_logs + %s + `, buildWhere(conditions)) + + stats := &UsageStats{} + if err := scanSingleRow( + ctx, + r.sql, + query, + args, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { + return nil, err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens + return stats, nil +} + // AccountUsageHistory represents daily usage history for an account type AccountUsageHistory = usagestats.AccountUsageHistory diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 6971fafa..f1ee43d2 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -47,6 +47,7 @@ type UsageLogRepository interface { // Admin usage listing/stats ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) + GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) // Account stats GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index 29362cc6..10a294ae 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -319,3 +319,12 @@ func (s *UsageService) GetGlobalStats(ctx context.Context, startTime, endTime ti } return stats, nil } + +// GetStatsWithFilters returns usage stats with optional filters. +func (s *UsageService) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) { + stats, err := s.usageRepo.GetStatsWithFilters(ctx, filters) + if err != nil { + return nil, fmt.Errorf("get usage stats with filters: %w", err) + } + return stats, nil +} diff --git a/frontend/src/api/admin/usage.ts b/frontend/src/api/admin/usage.ts index 42c23a87..4712dafd 100644 --- a/frontend/src/api/admin/usage.ts +++ b/frontend/src/api/admin/usage.ts @@ -54,15 +54,21 @@ export async function list( /** * Get usage statistics with optional filters (admin only) - * @param params - Query parameters (user_id, api_key_id, period/date range) + * @param params - Query parameters for filtering * @returns Usage statistics */ export async function getStats(params: { user_id?: number api_key_id?: number + account_id?: number + group_id?: number + model?: string + stream?: boolean + billing_type?: number period?: string start_date?: string end_date?: string + timezone?: string }): Promise { const { data } = await apiClient.get('/admin/usage/stats', { params diff --git a/frontend/src/components/admin/usage/UsageFilters.vue b/frontend/src/components/admin/usage/UsageFilters.vue index d6077ec5..822f41a8 100644 --- a/frontend/src/components/admin/usage/UsageFilters.vue +++ b/frontend/src/components/admin/usage/UsageFilters.vue @@ -50,7 +50,7 @@ class="input pr-8" :placeholder="t('admin.usage.searchApiKeyPlaceholder')" @input="debounceApiKeySearch" - @focus="showApiKeyDropdown = true" + @focus="onApiKeyFocus" /> +
+ +
@@ -166,6 +197,7 @@ const filters = toRef(props, 'modelValue') const userSearchRef = ref(null) const apiKeySearchRef = ref(null) +const accountSearchRef = ref(null) const userKeyword = ref('') const userResults = ref([]) @@ -177,9 +209,17 @@ const apiKeyResults = ref([]) const showApiKeyDropdown = ref(false) let apiKeySearchTimeout: ReturnType | null = null +interface SimpleAccount { + id: number + name: string +} +const accountKeyword = ref('') +const accountResults = ref([]) +const showAccountDropdown = ref(false) +let accountSearchTimeout: ReturnType | null = null + const modelOptions = ref([{ value: null, label: t('admin.usage.allModels') }]) const groupOptions = ref([{ value: null, label: t('admin.usage.allGroups') }]) -const accountOptions = ref([{ value: null, label: t('admin.usage.allAccounts') }]) const streamTypeOptions = ref([ { value: null, label: t('admin.usage.allTypes') }, @@ -278,6 +318,37 @@ const onClearApiKey = () => { emitChange() } +const debounceAccountSearch = () => { + if (accountSearchTimeout) clearTimeout(accountSearchTimeout) + accountSearchTimeout = setTimeout(async () => { + if (!accountKeyword.value) { + accountResults.value = [] + return + } + try { + const res = await adminAPI.accounts.list(1, 20, { search: accountKeyword.value }) + accountResults.value = res.items.map((a) => ({ id: a.id, name: a.name })) + } catch { + accountResults.value = [] + } + }, 300) +} + +const selectAccount = (a: SimpleAccount) => { + accountKeyword.value = a.name + showAccountDropdown.value = false + filters.value.account_id = a.id + emitChange() +} + +const clearAccount = () => { + accountKeyword.value = '' + accountResults.value = [] + showAccountDropdown.value = false + filters.value.account_id = undefined + emitChange() +} + const onApiKeyFocus = () => { showApiKeyDropdown.value = true // Trigger search if no results yet @@ -292,9 +363,11 @@ const onDocumentClick = (e: MouseEvent) => { const clickedInsideUser = userSearchRef.value?.contains(target) ?? false const clickedInsideApiKey = apiKeySearchRef.value?.contains(target) ?? false + const clickedInsideAccount = accountSearchRef.value?.contains(target) ?? false if (!clickedInsideUser) showUserDropdown.value = false if (!clickedInsideApiKey) showApiKeyDropdown.value = false + if (!clickedInsideAccount) showAccountDropdown.value = false } watch( @@ -333,20 +406,27 @@ watch( } ) +watch( + () => filters.value.account_id, + (accountId) => { + if (!accountId) { + accountKeyword.value = '' + accountResults.value = [] + } + } +) + onMounted(async () => { document.addEventListener('click', onDocumentClick) try { - const [gs, ms, as] = await Promise.all([ + const [gs, ms] = await Promise.all([ adminAPI.groups.list(1, 1000), - adminAPI.dashboard.getModelStats({ start_date: props.startDate, end_date: props.endDate }), - adminAPI.accounts.list(1, 1000) + adminAPI.dashboard.getModelStats({ start_date: props.startDate, end_date: props.endDate }) ]) groupOptions.value.push(...gs.items.map((g: any) => ({ value: g.id, label: g.name }))) - accountOptions.value.push(...as.items.map((a: any) => ({ value: a.id, label: a.name }))) - const uniqueModels = new Set() ms.models?.forEach((s: any) => s.model && uniqueModels.add(s.model)) modelOptions.value.push( diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index fd5768a9..79465bb7 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -143,8 +143,8 @@ >
-
-
Token {{ t('usage.details') }}
+
+
{{ t('usage.tokenDetails') }}
{{ t('admin.usage.inputTokens') }} {{ tokenTooltipData.input_tokens.toLocaleString() }} @@ -184,6 +184,27 @@ >
+ +
+
{{ t('usage.costDetails') }}
+
+ {{ t('admin.usage.inputCost') }} + ${{ tooltipData.input_cost.toFixed(6) }} +
+
+ {{ t('admin.usage.outputCost') }} + ${{ tooltipData.output_cost.toFixed(6) }} +
+
+ {{ t('admin.usage.cacheCreationCost') }} + ${{ tooltipData.cache_creation_cost.toFixed(6) }} +
+
+ {{ t('admin.usage.cacheReadCost') }} + ${{ tooltipData.cache_read_cost.toFixed(6) }} +
+
+
{{ t('usage.rate') }} {{ (tooltipData?.rate_multiplier || 1).toFixed(2) }}x diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 393641a7..4634d8b6 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -376,6 +376,8 @@ export default { usage: { title: 'Usage Records', description: 'View and analyze your API usage history', + costDetails: 'Cost Breakdown', + tokenDetails: 'Token Breakdown', totalRequests: 'Total Requests', totalTokens: 'Total Tokens', totalCost: 'Total Cost', @@ -1691,6 +1693,7 @@ export default { userFilter: 'User', searchUserPlaceholder: 'Search user by email...', searchApiKeyPlaceholder: 'Search API key by name...', + searchAccountPlaceholder: 'Search account by name...', selectedUser: 'Selected', user: 'User', account: 'Account', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index fb46bbbe..7e326bab 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -373,6 +373,8 @@ export default { usage: { title: '使用记录', description: '查看和分析您的 API 使用历史', + costDetails: '成本明细', + tokenDetails: 'Token 明细', totalRequests: '总请求数', totalTokens: '总 Token', totalCost: '总消费', @@ -1836,6 +1838,7 @@ export default { userFilter: '用户', searchUserPlaceholder: '按邮箱搜索用户...', searchApiKeyPlaceholder: '按名称搜索 API 密钥...', + searchAccountPlaceholder: '按名称搜索账号...', selectedUser: '已选择', user: '用户', account: '账户', diff --git a/frontend/src/views/admin/UsageView.vue b/frontend/src/views/admin/UsageView.vue index d5e94145..522f1b00 100644 --- a/frontend/src/views/admin/UsageView.vue +++ b/frontend/src/views/admin/UsageView.vue @@ -85,11 +85,48 @@ const exportToExcel = async () => { if (all.length >= total || res.items.length < 100) break; p++ } if(!c.signal.aborted) { - // 动态加载 xlsx,降低首屏包体并减少高危依赖的常驻暴露面。 const XLSX = await import('xlsx') - const ws = XLSX.utils.json_to_sheet(all); const wb = XLSX.utils.book_new(); XLSX.utils.book_append_sheet(wb, ws, 'Usage') - saveAs(new Blob([XLSX.write(wb, { bookType: 'xlsx', type: 'array' })], { type: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' }), `usage_${Date.now()}.xlsx`) - appStore.showSuccess('Export Success') + const headers = [ + t('usage.time'), t('admin.usage.user'), t('usage.apiKeyFilter'), + t('admin.usage.account'), t('usage.model'), t('admin.usage.group'), + t('usage.type'), + t('admin.usage.inputTokens'), t('admin.usage.outputTokens'), + t('admin.usage.cacheReadTokens'), t('admin.usage.cacheCreationTokens'), + t('admin.usage.inputCost'), t('admin.usage.outputCost'), + t('admin.usage.cacheReadCost'), t('admin.usage.cacheCreationCost'), + t('usage.rate'), t('usage.original'), t('usage.billed'), + t('usage.billingType'), t('usage.firstToken'), t('usage.duration'), + t('admin.usage.requestId') + ] + const rows = all.map(log => [ + log.created_at, + log.user?.email || '', + log.api_key?.name || '', + log.account?.name || '', + log.model, + log.group?.name || '', + log.stream ? t('usage.stream') : t('usage.sync'), + log.input_tokens, + log.output_tokens, + log.cache_read_tokens, + log.cache_creation_tokens, + log.input_cost?.toFixed(6) || '0.000000', + log.output_cost?.toFixed(6) || '0.000000', + log.cache_read_cost?.toFixed(6) || '0.000000', + log.cache_creation_cost?.toFixed(6) || '0.000000', + log.rate_multiplier?.toFixed(2) || '1.00', + log.total_cost?.toFixed(6) || '0.000000', + log.actual_cost?.toFixed(6) || '0.000000', + log.billing_type === 1 ? t('usage.subscription') : t('usage.balance'), + log.first_token_ms ?? '', + log.duration_ms, + log.request_id || '' + ]) + const ws = XLSX.utils.aoa_to_sheet([headers, ...rows]) + const wb = XLSX.utils.book_new() + XLSX.utils.book_append_sheet(wb, ws, 'Usage') + saveAs(new Blob([XLSX.write(wb, { bookType: 'xlsx', type: 'array' })], { type: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' }), `usage_${filters.value.start_date}_to_${filters.value.end_date}.xlsx`) + appStore.showSuccess(t('usage.exportSuccess')) } } catch (error) { console.error('Failed to export:', error); appStore.showError('Export Failed') } finally { if(exportAbortController === c) { exportAbortController = null; exporting.value = false; exportProgress.show = false } } diff --git a/frontend/src/views/user/UsageView.vue b/frontend/src/views/user/UsageView.vue index 567d4061..489e2726 100644 --- a/frontend/src/views/user/UsageView.vue +++ b/frontend/src/views/user/UsageView.vue @@ -342,8 +342,8 @@ >
-
-
Token 明细
+
+
{{ t('usage.tokenDetails') }}
{{ t('admin.usage.inputTokens') }} {{ tokenTooltipData.input_tokens.toLocaleString() }} @@ -389,6 +389,27 @@ class="whitespace-nowrap rounded-lg border border-gray-700 bg-gray-900 px-3 py-2.5 text-xs text-white shadow-xl dark:border-gray-600 dark:bg-gray-800" >
+ +
+
{{ t('usage.costDetails') }}
+
+ {{ t('admin.usage.inputCost') }} + ${{ tooltipData.input_cost.toFixed(6) }} +
+
+ {{ t('admin.usage.outputCost') }} + ${{ tooltipData.output_cost.toFixed(6) }} +
+
+ {{ t('admin.usage.cacheCreationCost') }} + ${{ tooltipData.cache_creation_cost.toFixed(6) }} +
+
+ {{ t('admin.usage.cacheReadCost') }} + ${{ tooltipData.cache_read_cost.toFixed(6) }} +
+
+
{{ t('usage.rate') }} Date: Wed, 7 Jan 2026 10:17:09 +0800 Subject: [PATCH 08/46] =?UTF-8?q?fix(gateway):=20=E4=BF=AE=E5=A4=8D=20cach?= =?UTF-8?q?e=5Fcontrol=20=E5=9D=97=E8=B6=85=E9=99=90=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E5=B9=B6=E4=BC=98=E5=8C=96=20Claude=20Code=20=E6=A3=80?= =?UTF-8?q?=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 问题: - OAuth/SetupToken 账号注入 system prompt 后可能导致 cache_control 块超过 Anthropic API 的 4 个限制 - Claude Code 检测使用精确匹配,无法识别 Agent SDK 等变体 修复: - 新增 enforceCacheControlLimit 函数,强制执行 4 个块限制 - 优先从 messages 移除,再从 system 尾部移除(保护注入的 prompt) - 改用前缀匹配检测 Claude Code 系统提示词,支持多种变体: - 标准版、Agent SDK 版、Explore Agent 版、Compact 版 --- backend/internal/service/gateway_service.go | 148 +++++++++++++++++++- 1 file changed, 145 insertions(+), 3 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index a83e7d05..8fd0b918 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -35,6 +35,7 @@ const ( stickySessionTTL = time.Hour // 粘性会话TTL defaultMaxLineSize = 10 * 1024 * 1024 claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." + maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 ) // sseDataRe matches SSE data lines with optional whitespace after colon. @@ -43,6 +44,16 @@ var ( sseDataRe = regexp.MustCompile(`^data:\s*`) sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) + + // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 + // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 + // 注意:前缀之间不应存在包含关系,否则会导致冗余匹配 + claudeCodePromptPrefixes = []string{ + "You are Claude Code, Anthropic's official CLI for Claude", // 标准版 & Agent SDK 版(含 running within...) + "You are a Claude agent, built on Anthropic's Claude Agent SDK", // Agent SDK 变体 + "You are a file search specialist for Claude Code", // Explore Agent 版 + "You are a helpful AI assistant tasked with summarizing conversations", // Compact 版 + } ) // allowedHeaders 白名单headers(参考CRS项目) @@ -1013,15 +1024,15 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool { } // systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词 -// 支持 string 和 []any 两种格式 +// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等) func systemIncludesClaudeCodePrompt(system any) bool { switch v := system.(type) { case string: - return v == claudeCodeSystemPrompt + return hasClaudeCodePrefix(v) case []any: for _, item := range v { if m, ok := item.(map[string]any); ok { - if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt { + if text, ok := m["text"].(string); ok && hasClaudeCodePrefix(text) { return true } } @@ -1030,6 +1041,16 @@ func systemIncludesClaudeCodePrompt(system any) bool { return false } +// hasClaudeCodePrefix 检查文本是否以 Claude Code 提示词的特征前缀开头 +func hasClaudeCodePrefix(text string) bool { + for _, prefix := range claudeCodePromptPrefixes { + if strings.HasPrefix(text, prefix) { + return true + } + } + return false +} + // injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 // 处理 null、字符串、数组三种格式 func injectClaudeCodePrompt(body []byte, system any) []byte { @@ -1073,6 +1094,124 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { return result } +// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个) +// 超限时优先从 messages 中移除 cache_control,保护 system 中的缓存控制 +func enforceCacheControlLimit(body []byte) []byte { + var data map[string]any + if err := json.Unmarshal(body, &data); err != nil { + return body + } + + // 计算当前 cache_control 块数量 + count := countCacheControlBlocks(data) + if count <= maxCacheControlBlocks { + return body + } + + // 超限:优先从 messages 中移除,再从 system 中移除 + for count > maxCacheControlBlocks { + if removeCacheControlFromMessages(data) { + count-- + continue + } + if removeCacheControlFromSystem(data) { + count-- + continue + } + break + } + + result, err := json.Marshal(data) + if err != nil { + return body + } + return result +} + +// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量 +func countCacheControlBlocks(data map[string]any) int { + count := 0 + + // 统计 system 中的块 + if system, ok := data["system"].([]any); ok { + for _, item := range system { + if m, ok := item.(map[string]any); ok { + if _, has := m["cache_control"]; has { + count++ + } + } + } + } + + // 统计 messages 中的块 + if messages, ok := data["messages"].([]any); ok { + for _, msg := range messages { + if msgMap, ok := msg.(map[string]any); ok { + if content, ok := msgMap["content"].([]any); ok { + for _, item := range content { + if m, ok := item.(map[string]any); ok { + if _, has := m["cache_control"]; has { + count++ + } + } + } + } + } + } + } + + return count +} + +// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始) +// 返回 true 表示成功移除,false 表示没有可移除的 +func removeCacheControlFromMessages(data map[string]any) bool { + messages, ok := data["messages"].([]any) + if !ok { + return false + } + + for _, msg := range messages { + msgMap, ok := msg.(map[string]any) + if !ok { + continue + } + content, ok := msgMap["content"].([]any) + if !ok { + continue + } + for _, item := range content { + if m, ok := item.(map[string]any); ok { + if _, has := m["cache_control"]; has { + delete(m, "cache_control") + return true + } + } + } + } + return false +} + +// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt) +// 返回 true 表示成功移除,false 表示没有可移除的 +func removeCacheControlFromSystem(data map[string]any) bool { + system, ok := data["system"].([]any) + if !ok { + return false + } + + // 从尾部开始移除,保护开头注入的 Claude Code prompt + for i := len(system) - 1; i >= 0; i-- { + if m, ok := system[i].(map[string]any); ok { + if _, has := m["cache_control"]; has { + delete(m, "cache_control") + return true + } + } + } + return false +} + // Forward 转发请求到Claude API func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { startTime := time.Now() @@ -1093,6 +1232,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A body = injectClaudeCodePrompt(body, parsed.System) } + // 强制执行 cache_control 块数量限制(最多 4 个) + body = enforceCacheControlLimit(body) + // 应用模型映射(仅对apikey类型账号) originalModel := reqModel if account.Type == AccountTypeAPIKey { From fc8fa83fcc03cbeea471f7d0e547031f6187f3a0 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 7 Jan 2026 10:26:24 +0800 Subject: [PATCH 09/46] =?UTF-8?q?fix(keys):=20=E4=BF=AE=E5=A4=8D=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E6=A1=86=E7=AC=AC=E4=B8=80=E8=A1=8C=E5=A4=9A=E4=BD=99?= =?UTF-8?q?=E7=A9=BA=E6=A0=BC=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit pre 标签会原样保留内部空白字符,导致 code 标签前的模板缩进 被渲染为实际空格。将 pre/code 标签写在同一行消除此问题。 --- frontend/src/components/keys/UseKeyModal.vue | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue index 16c39bf8..546a53ab 100644 --- a/frontend/src/components/keys/UseKeyModal.vue +++ b/frontend/src/components/keys/UseKeyModal.vue @@ -105,10 +105,7 @@
-
-                
-                
-              
+
From d99a3ef14b0c55351c572c7fdca9519d00bf99bd Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 7 Jan 2026 10:56:52 +0800 Subject: [PATCH 10/46] =?UTF-8?q?fix(gateway):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E8=B4=A6=E5=8F=B7=E8=B7=A8=E5=88=86=E7=BB=84=E8=B0=83=E5=BA=A6?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 问题:账号可能被调度到未分配的分组(如 simon 账号被调度到 claude_default) 根因: - 强制平台模式下分组查询失败时回退到全平台查询 - listSchedulableAccounts 中分组为空时回退到无分组查询 - 粘性会话只检查平台匹配,未校验账号分组归属 修复: - 移除强制平台模式的回退逻辑,分组内无账号时返回错误 - 移除 listSchedulableAccounts 的回退逻辑 - 新增 isAccountInGroup 方法用于分组校验 - 在三处粘性会话检查中增加分组归属验证 --- backend/internal/service/gateway_service.go | 43 ++++++++++++--------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 8fd0b918..120637d5 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -366,17 +366,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } - // 强制平台模式:优先按分组查找,找不到再查全部该平台账户 - if hasForcePlatform && groupID != nil { - account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) - if err == nil { - return account, nil - } - // 分组中找不到,回退查询全部该平台账户 - groupID = nil - } - // antigravity 分组、强制平台模式或无分组使用单平台选择 + // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } @@ -454,7 +445,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) if err == nil && accountID > 0 && !isExcluded(accountID) { account, err := s.accountRepo.GetByID(ctx, accountID) - if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) && + if err == nil && s.isAccountInGroup(account, groupID) && + s.isAccountAllowedForPlatform(account, platform, useMixed) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) @@ -671,9 +663,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) } else if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) - if err == nil && len(accounts) == 0 && hasForcePlatform { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) - } + // 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询 } else { accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) } @@ -696,6 +686,23 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform return account.Platform == platform } +// isAccountInGroup checks if the account belongs to the specified group. +// Returns true if groupID is nil (no group restriction) or account belongs to the group. +func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { + if groupID == nil { + return true // 无分组限制 + } + if account == nil { + return false + } + for _, ag := range account.AccountGroups { + if ag.GroupID == *groupID { + return true + } + } + return false +} + func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { if s.concurrencyService == nil { return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil @@ -734,8 +741,8 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) - // 检查账号平台是否匹配(确保粘性会话不会跨平台) - if err == nil && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) + if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) } @@ -823,8 +830,8 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) - // 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度 - if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 + if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) From b19c7875a41a697c73d413206369a05ba45909a1 Mon Sep 17 00:00:00 2001 From: Xu Kang <7836246@qq.com> Date: Wed, 7 Jan 2026 15:01:07 +0800 Subject: [PATCH 11/46] fix(i18n): use correct translation key for dashboard redeem code description (#194) Changed dashboard.addBalance to dashboard.addBalanceWithCode to match the existing translation key in locale files. --- .../src/components/user/dashboard/UserDashboardQuickActions.vue | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/components/user/dashboard/UserDashboardQuickActions.vue b/frontend/src/components/user/dashboard/UserDashboardQuickActions.vue index 9d884aed..44ab98d9 100644 --- a/frontend/src/components/user/dashboard/UserDashboardQuickActions.vue +++ b/frontend/src/components/user/dashboard/UserDashboardQuickActions.vue @@ -40,7 +40,7 @@

{{ t('dashboard.redeemCode') }}

-

{{ t('dashboard.addBalance') }}

+

{{ t('dashboard.addBalanceWithCode') }}

Date: Wed, 7 Jan 2026 16:35:51 +0800 Subject: [PATCH 12/46] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20Go=20?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E3=80=81=E5=8C=85=E7=AE=A1=E7=90=86=E5=99=A8?= =?UTF-8?q?=E5=92=8C=E6=8A=80=E6=9C=AF=E6=A0=88=E6=96=87=E6=A1=A3=20(#195)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - backend/Dockerfile: Go 版本从 1.21 更新到 1.25.5(与 go.mod 一致) - Makefile: 使用 pnpm 替代 npm(与 pnpm-lock.yaml 和 CI 一致) - README.md/README_CN.md: 技术栈从 GORM 修正为 Ent --- Makefile | 6 +++--- README.md | 2 +- README_CN.md | 2 +- backend/Dockerfile | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 4a08c23b..a5e18a37 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ build-backend: # 编译前端(需要已安装依赖) build-frontend: - @npm --prefix frontend run build + @pnpm --dir frontend run build # 运行测试(后端 + 前端) test: test-backend test-frontend @@ -18,5 +18,5 @@ test-backend: @$(MAKE) -C backend test test-frontend: - @npm --prefix frontend run lint:check - @npm --prefix frontend run typecheck + @pnpm --dir frontend run lint:check + @pnpm --dir frontend run typecheck diff --git a/README.md b/README.md index 684ad0f2..fa965e6f 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot | Component | Technology | |-----------|------------| -| Backend | Go 1.25.5, Gin, GORM | +| Backend | Go 1.25.5, Gin, Ent | | Frontend | Vue 3.4+, Vite 5+, TailwindCSS | | Database | PostgreSQL 15+ | | Cache/Queue | Redis 7+ | diff --git a/README_CN.md b/README_CN.md index 22a601bc..b8a818b3 100644 --- a/README_CN.md +++ b/README_CN.md @@ -44,7 +44,7 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( | 组件 | 技术 | |------|------| -| 后端 | Go 1.25.5, Gin, GORM | +| 后端 | Go 1.25.5, Gin, Ent | | 前端 | Vue 3.4+, Vite 5+, TailwindCSS | | 数据库 | PostgreSQL 15+ | | 缓存/队列 | Redis 7+ | diff --git a/backend/Dockerfile b/backend/Dockerfile index 3bc4e50f..770fdedf 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.21-alpine +FROM golang:1.25.5-alpine WORKDIR /app From 2b528c5f813b64b6a7ed38f6c75eda0d8fc725f8 Mon Sep 17 00:00:00 2001 From: LLLLLLiulei <1065070665@qq.com> Date: Wed, 7 Jan 2026 16:59:35 +0800 Subject: [PATCH 13/46] feat: auto-pause expired accounts --- backend/cmd/server/wire.go | 5 + backend/cmd/server/wire_gen.go | 12 +- backend/ent/account.go | 29 +++- backend/ent/account/account.go | 18 +++ backend/ent/account/where.go | 70 +++++++++ backend/ent/account_create.go | 143 ++++++++++++++++++ backend/ent/account_update.go | 86 +++++++++++ backend/ent/migrate/schema.go | 14 +- backend/ent/mutation.go | 129 +++++++++++++++- backend/ent/runtime/runtime.go | 8 +- backend/ent/schema/account.go | 10 ++ .../internal/handler/admin/account_handler.go | 8 + backend/internal/handler/dto/mappers.go | 16 +- backend/internal/handler/dto/types.go | 32 ++-- backend/internal/repository/account_repo.go | 49 +++++- backend/internal/service/account.go | 35 +++-- .../service/account_expiry_service.go | 71 +++++++++ backend/internal/service/account_service.go | 55 ++++--- .../service/account_service_delete_test.go | 4 + backend/internal/service/admin_service.go | 44 ++++-- .../service/gateway_multiplatform_test.go | 3 + .../service/gemini_multiplatform_test.go | 3 + backend/internal/service/wire.go | 8 + .../migrations/030_add_account_expires_at.sql | 10 ++ .../components/account/CreateAccountModal.vue | 135 ++++++++++++----- .../components/account/EditAccountModal.vue | 118 +++++++++++---- frontend/src/i18n/locales/en.ts | 6 + frontend/src/i18n/locales/zh.ts | 6 + frontend/src/types/index.ts | 6 + frontend/src/utils/format.ts | 41 ++++- frontend/src/views/admin/AccountsView.vue | 41 ++++- frontend/vite.config.ts | 3 +- 32 files changed, 1062 insertions(+), 156 deletions(-) create mode 100644 backend/internal/service/account_expiry_service.go create mode 100644 backend/migrations/030_add_account_expires_at.sql diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index ff6ab4e6..9447de45 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -63,6 +63,7 @@ func provideCleanup( entClient *ent.Client, rdb *redis.Client, tokenRefresh *service.TokenRefreshService, + accountExpiry *service.AccountExpiryService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, @@ -84,6 +85,10 @@ func provideCleanup( tokenRefresh.Stop() return nil }}, + {"AccountExpiryService", func() error { + accountExpiry.Stop() + return nil + }}, {"PricingService", func() error { pricing.Stop() return nil diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 768254f9..e952b298 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -87,6 +87,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig) + antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) tempUnschedCache := repository.NewTempUnschedCache(redisClient) rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache) @@ -97,13 +98,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiTokenCache := repository.NewGeminiTokenCache(redisClient) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) gatewayCache := repository.NewGatewayCache(redisClient) - antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) httpUpstream := repository.NewHTTPUpstream(configConfig) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) - concurrencyService := service.NewConcurrencyService(concurrencyCache) + concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) oAuthHandler := admin.NewOAuthHandler(oAuthService) @@ -148,7 +148,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService) httpServer := server.ProvideHTTPServer(configConfig, engine) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) - v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + accountExpiryService := service.ProvideAccountExpiryService(accountRepository) + v := provideCleanup(client, redisClient, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, Cleanup: v, @@ -174,6 +175,7 @@ func provideCleanup( entClient *ent.Client, rdb *redis.Client, tokenRefresh *service.TokenRefreshService, + accountExpiry *service.AccountExpiryService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, @@ -194,6 +196,10 @@ func provideCleanup( tokenRefresh.Stop() return nil }}, + {"AccountExpiryService", func() error { + accountExpiry.Stop() + return nil + }}, {"PricingService", func() error { pricing.Stop() return nil diff --git a/backend/ent/account.go b/backend/ent/account.go index e4823366..e960d324 100644 --- a/backend/ent/account.go +++ b/backend/ent/account.go @@ -49,6 +49,10 @@ type Account struct { ErrorMessage *string `json:"error_message,omitempty"` // LastUsedAt holds the value of the "last_used_at" field. LastUsedAt *time.Time `json:"last_used_at,omitempty"` + // Account expiration time (NULL means no expiration). + ExpiresAt *time.Time `json:"expires_at,omitempty"` + // Auto pause scheduling when account expires. + AutoPauseOnExpired bool `json:"auto_pause_on_expired,omitempty"` // Schedulable holds the value of the "schedulable" field. Schedulable bool `json:"schedulable,omitempty"` // RateLimitedAt holds the value of the "rate_limited_at" field. @@ -129,13 +133,13 @@ func (*Account) scanValues(columns []string) ([]any, error) { switch columns[i] { case account.FieldCredentials, account.FieldExtra: values[i] = new([]byte) - case account.FieldSchedulable: + case account.FieldAutoPauseOnExpired, account.FieldSchedulable: values[i] = new(sql.NullBool) case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority: values[i] = new(sql.NullInt64) case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus: values[i] = new(sql.NullString) - case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd: + case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -257,6 +261,19 @@ func (_m *Account) assignValues(columns []string, values []any) error { _m.LastUsedAt = new(time.Time) *_m.LastUsedAt = value.Time } + case account.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = new(time.Time) + *_m.ExpiresAt = value.Time + } + case account.FieldAutoPauseOnExpired: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field auto_pause_on_expired", values[i]) + } else if value.Valid { + _m.AutoPauseOnExpired = value.Bool + } case account.FieldSchedulable: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field schedulable", values[i]) @@ -416,6 +433,14 @@ func (_m *Account) String() string { builder.WriteString(v.Format(time.ANSIC)) } builder.WriteString(", ") + if v := _m.ExpiresAt; v != nil { + builder.WriteString("expires_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("auto_pause_on_expired=") + builder.WriteString(fmt.Sprintf("%v", _m.AutoPauseOnExpired)) + builder.WriteString(", ") builder.WriteString("schedulable=") builder.WriteString(fmt.Sprintf("%v", _m.Schedulable)) builder.WriteString(", ") diff --git a/backend/ent/account/account.go b/backend/ent/account/account.go index 26f72018..402e16ee 100644 --- a/backend/ent/account/account.go +++ b/backend/ent/account/account.go @@ -45,6 +45,10 @@ const ( FieldErrorMessage = "error_message" // FieldLastUsedAt holds the string denoting the last_used_at field in the database. FieldLastUsedAt = "last_used_at" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" + // FieldAutoPauseOnExpired holds the string denoting the auto_pause_on_expired field in the database. + FieldAutoPauseOnExpired = "auto_pause_on_expired" // FieldSchedulable holds the string denoting the schedulable field in the database. FieldSchedulable = "schedulable" // FieldRateLimitedAt holds the string denoting the rate_limited_at field in the database. @@ -115,6 +119,8 @@ var Columns = []string{ FieldStatus, FieldErrorMessage, FieldLastUsedAt, + FieldExpiresAt, + FieldAutoPauseOnExpired, FieldSchedulable, FieldRateLimitedAt, FieldRateLimitResetAt, @@ -172,6 +178,8 @@ var ( DefaultStatus string // StatusValidator is a validator for the "status" field. It is called by the builders before save. StatusValidator func(string) error + // DefaultAutoPauseOnExpired holds the default value on creation for the "auto_pause_on_expired" field. + DefaultAutoPauseOnExpired bool // DefaultSchedulable holds the default value on creation for the "schedulable" field. DefaultSchedulable bool // SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save. @@ -251,6 +259,16 @@ func ByLastUsedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldLastUsedAt, opts...).ToFunc() } +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} + +// ByAutoPauseOnExpired orders the results by the auto_pause_on_expired field. +func ByAutoPauseOnExpired(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAutoPauseOnExpired, opts...).ToFunc() +} + // BySchedulable orders the results by the schedulable field. func BySchedulable(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldSchedulable, opts...).ToFunc() diff --git a/backend/ent/account/where.go b/backend/ent/account/where.go index 1ab75a13..6c639fd1 100644 --- a/backend/ent/account/where.go +++ b/backend/ent/account/where.go @@ -120,6 +120,16 @@ func LastUsedAt(v time.Time) predicate.Account { return predicate.Account(sql.FieldEQ(FieldLastUsedAt, v)) } +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldExpiresAt, v)) +} + +// AutoPauseOnExpired applies equality check predicate on the "auto_pause_on_expired" field. It's identical to AutoPauseOnExpiredEQ. +func AutoPauseOnExpired(v bool) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldAutoPauseOnExpired, v)) +} + // Schedulable applies equality check predicate on the "schedulable" field. It's identical to SchedulableEQ. func Schedulable(v bool) predicate.Account { return predicate.Account(sql.FieldEQ(FieldSchedulable, v)) @@ -855,6 +865,66 @@ func LastUsedAtNotNil() predicate.Account { return predicate.Account(sql.FieldNotNull(FieldLastUsedAt)) } +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.Account { + return predicate.Account(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.Account { + return predicate.Account(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.Account { + return predicate.Account(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.Account { + return predicate.Account(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.Account { + return predicate.Account(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.Account { + return predicate.Account(sql.FieldLTE(FieldExpiresAt, v)) +} + +// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field. +func ExpiresAtIsNil() predicate.Account { + return predicate.Account(sql.FieldIsNull(FieldExpiresAt)) +} + +// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field. +func ExpiresAtNotNil() predicate.Account { + return predicate.Account(sql.FieldNotNull(FieldExpiresAt)) +} + +// AutoPauseOnExpiredEQ applies the EQ predicate on the "auto_pause_on_expired" field. +func AutoPauseOnExpiredEQ(v bool) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldAutoPauseOnExpired, v)) +} + +// AutoPauseOnExpiredNEQ applies the NEQ predicate on the "auto_pause_on_expired" field. +func AutoPauseOnExpiredNEQ(v bool) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldAutoPauseOnExpired, v)) +} + // SchedulableEQ applies the EQ predicate on the "schedulable" field. func SchedulableEQ(v bool) predicate.Account { return predicate.Account(sql.FieldEQ(FieldSchedulable, v)) diff --git a/backend/ent/account_create.go b/backend/ent/account_create.go index 2d7debc0..0725d43d 100644 --- a/backend/ent/account_create.go +++ b/backend/ent/account_create.go @@ -195,6 +195,34 @@ func (_c *AccountCreate) SetNillableLastUsedAt(v *time.Time) *AccountCreate { return _c } +// SetExpiresAt sets the "expires_at" field. +func (_c *AccountCreate) SetExpiresAt(v time.Time) *AccountCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_c *AccountCreate) SetNillableExpiresAt(v *time.Time) *AccountCreate { + if v != nil { + _c.SetExpiresAt(*v) + } + return _c +} + +// SetAutoPauseOnExpired sets the "auto_pause_on_expired" field. +func (_c *AccountCreate) SetAutoPauseOnExpired(v bool) *AccountCreate { + _c.mutation.SetAutoPauseOnExpired(v) + return _c +} + +// SetNillableAutoPauseOnExpired sets the "auto_pause_on_expired" field if the given value is not nil. +func (_c *AccountCreate) SetNillableAutoPauseOnExpired(v *bool) *AccountCreate { + if v != nil { + _c.SetAutoPauseOnExpired(*v) + } + return _c +} + // SetSchedulable sets the "schedulable" field. func (_c *AccountCreate) SetSchedulable(v bool) *AccountCreate { _c.mutation.SetSchedulable(v) @@ -405,6 +433,10 @@ func (_c *AccountCreate) defaults() error { v := account.DefaultStatus _c.mutation.SetStatus(v) } + if _, ok := _c.mutation.AutoPauseOnExpired(); !ok { + v := account.DefaultAutoPauseOnExpired + _c.mutation.SetAutoPauseOnExpired(v) + } if _, ok := _c.mutation.Schedulable(); !ok { v := account.DefaultSchedulable _c.mutation.SetSchedulable(v) @@ -464,6 +496,9 @@ func (_c *AccountCreate) check() error { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Account.status": %w`, err)} } } + if _, ok := _c.mutation.AutoPauseOnExpired(); !ok { + return &ValidationError{Name: "auto_pause_on_expired", err: errors.New(`ent: missing required field "Account.auto_pause_on_expired"`)} + } if _, ok := _c.mutation.Schedulable(); !ok { return &ValidationError{Name: "schedulable", err: errors.New(`ent: missing required field "Account.schedulable"`)} } @@ -555,6 +590,14 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { _spec.SetField(account.FieldLastUsedAt, field.TypeTime, value) _node.LastUsedAt = &value } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(account.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = &value + } + if value, ok := _c.mutation.AutoPauseOnExpired(); ok { + _spec.SetField(account.FieldAutoPauseOnExpired, field.TypeBool, value) + _node.AutoPauseOnExpired = value + } if value, ok := _c.mutation.Schedulable(); ok { _spec.SetField(account.FieldSchedulable, field.TypeBool, value) _node.Schedulable = value @@ -898,6 +941,36 @@ func (u *AccountUpsert) ClearLastUsedAt() *AccountUpsert { return u } +// SetExpiresAt sets the "expires_at" field. +func (u *AccountUpsert) SetExpiresAt(v time.Time) *AccountUpsert { + u.Set(account.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *AccountUpsert) UpdateExpiresAt() *AccountUpsert { + u.SetExcluded(account.FieldExpiresAt) + return u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *AccountUpsert) ClearExpiresAt() *AccountUpsert { + u.SetNull(account.FieldExpiresAt) + return u +} + +// SetAutoPauseOnExpired sets the "auto_pause_on_expired" field. +func (u *AccountUpsert) SetAutoPauseOnExpired(v bool) *AccountUpsert { + u.Set(account.FieldAutoPauseOnExpired, v) + return u +} + +// UpdateAutoPauseOnExpired sets the "auto_pause_on_expired" field to the value that was provided on create. +func (u *AccountUpsert) UpdateAutoPauseOnExpired() *AccountUpsert { + u.SetExcluded(account.FieldAutoPauseOnExpired) + return u +} + // SetSchedulable sets the "schedulable" field. func (u *AccountUpsert) SetSchedulable(v bool) *AccountUpsert { u.Set(account.FieldSchedulable, v) @@ -1308,6 +1381,41 @@ func (u *AccountUpsertOne) ClearLastUsedAt() *AccountUpsertOne { }) } +// SetExpiresAt sets the "expires_at" field. +func (u *AccountUpsertOne) SetExpiresAt(v time.Time) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateExpiresAt() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateExpiresAt() + }) +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *AccountUpsertOne) ClearExpiresAt() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.ClearExpiresAt() + }) +} + +// SetAutoPauseOnExpired sets the "auto_pause_on_expired" field. +func (u *AccountUpsertOne) SetAutoPauseOnExpired(v bool) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetAutoPauseOnExpired(v) + }) +} + +// UpdateAutoPauseOnExpired sets the "auto_pause_on_expired" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateAutoPauseOnExpired() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateAutoPauseOnExpired() + }) +} + // SetSchedulable sets the "schedulable" field. func (u *AccountUpsertOne) SetSchedulable(v bool) *AccountUpsertOne { return u.Update(func(s *AccountUpsert) { @@ -1904,6 +2012,41 @@ func (u *AccountUpsertBulk) ClearLastUsedAt() *AccountUpsertBulk { }) } +// SetExpiresAt sets the "expires_at" field. +func (u *AccountUpsertBulk) SetExpiresAt(v time.Time) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateExpiresAt() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateExpiresAt() + }) +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *AccountUpsertBulk) ClearExpiresAt() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.ClearExpiresAt() + }) +} + +// SetAutoPauseOnExpired sets the "auto_pause_on_expired" field. +func (u *AccountUpsertBulk) SetAutoPauseOnExpired(v bool) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetAutoPauseOnExpired(v) + }) +} + +// UpdateAutoPauseOnExpired sets the "auto_pause_on_expired" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateAutoPauseOnExpired() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateAutoPauseOnExpired() + }) +} + // SetSchedulable sets the "schedulable" field. func (u *AccountUpsertBulk) SetSchedulable(v bool) *AccountUpsertBulk { return u.Update(func(s *AccountUpsert) { diff --git a/backend/ent/account_update.go b/backend/ent/account_update.go index e329abcd..dcc3212d 100644 --- a/backend/ent/account_update.go +++ b/backend/ent/account_update.go @@ -247,6 +247,40 @@ func (_u *AccountUpdate) ClearLastUsedAt() *AccountUpdate { return _u } +// SetExpiresAt sets the "expires_at" field. +func (_u *AccountUpdate) SetExpiresAt(v time.Time) *AccountUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableExpiresAt(v *time.Time) *AccountUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (_u *AccountUpdate) ClearExpiresAt() *AccountUpdate { + _u.mutation.ClearExpiresAt() + return _u +} + +// SetAutoPauseOnExpired sets the "auto_pause_on_expired" field. +func (_u *AccountUpdate) SetAutoPauseOnExpired(v bool) *AccountUpdate { + _u.mutation.SetAutoPauseOnExpired(v) + return _u +} + +// SetNillableAutoPauseOnExpired sets the "auto_pause_on_expired" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableAutoPauseOnExpired(v *bool) *AccountUpdate { + if v != nil { + _u.SetAutoPauseOnExpired(*v) + } + return _u +} + // SetSchedulable sets the "schedulable" field. func (_u *AccountUpdate) SetSchedulable(v bool) *AccountUpdate { _u.mutation.SetSchedulable(v) @@ -610,6 +644,15 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.LastUsedAtCleared() { _spec.ClearField(account.FieldLastUsedAt, field.TypeTime) } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(account.FieldExpiresAt, field.TypeTime, value) + } + if _u.mutation.ExpiresAtCleared() { + _spec.ClearField(account.FieldExpiresAt, field.TypeTime) + } + if value, ok := _u.mutation.AutoPauseOnExpired(); ok { + _spec.SetField(account.FieldAutoPauseOnExpired, field.TypeBool, value) + } if value, ok := _u.mutation.Schedulable(); ok { _spec.SetField(account.FieldSchedulable, field.TypeBool, value) } @@ -1016,6 +1059,40 @@ func (_u *AccountUpdateOne) ClearLastUsedAt() *AccountUpdateOne { return _u } +// SetExpiresAt sets the "expires_at" field. +func (_u *AccountUpdateOne) SetExpiresAt(v time.Time) *AccountUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableExpiresAt(v *time.Time) *AccountUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (_u *AccountUpdateOne) ClearExpiresAt() *AccountUpdateOne { + _u.mutation.ClearExpiresAt() + return _u +} + +// SetAutoPauseOnExpired sets the "auto_pause_on_expired" field. +func (_u *AccountUpdateOne) SetAutoPauseOnExpired(v bool) *AccountUpdateOne { + _u.mutation.SetAutoPauseOnExpired(v) + return _u +} + +// SetNillableAutoPauseOnExpired sets the "auto_pause_on_expired" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableAutoPauseOnExpired(v *bool) *AccountUpdateOne { + if v != nil { + _u.SetAutoPauseOnExpired(*v) + } + return _u +} + // SetSchedulable sets the "schedulable" field. func (_u *AccountUpdateOne) SetSchedulable(v bool) *AccountUpdateOne { _u.mutation.SetSchedulable(v) @@ -1409,6 +1486,15 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er if _u.mutation.LastUsedAtCleared() { _spec.ClearField(account.FieldLastUsedAt, field.TypeTime) } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(account.FieldExpiresAt, field.TypeTime, value) + } + if _u.mutation.ExpiresAtCleared() { + _spec.ClearField(account.FieldExpiresAt, field.TypeTime) + } + if value, ok := _u.mutation.AutoPauseOnExpired(); ok { + _spec.SetField(account.FieldAutoPauseOnExpired, field.TypeBool, value) + } if value, ok := _u.mutation.Schedulable(); ok { _spec.SetField(account.FieldSchedulable, field.TypeBool, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index d0e43bf3..4fd96f87 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -80,6 +80,8 @@ var ( {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, {Name: "error_message", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "last_used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "auto_pause_on_expired", Type: field.TypeBool, Default: true}, {Name: "schedulable", Type: field.TypeBool, Default: true}, {Name: "rate_limited_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "rate_limit_reset_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, @@ -97,7 +99,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "accounts_proxies_proxy", - Columns: []*schema.Column{AccountsColumns[22]}, + Columns: []*schema.Column{AccountsColumns[24]}, RefColumns: []*schema.Column{ProxiesColumns[0]}, OnDelete: schema.SetNull, }, @@ -121,7 +123,7 @@ var ( { Name: "account_proxy_id", Unique: false, - Columns: []*schema.Column{AccountsColumns[22]}, + Columns: []*schema.Column{AccountsColumns[24]}, }, { Name: "account_priority", @@ -136,22 +138,22 @@ var ( { Name: "account_schedulable", Unique: false, - Columns: []*schema.Column{AccountsColumns[15]}, + Columns: []*schema.Column{AccountsColumns[17]}, }, { Name: "account_rate_limited_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[16]}, + Columns: []*schema.Column{AccountsColumns[18]}, }, { Name: "account_rate_limit_reset_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[17]}, + Columns: []*schema.Column{AccountsColumns[19]}, }, { Name: "account_overload_until", Unique: false, - Columns: []*schema.Column{AccountsColumns[18]}, + Columns: []*schema.Column{AccountsColumns[20]}, }, { Name: "account_deleted_at", diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 91883413..ccda9b17 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -1006,6 +1006,8 @@ type AccountMutation struct { status *string error_message *string last_used_at *time.Time + expires_at *time.Time + auto_pause_on_expired *bool schedulable *bool rate_limited_at *time.Time rate_limit_reset_at *time.Time @@ -1770,6 +1772,91 @@ func (m *AccountMutation) ResetLastUsedAt() { delete(m.clearedFields, account.FieldLastUsedAt) } +// SetExpiresAt sets the "expires_at" field. +func (m *AccountMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *AccountMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldExpiresAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (m *AccountMutation) ClearExpiresAt() { + m.expires_at = nil + m.clearedFields[account.FieldExpiresAt] = struct{}{} +} + +// ExpiresAtCleared returns if the "expires_at" field was cleared in this mutation. +func (m *AccountMutation) ExpiresAtCleared() bool { + _, ok := m.clearedFields[account.FieldExpiresAt] + return ok +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *AccountMutation) ResetExpiresAt() { + m.expires_at = nil + delete(m.clearedFields, account.FieldExpiresAt) +} + +// SetAutoPauseOnExpired sets the "auto_pause_on_expired" field. +func (m *AccountMutation) SetAutoPauseOnExpired(b bool) { + m.auto_pause_on_expired = &b +} + +// AutoPauseOnExpired returns the value of the "auto_pause_on_expired" field in the mutation. +func (m *AccountMutation) AutoPauseOnExpired() (r bool, exists bool) { + v := m.auto_pause_on_expired + if v == nil { + return + } + return *v, true +} + +// OldAutoPauseOnExpired returns the old "auto_pause_on_expired" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldAutoPauseOnExpired(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAutoPauseOnExpired is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAutoPauseOnExpired requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAutoPauseOnExpired: %w", err) + } + return oldValue.AutoPauseOnExpired, nil +} + +// ResetAutoPauseOnExpired resets all changes to the "auto_pause_on_expired" field. +func (m *AccountMutation) ResetAutoPauseOnExpired() { + m.auto_pause_on_expired = nil +} + // SetSchedulable sets the "schedulable" field. func (m *AccountMutation) SetSchedulable(b bool) { m.schedulable = &b @@ -2269,7 +2356,7 @@ func (m *AccountMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AccountMutation) Fields() []string { - fields := make([]string, 0, 22) + fields := make([]string, 0, 24) if m.created_at != nil { fields = append(fields, account.FieldCreatedAt) } @@ -2315,6 +2402,12 @@ func (m *AccountMutation) Fields() []string { if m.last_used_at != nil { fields = append(fields, account.FieldLastUsedAt) } + if m.expires_at != nil { + fields = append(fields, account.FieldExpiresAt) + } + if m.auto_pause_on_expired != nil { + fields = append(fields, account.FieldAutoPauseOnExpired) + } if m.schedulable != nil { fields = append(fields, account.FieldSchedulable) } @@ -2374,6 +2467,10 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) { return m.ErrorMessage() case account.FieldLastUsedAt: return m.LastUsedAt() + case account.FieldExpiresAt: + return m.ExpiresAt() + case account.FieldAutoPauseOnExpired: + return m.AutoPauseOnExpired() case account.FieldSchedulable: return m.Schedulable() case account.FieldRateLimitedAt: @@ -2427,6 +2524,10 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldErrorMessage(ctx) case account.FieldLastUsedAt: return m.OldLastUsedAt(ctx) + case account.FieldExpiresAt: + return m.OldExpiresAt(ctx) + case account.FieldAutoPauseOnExpired: + return m.OldAutoPauseOnExpired(ctx) case account.FieldSchedulable: return m.OldSchedulable(ctx) case account.FieldRateLimitedAt: @@ -2555,6 +2656,20 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error { } m.SetLastUsedAt(v) return nil + case account.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + case account.FieldAutoPauseOnExpired: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAutoPauseOnExpired(v) + return nil case account.FieldSchedulable: v, ok := value.(bool) if !ok { @@ -2676,6 +2791,9 @@ func (m *AccountMutation) ClearedFields() []string { if m.FieldCleared(account.FieldLastUsedAt) { fields = append(fields, account.FieldLastUsedAt) } + if m.FieldCleared(account.FieldExpiresAt) { + fields = append(fields, account.FieldExpiresAt) + } if m.FieldCleared(account.FieldRateLimitedAt) { fields = append(fields, account.FieldRateLimitedAt) } @@ -2723,6 +2841,9 @@ func (m *AccountMutation) ClearField(name string) error { case account.FieldLastUsedAt: m.ClearLastUsedAt() return nil + case account.FieldExpiresAt: + m.ClearExpiresAt() + return nil case account.FieldRateLimitedAt: m.ClearRateLimitedAt() return nil @@ -2794,6 +2915,12 @@ func (m *AccountMutation) ResetField(name string) error { case account.FieldLastUsedAt: m.ResetLastUsedAt() return nil + case account.FieldExpiresAt: + m.ResetExpiresAt() + return nil + case account.FieldAutoPauseOnExpired: + m.ResetAutoPauseOnExpired() + return nil case account.FieldSchedulable: m.ResetSchedulable() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index e2cb6a3c..5fe8d905 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -181,12 +181,16 @@ func init() { account.DefaultStatus = accountDescStatus.Default.(string) // account.StatusValidator is a validator for the "status" field. It is called by the builders before save. account.StatusValidator = accountDescStatus.Validators[0].(func(string) error) + // accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field. + accountDescAutoPauseOnExpired := accountFields[13].Descriptor() + // account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field. + account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool) // accountDescSchedulable is the schema descriptor for schedulable field. - accountDescSchedulable := accountFields[12].Descriptor() + accountDescSchedulable := accountFields[14].Descriptor() // account.DefaultSchedulable holds the default value on creation for the schedulable field. account.DefaultSchedulable = accountDescSchedulable.Default.(bool) // accountDescSessionWindowStatus is the schema descriptor for session_window_status field. - accountDescSessionWindowStatus := accountFields[18].Descriptor() + accountDescSessionWindowStatus := accountFields[20].Descriptor() // account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save. account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error) accountgroupFields := schema.AccountGroup{}.Fields() diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go index 55c75f28..ec192a97 100644 --- a/backend/ent/schema/account.go +++ b/backend/ent/schema/account.go @@ -118,6 +118,16 @@ func (Account) Fields() []ent.Field { Optional(). Nillable(). SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + // expires_at: 账户过期时间(可为空) + field.Time("expires_at"). + Optional(). + Nillable(). + Comment("Account expiration time (NULL means no expiration)."). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + // auto_pause_on_expired: 过期后自动暂停调度 + field.Bool("auto_pause_on_expired"). + Default(true). + Comment("Auto pause scheduling when account expires."), // ========== 调度和速率限制相关字段 ========== // 这些字段在 migrations/005_schema_parity.sql 中添加 diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 4303e020..da9f6990 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -85,6 +85,8 @@ type CreateAccountRequest struct { Concurrency int `json:"concurrency"` Priority int `json:"priority"` GroupIDs []int64 `json:"group_ids"` + ExpiresAt *int64 `json:"expires_at"` + AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险 } @@ -101,6 +103,8 @@ type UpdateAccountRequest struct { Priority *int `json:"priority"` Status string `json:"status" binding:"omitempty,oneof=active inactive"` GroupIDs *[]int64 `json:"group_ids"` + ExpiresAt *int64 `json:"expires_at"` + AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险 } @@ -204,6 +208,8 @@ func (h *AccountHandler) Create(c *gin.Context) { Concurrency: req.Concurrency, Priority: req.Priority, GroupIDs: req.GroupIDs, + ExpiresAt: req.ExpiresAt, + AutoPauseOnExpired: req.AutoPauseOnExpired, SkipMixedChannelCheck: skipCheck, }) if err != nil { @@ -261,6 +267,8 @@ func (h *AccountHandler) Update(c *gin.Context) { Priority: req.Priority, // 指针类型,nil 表示未提供 Status: req.Status, GroupIDs: req.GroupIDs, + ExpiresAt: req.ExpiresAt, + AutoPauseOnExpired: req.AutoPauseOnExpired, SkipMixedChannelCheck: skipCheck, }) if err != nil { diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index d937ed77..764a4132 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -1,7 +1,11 @@ // Package dto provides data transfer objects for HTTP handlers. package dto -import "github.com/Wei-Shaw/sub2api/internal/service" +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) func UserFromServiceShallow(u *service.User) *User { if u == nil { @@ -120,6 +124,8 @@ func AccountFromServiceShallow(a *service.Account) *Account { Status: a.Status, ErrorMessage: a.ErrorMessage, LastUsedAt: a.LastUsedAt, + ExpiresAt: timeToUnixSeconds(a.ExpiresAt), + AutoPauseOnExpired: a.AutoPauseOnExpired, CreatedAt: a.CreatedAt, UpdatedAt: a.UpdatedAt, Schedulable: a.Schedulable, @@ -157,6 +163,14 @@ func AccountFromService(a *service.Account) *Account { return out } +func timeToUnixSeconds(value *time.Time) *int64 { + if value == nil { + return nil + } + ts := value.Unix() + return &ts +} + func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup { if ag == nil { return nil diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index a8761f81..a11662fe 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -60,21 +60,23 @@ type Group struct { } type Account struct { - ID int64 `json:"id"` - Name string `json:"name"` - Notes *string `json:"notes"` - Platform string `json:"platform"` - Type string `json:"type"` - Credentials map[string]any `json:"credentials"` - Extra map[string]any `json:"extra"` - ProxyID *int64 `json:"proxy_id"` - Concurrency int `json:"concurrency"` - Priority int `json:"priority"` - Status string `json:"status"` - ErrorMessage string `json:"error_message"` - LastUsedAt *time.Time `json:"last_used_at"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int64 `json:"id"` + Name string `json:"name"` + Notes *string `json:"notes"` + Platform string `json:"platform"` + Type string `json:"type"` + Credentials map[string]any `json:"credentials"` + Extra map[string]any `json:"extra"` + ProxyID *int64 `json:"proxy_id"` + Concurrency int `json:"concurrency"` + Priority int `json:"priority"` + Status string `json:"status"` + ErrorMessage string `json:"error_message"` + LastUsedAt *time.Time `json:"last_used_at"` + ExpiresAt *int64 `json:"expires_at"` + AutoPauseOnExpired bool `json:"auto_pause_on_expired"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` Schedulable bool `json:"schedulable"` diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 1073ae0d..83f02608 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -76,7 +76,8 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account SetPriority(account.Priority). SetStatus(account.Status). SetErrorMessage(account.ErrorMessage). - SetSchedulable(account.Schedulable) + SetSchedulable(account.Schedulable). + SetAutoPauseOnExpired(account.AutoPauseOnExpired) if account.ProxyID != nil { builder.SetProxyID(*account.ProxyID) @@ -84,6 +85,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account if account.LastUsedAt != nil { builder.SetLastUsedAt(*account.LastUsedAt) } + if account.ExpiresAt != nil { + builder.SetExpiresAt(*account.ExpiresAt) + } if account.RateLimitedAt != nil { builder.SetRateLimitedAt(*account.RateLimitedAt) } @@ -280,7 +284,8 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account SetPriority(account.Priority). SetStatus(account.Status). SetErrorMessage(account.ErrorMessage). - SetSchedulable(account.Schedulable) + SetSchedulable(account.Schedulable). + SetAutoPauseOnExpired(account.AutoPauseOnExpired) if account.ProxyID != nil { builder.SetProxyID(*account.ProxyID) @@ -292,6 +297,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account } else { builder.ClearLastUsedAt() } + if account.ExpiresAt != nil { + builder.SetExpiresAt(*account.ExpiresAt) + } else { + builder.ClearExpiresAt() + } if account.RateLimitedAt != nil { builder.SetRateLimitedAt(*account.RateLimitedAt) } else { @@ -570,6 +580,7 @@ func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Acco dbaccount.StatusEQ(service.StatusActive), dbaccount.SchedulableEQ(true), tempUnschedulablePredicate(), + notExpiredPredicate(now), dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), ). @@ -596,6 +607,7 @@ func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platf dbaccount.StatusEQ(service.StatusActive), dbaccount.SchedulableEQ(true), tempUnschedulablePredicate(), + notExpiredPredicate(now), dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), ). @@ -629,6 +641,7 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat dbaccount.StatusEQ(service.StatusActive), dbaccount.SchedulableEQ(true), tempUnschedulablePredicate(), + notExpiredPredicate(now), dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), ). @@ -727,6 +740,27 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu return err } +func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) { + result, err := r.sql.ExecContext(ctx, ` + UPDATE accounts + SET schedulable = FALSE, + updated_at = NOW() + WHERE deleted_at IS NULL + AND schedulable = TRUE + AND auto_pause_on_expired = TRUE + AND expires_at IS NOT NULL + AND expires_at <= $1 + `, now) + if err != nil { + return 0, err + } + rows, err := result.RowsAffected() + if err != nil { + return 0, err + } + return rows, nil +} + func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { if len(updates) == 0 { return nil @@ -861,6 +895,7 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in preds = append(preds, dbaccount.SchedulableEQ(true), tempUnschedulablePredicate(), + notExpiredPredicate(now), dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), ) @@ -971,6 +1006,14 @@ func tempUnschedulablePredicate() dbpredicate.Account { }) } +func notExpiredPredicate(now time.Time) dbpredicate.Account { + return dbaccount.Or( + dbaccount.ExpiresAtIsNil(), + dbaccount.ExpiresAtGT(now), + dbaccount.AutoPauseOnExpiredEQ(false), + ) +} + func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) { out := make(map[int64]tempUnschedSnapshot) if len(accountIDs) == 0 { @@ -1086,6 +1129,8 @@ func accountEntityToService(m *dbent.Account) *service.Account { Status: m.Status, ErrorMessage: derefString(m.ErrorMessage), LastUsedAt: m.LastUsedAt, + ExpiresAt: m.ExpiresAt, + AutoPauseOnExpired: m.AutoPauseOnExpired, CreatedAt: m.CreatedAt, UpdatedAt: m.UpdatedAt, Schedulable: m.Schedulable, diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index eb765988..cfce9bfa 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -9,21 +9,23 @@ import ( ) type Account struct { - ID int64 - Name string - Notes *string - Platform string - Type string - Credentials map[string]any - Extra map[string]any - ProxyID *int64 - Concurrency int - Priority int - Status string - ErrorMessage string - LastUsedAt *time.Time - CreatedAt time.Time - UpdatedAt time.Time + ID int64 + Name string + Notes *string + Platform string + Type string + Credentials map[string]any + Extra map[string]any + ProxyID *int64 + Concurrency int + Priority int + Status string + ErrorMessage string + LastUsedAt *time.Time + ExpiresAt *time.Time + AutoPauseOnExpired bool + CreatedAt time.Time + UpdatedAt time.Time Schedulable bool @@ -60,6 +62,9 @@ func (a *Account) IsSchedulable() bool { return false } now := time.Now() + if a.AutoPauseOnExpired && a.ExpiresAt != nil && !now.Before(*a.ExpiresAt) { + return false + } if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) { return false } diff --git a/backend/internal/service/account_expiry_service.go b/backend/internal/service/account_expiry_service.go new file mode 100644 index 00000000..eaada11c --- /dev/null +++ b/backend/internal/service/account_expiry_service.go @@ -0,0 +1,71 @@ +package service + +import ( + "context" + "log" + "sync" + "time" +) + +// AccountExpiryService periodically pauses expired accounts when auto-pause is enabled. +type AccountExpiryService struct { + accountRepo AccountRepository + interval time.Duration + stopCh chan struct{} + stopOnce sync.Once + wg sync.WaitGroup +} + +func NewAccountExpiryService(accountRepo AccountRepository, interval time.Duration) *AccountExpiryService { + return &AccountExpiryService{ + accountRepo: accountRepo, + interval: interval, + stopCh: make(chan struct{}), + } +} + +func (s *AccountExpiryService) Start() { + if s == nil || s.accountRepo == nil || s.interval <= 0 { + return + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + ticker := time.NewTicker(s.interval) + defer ticker.Stop() + + s.runOnce() + for { + select { + case <-ticker.C: + s.runOnce() + case <-s.stopCh: + return + } + } + }() +} + +func (s *AccountExpiryService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + close(s.stopCh) + }) + s.wg.Wait() +} + +func (s *AccountExpiryService) runOnce() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + updated, err := s.accountRepo.AutoPauseExpiredAccounts(ctx, time.Now()) + if err != nil { + log.Printf("[AccountExpiry] Auto pause expired accounts failed: %v", err) + return + } + if updated > 0 { + log.Printf("[AccountExpiry] Auto paused %d expired accounts", updated) + } +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index c84cb5e9..e1b93fcb 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -38,6 +38,7 @@ type AccountRepository interface { BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error SetError(ctx context.Context, id int64, errorMsg string) error SetSchedulable(ctx context.Context, id int64, schedulable bool) error + AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error ListSchedulable(ctx context.Context) ([]Account, error) @@ -71,29 +72,33 @@ type AccountBulkUpdate struct { // CreateAccountRequest 创建账号请求 type CreateAccountRequest struct { - Name string `json:"name"` - Notes *string `json:"notes"` - Platform string `json:"platform"` - Type string `json:"type"` - Credentials map[string]any `json:"credentials"` - Extra map[string]any `json:"extra"` - ProxyID *int64 `json:"proxy_id"` - Concurrency int `json:"concurrency"` - Priority int `json:"priority"` - GroupIDs []int64 `json:"group_ids"` + Name string `json:"name"` + Notes *string `json:"notes"` + Platform string `json:"platform"` + Type string `json:"type"` + Credentials map[string]any `json:"credentials"` + Extra map[string]any `json:"extra"` + ProxyID *int64 `json:"proxy_id"` + Concurrency int `json:"concurrency"` + Priority int `json:"priority"` + GroupIDs []int64 `json:"group_ids"` + ExpiresAt *time.Time `json:"expires_at"` + AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` } // UpdateAccountRequest 更新账号请求 type UpdateAccountRequest struct { - Name *string `json:"name"` - Notes *string `json:"notes"` - Credentials *map[string]any `json:"credentials"` - Extra *map[string]any `json:"extra"` - ProxyID *int64 `json:"proxy_id"` - Concurrency *int `json:"concurrency"` - Priority *int `json:"priority"` - Status *string `json:"status"` - GroupIDs *[]int64 `json:"group_ids"` + Name *string `json:"name"` + Notes *string `json:"notes"` + Credentials *map[string]any `json:"credentials"` + Extra *map[string]any `json:"extra"` + ProxyID *int64 `json:"proxy_id"` + Concurrency *int `json:"concurrency"` + Priority *int `json:"priority"` + Status *string `json:"status"` + GroupIDs *[]int64 `json:"group_ids"` + ExpiresAt *time.Time `json:"expires_at"` + AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` } // AccountService 账号管理服务 @@ -134,6 +139,12 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( Concurrency: req.Concurrency, Priority: req.Priority, Status: StatusActive, + ExpiresAt: req.ExpiresAt, + } + if req.AutoPauseOnExpired != nil { + account.AutoPauseOnExpired = *req.AutoPauseOnExpired + } else { + account.AutoPauseOnExpired = true } if err := s.accountRepo.Create(ctx, account); err != nil { @@ -224,6 +235,12 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount if req.Status != nil { account.Status = *req.Status } + if req.ExpiresAt != nil { + account.ExpiresAt = req.ExpiresAt + } + if req.AutoPauseOnExpired != nil { + account.AutoPauseOnExpired = *req.AutoPauseOnExpired + } // 先验证分组是否存在(在任何写操作之前) if req.GroupIDs != nil { diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index 974a515c..edad8672 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -103,6 +103,10 @@ func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedula panic("unexpected SetSchedulable call") } +func (s *accountRepoStub) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) { + panic("unexpected AutoPauseExpiredAccounts call") +} + func (s *accountRepoStub) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { panic("unexpected BindGroups call") } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 0eacfd16..80acd440 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -122,16 +122,18 @@ type UpdateGroupInput struct { } type CreateAccountInput struct { - Name string - Notes *string - Platform string - Type string - Credentials map[string]any - Extra map[string]any - ProxyID *int64 - Concurrency int - Priority int - GroupIDs []int64 + Name string + Notes *string + Platform string + Type string + Credentials map[string]any + Extra map[string]any + ProxyID *int64 + Concurrency int + Priority int + GroupIDs []int64 + ExpiresAt *int64 + AutoPauseOnExpired *bool // SkipMixedChannelCheck skips the mixed channel risk check when binding groups. // This should only be set when the caller has explicitly confirmed the risk. SkipMixedChannelCheck bool @@ -148,6 +150,8 @@ type UpdateAccountInput struct { Priority *int // 使用指针区分"未提供"和"设置为0" Status string GroupIDs *[]int64 + ExpiresAt *int64 + AutoPauseOnExpired *bool SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险) } @@ -700,6 +704,15 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou Status: StatusActive, Schedulable: true, } + if input.ExpiresAt != nil && *input.ExpiresAt > 0 { + expiresAt := time.Unix(*input.ExpiresAt, 0) + account.ExpiresAt = &expiresAt + } + if input.AutoPauseOnExpired != nil { + account.AutoPauseOnExpired = *input.AutoPauseOnExpired + } else { + account.AutoPauseOnExpired = true + } if err := s.accountRepo.Create(ctx, account); err != nil { return nil, err } @@ -755,6 +768,17 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U if input.Status != "" { account.Status = input.Status } + if input.ExpiresAt != nil { + if *input.ExpiresAt <= 0 { + account.ExpiresAt = nil + } else { + expiresAt := time.Unix(*input.ExpiresAt, 0) + account.ExpiresAt = &expiresAt + } + } + if input.AutoPauseOnExpired != nil { + account.AutoPauseOnExpired = *input.AutoPauseOnExpired + } // 先验证分组是否存在(在任何写操作之前) if input.GroupIDs != nil { diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 6c8198b2..47279581 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, err func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { return nil } +func (m *mockAccountRepoForPlatform) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) { + return 0, nil +} func (m *mockAccountRepoForPlatform) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { return nil } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 0a434835..5070b510 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -90,6 +90,9 @@ func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, error func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { return nil } +func (m *mockAccountRepoForGemini) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) { + return 0, nil +} func (m *mockAccountRepoForGemini) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { return nil } diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index d4b984d6..cb73409b 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -47,6 +47,13 @@ func ProvideTokenRefreshService( return svc } +// ProvideAccountExpiryService creates and starts AccountExpiryService. +func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService { + svc := NewAccountExpiryService(accountRepo, time.Minute) + svc.Start() + return svc +} + // ProvideTimingWheelService creates and starts TimingWheelService func ProvideTimingWheelService() *TimingWheelService { svc := NewTimingWheelService() @@ -110,6 +117,7 @@ var ProviderSet = wire.NewSet( NewCRSSyncService, ProvideUpdateService, ProvideTokenRefreshService, + ProvideAccountExpiryService, ProvideTimingWheelService, ProvideDeferredService, NewAntigravityQuotaFetcher, diff --git a/backend/migrations/030_add_account_expires_at.sql b/backend/migrations/030_add_account_expires_at.sql new file mode 100644 index 00000000..905220e9 --- /dev/null +++ b/backend/migrations/030_add_account_expires_at.sql @@ -0,0 +1,10 @@ +-- Add expires_at for account expiration configuration +ALTER TABLE accounts ADD COLUMN IF NOT EXISTS expires_at timestamptz; +-- Document expires_at meaning +COMMENT ON COLUMN accounts.expires_at IS 'Account expiration time (NULL means no expiration).'; +-- Add auto_pause_on_expired for account expiration scheduling control +ALTER TABLE accounts ADD COLUMN IF NOT EXISTS auto_pause_on_expired boolean NOT NULL DEFAULT true; +-- Document auto_pause_on_expired meaning +COMMENT ON COLUMN accounts.auto_pause_on_expired IS 'Auto pause scheduling when account expires.'; +-- Ensure existing accounts are enabled by default +UPDATE accounts SET auto_pause_on_expired = true; diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 0091873c..e90bec6c 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1012,7 +1012,7 @@
-
+
@@ -1213,46 +1213,81 @@

{{ t('admin.accounts.priorityHint') }}

+
+ + +

{{ t('admin.accounts.expiresAtHint') }}

+
- -
- -
- - ? - - -
- {{ t('admin.accounts.mixedSchedulingTooltip') }} -
+
+
+
+ +

+ {{ t('admin.accounts.autoPauseOnExpiredDesc') }} +

+
- - +
+ +
+ +
+ + ? + + +
+ {{ t('admin.accounts.mixedSchedulingTooltip') }} +
+
+
+
+ + + +
@@ -1598,6 +1633,7 @@ import Icon from '@/components/icons/Icon.vue' import ProxySelector from '@/components/common/ProxySelector.vue' import GroupSelector from '@/components/common/GroupSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' +import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue' // Type for exposed OAuthAuthorizationFlow component @@ -1713,6 +1749,7 @@ const customErrorCodesEnabled = ref(false) const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) const interceptWarmupRequests = ref(false) +const autoPauseOnExpired = ref(true) const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) @@ -1795,7 +1832,8 @@ const form = reactive({ proxy_id: null as number | null, concurrency: 10, priority: 1, - group_ids: [] as number[] + group_ids: [] as number[], + expires_at: null as number | null }) // Helper to check if current type needs OAuth flow @@ -1805,6 +1843,13 @@ const isManualInputMethod = computed(() => { return oauthFlowRef.value?.inputMethod === 'manual' }) +const expiresAtInput = computed({ + get: () => formatDateTimeLocal(form.expires_at), + set: (value: string) => { + form.expires_at = parseDateTimeLocal(value) + } +}) + const canExchangeCode = computed(() => { const authCode = oauthFlowRef.value?.authCode || '' if (form.platform === 'openai') { @@ -2055,6 +2100,7 @@ const resetForm = () => { form.concurrency = 10 form.priority = 1 form.group_ids = [] + form.expires_at = null accountCategory.value = 'oauth-based' addMethod.value = 'oauth' apiKeyBaseUrl.value = 'https://api.anthropic.com' @@ -2066,6 +2112,7 @@ const resetForm = () => { selectedErrorCodes.value = [] customErrorCodeInput.value = null interceptWarmupRequests.value = false + autoPauseOnExpired.value = true tempUnschedEnabled.value = false tempUnschedRules.value = [] geminiOAuthType.value = 'code_assist' @@ -2133,7 +2180,6 @@ const handleSubmit = async () => { if (interceptWarmupRequests.value) { credentials.intercept_warmup_requests = true } - if (!applyTempUnschedConfig(credentials)) { return } @@ -2144,7 +2190,8 @@ const handleSubmit = async () => { try { await adminAPI.accounts.create({ ...form, - group_ids: form.group_ids + group_ids: form.group_ids, + auto_pause_on_expired: autoPauseOnExpired.value }) appStore.showSuccess(t('admin.accounts.accountCreated')) emit('created') @@ -2182,6 +2229,9 @@ const handleGenerateUrl = async () => { } } +const formatDateTimeLocal = formatDateTimeLocalInput +const parseDateTimeLocal = parseDateTimeLocalInput + // Create account and handle success/failure const createAccountAndFinish = async ( platform: AccountPlatform, @@ -2202,7 +2252,9 @@ const createAccountAndFinish = async ( proxy_id: form.proxy_id, concurrency: form.concurrency, priority: form.priority, - group_ids: form.group_ids + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value }) appStore.showSuccess(t('admin.accounts.accountCreated')) emit('created') @@ -2416,7 +2468,8 @@ const handleCookieAuth = async (sessionKey: string) => { extra, proxy_id: form.proxy_id, concurrency: form.concurrency, - priority: form.priority + priority: form.priority, + auto_pause_on_expired: autoPauseOnExpired.value }) successCount++ diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 3f47ee31..3b36cfbf 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -365,7 +365,7 @@
-
+
@@ -565,39 +565,74 @@ />
- -
- - +

{{ t('admin.accounts.expiresAtHint') }}

- -
- -
- +
+
+ +

+ {{ t('admin.accounts.autoPauseOnExpiredDesc') }} +

+
+ +
+
+ +
+
+ + + + {{ t('admin.accounts.mixedScheduling') }} + + +
+ + ? + +
+ class="pointer-events-none absolute left-0 top-full z-[100] mt-1.5 w-72 rounded bg-gray-900 px-3 py-2 text-xs text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700" + > + {{ t('admin.accounts.mixedSchedulingTooltip') }} +
+
@@ -666,6 +701,7 @@ import Icon from '@/components/icons/Icon.vue' import ProxySelector from '@/components/common/ProxySelector.vue' import GroupSelector from '@/components/common/GroupSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' +import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { getPresetMappingsByPlatform, commonErrorCodes, @@ -721,6 +757,7 @@ const customErrorCodesEnabled = ref(false) const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) const interceptWarmupRequests = ref(false) +const autoPauseOnExpired = ref(false) const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) @@ -771,7 +808,8 @@ const form = reactive({ concurrency: 1, priority: 1, status: 'active' as 'active' | 'inactive', - group_ids: [] as number[] + group_ids: [] as number[], + expires_at: null as number | null }) const statusOptions = computed(() => [ @@ -779,6 +817,13 @@ const statusOptions = computed(() => [ { value: 'inactive', label: t('common.inactive') } ]) +const expiresAtInput = computed({ + get: () => formatDateTimeLocal(form.expires_at), + set: (value: string) => { + form.expires_at = parseDateTimeLocal(value) + } +}) + // Watchers watch( () => props.account, @@ -791,10 +836,12 @@ watch( form.priority = newAccount.priority form.status = newAccount.status as 'active' | 'inactive' form.group_ids = newAccount.group_ids || [] + form.expires_at = newAccount.expires_at ?? null // Load intercept warmup requests setting (applies to all account types) const credentials = newAccount.credentials as Record | undefined interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true + autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true // Load mixed scheduling setting (only for antigravity accounts) const extra = newAccount.extra as Record | undefined @@ -1042,6 +1089,9 @@ function toPositiveNumber(value: unknown) { return Math.trunc(num) } +const formatDateTimeLocal = formatDateTimeLocalInput +const parseDateTimeLocal = parseDateTimeLocalInput + // Methods const handleClose = () => { emit('close') @@ -1057,6 +1107,10 @@ const handleSubmit = async () => { if (updatePayload.proxy_id === null) { updatePayload.proxy_id = 0 } + if (form.expires_at === null) { + updatePayload.expires_at = 0 + } + updatePayload.auto_pause_on_expired = autoPauseOnExpired.value // For apikey type, handle credentials update if (props.account.type === 'apikey') { @@ -1097,7 +1151,6 @@ const handleSubmit = async () => { if (interceptWarmupRequests.value) { newCredentials.intercept_warmup_requests = true } - if (!applyTempUnschedConfig(newCredentials)) { submitting.value = false return @@ -1114,7 +1167,6 @@ const handleSubmit = async () => { } else { delete newCredentials.intercept_warmup_requests } - if (!applyTempUnschedConfig(newCredentials)) { submitting.value = false return diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 4634d8b6..97321ca6 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1011,6 +1011,7 @@ export default { groups: 'Groups', usageWindows: 'Usage Windows', lastUsed: 'Last Used', + expiresAt: 'Expires At', actions: 'Actions' }, tempUnschedulable: { @@ -1152,11 +1153,16 @@ export default { interceptWarmupRequests: 'Intercept Warmup Requests', interceptWarmupRequestsDesc: 'When enabled, warmup requests like title generation will return mock responses without consuming upstream tokens', + autoPauseOnExpired: 'Auto Pause On Expired', + autoPauseOnExpiredDesc: 'When enabled, the account will auto pause scheduling after it expires', + expired: 'Expired', proxy: 'Proxy', noProxy: 'No Proxy', concurrency: 'Concurrency', priority: 'Priority', priorityHint: 'Higher priority accounts are used first', + expiresAt: 'Expires At', + expiresAtHint: 'Leave empty for no expiration', higherPriorityFirst: 'Higher value means higher priority', mixedScheduling: 'Use in /v1/messages', mixedSchedulingHint: 'Enable to participate in Anthropic/Gemini group scheduling', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 7e326bab..3f0e2c4f 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1061,6 +1061,7 @@ export default { groups: '分组', usageWindows: '用量窗口', lastUsed: '最近使用', + expiresAt: '过期时间', actions: '操作' }, clearRateLimit: '清除速率限制', @@ -1286,11 +1287,16 @@ export default { errorCodeExists: '该错误码已被选中', interceptWarmupRequests: '拦截预热请求', interceptWarmupRequestsDesc: '启用后,标题生成等预热请求将返回 mock 响应,不消耗上游 token', + autoPauseOnExpired: '过期自动暂停调度', + autoPauseOnExpiredDesc: '启用后,账号过期将自动暂停调度', + expired: '已过期', proxy: '代理', noProxy: '无代理', concurrency: '并发数', priority: '优先级', priorityHint: '优先级越高的账号优先使用', + expiresAt: '过期时间', + expiresAtHint: '留空表示不过期', higherPriorityFirst: '数值越高优先级越高', mixedScheduling: '在 /v1/messages 中使用', mixedSchedulingHint: '启用后可参与 Anthropic/Gemini 分组的调度', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 98368b0e..b16c66ef 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -401,6 +401,8 @@ export interface Account { status: 'active' | 'inactive' | 'error' error_message: string | null last_used_at: string | null + expires_at: number | null + auto_pause_on_expired: boolean created_at: string updated_at: string proxy?: Proxy @@ -491,6 +493,8 @@ export interface CreateAccountRequest { concurrency?: number priority?: number group_ids?: number[] + expires_at?: number | null + auto_pause_on_expired?: boolean confirm_mixed_channel_risk?: boolean } @@ -506,6 +510,8 @@ export interface UpdateAccountRequest { schedulable?: boolean status?: 'active' | 'inactive' group_ids?: number[] + expires_at?: number | null + auto_pause_on_expired?: boolean confirm_mixed_channel_risk?: boolean } diff --git a/frontend/src/utils/format.ts b/frontend/src/utils/format.ts index 2dc8da4e..bdc68660 100644 --- a/frontend/src/utils/format.ts +++ b/frontend/src/utils/format.ts @@ -96,6 +96,7 @@ export function formatBytes(bytes: number, decimals: number = 2): string { * 格式化日期 * @param date 日期字符串或 Date 对象 * @param options Intl.DateTimeFormatOptions + * @param localeOverride 可选 locale 覆盖 * @returns 格式化后的日期字符串 */ export function formatDate( @@ -108,14 +109,15 @@ export function formatDate( minute: '2-digit', second: '2-digit', hour12: false - } + }, + localeOverride?: string ): string { if (!date) return '' const d = new Date(date) if (isNaN(d.getTime())) return '' - const locale = getLocale() + const locale = localeOverride ?? getLocale() return new Intl.DateTimeFormat(locale, options).format(d) } @@ -135,10 +137,41 @@ export function formatDateOnly(date: string | Date | null | undefined): string { /** * 格式化日期时间(完整格式) * @param date 日期字符串或 Date 对象 + * @param options Intl.DateTimeFormatOptions + * @param localeOverride 可选 locale 覆盖 * @returns 格式化后的日期时间字符串 */ -export function formatDateTime(date: string | Date | null | undefined): string { - return formatDate(date) +export function formatDateTime( + date: string | Date | null | undefined, + options?: Intl.DateTimeFormatOptions, + localeOverride?: string +): string { + return formatDate(date, options, localeOverride) +} + +/** + * 格式化为 datetime-local 控件值(YYYY-MM-DDTHH:mm,使用本地时间) + */ +export function formatDateTimeLocalInput(timestampSeconds: number | null): string { + if (!timestampSeconds) return '' + const date = new Date(timestampSeconds * 1000) + if (isNaN(date.getTime())) return '' + const year = date.getFullYear() + const month = String(date.getMonth() + 1).padStart(2, '0') + const day = String(date.getDate()).padStart(2, '0') + const hours = String(date.getHours()).padStart(2, '0') + const minutes = String(date.getMinutes()).padStart(2, '0') + return `${year}-${month}-${day}T${hours}:${minutes}` +} + +/** + * 解析 datetime-local 控件值为时间戳(秒,使用本地时间) + */ +export function parseDateTimeLocalInput(value: string): number | null { + if (!value) return null + const date = new Date(value) + if (isNaN(date.getTime())) return null + return Math.floor(date.getTime() / 1000) } /** diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index c95b89f3..0ca22a76 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -70,6 +70,25 @@ + - + + @@ -480,7 +485,8 @@ const columns = computed(() => [ { key: 'billing_type', label: t('usage.billingType'), sortable: false }, { key: 'first_token', label: t('usage.firstToken'), sortable: false }, { key: 'duration', label: t('usage.duration'), sortable: false }, - { key: 'created_at', label: t('usage.time'), sortable: true } + { key: 'created_at', label: t('usage.time'), sortable: true }, + { key: 'user_agent', label: t('usage.userAgent'), sortable: false } ]) const usageLogs = ref([]) @@ -545,6 +551,19 @@ const formatDuration = (ms: number): string => { return `${(ms / 1000).toFixed(2)}s` } +const formatUserAgent = (ua: string): string => { + // 提取主要客户端标识 + if (ua.includes('claude-cli')) return ua.match(/claude-cli\/[\d.]+/)?.[0] || 'Claude CLI' + if (ua.includes('Cursor')) return 'Cursor' + if (ua.includes('VSCode') || ua.includes('vscode')) return 'VS Code' + if (ua.includes('Continue')) return 'Continue' + if (ua.includes('Cline')) return 'Cline' + if (ua.includes('OpenAI')) return 'OpenAI SDK' + if (ua.includes('anthropic')) return 'Anthropic SDK' + // 截断过长的 UA + return ua.length > 30 ? ua.substring(0, 30) + '...' : ua +} + const formatTokens = (value: number): string => { if (value >= 1_000_000_000) { return `${(value / 1_000_000_000).toFixed(2)}B` From eb198e5969a02958fb3b2a4ee761c122190d8103 Mon Sep 17 00:00:00 2001 From: Edric Li Date: Thu, 8 Jan 2026 21:20:12 +0800 Subject: [PATCH 22/46] feat(proxies): add account count column to proxy list Display the number of accounts bound to each proxy in the admin proxy management page, similar to the groups list view. --- .../internal/handler/admin/proxy_handler.go | 6 +-- backend/internal/repository/proxy_repo.go | 49 +++++++++++++++++++ backend/internal/service/admin_service.go | 10 ++++ backend/internal/service/proxy_service.go | 1 + frontend/src/i18n/locales/en.ts | 1 + frontend/src/i18n/locales/zh.ts | 1 + frontend/src/views/admin/ProxiesView.vue | 9 ++++ 7 files changed, 74 insertions(+), 3 deletions(-) diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index 99557f9a..4fabd8ec 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -52,15 +52,15 @@ func (h *ProxyHandler) List(c *gin.Context) { status := c.Query("status") search := c.Query("search") - proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search) + proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search) if err != nil { response.ErrorFrom(c, err) return } - out := make([]dto.Proxy, 0, len(proxies)) + out := make([]dto.ProxyWithAccountCount, 0, len(proxies)) for i := range proxies { - out = append(out, *dto.ProxyFromService(&proxies[i])) + out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i])) } response.Paginated(c, out, total, page, pageSize) } diff --git a/backend/internal/repository/proxy_repo.go b/backend/internal/repository/proxy_repo.go index c24b2e2c..622b0aeb 100644 --- a/backend/internal/repository/proxy_repo.go +++ b/backend/internal/repository/proxy_repo.go @@ -133,6 +133,55 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination return outProxies, paginationResultFromTotal(int64(total), params), nil } +// ListWithFiltersAndAccountCount lists proxies with filters and includes account count per proxy +func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) { + q := r.client.Proxy.Query() + if protocol != "" { + q = q.Where(proxy.ProtocolEQ(protocol)) + } + if status != "" { + q = q.Where(proxy.StatusEQ(status)) + } + if search != "" { + q = q.Where(proxy.NameContainsFold(search)) + } + + total, err := q.Count(ctx) + if err != nil { + return nil, nil, err + } + + proxies, err := q. + Offset(params.Offset()). + Limit(params.Limit()). + Order(dbent.Desc(proxy.FieldID)). + All(ctx) + if err != nil { + return nil, nil, err + } + + // Get account counts + counts, err := r.GetAccountCountsForProxies(ctx) + if err != nil { + return nil, nil, err + } + + // Build result with account counts + result := make([]service.ProxyWithAccountCount, 0, len(proxies)) + for i := range proxies { + proxyOut := proxyEntityToService(proxies[i]) + if proxyOut == nil { + continue + } + result = append(result, service.ProxyWithAccountCount{ + Proxy: *proxyOut, + AccountCount: counts[proxyOut.ID], + }) + } + + return result, paginationResultFromTotal(int64(total), params), nil +} + func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) { proxies, err := r.client.Proxy.Query(). Where(proxy.StatusEQ(service.StatusActive)). diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 80acd440..0f2cf998 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -47,6 +47,7 @@ type AdminService interface { // Proxy management ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) + ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) GetAllProxies(ctx context.Context) ([]Proxy, error) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) GetProxy(ctx context.Context, id int64) (*Proxy, error) @@ -950,6 +951,15 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, return proxies, result.Total, nil } +func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search) + if err != nil { + return nil, 0, err + } + return proxies, result.Total, nil +} + func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) { return s.proxyRepo.ListActive(ctx) } diff --git a/backend/internal/service/proxy_service.go b/backend/internal/service/proxy_service.go index 044f9ffc..58408d04 100644 --- a/backend/internal/service/proxy_service.go +++ b/backend/internal/service/proxy_service.go @@ -20,6 +20,7 @@ type ProxyRepository interface { List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) + ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) ListActive(ctx context.Context) ([]Proxy, error) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 1a24ea95..f28048e2 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1561,6 +1561,7 @@ export default { protocol: 'Protocol', address: 'Address', status: 'Status', + accounts: 'Accounts', actions: 'Actions' }, testConnection: 'Test Connection', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 2a58e74b..a042c1dc 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1647,6 +1647,7 @@ export default { protocol: '协议', address: '地址', status: '状态', + accounts: '账号数', actions: '操作', nameLabel: '名称', namePlaceholder: '请输入代理名称', diff --git a/frontend/src/views/admin/ProxiesView.vue b/frontend/src/views/admin/ProxiesView.vue index 6b173033..00b43ba2 100644 --- a/frontend/src/views/admin/ProxiesView.vue +++ b/frontend/src/views/admin/ProxiesView.vue @@ -85,6 +85,14 @@ + + diff --git a/frontend/src/views/admin/RedeemView.vue b/frontend/src/views/admin/RedeemView.vue index 3503cd16..50c55ba3 100644 --- a/frontend/src/views/admin/RedeemView.vue +++ b/frontend/src/views/admin/RedeemView.vue @@ -364,7 +364,7 @@ diff --git a/frontend/src/views/admin/UsersView.vue b/frontend/src/views/admin/UsersView.vue index d4bf555c..f1fc9e1f 100644 --- a/frontend/src/views/admin/UsersView.vue +++ b/frontend/src/views/admin/UsersView.vue @@ -943,6 +943,7 @@ const toggleBuiltInFilter = (key: string) => { visibleFilters.add(key) } saveFiltersToStorage() + pagination.page = 1 loadUsers() } @@ -957,6 +958,7 @@ const toggleAttributeFilter = (attr: UserAttributeDefinition) => { activeAttributeFilters[attr.id] = '' } saveFiltersToStorage() + pagination.page = 1 loadUsers() } @@ -1059,5 +1061,7 @@ onMounted(async () => { onUnmounted(() => { document.removeEventListener('click', handleClickOutside) + clearTimeout(searchTimeout) + abortController?.abort() }) From 514f5802b54b18aa0211d5d3ba713bd5a3febb6e Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Fri, 9 Jan 2026 17:58:21 +0800 Subject: [PATCH 36/46] =?UTF-8?q?fix(fe):=20=E4=BF=AE=E5=A4=8D=E4=B8=AD?= =?UTF-8?q?=E4=BC=98=E5=85=88=E7=BA=A7=E8=A1=A8=E6=A0=BC=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复的问题: 1. **搜索和筛选防抖不同步**(AccountsView.vue) - 问题:筛选器使用 reload(立即),搜索使用 debouncedReload(300ms延迟) - 修复:统一使用 debouncedReload,避免多余的API调用 2. **useTableLoader 竞态条件**(useTableLoader.ts) - 问题:finally 块检查 signal.aborted 而不是 controller 实例 - 修复:检查 abortController === currentController 3. **改进错误处理**(UsersView.vue) - 添加详细错误消息:error.response?.data?.detail || error.message - 用户可以看到具体的错误原因而不是通用消息 4. **分页边界检查**(useTableLoader.ts, UsersView.vue) - 添加页码有效性检查:Math.max(1, Math.min(page, pagination.pages || 1)) - 防止分页越界导致显示空表 影响范围: - frontend/src/composables/useTableLoader.ts - frontend/src/views/admin/AccountsView.vue - frontend/src/views/admin/UsersView.vue 测试:✓ 前端构建测试通过 --- frontend/src/composables/useTableLoader.ts | 13 ++++++++----- frontend/src/views/admin/AccountsView.vue | 2 +- frontend/src/views/admin/UsersView.vue | 9 ++++++--- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/frontend/src/composables/useTableLoader.ts b/frontend/src/composables/useTableLoader.ts index 01703ee1..5fb6c5e0 100644 --- a/frontend/src/composables/useTableLoader.ts +++ b/frontend/src/composables/useTableLoader.ts @@ -43,7 +43,8 @@ export function useTableLoader>(options: TableL if (abortController) { abortController.abort() } - abortController = new AbortController() + const currentController = new AbortController() + abortController = currentController loading.value = true try { @@ -51,9 +52,9 @@ export function useTableLoader>(options: TableL pagination.page, pagination.page_size, toRaw(params) as P, - { signal: abortController.signal } + { signal: currentController.signal } ) - + items.value = response.items || [] pagination.total = response.total || 0 pagination.pages = response.pages || 0 @@ -63,7 +64,7 @@ export function useTableLoader>(options: TableL throw error } } finally { - if (abortController && !abortController.signal.aborted) { + if (abortController === currentController) { loading.value = false } } @@ -77,7 +78,9 @@ export function useTableLoader>(options: TableL const debouncedReload = useDebounceFn(reload, debounceMs) const handlePageChange = (page: number) => { - pagination.page = page + // 确保页码在有效范围内 + const validPage = Math.max(1, Math.min(page, pagination.pages || 1)) + pagination.page = validPage load() } diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index 560ae12b..27cd9c19 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -7,7 +7,7 @@ v-model:searchQuery="params.search" :filters="params" @update:filters="(newFilters) => Object.assign(params, newFilters)" - @change="reload" + @change="debouncedReload" @update:searchQuery="debouncedReload" /> { } } } - } catch (error) { + } catch (error: any) { const errorInfo = error as { name?: string; code?: string } if (errorInfo?.name === 'AbortError' || errorInfo?.name === 'CanceledError' || errorInfo?.code === 'ERR_CANCELED') { return } - appStore.showError(t('admin.users.failedToLoad')) + const message = error.response?.data?.detail || error.message || t('admin.users.failedToLoad') + appStore.showError(message) console.error('Error loading users:', error) } finally { if (abortController === currentAbortController) { @@ -917,7 +918,9 @@ const handleSearch = () => { } const handlePageChange = (page: number) => { - pagination.page = page + // 确保页码在有效范围内 + const validPage = Math.max(1, Math.min(page, pagination.pages || 1)) + pagination.page = validPage loadUsers() } From 152d0cdec65d81cf9fcf00e060e77b468044de9d Mon Sep 17 00:00:00 2001 From: admin Date: Fri, 9 Jan 2026 12:05:25 +0800 Subject: [PATCH 37/46] =?UTF-8?q?feat(auth):=20=E6=B7=BB=E5=8A=A0=20Linux?= =?UTF-8?q?=20DO=20Connect=20OAuth=20=E7=99=BB=E5=BD=95=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 Linux DO OAuth 配置项和环境变量支持 - 实现 OAuth 授权流程和回调处理 - 前端添加 Linux DO 登录按钮和回调页面 - 支持通过 Linux DO 账号注册/登录 - 添加相关国际化文本 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- Linux DO Connect.md | 368 +++++++++++++ backend/cmd/server/VERSION | 2 +- backend/internal/config/config.go | 215 +++++++- backend/internal/config/config_test.go | 51 ++ .../internal/handler/auth_linuxdo_oauth.go | 517 ++++++++++++++++++ .../handler/auth_linuxdo_oauth_test.go | 74 +++ backend/internal/handler/dto/settings.go | 1 + backend/internal/handler/setting_handler.go | 1 + backend/internal/server/routes/auth.go | 2 + backend/internal/service/auth_service.go | 131 +++++ .../service/auth_service_register_test.go | 8 + backend/internal/service/setting_service.go | 1 + backend/internal/service/settings_view.go | 1 + deploy/config.example.yaml | 25 + frontend/src/i18n/locales/en.ts | 9 + frontend/src/i18n/locales/zh.ts | 9 + frontend/src/router/index.ts | 9 + frontend/src/stores/app.ts | 35 +- frontend/src/stores/auth.ts | 22 + frontend/src/types/index.ts | 1 + .../src/views/auth/LinuxDoCallbackView.vue | 119 ++++ frontend/src/views/auth/LoginView.vue | 55 ++ frontend/src/views/auth/RegisterView.vue | 55 ++ 23 files changed, 1675 insertions(+), 36 deletions(-) create mode 100644 Linux DO Connect.md create mode 100644 backend/internal/handler/auth_linuxdo_oauth.go create mode 100644 backend/internal/handler/auth_linuxdo_oauth_test.go create mode 100644 frontend/src/views/auth/LinuxDoCallbackView.vue diff --git a/Linux DO Connect.md b/Linux DO Connect.md new file mode 100644 index 00000000..7ca1260f --- /dev/null +++ b/Linux DO Connect.md @@ -0,0 +1,368 @@ +# Linux DO Connect + +OAuth(Open Authorization)是一个开放的网络授权标准,目前最新版本为 OAuth 2.0。我们日常使用的第三方登录(如 Google 账号登录)就采用了该标准。OAuth 允许用户授权第三方应用访问存储在其他服务提供商(如 Google)上的信息,无需在不同平台上重复填写注册信息。用户授权后,平台可以直接访问用户的账户信息进行身份验证,而用户无需向第三方应用提供密码。 + +目前系统已实现完整的 OAuth2 授权码(code)方式鉴权,但界面等配套功能还在持续完善中。让我们一起打造一个更完善的共享方案。 + +## 基本介绍 + +这是一套标准的 OAuth2 鉴权系统,可以让开发者共享论坛的用户基本信息。 + +- 可获取字段: + +| 参数 | 说明 | +| ----------------- | ------------------------------- | +| `id` | 用户唯一标识(不可变) | +| `username` | 论坛用户名 | +| `name` | 论坛用户昵称(可变) | +| `avatar_template` | 用户头像模板URL(支持多种尺寸) | +| `active` | 账号活跃状态 | +| `trust_level` | 信任等级(0-4) | +| `silenced` | 禁言状态 | +| `external_ids` | 外部ID关联信息 | +| `api_key` | API访问密钥 | + +通过这些信息,公益网站/接口可以实现: + +1. 基于 `id` 的服务频率限制 +2. 基于 `trust_level` 的服务额度分配 +3. 基于用户信息的滥用举报机制 + +## 相关端点 + +- Authorize 端点: `https://connect.linux.do/oauth2/authorize` +- Token 端点:`https://connect.linux.do/oauth2/token` +- 用户信息 端点:`https://connect.linux.do/api/user` + +## 申请使用 + +- 访问 [Connect.Linux.Do](https://connect.linux.do/) 申请接入你的应用。 + +![linuxdoconnect_1](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_1.png&w=1080&q=75) + +- 点击 **`我的应用接入`** - **`申请新接入`**,填写相关信息。其中 **`回调地址`** 是你的应用接收用户信息的地址。 + +![linuxdoconnect_2](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_2.png&w=1080&q=75) + +- 申请成功后,你将获得 **`Client Id`** 和 **`Client Secret`**,这是你应用的唯一身份凭证。 + +![linuxdoconnect_3](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_3.png&w=1080&q=75) + +## 接入 Linux Do + +JavaScript +```JavaScript +// 安装第三方请求库(或使用原生的 Fetch API),本例中使用 axios +// npm install axios + +// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 +const axios = require('axios'); +const readline = require('readline'); + +// 配置信息(建议通过环境变量配置,避免使用硬编码) +const CLIENT_ID = '你的 Client ID'; +const CLIENT_SECRET = '你的 Client Secret'; +const REDIRECT_URI = '你的回调地址'; +const AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; +const TOKEN_URL = 'https://connect.linux.do/oauth2/token'; +const USER_INFO_URL = 'https://connect.linux.do/api/user'; + +// 第一步:生成授权 URL +function getAuthUrl() { + const params = new URLSearchParams({ + client_id: CLIENT_ID, + redirect_uri: REDIRECT_URI, + response_type: 'code', + scope: 'user' + }); + + return `${AUTH_URL}?${params.toString()}`; +} + +// 第二步:获取 code 参数 +function getCode() { + return new Promise((resolve) => { + // 本例中使用终端输入来模拟流程,仅供本地测试 + // 请在实际应用中替换为真实的处理逻辑 + const rl = readline.createInterface({ input: process.stdin, output: process.stdout }); + rl.question('从回调 URL 中提取出 code,粘贴到此处并按回车:', (answer) => { + rl.close(); + resolve(answer.trim()); + }); + }); +} + +// 第三步:使用 code 参数获取访问令牌 +async function getAccessToken(code) { + try { + const form = new URLSearchParams({ + client_id: CLIENT_ID, + client_secret: CLIENT_SECRET, + code: code, + redirect_uri: REDIRECT_URI, + grant_type: 'authorization_code' + }).toString(); + + const response = await axios.post(TOKEN_URL, form, { + // 提醒:需正确配置请求头,否则无法正常获取访问令牌 + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Accept': 'application/json' + } + }); + + return response.data; + } catch (error) { + console.error(`获取访问令牌失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); + throw error; + } +} + +// 第四步:使用访问令牌获取用户信息 +async function getUserInfo(accessToken) { + try { + const response = await axios.get(USER_INFO_URL, { + headers: { + Authorization: `Bearer ${accessToken}` + } + }); + + return response.data; + } catch (error) { + console.error(`获取用户信息失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); + throw error; + } +} + +// 主流程 +async function main() { + // 1. 生成授权 URL,前端引导用户访问授权页 + const authUrl = getAuthUrl(); + console.log(`请访问此 URL 授权:${authUrl} +`); + + // 2. 用户授权后,从回调 URL 获取 code 参数 + const code = await getCode(); + + try { + // 3. 使用 code 参数获取访问令牌 + const tokenData = await getAccessToken(code); + const accessToken = tokenData.access_token; + + // 4. 使用访问令牌获取用户信息 + if (accessToken) { + const userInfo = await getUserInfo(accessToken); + console.log(` +获取用户信息成功:${JSON.stringify(userInfo, null, 2)}`); + } else { + console.log(` +获取访问令牌失败:${JSON.stringify(tokenData)}`); + } + } catch (error) { + console.error('发生错误:', error); + } +} +``` +Python +```python +# 安装第三方请求库,本例中使用 requests +# pip install requests + +# 通过 OAuth2 获取 Linux Do 用户信息的参考流程 +import requests +import json + +# 配置信息(建议通过环境变量配置,避免使用硬编码) +CLIENT_ID = '你的 Client ID' +CLIENT_SECRET = '你的 Client Secret' +REDIRECT_URI = '你的回调地址' +AUTH_URL = 'https://connect.linux.do/oauth2/authorize' +TOKEN_URL = 'https://connect.linux.do/oauth2/token' +USER_INFO_URL = 'https://connect.linux.do/api/user' + +# 第一步:生成授权 URL +def get_auth_url(): + params = { + 'client_id': CLIENT_ID, + 'redirect_uri': REDIRECT_URI, + 'response_type': 'code', + 'scope': 'user' + } + auth_url = f"{AUTH_URL}?{'&'.join(f'{k}={v}' for k, v in params.items())}" + return auth_url + +# 第二步:获取 code 参数 +def get_code(): + # 本例中使用终端输入来模拟流程,仅供本地测试 + # 请在实际应用中替换为真实的处理逻辑 + return input('从回调 URL 中提取出 code,粘贴到此处并按回车:').strip() + +# 第三步:使用 code 参数获取访问令牌 +def get_access_token(code): + try: + data = { + 'client_id': CLIENT_ID, + 'client_secret': CLIENT_SECRET, + 'code': code, + 'redirect_uri': REDIRECT_URI, + 'grant_type': 'authorization_code' + } + # 提醒:需正确配置请求头,否则无法正常获取访问令牌 + headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Accept': 'application/json' + } + response = requests.post(TOKEN_URL, data=data, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"获取访问令牌失败:{e}") + return None + +# 第四步:使用访问令牌获取用户信息 +def get_user_info(access_token): + try: + headers = { + 'Authorization': f'Bearer {access_token}' + } + response = requests.get(USER_INFO_URL, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"获取用户信息失败:{e}") + return None + +# 主流程 +if __name__ == '__main__': + # 1. 生成授权 URL,前端引导用户访问授权页 + auth_url = get_auth_url() + print(f'请访问此 URL 授权:{auth_url} +') + + # 2. 用户授权后,从回调 URL 获取 code 参数 + code = get_code() + + # 3. 使用 code 参数获取访问令牌 + token_data = get_access_token(code) + if token_data: + access_token = token_data.get('access_token') + + # 4. 使用访问令牌获取用户信息 + if access_token: + user_info = get_user_info(access_token) + if user_info: + print(f" +获取用户信息成功:{json.dumps(user_info, indent=2)}") + else: + print(" +获取用户信息失败") + else: + print(f" +获取访问令牌失败:{json.dumps(token_data, indent=2)}") + else: + print(" +获取访问令牌失败") +``` +PHP +```php +// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 + +// 配置信息 +$CLIENT_ID = '你的 Client ID'; +$CLIENT_SECRET = '你的 Client Secret'; +$REDIRECT_URI = '你的回调地址'; +$AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; +$TOKEN_URL = 'https://connect.linux.do/oauth2/token'; +$USER_INFO_URL = 'https://connect.linux.do/api/user'; + +// 生成授权 URL +function getAuthUrl($clientId, $redirectUri) { + global $AUTH_URL; + return $AUTH_URL . '?' . http_build_query([ + 'client_id' => $clientId, + 'redirect_uri' => $redirectUri, + 'response_type' => 'code', + 'scope' => 'user' + ]); +} + +// 使用 code 参数获取用户信息(合并获取令牌和获取用户信息的步骤) +function getUserInfoWithCode($code, $clientId, $clientSecret, $redirectUri) { + global $TOKEN_URL, $USER_INFO_URL; + + // 1. 获取访问令牌 + $ch = curl_init($TOKEN_URL); + curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); + curl_setopt($ch, CURLOPT_POST, true); + curl_setopt($ch, CURLOPT_POSTFIELDS, http_build_query([ + 'client_id' => $clientId, + 'client_secret' => $clientSecret, + 'code' => $code, + 'redirect_uri' => $redirectUri, + 'grant_type' => 'authorization_code' + ])); + curl_setopt($ch, CURLOPT_HTTPHEADER, [ + 'Content-Type: application/x-www-form-urlencoded', + 'Accept: application/json' + ]); + + $tokenResponse = curl_exec($ch); + curl_close($ch); + + $tokenData = json_decode($tokenResponse, true); + if (!isset($tokenData['access_token'])) { + return ['error' => '获取访问令牌失败', 'details' => $tokenData]; + } + + // 2. 获取用户信息 + $ch = curl_init($USER_INFO_URL); + curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); + curl_setopt($ch, CURLOPT_HTTPHEADER, [ + 'Authorization: Bearer ' . $tokenData['access_token'] + ]); + + $userResponse = curl_exec($ch); + curl_close($ch); + + return json_decode($userResponse, true); +} + +// 主流程 +// 1. 生成授权 URL +$authUrl = getAuthUrl($CLIENT_ID, $REDIRECT_URI); +echo "使用 Linux Do 登录"; + +// 2. 处理回调并获取用户信息 +if (isset($_GET['code'])) { + $userInfo = getUserInfoWithCode( + $_GET['code'], + $CLIENT_ID, + $CLIENT_SECRET, + $REDIRECT_URI + ); + + if (isset($userInfo['error'])) { + echo '错误: ' . $userInfo['error']; + } else { + echo '欢迎, ' . $userInfo['name'] . '!'; + // 处理用户登录逻辑... + } +} +``` + +## 使用说明 + +### 授权流程 + +1. 用户点击应用中的’使用 Linux Do 登录’按钮 +2. 系统将用户重定向至 Linux Do 的授权页面 +3. 用户完成授权后,系统自动重定向回应用并携带授权码 +4. 应用使用授权码获取访问令牌 +5. 使用访问令牌获取用户信息 + +### 安全建议 + +- 切勿在前端代码中暴露 Client Secret +- 对所有用户输入数据进行严格验证 +- 确保使用 HTTPS 协议传输数据 +- 定期更新并妥善保管 Client Secret \ No newline at end of file diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 17e51c38..79e0dd8a 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.1 +0.1.46 diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index c1e15290..af51c8ed 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "log" + "net/url" "os" "strings" "time" @@ -35,24 +36,25 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - JWT JWTConfig `mapstructure:"jwt"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + JWT JWTConfig `mapstructure:"jwt"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } // UpdateConfig 在线更新相关配置 @@ -322,6 +324,30 @@ type TurnstileConfig struct { Required bool `mapstructure:"required"` } +// LinuxDoConnectConfig controls LinuxDo Connect OAuth login (end-user SSO). +// +// Note: This is NOT the same as upstream account OAuth (e.g. OpenAI/Gemini). +// It is used for logging in to Sub2API itself. +type LinuxDoConnectConfig struct { + Enabled bool `mapstructure:"enabled"` + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + AuthorizeURL string `mapstructure:"authorize_url"` + TokenURL string `mapstructure:"token_url"` + UserInfoURL string `mapstructure:"userinfo_url"` + Scopes string `mapstructure:"scopes"` + RedirectURL string `mapstructure:"redirect_url"` // backend callback URL registered at the provider + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // frontend route to receive token (default: /auth/linuxdo/callback) + TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none + UsePKCE bool `mapstructure:"use_pkce"` + + // Optional: gjson paths to extract fields from userinfo JSON. + // When empty, the server tries a set of common keys. + UserInfoEmailPath string `mapstructure:"userinfo_email_path"` + UserInfoIDPath string `mapstructure:"userinfo_id_path"` + UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` +} + type DefaultConfig struct { AdminEmail string `mapstructure:"admin_email"` AdminPassword string `mapstructure:"admin_password"` @@ -388,6 +414,18 @@ func Load() (*Config, error) { cfg.Server.Mode = "debug" } cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) + cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) + cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) + cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL) + cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL) + cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL) + cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes) + cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL) + cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL) + cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod)) + cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath) + cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath) + cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) @@ -426,6 +464,77 @@ func Load() (*Config, error) { return &cfg, nil } +func validateAbsoluteHTTPURL(raw string) error { + raw = strings.TrimSpace(raw) + if raw == "" { + return fmt.Errorf("empty url") + } + u, err := url.Parse(raw) + if err != nil { + return err + } + if !u.IsAbs() { + return fmt.Errorf("must be absolute") + } + if !isHTTPScheme(u.Scheme) { + return fmt.Errorf("unsupported scheme: %s", u.Scheme) + } + if strings.TrimSpace(u.Host) == "" { + return fmt.Errorf("missing host") + } + if u.Fragment != "" { + return fmt.Errorf("must not include fragment") + } + return nil +} + +func validateFrontendRedirectURL(raw string) error { + raw = strings.TrimSpace(raw) + if raw == "" { + return fmt.Errorf("empty url") + } + if strings.ContainsAny(raw, "\r\n") { + return fmt.Errorf("contains invalid characters") + } + if strings.HasPrefix(raw, "/") { + if strings.HasPrefix(raw, "//") { + return fmt.Errorf("must not start with //") + } + return nil + } + u, err := url.Parse(raw) + if err != nil { + return err + } + if !u.IsAbs() { + return fmt.Errorf("must be absolute http(s) url or relative path") + } + if !isHTTPScheme(u.Scheme) { + return fmt.Errorf("unsupported scheme: %s", u.Scheme) + } + if strings.TrimSpace(u.Host) == "" { + return fmt.Errorf("missing host") + } + if u.Fragment != "" { + return fmt.Errorf("must not include fragment") + } + return nil +} + +func isHTTPScheme(scheme string) bool { + return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https") +} + +func warnIfInsecureURL(field, raw string) { + u, err := url.Parse(strings.TrimSpace(raw)) + if err != nil { + return + } + if strings.EqualFold(u.Scheme, "http") { + log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field) + } +} + func setDefaults() { viper.SetDefault("run_mode", RunModeStandard) @@ -475,6 +584,22 @@ func setDefaults() { // Turnstile viper.SetDefault("turnstile.required", false) + // LinuxDo Connect OAuth login (end-user SSO) + viper.SetDefault("linuxdo_connect.enabled", false) + viper.SetDefault("linuxdo_connect.client_id", "") + viper.SetDefault("linuxdo_connect.client_secret", "") + viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize") + viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token") + viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user") + viper.SetDefault("linuxdo_connect.scopes", "user") + viper.SetDefault("linuxdo_connect.redirect_url", "") + viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback") + viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post") + viper.SetDefault("linuxdo_connect.use_pkce", false) + viper.SetDefault("linuxdo_connect.userinfo_email_path", "") + viper.SetDefault("linuxdo_connect.userinfo_id_path", "") + viper.SetDefault("linuxdo_connect.userinfo_username_path", "") + // Database viper.SetDefault("database.host", "localhost") viper.SetDefault("database.port", 5432) @@ -586,6 +711,60 @@ func (c *Config) Validate() error { if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" { return fmt.Errorf("security.csp.policy is required when CSP is enabled") } + if c.LinuxDo.Enabled { + if strings.TrimSpace(c.LinuxDo.ClientID) == "" { + return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" { + return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.TokenURL) == "" { + return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" { + return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" { + return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true") + } + method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod)) + switch method { + case "", "client_secret_post", "client_secret_basic", "none": + default: + return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") + } + if method == "none" && !c.LinuxDo.UsePKCE { + return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none") + } + if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" { + return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") + } + if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" { + return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true") + } + + if err := validateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil { + return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err) + } + if err := validateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil { + return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err) + } + if err := validateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil { + return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err) + } + if err := validateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil { + return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err) + } + if err := validateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil { + return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err) + } + + warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL) + warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL) + warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL) + warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL) + warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) + } if c.Billing.CircuitBreaker.Enabled { if c.Billing.CircuitBreaker.FailureThreshold <= 0 { return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive") diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index f28680c6..a39d41f9 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "strings" "testing" "time" @@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { t.Fatalf("ResponseHeaders.Enabled = true, want false") } } + +func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.LinuxDo.Enabled = true + cfg.LinuxDo.ClientID = "test-client" + cfg.LinuxDo.ClientSecret = "test-secret" + cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" + cfg.LinuxDo.TokenAuthMethod = "client_secret_post" + cfg.LinuxDo.UsePKCE = false + + cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)" + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for javascript scheme, got nil") + } + if !strings.Contains(err.Error(), "linuxdo_connect.frontend_redirect_url") { + t.Fatalf("Validate() expected frontend_redirect_url error, got: %v", err) + } +} + +func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.LinuxDo.Enabled = true + cfg.LinuxDo.ClientID = "test-client" + cfg.LinuxDo.ClientSecret = "" + cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" + cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback" + cfg.LinuxDo.TokenAuthMethod = "none" + cfg.LinuxDo.UsePKCE = false + + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil") + } + if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") { + t.Fatalf("Validate() expected use_pkce error, got: %v", err) + } +} diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go new file mode 100644 index 00000000..07310213 --- /dev/null +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -0,0 +1,517 @@ +package handler + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "log" + "net/http" + "net/url" + "strings" + "time" + "unicode/utf8" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + + "github.com/gin-gonic/gin" + "github.com/imroc/req/v3" + "github.com/tidwall/gjson" +) + +const ( + linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo" + linuxDoOAuthStateCookieName = "linuxdo_oauth_state" + linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier" + linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect" + linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes + linuxDoOAuthDefaultRedirectTo = "/dashboard" + linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback" + + linuxDoOAuthMaxRedirectLen = 2048 + linuxDoOAuthMaxFragmentValueLen = 512 + linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-") +) + +type linuxDoTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` +} + +// LinuxDoOAuthStart starts the LinuxDo Connect OAuth login flow. +// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard +func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) { + cfg, err := linuxDoOAuthConfig(h.cfg) + if err != nil { + response.ErrorFrom(c, err) + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) + if redirectTo == "" { + redirectTo = linuxDoOAuthDefaultRedirectTo + } + + secureCookie := isRequestHTTPS(c) + setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie) + setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie) + + codeChallenge := "" + if cfg.UsePKCE { + verifier, err := oauth.GenerateCodeVerifier() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) + return + } + codeChallenge = oauth.GenerateCodeChallenge(verifier) + setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) + } + + redirectURI := strings.TrimSpace(cfg.RedirectURL) + if redirectURI == "" { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")) + return + } + + authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// LinuxDoOAuthCallback handles the OAuth callback, creates/logins the user, then redirects to frontend. +// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=... +func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { + cfg, cfgErr := linuxDoOAuthConfig(h.cfg) + if cfgErr != nil { + response.ErrorFrom(c, cfgErr) + return + } + + frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL) + if frontendCallback == "" { + frontendCallback = linuxDoOAuthDefaultFrontendCB + } + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + clearCookie(c, linuxDoOAuthStateCookieName, secureCookie) + clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie) + clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName) + if err != nil || expectedState == "" || state != expectedState { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + + redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie) + redirectTo = sanitizeFrontendRedirectPath(redirectTo) + if redirectTo == "" { + redirectTo = linuxDoOAuthDefaultRedirectTo + } + + codeVerifier := "" + if cfg.UsePKCE { + codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie) + if codeVerifier == "" { + redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") + return + } + } + + redirectURI := strings.TrimSpace(cfg.RedirectURL) + if redirectURI == "" { + redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "") + return + } + + tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier) + if err != nil { + log.Printf("[LinuxDo OAuth] token exchange failed: %v", err) + redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", "") + return + } + + email, username, _, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) + if err != nil { + log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err) + redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") + return + } + + jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username) + if err != nil { + // Avoid leaking internal details to the client; keep structured reason for frontend. + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + fragment := url.Values{} + fragment.Set("access_token", jwtToken) + fragment.Set("token_type", "Bearer") + fragment.Set("redirect", redirectTo) + redirectWithFragment(c, frontendCallback, fragment) +} + +func linuxDoOAuthConfig(cfg *config.Config) (config.LinuxDoConnectConfig, error) { + if cfg == nil { + return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") + } + if !cfg.LinuxDo.Enabled { + return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") + } + return cfg.LinuxDo, nil +} + +func linuxDoExchangeCode( + ctx context.Context, + cfg config.LinuxDoConnectConfig, + code string, + redirectURI string, + codeVerifier string, +) (*linuxDoTokenResponse, error) { + client := req.C().SetTimeout(30 * time.Second) + + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", cfg.ClientID) + form.Set("code", code) + form.Set("redirect_uri", redirectURI) + if cfg.UsePKCE { + form.Set("code_verifier", codeVerifier) + } + + r := client.R(). + SetContext(ctx). + SetHeader("Accept", "application/json") + + switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) { + case "", "client_secret_post": + form.Set("client_secret", cfg.ClientSecret) + case "client_secret_basic": + r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) + case "none": + default: + return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod) + } + + var tokenResp linuxDoTokenResponse + resp, err := r.SetFormDataFromValues(form).SetSuccessResult(&tokenResp).Post(cfg.TokenURL) + if err != nil { + return nil, fmt.Errorf("request token: %w", err) + } + if !resp.IsSuccessState() { + return nil, fmt.Errorf("token exchange status=%d", resp.StatusCode) + } + if strings.TrimSpace(tokenResp.AccessToken) == "" { + return nil, errors.New("token response missing access_token") + } + if strings.TrimSpace(tokenResp.TokenType) == "" { + tokenResp.TokenType = "Bearer" + } + return &tokenResp, nil +} + +func linuxDoFetchUserInfo( + ctx context.Context, + cfg config.LinuxDoConnectConfig, + token *linuxDoTokenResponse, +) (email string, username string, subject string, err error) { + client := req.C().SetTimeout(30 * time.Second) + authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken) + if err != nil { + return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) + } + + resp, err := client.R(). + SetContext(ctx). + SetHeader("Accept", "application/json"). + SetHeader("Authorization", authorization). + Get(cfg.UserInfoURL) + if err != nil { + return "", "", "", fmt.Errorf("request userinfo: %w", err) + } + if !resp.IsSuccessState() { + return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) + } + + return linuxDoParseUserInfo(resp.String(), cfg) +} + +func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) { + email = firstNonEmpty( + getGJSON(body, cfg.UserInfoEmailPath), + getGJSON(body, "email"), + getGJSON(body, "user.email"), + getGJSON(body, "data.email"), + getGJSON(body, "attributes.email"), + ) + username = firstNonEmpty( + getGJSON(body, cfg.UserInfoUsernamePath), + getGJSON(body, "username"), + getGJSON(body, "preferred_username"), + getGJSON(body, "name"), + getGJSON(body, "user.username"), + getGJSON(body, "user.name"), + ) + subject = firstNonEmpty( + getGJSON(body, cfg.UserInfoIDPath), + getGJSON(body, "sub"), + getGJSON(body, "id"), + getGJSON(body, "user_id"), + getGJSON(body, "uid"), + getGJSON(body, "user.id"), + ) + + subject = strings.TrimSpace(subject) + if subject == "" { + return "", "", "", errors.New("userinfo missing id field") + } + if !isSafeLinuxDoSubject(subject) { + return "", "", "", errors.New("userinfo returned invalid id field") + } + + email = strings.TrimSpace(email) + if email == "" { + // LinuxDo Connect userinfo does not necessarily provide email. To keep compatibility with the + // existing user schema (email is required/unique), use a stable synthetic email. + email = fmt.Sprintf("linuxdo-%s@linuxdo-connect.invalid", subject) + } + + username = strings.TrimSpace(username) + if username == "" { + username = "linuxdo_" + subject + } + + return email, username, subject, nil +} + +func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) { + u, err := url.Parse(cfg.AuthorizeURL) + if err != nil { + return "", fmt.Errorf("parse authorize_url: %w", err) + } + + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", cfg.ClientID) + q.Set("redirect_uri", redirectURI) + if strings.TrimSpace(cfg.Scopes) != "" { + q.Set("scope", cfg.Scopes) + } + q.Set("state", state) + if cfg.UsePKCE { + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + } + + u.RawQuery = q.Encode() + return u.String(), nil +} + +func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) { + fragment := url.Values{} + fragment.Set("error", truncateFragmentValue(code)) + if strings.TrimSpace(message) != "" { + fragment.Set("error_message", truncateFragmentValue(message)) + } + if strings.TrimSpace(description) != "" { + fragment.Set("error_description", truncateFragmentValue(description)) + } + redirectWithFragment(c, frontendCallback, fragment) +} + +func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) { + u, err := url.Parse(frontendCallback) + if err != nil { + // Fallback: best-effort redirect. + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + u.Fragment = fragment.Encode() + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") + c.Redirect(http.StatusFound, u.String()) +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + v = strings.TrimSpace(v) + if v != "" { + return v + } + } + return "" +} + +func getGJSON(body string, path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + res := gjson.Get(body, path) + if !res.Exists() { + return "" + } + return res.String() +} + +func sanitizeFrontendRedirectPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + if len(path) > linuxDoOAuthMaxRedirectLen { + return "" + } + // Only allow same-origin relative paths (avoid open redirect). + if !strings.HasPrefix(path, "/") { + return "" + } + if strings.HasPrefix(path, "//") { + return "" + } + if strings.Contains(path, "://") { + return "" + } + if strings.ContainsAny(path, "\r\n") { + return "" + } + return path +} + +func isRequestHTTPS(c *gin.Context) bool { + if c.Request.TLS != nil { + return true + } + proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto"))) + return proto == "https" +} + +func encodeCookieValue(value string) string { + return base64.RawURLEncoding.EncodeToString([]byte(value)) +} + +func decodeCookieValue(value string) (string, error) { + raw, err := base64.RawURLEncoding.DecodeString(value) + if err != nil { + return "", err + } + return string(raw), nil +} + +func readCookieDecoded(c *gin.Context, name string) (string, error) { + ck, err := c.Request.Cookie(name) + if err != nil { + return "", err + } + return decodeCookieValue(ck.Value) +} + +func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: linuxDoOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: linuxDoOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func truncateFragmentValue(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + if len(value) > linuxDoOAuthMaxFragmentValueLen { + value = value[:linuxDoOAuthMaxFragmentValueLen] + for !utf8.ValidString(value) { + value = value[:len(value)-1] + } + } + return value +} + +func buildBearerAuthorization(tokenType, accessToken string) (string, error) { + tokenType = strings.TrimSpace(tokenType) + if tokenType == "" { + tokenType = "Bearer" + } + if !strings.EqualFold(tokenType, "Bearer") { + return "", fmt.Errorf("unsupported token_type: %s", tokenType) + } + + accessToken = strings.TrimSpace(accessToken) + if accessToken == "" { + return "", errors.New("missing access_token") + } + if strings.ContainsAny(accessToken, " \t\r\n") { + return "", errors.New("access_token contains whitespace") + } + return "Bearer " + accessToken, nil +} + +func isSafeLinuxDoSubject(subject string) bool { + subject = strings.TrimSpace(subject) + if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen { + return false + } + for _, r := range subject { + switch { + case r >= '0' && r <= '9': + case r >= 'a' && r <= 'z': + case r >= 'A' && r <= 'Z': + case r == '_' || r == '-': + default: + return false + } + } + return true +} diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go new file mode 100644 index 00000000..03db69a8 --- /dev/null +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -0,0 +1,74 @@ +package handler + +import ( + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestSanitizeFrontendRedirectPath(t *testing.T) { + require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath("/dashboard")) + require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath(" /dashboard ")) + require.Equal(t, "", sanitizeFrontendRedirectPath("dashboard")) + require.Equal(t, "", sanitizeFrontendRedirectPath("//evil.com")) + require.Equal(t, "", sanitizeFrontendRedirectPath("https://evil.com")) + require.Equal(t, "", sanitizeFrontendRedirectPath("/\nfoo")) + + long := "/" + strings.Repeat("a", linuxDoOAuthMaxRedirectLen) + require.Equal(t, "", sanitizeFrontendRedirectPath(long)) +} + +func TestBuildBearerAuthorization(t *testing.T) { + auth, err := buildBearerAuthorization("", "token123") + require.NoError(t, err) + require.Equal(t, "Bearer token123", auth) + + auth, err = buildBearerAuthorization("bearer", "token123") + require.NoError(t, err) + require.Equal(t, "Bearer token123", auth) + + _, err = buildBearerAuthorization("MAC", "token123") + require.Error(t, err) + + _, err = buildBearerAuthorization("Bearer", "token 123") + require.Error(t, err) +} + +func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg) + require.NoError(t, err) + require.Equal(t, "123", subject) + require.Equal(t, "alice", username) + require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) +} + +func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) + require.NoError(t, err) + require.Equal(t, "123", subject) + require.Equal(t, "linuxdo_123", username) + require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) +} + +func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) + require.Error(t, err) + + tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1) + _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) + require.Error(t, err) +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 4c50cedf..7382a577 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -50,5 +50,6 @@ type PublicSettings struct { APIBaseURL string `json:"api_base_url"` ContactInfo string `json:"contact_info"` DocURL string `json:"doc_url"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` Version string `json:"version"` } diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 3cae7a7f..e1b20c8c 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { APIBaseURL: settings.APIBaseURL, ContactInfo: settings.ContactInfo, DocURL: settings.DocURL, + LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, Version: h.version, }) } diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 196d8bdb..e61d3939 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -19,6 +19,8 @@ func RegisterAuthRoutes( auth.POST("/register", h.Auth.Register) auth.POST("/login", h.Auth.Login) auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) + auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) } // 公开设置(无需认证) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 6e685869..e3532b25 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -2,9 +2,13 @@ package service import ( "context" + "crypto/rand" + "encoding/hex" "errors" "fmt" "log" + "net/mail" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -18,6 +22,7 @@ var ( ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") + ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") @@ -27,6 +32,8 @@ var ( ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") ) +const linuxDoSyntheticEmailDomain = "@linuxdo-connect.invalid" + // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 const maxTokenLength = 8192 @@ -80,6 +87,11 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, ErrRegDisabled } + // Prevent users from registering emails reserved for synthetic OAuth accounts. + if isReservedEmail(email) { + return "", nil, ErrEmailReserved + } + // 检查是否需要邮件验证 if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { // 如果邮件验证已开启但邮件服务未配置,拒绝注册 @@ -161,6 +173,10 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { return ErrRegDisabled } + if isReservedEmail(email) { + return ErrEmailReserved + } + // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { @@ -195,6 +211,10 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S return nil, ErrRegDisabled } + if isReservedEmail(email) { + return nil, ErrEmailReserved + } + // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { @@ -319,6 +339,101 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string return token, user, nil } +// LoginOrRegisterOAuth logs a user in by email (trusted from an OAuth provider) or creates a new user. +// +// This is used by end-user OAuth/SSO login flows (e.g. LinuxDo Connect), and intentionally does +// NOT require the local password. A random password hash is generated for new users to satisfy +// the existing database constraint. +func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) { + email = strings.TrimSpace(email) + if email == "" || len(email) > 255 { + return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + if _, err := mail.ParseAddress(email); err != nil { + return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + + username = strings.TrimSpace(username) + if len([]rune(username)) > 100 { + username = string([]rune(username)[:100]) + } + + user, err := s.userRepo.GetByEmail(ctx, email) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + // Treat OAuth-first login as registration. + if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + return "", nil, ErrRegDisabled + } + + randomPassword, err := randomHexString(32) + if err != nil { + log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + return "", nil, ErrServiceUnavailable + } + hashedPassword, err := s.HashPassword(randomPassword) + if err != nil { + return "", nil, fmt.Errorf("hash password: %w", err) + } + + // Defaults for new users. + defaultBalance := s.cfg.Default.UserBalance + defaultConcurrency := s.cfg.Default.UserConcurrency + if s.settingService != nil { + defaultBalance = s.settingService.GetDefaultBalance(ctx) + defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) + } + + newUser := &User{ + Email: email, + Username: username, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: defaultBalance, + Concurrency: defaultConcurrency, + Status: StatusActive, + } + + if err := s.userRepo.Create(ctx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + // Race: user created between GetByEmail and Create. + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + log.Printf("[Auth] Database error getting user after conflict: %v", err) + return "", nil, ErrServiceUnavailable + } + } else { + log.Printf("[Auth] Database error creating oauth user: %v", err) + return "", nil, ErrServiceUnavailable + } + } else { + user = newUser + } + } else { + log.Printf("[Auth] Database error during oauth login: %v", err) + return "", nil, ErrServiceUnavailable + } + } + + if !user.IsActive() { + return "", nil, ErrUserNotActive + } + + // Best-effort: fill username when empty. + if user.Username == "" && username != "" { + user.Username = username + if err := s.userRepo.Update(ctx, user); err != nil { + log.Printf("[Auth] Failed to update username after oauth login: %v", err) + } + } + + token, err := s.GenerateToken(user) + if err != nil { + return "", nil, fmt.Errorf("generate token: %w", err) + } + return token, user, nil +} + // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 @@ -361,6 +476,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { return nil, ErrInvalidToken } +func randomHexString(byteLength int) (string, error) { + if byteLength <= 0 { + byteLength = 16 + } + buf := make([]byte, byteLength) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func isReservedEmail(email string) bool { + normalized := strings.ToLower(strings.TrimSpace(email)) + return strings.HasSuffix(normalized, linuxDoSyntheticEmailDomain) +} + // GenerateToken 生成JWT token func (s *AuthService) GenerateToken(user *User) (string, error) { now := time.Now() diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index bfd504a3..8e99ea29 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -182,6 +182,14 @@ func TestAuthService_Register_CheckEmailError(t *testing.T) { require.ErrorIs(t, err, ErrServiceUnavailable) } +func TestAuthService_Register_ReservedEmail(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, nil, nil) + + _, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password") + require.ErrorIs(t, err, ErrEmailReserved) +} + func TestAuthService_Register_CreateError(t *testing.T) { repo := &userRepoStub{createErr: errors.New("create failed")} service := newAuthService(repo, map[string]string{ diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 965253cf..b3a3bf21 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -82,6 +82,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings APIBaseURL: settings[SettingKeyAPIBaseURL], ContactInfo: settings[SettingKeyContactInfo], DocURL: settings[SettingKeyDocURL], + LinuxDoOAuthEnabled: s.cfg != nil && s.cfg.LinuxDo.Enabled, }, nil } diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index de0331f7..a06723f8 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -51,5 +51,6 @@ type PublicSettings struct { APIBaseURL string ContactInfo string DocURL string + LinuxDoOAuthEnabled bool Version string } diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 49bf0afa..936f0ea4 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -234,6 +234,31 @@ jwt: # 令牌过期时间(小时,最大 24) expire_hour: 24 +# ============================================================================= +# LinuxDo Connect OAuth Login (SSO) +# LinuxDo Connect OAuth 登录(用于 Sub2API 用户登录) +# ============================================================================= +linuxdo_connect: + enabled: false + client_id: "" + client_secret: "" + authorize_url: "https://connect.linux.do/oauth2/authorize" + token_url: "https://connect.linux.do/oauth2/token" + userinfo_url: "https://connect.linux.do/api/user" + scopes: "user" + # 示例: "https://your-domain.com/api/v1/auth/oauth/linuxdo/callback" + redirect_url: "" + # 安全提示: + # - 建议使用同源相对路径(以 / 开头),避免把 token 重定向到意外的第三方域名 + # - 该地址不应包含 #fragment(本实现使用 URL fragment 传递 access_token) + frontend_redirect_url: "/auth/linuxdo/callback" + token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none + # 注意:当 token_auth_method=none(public client)时,必须启用 PKCE + use_pkce: false + userinfo_email_path: "" + userinfo_id_path: "" + userinfo_username_path: "" + # ============================================================================= # Default Settings # 默认设置 diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 2732d84d..745445bf 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -229,6 +229,15 @@ export default { sendingCode: 'Sending...', clickToResend: 'Click to resend code', resendCode: 'Resend verification code', + linuxdo: { + signIn: 'Continue with Linux.do', + orContinue: 'or continue with email', + callbackTitle: 'Signing you in', + callbackProcessing: 'Completing login, please wait...', + callbackHint: 'If you are not redirected automatically, go back to the login page and try again.', + callbackMissingToken: 'Missing login token, please try again.', + backToLogin: 'Back to Login' + }, oauth: { code: 'Code', state: 'State', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 40aa39ab..83df3ddc 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -227,6 +227,15 @@ export default { sendingCode: '发送中...', clickToResend: '点击重新发送验证码', resendCode: '重新发送验证码', + linuxdo: { + signIn: '使用 Linux.do 登录', + orContinue: '或使用邮箱密码继续', + callbackTitle: '正在完成登录', + callbackProcessing: '正在验证登录信息,请稍候...', + callbackHint: '如果页面未自动跳转,请返回登录页重试。', + callbackMissingToken: '登录信息缺失,请返回重试。', + backToLogin: '返回登录' + }, oauth: { code: '授权码', state: '状态', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 48a6f0fd..238982ef 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -67,6 +67,15 @@ const routes: RouteRecordRaw[] = [ title: 'OAuth Callback' } }, + { + path: '/auth/linuxdo/callback', + name: 'LinuxDoOAuthCallback', + component: () => import('@/views/auth/LinuxDoCallbackView.vue'), + meta: { + requiresAuth: false, + title: 'LinuxDo OAuth Callback' + } + }, // ==================== User Routes ==================== { diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index cfc9d677..d91a9b7e 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -282,23 +282,24 @@ export const useAppStore = defineStore('app', () => { * Fetch public settings (uses cache unless force=true) * @param force - Force refresh from API */ - async function fetchPublicSettings(force = false): Promise { - // Return cached data if available and not forcing refresh - if (publicSettingsLoaded.value && !force) { - return { - registration_enabled: false, - email_verify_enabled: false, - turnstile_enabled: false, - turnstile_site_key: '', - site_name: siteName.value, - site_logo: siteLogo.value, - site_subtitle: '', - api_base_url: apiBaseUrl.value, - contact_info: contactInfo.value, - doc_url: docUrl.value, - version: siteVersion.value - } - } + async function fetchPublicSettings(force = false): Promise { + // Return cached data if available and not forcing refresh + if (publicSettingsLoaded.value && !force) { + return { + registration_enabled: false, + email_verify_enabled: false, + turnstile_enabled: false, + turnstile_site_key: '', + site_name: siteName.value, + site_logo: siteLogo.value, + site_subtitle: '', + api_base_url: apiBaseUrl.value, + contact_info: contactInfo.value, + doc_url: docUrl.value, + linuxdo_oauth_enabled: false, + version: siteVersion.value + } + } // Prevent duplicate requests if (publicSettingsLoading.value) { diff --git a/frontend/src/stores/auth.ts b/frontend/src/stores/auth.ts index 27faaf4b..3f22a9d3 100644 --- a/frontend/src/stores/auth.ts +++ b/frontend/src/stores/auth.ts @@ -159,6 +159,27 @@ export const useAuthStore = defineStore('auth', () => { } } + /** + * Set token directly (OAuth/SSO callback) and load current user profile. + * @param newToken - JWT access token issued by backend + */ + async function setToken(newToken: string): Promise { + // Clear any previous state first (avoid mixing sessions) + clearAuth() + + token.value = newToken + localStorage.setItem(AUTH_TOKEN_KEY, newToken) + + try { + const userData = await refreshUser() + startAutoRefresh() + return userData + } catch (error) { + clearAuth() + throw error + } + } + /** * User logout * Clears all authentication state and persisted data @@ -233,6 +254,7 @@ export const useAuthStore = defineStore('auth', () => { // Actions login, register, + setToken, logout, checkAuth, refreshUser diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index eaea24be..4b8fff09 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -73,6 +73,7 @@ export interface PublicSettings { api_base_url: string contact_info: string doc_url: string + linuxdo_oauth_enabled: boolean version: string } diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue new file mode 100644 index 00000000..c6f93e6b --- /dev/null +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -0,0 +1,119 @@ + + + + + + diff --git a/frontend/src/views/auth/LoginView.vue b/frontend/src/views/auth/LoginView.vue index 903db100..a6b5d2b2 100644 --- a/frontend/src/views/auth/LoginView.vue +++ b/frontend/src/views/auth/LoginView.vue @@ -11,6 +11,51 @@

+ +
+ + +
+
+ + {{ t('auth.linuxdo.orContinue') }} + +
+
+
+
@@ -179,6 +224,7 @@ const showPassword = ref(false) // Public settings const turnstileEnabled = ref(false) const turnstileSiteKey = ref('') +const linuxdoOAuthEnabled = ref(false) // Turnstile const turnstileRef = ref | null>(null) @@ -210,6 +256,7 @@ onMounted(async () => { const settings = await getPublicSettings() turnstileEnabled.value = settings.turnstile_enabled turnstileSiteKey.value = settings.turnstile_site_key || '' + linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled } catch (error) { console.error('Failed to load public settings:', error) } @@ -320,6 +367,14 @@ async function handleLogin(): Promise { isLoading.value = false } } + +function handleLinuxDoLogin(): void { + const redirectTo = (router.currentRoute.value.query.redirect as string) || '/dashboard' + const apiBase = (import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1' + const normalized = apiBase.replace(/\/$/, '') + const startURL = `${normalized}/auth/oauth/linuxdo/start?redirect=${encodeURIComponent(redirectTo)}` + window.location.href = startURL +}