From fdcbf7aacf67a343b36b8e7521a2cb27d10bd129 Mon Sep 17 00:00:00 2001 From: QTom Date: Mon, 2 Mar 2026 15:53:26 +0800 Subject: [PATCH 001/286] =?UTF-8?q?feat(proxy):=20=E9=9B=86=E4=B8=AD?= =?UTF-8?q?=E4=BB=A3=E7=90=86=20URL=20=E9=AA=8C=E8=AF=81=E5=B9=B6=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E5=85=A8=E5=B1=80=20fail-fast?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 提取 proxyurl.Parse() 公共包,将分散在 6 处的代理 URL 验证逻辑 统一收敛,确保无效代理配置在创建时立即失败,永不静默回退直连。 主要变更: - 新增 proxyurl 包:统一 TrimSpace → url.Parse → Host 校验 → Scheme 白名单 - socks5:// 自动升级为 socks5h://,防止 DNS 泄漏(大小写不敏感) - antigravity: http.ProxyURL → proxyutil.ConfigureTransportProxy 支持 SOCKS5 - openai_oauth: 删除 newOpenAIOAuthHTTPClient,收编至 httpclient.GetClient - 移除未使用的 ProxyStrict 字段(fail-fast 已是全局默认行为) - 补充 15 个 proxyurl 测试 + pricing/usage fail-fast 测试 --- backend/internal/config/config.go | 15 +- backend/internal/pkg/antigravity/client.go | 21 +- .../internal/pkg/antigravity/client_test.go | 72 +++--- backend/internal/pkg/httpclient/pool.go | 16 +- backend/internal/pkg/proxyurl/parse.go | 66 ++++++ backend/internal/pkg/proxyurl/parse_test.go | 215 ++++++++++++++++++ backend/internal/pkg/proxyutil/dialer.go | 9 +- .../repository/claude_oauth_service.go | 35 ++- .../repository/claude_oauth_service_test.go | 8 +- .../repository/claude_usage_service.go | 2 +- .../repository/claude_usage_service_test.go | 13 +- .../repository/gemini_oauth_client.go | 12 +- .../repository/geminicli_codeassist_client.go | 14 +- .../repository/github_release_service.go | 13 +- backend/internal/repository/http_upstream.go | 37 +-- .../http_upstream_benchmark_test.go | 5 +- .../internal/repository/http_upstream_test.go | 54 +++-- .../repository/openai_oauth_service.go | 12 +- .../internal/repository/pricing_service.go | 26 ++- .../repository/pricing_service_test.go | 18 +- .../repository/proxy_probe_service.go | 1 - .../internal/repository/req_client_pool.go | 18 +- .../repository/req_client_pool_test.go | 46 +++- backend/internal/repository/wire.go | 2 +- backend/internal/service/admin_service.go | 1 - .../service/antigravity_oauth_service.go | 20 +- .../service/antigravity_quota_fetcher.go | 6 +- backend/internal/service/crs_sync_service.go | 2 +- .../internal/service/gemini_oauth_service.go | 2 +- .../internal/service/openai_oauth_service.go | 23 +- deploy/config.example.yaml | 6 + 31 files changed, 633 insertions(+), 157 deletions(-) create mode 100644 backend/internal/pkg/proxyurl/parse.go create mode 100644 backend/internal/pkg/proxyurl/parse_test.go diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 4f6fea37..763ed829 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -265,8 +265,13 @@ type CSPConfig struct { } type ProxyFallbackConfig struct { - // AllowDirectOnError 当代理初始化失败时是否允许回退直连。 - // 默认 false:避免因代理配置错误导致 IP 泄露/关联。 + // AllowDirectOnError 当辅助服务的代理初始化失败时是否允许回退直连。 + // 仅影响以下非 AI 账号连接的辅助服务: + // - GitHub Release 更新检查 + // - 定价数据拉取 + // 不影响 AI 账号网关连接(Claude/OpenAI/Gemini/Antigravity), + // 这些关键路径的代理失败始终返回错误,不会回退直连。 + // 默认 false:避免因代理配置错误导致服务器真实 IP 泄露。 AllowDirectOnError bool `mapstructure:"allow_direct_on_error"` } @@ -1105,6 +1110,9 @@ func setDefaults() { viper.SetDefault("security.csp.policy", DefaultCSPPolicy) viper.SetDefault("security.proxy_probe.insecure_skip_verify", false) + // Security - disable direct fallback on proxy error + viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) + // Billing viper.SetDefault("billing.circuit_breaker.enabled", true) viper.SetDefault("billing.circuit_breaker.failure_threshold", 5) @@ -1415,9 +1423,6 @@ func setDefaults() { viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.quota.policy", "") - // Security - proxy fallback - viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) - // Subscription Maintenance (bounded queue + worker pool) viper.SetDefault("subscription_maintenance.worker_count", 2) viper.SetDefault("subscription_maintenance.queue_size", 1024) diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 1998221a..d46bbc45 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -14,6 +14,9 @@ import ( "net/url" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" ) // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) @@ -149,22 +152,26 @@ type Client struct { httpClient *http.Client } -func NewClient(proxyURL string) *Client { +func NewClient(proxyURL string) (*Client, error) { client := &http.Client{ Timeout: 30 * time.Second, } - if strings.TrimSpace(proxyURL) != "" { - if proxyURLParsed, err := url.Parse(proxyURL); err == nil { - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURLParsed), - } + _, parsed, err := proxyurl.Parse(proxyURL) + if err != nil { + return nil, err + } + if parsed != nil { + transport := &http.Transport{} + if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { + return nil, fmt.Errorf("configure proxy: %w", err) } + client.Transport = transport } return &Client{ httpClient: client, - } + }, nil } // isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) diff --git a/backend/internal/pkg/antigravity/client_test.go b/backend/internal/pkg/antigravity/client_test.go index 394b6128..20b57833 100644 --- a/backend/internal/pkg/antigravity/client_test.go +++ b/backend/internal/pkg/antigravity/client_test.go @@ -228,8 +228,20 @@ func TestGetTier_两者都为nil(t *testing.T) { // NewClient // --------------------------------------------------------------------------- +func mustNewClient(t *testing.T, proxyURL string) *Client { + t.Helper() + client, err := NewClient(proxyURL) + if err != nil { + t.Fatalf("NewClient(%q) failed: %v", proxyURL, err) + } + return client +} + func TestNewClient_无代理(t *testing.T) { - client := NewClient("") + client, err := NewClient("") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } if client == nil { t.Fatal("NewClient 返回 nil") } @@ -246,7 +258,10 @@ func TestNewClient_无代理(t *testing.T) { } func TestNewClient_有代理(t *testing.T) { - client := NewClient("http://proxy.example.com:8080") + client, err := NewClient("http://proxy.example.com:8080") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } if client == nil { t.Fatal("NewClient 返回 nil") } @@ -256,7 +271,10 @@ func TestNewClient_有代理(t *testing.T) { } func TestNewClient_空格代理(t *testing.T) { - client := NewClient(" ") + client, err := NewClient(" ") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } if client == nil { t.Fatal("NewClient 返回 nil") } @@ -267,15 +285,13 @@ func TestNewClient_空格代理(t *testing.T) { } func TestNewClient_无效代理URL(t *testing.T) { - // 无效 URL 时 url.Parse 不一定返回错误(Go 的 url.Parse 很宽容), - // 但 ://invalid 会导致解析错误 - client := NewClient("://invalid") - if client == nil { - t.Fatal("NewClient 返回 nil") + // 无效 URL 应返回 error + _, err := NewClient("://invalid") + if err == nil { + t.Fatal("无效代理 URL 应返回错误") } - // 无效 URL 解析失败时,Transport 应保持 nil - if client.httpClient.Transport != nil { - t.Error("无效代理 URL 时 Transport 应为 nil") + if !strings.Contains(err.Error(), "invalid proxy URL") { + t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error()) } } @@ -499,7 +515,7 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) { defaultClientSecret = "" t.Cleanup(func() { defaultClientSecret = old }) - client := NewClient("") + client := mustNewClient(t, "") _, err := client.ExchangeCode(context.Background(), "code", "verifier") if err == nil { t.Fatal("缺少 client_secret 时应返回错误") @@ -602,7 +618,7 @@ func TestClient_RefreshToken_无ClientSecret(t *testing.T) { defaultClientSecret = "" t.Cleanup(func() { defaultClientSecret = old }) - client := NewClient("") + client := mustNewClient(t, "") _, err := client.RefreshToken(context.Background(), "refresh-tok") if err == nil { t.Fatal("缺少 client_secret 时应返回错误") @@ -1242,7 +1258,7 @@ func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token") if err != nil { t.Fatalf("LoadCodeAssist 失败: %v", err) @@ -1277,7 +1293,7 @@ func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.LoadCodeAssist(context.Background(), "bad-token") if err == nil { t.Fatal("服务器返回 403 时应返回错误") @@ -1300,7 +1316,7 @@ func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.LoadCodeAssist(context.Background(), "token") if err == nil { t.Fatal("无效 JSON 响应应返回错误") @@ -1333,7 +1349,7 @@ func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, _, err := client.LoadCodeAssist(context.Background(), "token") if err != nil { t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err) @@ -1361,7 +1377,7 @@ func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.LoadCodeAssist(context.Background(), "token") if err == nil { t.Fatal("所有 URL 都失败时应返回错误") @@ -1377,7 +1393,7 @@ func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -1441,7 +1457,7 @@ func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc") if err != nil { t.Fatalf("FetchAvailableModels 失败: %v", err) @@ -1496,7 +1512,7 @@ func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj") if err == nil { t.Fatal("服务器返回 403 时应返回错误") @@ -1516,7 +1532,7 @@ func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err == nil { t.Fatal("无效 JSON 响应应返回错误") @@ -1546,7 +1562,7 @@ func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err) @@ -1574,7 +1590,7 @@ func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err == nil { t.Fatal("所有 URL 都失败时应返回错误") @@ -1590,7 +1606,7 @@ func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -1610,7 +1626,7 @@ func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 失败: %v", err) @@ -1646,7 +1662,7 @@ func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, _, err := client.LoadCodeAssist(context.Background(), "token") if err != nil { t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err) @@ -1672,7 +1688,7 @@ func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err) diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index 6ef3d714..32e4bc5b 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -18,11 +18,11 @@ package httpclient import ( "fmt" "net/http" - "net/url" "strings" "sync" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) @@ -41,7 +41,6 @@ type Options struct { Timeout time.Duration // 请求总超时时间 ResponseHeaderTimeout time.Duration // 等待响应头超时时间 InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true) - ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退 ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding) AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用) @@ -120,15 +119,13 @@ func buildTransport(opts Options) (*http.Transport, error) { return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead") } - proxyURL := strings.TrimSpace(opts.ProxyURL) - if proxyURL == "" { - return transport, nil - } - - parsed, err := url.Parse(proxyURL) + _, parsed, err := proxyurl.Parse(opts.ProxyURL) if err != nil { return nil, err } + if parsed == nil { + return transport, nil + } if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { return nil, err @@ -138,12 +135,11 @@ func buildTransport(opts Options) (*http.Transport, error) { } func buildClientKey(opts Options) string { - return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d", + return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%d|%d|%d", strings.TrimSpace(opts.ProxyURL), opts.Timeout.String(), opts.ResponseHeaderTimeout.String(), opts.InsecureSkipVerify, - opts.ProxyStrict, opts.ValidateResolvedIP, opts.AllowPrivateHosts, opts.MaxIdleConns, diff --git a/backend/internal/pkg/proxyurl/parse.go b/backend/internal/pkg/proxyurl/parse.go new file mode 100644 index 00000000..217556f2 --- /dev/null +++ b/backend/internal/pkg/proxyurl/parse.go @@ -0,0 +1,66 @@ +// Package proxyurl 提供代理 URL 的统一验证(fail-fast,无效代理不回退直连) +// +// 所有需要解析代理 URL 的地方必须通过此包的 Parse 函数。 +// 直接使用 url.Parse 处理代理 URL 是被禁止的。 +// 这确保了 fail-fast 行为:无效代理配置在创建时立即失败, +// 而不是在运行时静默回退到直连(产生 IP 关联风险)。 +package proxyurl + +import ( + "fmt" + "net/url" + "strings" +) + +// allowedSchemes 代理协议白名单 +var allowedSchemes = map[string]bool{ + "http": true, + "https": true, + "socks5": true, + "socks5h": true, +} + +// Parse 解析并验证代理 URL。 +// +// 语义: +// - 空字符串 → ("", nil, nil),表示直连 +// - 非空且有效 → (trimmed, *url.URL, nil) +// - 非空但无效 → ("", nil, error),fail-fast 不回退 +// +// 验证规则: +// - TrimSpace 后为空视为直连 +// - url.Parse 失败返回 error(不含原始 URL,防凭据泄露) +// - Host 为空返回 error(用 Redacted() 脱敏) +// - Scheme 必须为 http/https/socks5/socks5h +// - socks5:// 自动升级为 socks5h://(确保 DNS 由代理端解析,防止 DNS 泄漏) +func Parse(raw string) (trimmed string, parsed *url.URL, err error) { + trimmed = strings.TrimSpace(raw) + if trimmed == "" { + return "", nil, nil + } + + parsed, err = url.Parse(trimmed) + if err != nil { + // 不使用 %w 包装,避免 url.Parse 的底层错误消息泄漏原始 URL(可能含凭据) + return "", nil, fmt.Errorf("invalid proxy URL: %v", err) + } + + if parsed.Host == "" || parsed.Hostname() == "" { + return "", nil, fmt.Errorf("proxy URL missing host: %s", parsed.Redacted()) + } + + scheme := strings.ToLower(parsed.Scheme) + if !allowedSchemes[scheme] { + return "", nil, fmt.Errorf("unsupported proxy scheme %q (allowed: http, https, socks5, socks5h)", scheme) + } + + // 自动升级 socks5 → socks5h,确保 DNS 由代理端解析,防止 DNS 泄漏。 + // Go 的 golang.org/x/net/proxy 对 socks5:// 默认在客户端本地解析 DNS, + // 仅 socks5h:// 才将域名发送给代理端做远程 DNS 解析。 + if scheme == "socks5" { + parsed.Scheme = "socks5h" + trimmed = parsed.String() + } + + return trimmed, parsed, nil +} diff --git a/backend/internal/pkg/proxyurl/parse_test.go b/backend/internal/pkg/proxyurl/parse_test.go new file mode 100644 index 00000000..5fb57c16 --- /dev/null +++ b/backend/internal/pkg/proxyurl/parse_test.go @@ -0,0 +1,215 @@ +package proxyurl + +import ( + "strings" + "testing" +) + +func TestParse_空字符串直连(t *testing.T) { + trimmed, parsed, err := Parse("") + if err != nil { + t.Fatalf("空字符串应直连: %v", err) + } + if trimmed != "" { + t.Errorf("trimmed 应为空: got %q", trimmed) + } + if parsed != nil { + t.Errorf("parsed 应为 nil: got %v", parsed) + } +} + +func TestParse_空白字符串直连(t *testing.T) { + trimmed, parsed, err := Parse(" ") + if err != nil { + t.Fatalf("空白字符串应直连: %v", err) + } + if trimmed != "" { + t.Errorf("trimmed 应为空: got %q", trimmed) + } + if parsed != nil { + t.Errorf("parsed 应为 nil: got %v", parsed) + } +} + +func TestParse_有效HTTP代理(t *testing.T) { + trimmed, parsed, err := Parse("http://proxy.example.com:8080") + if err != nil { + t.Fatalf("有效 HTTP 代理应成功: %v", err) + } + if trimmed != "http://proxy.example.com:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } + if parsed == nil { + t.Fatal("parsed 不应为 nil") + } + if parsed.Host != "proxy.example.com:8080" { + t.Errorf("Host 不匹配: got %q", parsed.Host) + } +} + +func TestParse_有效HTTPS代理(t *testing.T) { + _, parsed, err := Parse("https://proxy.example.com:443") + if err != nil { + t.Fatalf("有效 HTTPS 代理应成功: %v", err) + } + if parsed.Scheme != "https" { + t.Errorf("Scheme 不匹配: got %q", parsed.Scheme) + } +} + +func TestParse_有效SOCKS5代理_自动升级为SOCKS5H(t *testing.T) { + trimmed, parsed, err := Parse("socks5://127.0.0.1:1080") + if err != nil { + t.Fatalf("有效 SOCKS5 代理应成功: %v", err) + } + // socks5 自动升级为 socks5h,确保 DNS 由代理端解析 + if trimmed != "socks5h://127.0.0.1:1080" { + t.Errorf("trimmed 应升级为 socks5h: got %q", trimmed) + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } +} + +func TestParse_无效URL(t *testing.T) { + _, _, err := Parse("://invalid") + if err == nil { + t.Fatal("无效 URL 应返回错误") + } + if !strings.Contains(err.Error(), "invalid proxy URL") { + t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error()) + } +} + +func TestParse_缺少Host(t *testing.T) { + _, _, err := Parse("http://") + if err == nil { + t.Fatal("缺少 host 应返回错误") + } + if !strings.Contains(err.Error(), "missing host") { + t.Errorf("错误信息应包含 'missing host': got %s", err.Error()) + } +} + +func TestParse_不支持的Scheme(t *testing.T) { + _, _, err := Parse("ftp://proxy.example.com:21") + if err == nil { + t.Fatal("不支持的 scheme 应返回错误") + } + if !strings.Contains(err.Error(), "unsupported proxy scheme") { + t.Errorf("错误信息应包含 'unsupported proxy scheme': got %s", err.Error()) + } +} + +func TestParse_含密码URL脱敏(t *testing.T) { + // 场景 1: 带密码的 socks5 URL 应成功解析并升级为 socks5h + trimmed, parsed, err := Parse("socks5://user:secret_password@proxy.local:1080") + if err != nil { + t.Fatalf("含密码的有效 URL 应成功: %v", err) + } + if trimmed == "" || parsed == nil { + t.Fatal("应返回非空结果") + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } + if !strings.HasPrefix(trimmed, "socks5h://") { + t.Errorf("trimmed 应以 socks5h:// 开头: got %q", trimmed) + } + if parsed.User == nil { + t.Error("升级后应保留 UserInfo") + } + + // 场景 2: 带密码但缺少 host(触发 Redacted 脱敏路径) + _, _, err = Parse("http://user:secret_password@:0/") + if err == nil { + t.Fatal("缺少 host 应返回错误") + } + if strings.Contains(err.Error(), "secret_password") { + t.Error("错误信息不应包含明文密码") + } + if !strings.Contains(err.Error(), "missing host") { + t.Errorf("错误信息应包含 'missing host': got %s", err.Error()) + } +} + +func TestParse_带空白的有效URL(t *testing.T) { + trimmed, parsed, err := Parse(" http://proxy.example.com:8080 ") + if err != nil { + t.Fatalf("带空白的有效 URL 应成功: %v", err) + } + if trimmed != "http://proxy.example.com:8080" { + t.Errorf("trimmed 应去除空白: got %q", trimmed) + } + if parsed == nil { + t.Fatal("parsed 不应为 nil") + } +} + +func TestParse_Scheme大小写不敏感(t *testing.T) { + // 大写 SOCKS5 应被接受并升级为 socks5h + trimmed, parsed, err := Parse("SOCKS5://proxy.example.com:1080") + if err != nil { + t.Fatalf("大写 SOCKS5 应被接受: %v", err) + } + if parsed.Scheme != "socks5h" { + t.Errorf("大写 SOCKS5 Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } + if !strings.HasPrefix(trimmed, "socks5h://") { + t.Errorf("大写 SOCKS5 trimmed 应升级为 socks5h://: got %q", trimmed) + } + + // 大写 HTTP 应被接受(不变) + _, _, err = Parse("HTTP://proxy.example.com:8080") + if err != nil { + t.Fatalf("大写 HTTP 应被接受: %v", err) + } +} + +func TestParse_带认证的有效代理(t *testing.T) { + trimmed, parsed, err := Parse("http://user:pass@proxy.example.com:8080") + if err != nil { + t.Fatalf("带认证的代理 URL 应成功: %v", err) + } + if parsed.User == nil { + t.Error("应保留 UserInfo") + } + if trimmed != "http://user:pass@proxy.example.com:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } +} + +func TestParse_IPv6地址(t *testing.T) { + trimmed, parsed, err := Parse("http://[::1]:8080") + if err != nil { + t.Fatalf("IPv6 代理 URL 应成功: %v", err) + } + if parsed.Hostname() != "::1" { + t.Errorf("Hostname 不匹配: got %q", parsed.Hostname()) + } + if trimmed != "http://[::1]:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } +} + +func TestParse_SOCKS5H保持不变(t *testing.T) { + trimmed, parsed, err := Parse("socks5h://proxy.local:1080") + if err != nil { + t.Fatalf("有效 SOCKS5H 代理应成功: %v", err) + } + // socks5h 不需要升级,应保持原样 + if trimmed != "socks5h://proxy.local:1080" { + t.Errorf("trimmed 不应变化: got %q", trimmed) + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应保持 socks5h: got %q", parsed.Scheme) + } +} + +func TestParse_无Scheme裸地址(t *testing.T) { + // 无 scheme 的裸地址,Go url.Parse 将其视为 path,Host 为空 + _, _, err := Parse("proxy.example.com:8080") + if err == nil { + t.Fatal("无 scheme 的裸地址应返回错误") + } +} diff --git a/backend/internal/pkg/proxyutil/dialer.go b/backend/internal/pkg/proxyutil/dialer.go index 91b224a2..e437cae3 100644 --- a/backend/internal/pkg/proxyutil/dialer.go +++ b/backend/internal/pkg/proxyutil/dialer.go @@ -2,7 +2,11 @@ // // 支持的代理协议: // - HTTP/HTTPS: 通过 Transport.Proxy 设置 -// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS) +// - SOCKS5: 通过 Transport.DialContext 设置(客户端本地解析 DNS) +// - SOCKS5H: 通过 Transport.DialContext 设置(代理端远程解析 DNS,推荐) +// +// 注意:proxyurl.Parse() 会自动将 socks5:// 升级为 socks5h://, +// 确保 DNS 也由代理端解析,防止 DNS 泄漏。 package proxyutil import ( @@ -20,7 +24,8 @@ import ( // // 支持的协议: // - http/https: 设置 transport.Proxy -// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS) +// - socks5: 设置 transport.DialContext(客户端本地解析 DNS) +// - socks5h: 设置 transport.DialContext(代理端远程解析 DNS,推荐) // // 参数: // - transport: 需要配置的 http.Transport diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index 77764881..b754bd55 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -11,6 +11,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/util/logredact" @@ -28,11 +29,14 @@ func NewClaudeOAuthClient() service.ClaudeOAuthClient { type claudeOAuthService struct { baseURL string tokenURL string - clientFactory func(proxyURL string) *req.Client + clientFactory func(proxyURL string) (*req.Client, error) } func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return "", fmt.Errorf("create HTTP client: %w", err) + } var orgs []struct { UUID string `json:"uuid"` @@ -88,7 +92,10 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey } func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return "", fmt.Errorf("create HTTP client: %w", err) + } authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID) @@ -165,7 +172,10 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe } func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } // Parse code which may contain state in format "authCode#state" authCode := code @@ -223,7 +233,10 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod } func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } reqBody := map[string]any{ "grant_type": "refresh_token", @@ -253,16 +266,20 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro return &tokenResp, nil } -func createReqClient(proxyURL string) *req.Client { +func createReqClient(proxyURL string) (*req.Client, error) { // 禁用 CookieJar,确保每次授权都是干净的会话 client := req.C(). SetTimeout(60 * time.Second). ImpersonateChrome(). SetCookieJar(nil) // 禁用 CookieJar - if strings.TrimSpace(proxyURL) != "" { - client.SetProxyURL(strings.TrimSpace(proxyURL)) + trimmed, _, err := proxyurl.Parse(proxyURL) + if err != nil { + return nil, err + } + if trimmed != "" { + client.SetProxyURL(trimmed) } - return client + return client, nil } diff --git a/backend/internal/repository/claude_oauth_service_test.go b/backend/internal/repository/claude_oauth_service_test.go index 7395c6d8..c6383033 100644 --- a/backend/internal/repository/claude_oauth_service_test.go +++ b/backend/internal/repository/claude_oauth_service_test.go @@ -91,7 +91,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.baseURL = "http://in-process" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "") @@ -169,7 +169,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.baseURL = "http://in-process" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "") @@ -276,7 +276,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.tokenURL = "http://in-process/token" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken) @@ -372,7 +372,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.tokenURL = "http://in-process/token" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } resp, err := s.client.RefreshToken(context.Background(), "rt", "") diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index 1198f472..f6054828 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -83,7 +83,7 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se AllowPrivateHosts: s.allowPrivateHosts, }) if err != nil { - client = &http.Client{Timeout: 30 * time.Second} + return nil, fmt.Errorf("create http client failed: %w", err) } resp, err = client.Do(req) diff --git a/backend/internal/repository/claude_usage_service_test.go b/backend/internal/repository/claude_usage_service_test.go index 2e10f3e5..cbd0b6d3 100644 --- a/backend/internal/repository/claude_usage_service_test.go +++ b/backend/internal/repository/claude_usage_service_test.go @@ -50,7 +50,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { allowPrivateHosts: true, } - resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") + resp, err := s.fetcher.FetchUsage(context.Background(), "at", "") require.NoError(s.T(), err, "FetchUsage") require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch") require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch") @@ -112,6 +112,17 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() { require.Error(s.T(), err, "expected error for cancelled context") } +func (s *ClaudeUsageServiceSuite) TestFetchUsage_InvalidProxyReturnsError() { + s.fetcher = &claudeUsageService{ + usageURL: "http://example.com", + allowPrivateHosts: true, + } + + _, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "create http client failed") +} + func TestClaudeUsageServiceSuite(t *testing.T) { suite.Run(t, new(ClaudeUsageServiceSuite)) } diff --git a/backend/internal/repository/gemini_oauth_client.go b/backend/internal/repository/gemini_oauth_client.go index 8b7fe625..eb14f313 100644 --- a/backend/internal/repository/gemini_oauth_client.go +++ b/backend/internal/repository/gemini_oauth_client.go @@ -26,7 +26,10 @@ func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient { } func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) { - client := createGeminiReqClient(proxyURL) + client, err := createGeminiReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } // Use different OAuth clients based on oauthType: // - code_assist: always use built-in Gemini CLI OAuth client (public) @@ -72,7 +75,10 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c } func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { - client := createGeminiReqClient(proxyURL) + client, err := createGeminiReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } oauthCfgInput := geminicli.OAuthConfig{ ClientID: c.cfg.Gemini.OAuth.ClientID, @@ -111,7 +117,7 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh return &tokenResp, nil } -func createGeminiReqClient(proxyURL string) *req.Client { +func createGeminiReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 60 * time.Second, diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go index 4f63280d..b5bc6497 100644 --- a/backend/internal/repository/geminicli_codeassist_client.go +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -26,7 +26,11 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo } var out geminicli.LoadCodeAssistResponse - resp, err := createGeminiCliReqClient(proxyURL).R(). + client, err := createGeminiCliReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } + resp, err := client.R(). SetContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Content-Type", "application/json"). @@ -66,7 +70,11 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody) var out geminicli.OnboardUserResponse - resp, err := createGeminiCliReqClient(proxyURL).R(). + client, err := createGeminiCliReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } + resp, err := client.R(). SetContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Content-Type", "application/json"). @@ -98,7 +106,7 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken return &out, nil } -func createGeminiCliReqClient(proxyURL string) *req.Client { +func createGeminiCliReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 30 * time.Second, diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go index 28efe914..ad1f22e3 100644 --- a/backend/internal/repository/github_release_service.go +++ b/backend/internal/repository/github_release_service.go @@ -5,8 +5,10 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net/http" "os" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" @@ -24,13 +26,19 @@ type githubReleaseClientError struct { // NewGitHubReleaseClient 创建 GitHub Release 客户端 // proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议 +// 代理配置失败时行为由 allowDirectOnProxyError 控制: +// - false(默认):返回错误占位客户端,禁止回退到直连 +// - true:回退到直连(仅限管理员显式开启) func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient { + // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据, + // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。 sharedClient, err := httpclient.GetClient(httpclient.Options{ Timeout: 30 * time.Second, ProxyURL: proxyURL, }) if err != nil { - if proxyURL != "" && !allowDirectOnProxyError { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy client init failed, all requests will fail", "service", "github_release", "error", err) return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} } sharedClient = &http.Client{Timeout: 30 * time.Second} @@ -42,7 +50,8 @@ func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) servi ProxyURL: proxyURL, }) if err != nil { - if proxyURL != "" && !allowDirectOnProxyError { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy download client init failed, all requests will fail", "service", "github_release", "error", err) return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} } downloadClient = &http.Client{Timeout: 10 * time.Minute} diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index b0f15f19..a4674c1a 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -14,6 +14,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/Wei-Shaw/sub2api/internal/service" @@ -235,7 +236,10 @@ func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID in // TLS 指纹客户端使用独立的缓存键,与普通客户端隔离 func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { isolation := s.getIsolationMode() - proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL) + if err != nil { + return nil, err + } // TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀 cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID) poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls" @@ -373,9 +377,8 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac // - proxy: 按代理地址隔离,同一代理共享客户端 // - account: 按账户隔离,同一账户共享客户端(代理变更时重建) // - account_proxy: 按账户+代理组合隔离,最细粒度 -func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry { - entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) - return entry +func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { + return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) } // getClientEntry 获取或创建客户端条目 @@ -385,7 +388,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a // 获取隔离模式 isolation := s.getIsolationMode() // 标准化代理 URL 并解析 - proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL) + if err != nil { + return nil, err + } // 构建缓存键(根据隔离策略不同) cacheKey := buildCacheKey(isolation, proxyKey, accountID) // 构建连接池配置键(用于检测配置变更) @@ -680,17 +686,18 @@ func buildCacheKey(isolation, proxyKey string, accountID int64) string { // - raw: 原始代理 URL 字符串 // // 返回: -// - string: 标准化的代理键(空或解析失败返回 "direct") -// - *url.URL: 解析后的 URL(空或解析失败返回 nil) -func normalizeProxyURL(raw string) (string, *url.URL) { - proxyURL := strings.TrimSpace(raw) - if proxyURL == "" { - return directProxyKey, nil - } - parsed, err := url.Parse(proxyURL) +// - string: 标准化的代理键(空返回 "direct") +// - *url.URL: 解析后的 URL(空返回 nil) +// - error: 非空代理 URL 解析失败时返回错误(禁止回退到直连) +func normalizeProxyURL(raw string) (string, *url.URL, error) { + _, parsed, err := proxyurl.Parse(raw) if err != nil { - return directProxyKey, nil + return "", nil, err } + if parsed == nil { + return directProxyKey, nil, nil + } + // 规范化:小写 scheme/host,去除路径和查询参数 parsed.Scheme = strings.ToLower(parsed.Scheme) parsed.Host = strings.ToLower(parsed.Host) parsed.Path = "" @@ -710,7 +717,7 @@ func normalizeProxyURL(raw string) (string, *url.URL) { parsed.Host = hostname } } - return parsed.String(), parsed + return parsed.String(), parsed, nil } // defaultPoolSettings 获取默认连接池配置 diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go index 1e7430a3..89892b3b 100644 --- a/backend/internal/repository/http_upstream_benchmark_test.go +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -59,7 +59,10 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { // 模拟优化后的行为,从缓存获取客户端 b.Run("复用", func(b *testing.B) { // 预热:确保客户端已缓存 - entry := svc.getOrCreateClient(proxyURL, 1, 1) + entry, err := svc.getOrCreateClient(proxyURL, 1, 1) + if err != nil { + b.Fatalf("getOrCreateClient: %v", err) + } client := entry.client b.ResetTimer() // 重置计时器,排除预热时间 for i := 0; i < b.N; i++ { diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index fbe44c5e..b3268463 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -44,7 +44,7 @@ func (s *HTTPUpstreamSuite) newService() *httpUpstreamService { // 验证未配置时使用 300 秒默认值 func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { svc := s.newService() - entry := svc.getOrCreateClient("", 0, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") @@ -55,25 +55,27 @@ func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() { s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7} svc := s.newService() - entry := svc.getOrCreateClient("", 0, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") } -// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退 -// 验证解析失败时回退到直连模式 -func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() { +// TestGetOrCreateClient_InvalidURLReturnsError 测试无效代理 URL 返回错误 +// 验证解析失败时拒绝回退到直连模式 +func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLReturnsError() { svc := s.newService() - entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1) - require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback") + _, err := svc.getClientEntry("://bad-proxy-url", 1, 1, false, false) + require.Error(s.T(), err, "expected error for invalid proxy URL") } // TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化 // 验证等价地址能够映射到同一缓存键 func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() { - key1, _ := normalizeProxyURL("http://proxy.local:8080") - key2, _ := normalizeProxyURL("http://proxy.local:8080/") + key1, _, err1 := normalizeProxyURL("http://proxy.local:8080") + require.NoError(s.T(), err1) + key2, _, err2 := normalizeProxyURL("http://proxy.local:8080/") + require.NoError(s.T(), err2) require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match") } @@ -171,8 +173,8 @@ func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 同一代理,不同账户 - entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 2, 3) require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池") require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端") } @@ -183,8 +185,8 @@ func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy} svc := s.newService() // 同一账户,不同代理 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3) require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理") require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端") } @@ -195,8 +197,8 @@ func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 同一账户,先后使用不同代理 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3) require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池") require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池") require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理") @@ -208,7 +210,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 账户并发数为 12 - entry := svc.getOrCreateClient("", 1, 12) + entry := mustGetOrCreateClient(s.T(), svc, "", 1, 12) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") // 连接池参数应与并发数一致 @@ -228,7 +230,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() { } svc := s.newService() // 账户并发数为 0,应使用全局配置 - entry := svc.getOrCreateClient("", 1, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 1, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch") @@ -245,12 +247,12 @@ func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() { } svc := s.newService() // 创建两个客户端,设置不同的最后使用时间 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 1) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 2, 1) atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久 atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano()) // 创建第三个客户端,触发淘汰 - _ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1) + _ = mustGetOrCreateClient(s.T(), svc, "http://proxy-c:8080", 3, 1) require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内") require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理") @@ -264,12 +266,12 @@ func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() { ClientIdleTTLSeconds: 1, // 1 秒空闲超时 } svc := s.newService() - entry1 := svc.getOrCreateClient("", 1, 1) + entry1 := mustGetOrCreateClient(s.T(), svc, "", 1, 1) // 设置为很久之前使用,但有活跃请求 atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano()) atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求 // 创建新客户端,触发淘汰检查 - _ = svc.getOrCreateClient("", 2, 1) + _, _ = svc.getOrCreateClient("", 2, 1) require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收") } @@ -279,6 +281,14 @@ func TestHTTPUpstreamSuite(t *testing.T) { suite.Run(t, new(HTTPUpstreamSuite)) } +// mustGetOrCreateClient 测试辅助函数,调用 getOrCreateClient 并断言无错误 +func mustGetOrCreateClient(t *testing.T, svc *httpUpstreamService, proxyURL string, accountID int64, concurrency int) *upstreamClientEntry { + t.Helper() + entry, err := svc.getOrCreateClient(proxyURL, accountID, concurrency) + require.NoError(t, err, "getOrCreateClient(%q, %d, %d)", proxyURL, accountID, concurrency) + return entry +} + // hasEntry 检查客户端是否存在于缓存中 // 辅助函数,用于验证淘汰逻辑 func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool { diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index 3e155971..dca0b612 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -23,7 +23,10 @@ type openaiOAuthService struct { } func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { - client := createOpenAIReqClient(proxyURL) + client, err := createOpenAIReqClient(proxyURL) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err) + } if redirectURI == "" { redirectURI = openai.DefaultRedirectURI @@ -74,7 +77,10 @@ func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refre } func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) { - client := createOpenAIReqClient(proxyURL) + client, err := createOpenAIReqClient(proxyURL) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err) + } formData := url.Values{} formData.Set("grant_type", "refresh_token") @@ -102,7 +108,7 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre return &tokenResp, nil } -func createOpenAIReqClient(proxyURL string) *req.Client { +func createOpenAIReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 120 * time.Second, diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go index 07d796b8..ee8e1749 100644 --- a/backend/internal/repository/pricing_service.go +++ b/backend/internal/repository/pricing_service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "log/slog" "net/http" "strings" "time" @@ -16,14 +17,37 @@ type pricingRemoteClient struct { httpClient *http.Client } +// pricingRemoteClientError 代理初始化失败时的错误占位客户端 +// 所有请求直接返回初始化错误,禁止回退到直连 +type pricingRemoteClientError struct { + err error +} + +func (c *pricingRemoteClientError) FetchPricingJSON(_ context.Context, _ string) ([]byte, error) { + return nil, c.err +} + +func (c *pricingRemoteClientError) FetchHashText(_ context.Context, _ string) (string, error) { + return "", c.err +} + // NewPricingRemoteClient 创建定价数据远程客户端 // proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议 -func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient { +// 代理配置失败时行为由 allowDirectOnProxyError 控制: +// - false(默认):返回错误占位客户端,禁止回退到直连 +// - true:回退到直连(仅限管理员显式开启) +func NewPricingRemoteClient(proxyURL string, allowDirectOnProxyError bool) service.PricingRemoteClient { + // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据, + // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。 sharedClient, err := httpclient.GetClient(httpclient.Options{ Timeout: 30 * time.Second, ProxyURL: proxyURL, }) if err != nil { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy client init failed, all requests will fail", "service", "pricing", "error", err) + return &pricingRemoteClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} + } sharedClient = &http.Client{Timeout: 30 * time.Second} } return &pricingRemoteClient{ diff --git a/backend/internal/repository/pricing_service_test.go b/backend/internal/repository/pricing_service_test.go index 6ea11211..ef2f214b 100644 --- a/backend/internal/repository/pricing_service_test.go +++ b/backend/internal/repository/pricing_service_test.go @@ -19,7 +19,7 @@ type PricingServiceSuite struct { func (s *PricingServiceSuite) SetupTest() { s.ctx = context.Background() - client, ok := NewPricingRemoteClient("").(*pricingRemoteClient) + client, ok := NewPricingRemoteClient("", false).(*pricingRemoteClient) require.True(s.T(), ok, "type assertion failed") s.client = client } @@ -140,6 +140,22 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() { require.Error(s.T(), err) } +func TestNewPricingRemoteClient_InvalidProxy_NoFallback(t *testing.T) { + client := NewPricingRemoteClient("://bad", false) + _, ok := client.(*pricingRemoteClientError) + require.True(t, ok, "should return error client when proxy is invalid and fallback disabled") + + _, err := client.FetchPricingJSON(context.Background(), "http://example.com") + require.Error(t, err) + require.Contains(t, err.Error(), "proxy client init failed") +} + +func TestNewPricingRemoteClient_InvalidProxy_WithFallback(t *testing.T) { + client := NewPricingRemoteClient("://bad", true) + _, ok := client.(*pricingRemoteClient) + require.True(t, ok, "should fallback to direct client when allowed") +} + func TestPricingServiceSuite(t *testing.T) { suite.Run(t, new(PricingServiceSuite)) } diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 54de2897..b4aeab71 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -66,7 +66,6 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s ProxyURL: proxyURL, Timeout: defaultProxyProbeTimeout, InsecureSkipVerify: s.insecureSkipVerify, - ProxyStrict: true, ValidateResolvedIP: s.validateResolvedIP, AllowPrivateHosts: s.allowPrivateHosts, }) diff --git a/backend/internal/repository/req_client_pool.go b/backend/internal/repository/req_client_pool.go index af71a7ee..79b24396 100644 --- a/backend/internal/repository/req_client_pool.go +++ b/backend/internal/repository/req_client_pool.go @@ -6,6 +6,8 @@ import ( "sync" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "github.com/imroc/req/v3" ) @@ -33,11 +35,11 @@ var sharedReqClients sync.Map // getSharedReqClient 获取共享的 req 客户端实例 // 性能优化:相同配置复用同一客户端,避免重复创建 -func getSharedReqClient(opts reqClientOptions) *req.Client { +func getSharedReqClient(opts reqClientOptions) (*req.Client, error) { key := buildReqClientKey(opts) if cached, ok := sharedReqClients.Load(key); ok { if c, ok := cached.(*req.Client); ok { - return c + return c, nil } } @@ -48,15 +50,19 @@ func getSharedReqClient(opts reqClientOptions) *req.Client { if opts.Impersonate { client = client.ImpersonateChrome() } - if strings.TrimSpace(opts.ProxyURL) != "" { - client.SetProxyURL(strings.TrimSpace(opts.ProxyURL)) + trimmed, _, err := proxyurl.Parse(opts.ProxyURL) + if err != nil { + return nil, err + } + if trimmed != "" { + client.SetProxyURL(trimmed) } actual, _ := sharedReqClients.LoadOrStore(key, client) if c, ok := actual.(*req.Client); ok { - return c + return c, nil } - return client + return client, nil } func buildReqClientKey(opts reqClientOptions) string { diff --git a/backend/internal/repository/req_client_pool_test.go b/backend/internal/repository/req_client_pool_test.go index 904ed4d6..9067d012 100644 --- a/backend/internal/repository/req_client_pool_test.go +++ b/backend/internal/repository/req_client_pool_test.go @@ -26,11 +26,13 @@ func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) { ProxyURL: "http://proxy.local:8080", Timeout: time.Second, } - clientDefault := getSharedReqClient(base) + clientDefault, err := getSharedReqClient(base) + require.NoError(t, err) force := base force.ForceHTTP2 = true - clientForce := getSharedReqClient(force) + clientForce, err := getSharedReqClient(force) + require.NoError(t, err) require.NotSame(t, clientDefault, clientForce) require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force)) @@ -42,8 +44,10 @@ func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) { ProxyURL: "http://proxy.local:8080", Timeout: 2 * time.Second, } - first := getSharedReqClient(opts) - second := getSharedReqClient(opts) + first, err := getSharedReqClient(opts) + require.NoError(t, err) + second, err := getSharedReqClient(opts) + require.NoError(t, err) require.Same(t, first, second) } @@ -56,7 +60,8 @@ func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) { key := buildReqClientKey(opts) sharedReqClients.Store(key, "invalid") - client := getSharedReqClient(opts) + client, err := getSharedReqClient(opts) + require.NoError(t, err) require.NotNil(t, client) loaded, ok := sharedReqClients.Load(key) @@ -71,20 +76,45 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) { Timeout: 4 * time.Second, Impersonate: true, } - client := getSharedReqClient(opts) + client, err := getSharedReqClient(opts) + require.NoError(t, err) require.NotNil(t, client) require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts)) } +func TestGetSharedReqClient_InvalidProxyURL(t *testing.T) { + sharedReqClients = sync.Map{} + opts := reqClientOptions{ + ProxyURL: "://missing-scheme", + Timeout: time.Second, + } + _, err := getSharedReqClient(opts) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid proxy URL") +} + +func TestGetSharedReqClient_ProxyURLMissingHost(t *testing.T) { + sharedReqClients = sync.Map{} + opts := reqClientOptions{ + ProxyURL: "http://", + Timeout: time.Second, + } + _, err := getSharedReqClient(opts) + require.Error(t, err) + require.Contains(t, err.Error(), "proxy URL missing host") +} + func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) { sharedReqClients = sync.Map{} - client := createOpenAIReqClient("http://proxy.local:8080") + client, err := createOpenAIReqClient("http://proxy.local:8080") + require.NoError(t, err) require.Equal(t, 120*time.Second, client.GetClient().Timeout) } func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) { sharedReqClients = sync.Map{} - client := createGeminiReqClient("http://proxy.local:8080") + client, err := createGeminiReqClient("http://proxy.local:8080") + require.NoError(t, err) require.Equal(t, "", forceHTTPVersion(t, client)) } diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 2344035c..ee796d98 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -34,7 +34,7 @@ func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient // ProvidePricingRemoteClient 创建定价数据远程客户端 // 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据 func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient { - return NewPricingRemoteClient(cfg.Update.ProxyURL) + return NewPricingRemoteClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError) } // ProvideSessionLimitCache 创建会话限制缓存 diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index bdd1aa4a..7e6982d3 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -2028,7 +2028,6 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr ProxyURL: proxyURL, Timeout: proxyQualityRequestTimeout, ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout, - ProxyStrict: true, }) if err != nil { result.Items = append(result.Items, ProxyQualityCheckItem{ diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index b67c7faf..5f6691be 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -112,7 +112,10 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig } } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } // 交换 token tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier) @@ -167,7 +170,10 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken time.Sleep(backoff) } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } tokenResp, err := client.RefreshToken(ctx, refreshToken) if err == nil { now := time.Now() @@ -209,7 +215,10 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr } // 获取用户信息(email) - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken) if err != nil { fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err) @@ -309,7 +318,10 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac time.Sleep(backoff) } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return "", fmt.Errorf("create antigravity client failed: %w", err) + } loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken) if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go index 07eb563d..e950ec1d 100644 --- a/backend/internal/service/antigravity_quota_fetcher.go +++ b/backend/internal/service/antigravity_quota_fetcher.go @@ -2,6 +2,7 @@ package service import ( "context" + "fmt" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" @@ -31,7 +32,10 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou accessToken := account.GetCredential("access_token") projectID := account.GetCredential("project_id") - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } // 调用 API 获取配额 modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index 040b2357..6a916740 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -221,7 +221,7 @@ func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username, AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts, }) if err != nil { - client = &http.Client{Timeout: 20 * time.Second} + return nil, fmt.Errorf("create http client failed: %w", err) } adminToken, err := crsLogin(ctx, client, normalizedURL, username, password) diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index e866bdc3..08a74a37 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -1045,7 +1045,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR ValidateResolvedIP: true, }) if err != nil { - client = &http.Client{Timeout: 30 * time.Second} + return "", fmt.Errorf("create http client failed: %w", err) } resp, err := client.Do(req) diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index 07cb5472..72f4bbb0 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -7,7 +7,6 @@ import ( "io" "log/slog" "net/http" - "net/url" "regexp" "sort" "strconv" @@ -15,6 +14,7 @@ import ( "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) @@ -273,7 +273,13 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi req.Header.Set("Referer", "https://sora.chatgpt.com/") req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - client := newOpenAIOAuthHTTPClient(proxyURL) + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: 120 * time.Second, + }) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err) + } resp, err := client.Do(req) if err != nil { return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err) @@ -530,19 +536,6 @@ func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64 return proxy.URL(), nil } -func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client { - transport := &http.Transport{} - if strings.TrimSpace(proxyURL) != "" { - if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" { - transport.Proxy = http.ProxyURL(parsed) - } - } - return &http.Client{ - Timeout: 120 * time.Second, - Transport: transport, - } -} - func normalizeOpenAIOAuthPlatform(platform string) string { switch strings.ToLower(strings.TrimSpace(platform)) { case PlatformSora: diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index faa85854..e2eb3130 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -134,6 +134,12 @@ security: # Allow skipping TLS verification for proxy probe (debug only) # 允许代理探测时跳过 TLS 证书验证(仅用于调试) insecure_skip_verify: false + proxy_fallback: + # Allow auxiliary services (update check, pricing data) to fallback to direct + # connection when proxy initialization fails. Does NOT affect AI gateway connections. + # 辅助服务(更新检查、定价数据拉取)代理初始化失败时是否允许回退直连。 + # 不影响 AI 账号网关连接。默认 false:fail-fast 防止 IP 泄露。 + allow_direct_on_error: false # ============================================================================= # Gateway Configuration From ec6bcfeb83bd362ec5d6ec5e5935b404cd2353c3 Mon Sep 17 00:00:00 2001 From: zqq61 <1852150449@qq.com> Date: Mon, 2 Mar 2026 22:54:38 +0800 Subject: [PATCH 002/286] =?UTF-8?q?fix:=20OAuth=20401=20=E4=B8=8D=E5=86=8D?= =?UTF-8?q?=E6=B0=B8=E4=B9=85=E9=94=81=E6=AD=BB=E8=B4=A6=E5=8F=B7=EF=BC=8C?= =?UTF-8?q?=E6=94=B9=E7=94=A8=E4=B8=B4=E6=97=B6=E4=B8=8D=E5=8F=AF=E8=B0=83?= =?UTF-8?q?=E5=BA=A6=E5=AE=9E=E7=8E=B0=E8=87=AA=E5=8A=A8=E6=81=A2=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit OAuth 账号收到 401 时,原逻辑同时设置 expires_at=now() 和 SetError(), 但刷新服务只查询 status=active 的账号,导致 error 状态的账号永远无法 被刷新服务拾取,expires_at=now() 实际上是死代码。 修复: - OAuth 401 使用 SetTempUnschedulable 替代 SetError,保持 status=active - 新增 oauth_401_cooldown_minutes 配置项(默认 10 分钟) - 刷新成功后同步清除 DB 和 Redis 中的临时不可调度状态 - 不可重试错误检查(invalid_grant 等)从 Antigravity 推广到所有平台 - 可重试错误耗尽后不再标记 error,下个刷新周期继续重试 恢复流程: OAuth 401 → temp_unschedulable + expires_at=now → 刷新服务拾取 → 成功: 清除 temp_unschedulable → 自动恢复 → invalid_grant: SetError → 永久禁用 → 网络错误: 仅记日志 → 下周期重试 --- backend/cmd/server/wire_gen.go | 2 +- backend/internal/config/config.go | 2 + backend/internal/service/ratelimit_service.go | 28 +++- .../service/ratelimit_service_401_test.go | 10 +- .../internal/service/token_refresh_service.go | 53 +++++--- .../service/token_refresh_service_test.go | 126 +++++++++++++++--- backend/internal/service/wire.go | 3 +- 7 files changed, 175 insertions(+), 49 deletions(-) diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 37ad5d9f..a66d7d05 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -219,7 +219,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 4f6fea37..980a6deb 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -814,6 +814,7 @@ type DefaultConfig struct { type RateLimitConfig struct { OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) + OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` // OAuth 401临时不可调度冷却(分钟) } // APIKeyAuthCacheConfig API Key 认证缓存配置 @@ -1190,6 +1191,7 @@ func setDefaults() { // RateLimit viper.SetDefault("rate_limit.overload_cooldown_minutes", 10) + viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10) // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移) viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json") diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index d4d70536..84bf95ce 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -146,13 +146,29 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } else { slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform) } + // 3. 临时不可调度,替代 SetError(保持 status=active 让刷新服务能拾取) + msg := "Authentication failed (401): invalid or expired credentials" + if upstreamMsg != "" { + msg = "OAuth 401: " + upstreamMsg + } + cooldownMinutes := s.cfg.RateLimit.OAuth401CooldownMinutes + if cooldownMinutes <= 0 { + cooldownMinutes = 10 + } + until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) + if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil { + slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err) + } + shouldDisable = true + } else { + // 非 OAuth 账号(APIKey):保持原有 SetError 行为 + msg := "Authentication failed (401): invalid or expired credentials" + if upstreamMsg != "" { + msg = "Authentication failed (401): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + shouldDisable = true } - msg := "Authentication failed (401): invalid or expired credentials" - if upstreamMsg != "" { - msg = "Authentication failed (401): " + upstreamMsg - } - s.handleAuthError(ctx, account, msg) - shouldDisable = true case 402: // 支付要求:余额不足或计费问题,停止调度 msg := "Payment required (402): insufficient balance or billing issue" diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 36357a4b..7bced46f 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -41,7 +41,7 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc return r.err } -func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) { +func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) { tests := []struct { name string platform string @@ -76,9 +76,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) { shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) - require.Equal(t, 1, repo.setErrorCalls) - require.Equal(t, 0, repo.tempCalls) - require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)") + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.tempCalls) require.Len(t, invalidator.accounts, 1) }) } @@ -98,7 +97,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) - require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.tempCalls) require.Len(t, invalidator.accounts, 1) } diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index a37e0d0a..f069bb5e 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -18,7 +18,8 @@ type TokenRefreshService struct { refreshers []TokenRefresher cfg *config.TokenRefreshConfig cacheInvalidator TokenCacheInvalidator - schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题 + schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题 + tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存 stopCh chan struct{} wg sync.WaitGroup @@ -34,12 +35,14 @@ func NewTokenRefreshService( cacheInvalidator TokenCacheInvalidator, schedulerCache SchedulerCache, cfg *config.Config, + tempUnschedCache TempUnschedCache, ) *TokenRefreshService { s := &TokenRefreshService{ accountRepo: accountRepo, cfg: &cfg.TokenRefresh, cacheInvalidator: cacheInvalidator, schedulerCache: schedulerCache, + tempUnschedCache: tempUnschedCache, stopCh: make(chan struct{}), } @@ -231,6 +234,26 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID) } } + // 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景) + if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { + if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil { + slog.Warn("token_refresh.clear_temp_unschedulable_failed", + "account_id", account.ID, + "error", clearErr, + ) + } else { + slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID) + } + // 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态 + if s.tempUnschedCache != nil { + if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil { + slog.Warn("token_refresh.clear_temp_unsched_cache_failed", + "account_id", account.ID, + "error", clearErr, + ) + } + } + } // 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理) if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth { if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil { @@ -257,8 +280,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc return nil } - // Antigravity 账户:不可重试错误直接标记 error 状态并返回 - if account.Platform == PlatformAntigravity && isNonRetryableRefreshError(err) { + // 不可重试错误(invalid_grant/invalid_client 等)直接标记 error 状态并返回 + if isNonRetryableRefreshError(err) { errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err) if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil { slog.Error("token_refresh.set_error_status_failed", @@ -285,23 +308,13 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc } } - // Antigravity 账户:其他错误仅记录日志,不标记 error(可能是临时网络问题) - // 其他平台账户:重试失败后标记 error - if account.Platform == PlatformAntigravity { - slog.Warn("token_refresh.retry_exhausted_antigravity", - "account_id", account.ID, - "max_retries", s.cfg.MaxRetries, - "error", lastErr, - ) - } else { - errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr) - if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { - slog.Error("token_refresh.set_error_status_failed", - "account_id", account.ID, - "error", err, - ) - } - } + // 可重试错误耗尽:仅记录日志,不标记 error(可能是临时网络问题,下个周期继续重试) + slog.Warn("token_refresh.retry_exhausted", + "account_id", account.ID, + "platform", account.Platform, + "max_retries", s.cfg.MaxRetries, + "error", lastErr, + ) return lastErr } diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go index 8e16c6f5..bdef0ed7 100644 --- a/backend/internal/service/token_refresh_service_test.go +++ b/backend/internal/service/token_refresh_service_test.go @@ -14,10 +14,11 @@ import ( type tokenRefreshAccountRepo struct { mockAccountRepoForGemini - updateCalls int - setErrorCalls int - lastAccount *Account - updateErr error + updateCalls int + setErrorCalls int + clearTempCalls int + lastAccount *Account + updateErr error } func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { @@ -31,6 +32,11 @@ func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorM return nil } +func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { + r.clearTempCalls++ + return nil +} + type tokenCacheInvalidatorStub struct { calls int err error @@ -41,6 +47,23 @@ func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account return s.err } +type tempUnschedCacheStub struct { + deleteCalls int +} + +func (s *tempUnschedCacheStub) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error { + return nil +} + +func (s *tempUnschedCacheStub) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) { + return nil, nil +} + +func (s *tempUnschedCacheStub) DeleteTempUnsched(ctx context.Context, accountID int64) error { + s.deleteCalls++ + return nil +} + type tokenRefresherStub struct { credentials map[string]any err error @@ -70,7 +93,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 5, Platform: PlatformGemini, @@ -98,7 +121,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 6, Platform: PlatformGemini, @@ -124,7 +147,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil) account := &Account{ ID: 7, Platform: PlatformGemini, @@ -151,7 +174,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 8, Platform: PlatformAntigravity, @@ -179,7 +202,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 9, Platform: PlatformGemini, @@ -207,7 +230,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 10, Platform: PlatformOpenAI, // OpenAI OAuth 账户 @@ -235,7 +258,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 11, Platform: PlatformGemini, @@ -254,7 +277,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { require.Equal(t, 0, invalidator.calls) // 更新失败时不应触发缓存失效 } -// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试刷新失败的情况 +// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试可重试错误耗尽不标记 error func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) { repo := &tokenRefreshAccountRepo{} invalidator := &tokenCacheInvalidatorStub{} @@ -264,7 +287,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 12, Platform: PlatformGemini, @@ -278,7 +301,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) { require.Error(t, err) require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新 require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效 - require.Equal(t, 1, repo.setErrorCalls) // 应设置错误状态 + require.Equal(t, 0, repo.setErrorCalls) // 可重试错误耗尽不标记 error,下个周期继续重试 } // TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态 @@ -291,7 +314,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 13, Platform: PlatformAntigravity, @@ -318,7 +341,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 14, Platform: PlatformAntigravity, @@ -335,6 +358,77 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te require.Equal(t, 1, repo.setErrorCalls) // 不可重试错误应设置错误状态 } +// TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable 测试刷新成功后清除临时不可调度(DB + Redis) +func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{} + tempCache := &tempUnschedCacheStub{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, tempCache) + until := time.Now().Add(10 * time.Minute) + account := &Account{ + ID: 15, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + TempUnschedulableUntil: &until, + } + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "new-token", + }, + } + + err := service.refreshWithRetry(context.Background(), account, refresher) + require.NoError(t, err) + require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 1, repo.clearTempCalls) // DB 清除 + require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除 +} + +// TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms 测试所有平台不可重试错误都 SetError +func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *testing.T) { + tests := []struct { + name string + platform string + }{ + {name: "gemini", platform: PlatformGemini}, + {name: "anthropic", platform: PlatformAnthropic}, + {name: "openai", platform: PlatformOpenAI}, + {name: "antigravity", platform: PlatformAntigravity}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 3, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + account := &Account{ + ID: 16, + Platform: tt.platform, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + err: errors.New("invalid_grant: token revoked"), + } + + err := service.refreshWithRetry(context.Background(), account, refresher) + require.Error(t, err) + require.Equal(t, 1, repo.setErrorCalls) // 所有平台不可重试错误都应 SetError + }) + } +} + // TestIsNonRetryableRefreshError 测试不可重试错误判断 func TestIsNonRetryableRefreshError(t *testing.T) { tests := []struct { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 68deace9..ac90db27 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -48,8 +48,9 @@ func ProvideTokenRefreshService( cacheInvalidator TokenCacheInvalidator, schedulerCache SchedulerCache, cfg *config.Config, + tempUnschedCache TempUnschedCache, ) *TokenRefreshService { - svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg) + svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache) // 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表 svc.SetSoraAccountRepo(soraAccountRepo) svc.Start() From a9285b8a94d425820e4db975ee9a28ea24735d12 Mon Sep 17 00:00:00 2001 From: QTom Date: Tue, 3 Mar 2026 01:02:39 +0800 Subject: [PATCH 003/286] =?UTF-8?q?feat(gateway):=20=E5=8F=8C=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E7=94=A8=E6=88=B7=E6=B6=88=E6=81=AF=E9=98=9F=E5=88=97?= =?UTF-8?q?=20=E2=80=94=20=E4=B8=B2=E8=A1=8C=E9=98=9F=E5=88=97=20+=20?= =?UTF-8?q?=E8=BD=AF=E6=80=A7=E9=99=90=E9=80=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增 UMQ (User Message Queue) 双模式支持: - serialize: 账号级分布式串行锁 + RPM 自适应延迟(严格限流) - throttle: 仅 RPM 自适应前置延迟,不阻塞并发(软性限速) 后端: - config: 新增 Mode 字段,保留 Enabled 向后兼容 - service: 新增 UserMessageQueueService(Lua 锁/延迟算法/清理 worker) - repository: 新增 UserMsgQueueCache(Redis Lua acquire/release/force-release) - handler: 新增 UserMsgQueueHelper(SSE ping + 等待循环 + throttle) - gateway: 按 mode 分支集成 serialize/throttle 逻辑 - lint: 修复 gofmt rewrite rules、errcheck 类型断言、staticcheck QF1012 前端: - 三态选择器 UI(关闭/软性限速/串行队列)替代 toggle 开关 - BulkEdit 支持 null 语义(不修改) - i18n 中英文文案 通过 6 轮专家评审(42 次 review)、golangci-lint、单元测试、集成测试。 --- backend/cmd/server/wire_gen.go | 4 +- backend/internal/config/config.go | 70 ++++ backend/internal/handler/dto/mappers.go | 4 + backend/internal/handler/dto/types.go | 7 +- backend/internal/handler/gateway_handler.go | 91 +++++ .../internal/handler/user_msg_queue_helper.go | 237 +++++++++++++ .../ops_repo_latency_histogram_buckets.go | 8 +- .../repository/user_msg_queue_cache.go | 186 ++++++++++ backend/internal/repository/wire.go | 1 + backend/internal/service/account.go | 21 ++ backend/internal/service/gateway_request.go | 4 + backend/internal/service/gateway_service.go | 6 + backend/internal/service/setting_service.go | 12 +- .../service/user_msg_queue_service.go | 318 ++++++++++++++++++ backend/internal/service/wire.go | 10 + .../account/BulkEditAccountModal.vue | 43 ++- .../components/account/CreateAccountModal.vue | 38 +++ .../components/account/EditAccountModal.vue | 39 +++ frontend/src/i18n/locales/en.ts | 7 +- frontend/src/i18n/locales/zh.ts | 7 +- frontend/src/types/index.ts | 1 + 21 files changed, 1099 insertions(+), 15 deletions(-) create mode 100644 backend/internal/handler/user_msg_queue_helper.go create mode 100644 backend/internal/repository/user_msg_queue_cache.go create mode 100644 backend/internal/service/user_msg_queue_service.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 90709f5b..2e9afc26 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -196,7 +196,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig, settingService) + userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) + userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 763ed829..c1f54ab6 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -30,6 +30,14 @@ const ( // __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" +// UMQ(用户消息队列)模式常量 +const ( + // UMQModeSerialize: 账号级串行锁 + RPM 自适应延迟 + UMQModeSerialize = "serialize" + // UMQModeThrottle: 仅 RPM 自适应前置延迟,不阻塞并发 + UMQModeThrottle = "throttle" +) + // 连接池隔离策略常量 // 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗 const ( @@ -455,6 +463,52 @@ type GatewayConfig struct { UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"` // ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL(秒) ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"` + + // UserMessageQueue: 用户消息串行队列配置 + // 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟 + UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"` +} + +// UserMessageQueueConfig 用户消息串行队列配置 +// 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送 +type UserMessageQueueConfig struct { + // Mode: 模式选择 + // "serialize" = 账号级串行锁 + RPM 自适应延迟 + // "throttle" = 仅 RPM 自适应前置延迟,不阻塞并发 + // "" = 禁用(默认) + Mode string `mapstructure:"mode"` + // Enabled: 已废弃,仅向后兼容(等同于 mode: "serialize") + Enabled bool `mapstructure:"enabled"` + // LockTTLMs: 串行锁 TTL(毫秒),应大于最长请求时间 + LockTTLMs int `mapstructure:"lock_ttl_ms"` + // WaitTimeoutMs: 等待获取锁的超时时间(毫秒) + WaitTimeoutMs int `mapstructure:"wait_timeout_ms"` + // MinDelayMs: RPM 自适应延迟下限(毫秒) + MinDelayMs int `mapstructure:"min_delay_ms"` + // MaxDelayMs: RPM 自适应延迟上限(毫秒) + MaxDelayMs int `mapstructure:"max_delay_ms"` + // CleanupIntervalSeconds: 孤儿锁清理间隔(秒),0 表示禁用 + CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"` +} + +// WaitTimeout 返回等待超时的 time.Duration +func (c *UserMessageQueueConfig) WaitTimeout() time.Duration { + if c.WaitTimeoutMs <= 0 { + return 30 * time.Second + } + return time.Duration(c.WaitTimeoutMs) * time.Millisecond +} + +// GetEffectiveMode 返回生效的模式 +// 注意:Mode 字段已在 load() 中做过白名单校验和规范化,此处无需重复验证 +func (c *UserMessageQueueConfig) GetEffectiveMode() string { + if c.Mode == UMQModeSerialize || c.Mode == UMQModeThrottle { + return c.Mode + } + if c.Enabled { + return UMQModeSerialize // 向后兼容 + } + return "" } // GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。 @@ -994,6 +1048,14 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds } + // Normalize UMQ mode: 白名单校验,非法值在加载时一次性 warn 并清空 + if m := cfg.Gateway.UserMessageQueue.Mode; m != "" && m != UMQModeSerialize && m != UMQModeThrottle { + slog.Warn("invalid user_message_queue mode, disabling", + "mode", m, + "valid_modes", []string{UMQModeSerialize, UMQModeThrottle}) + cfg.Gateway.UserMessageQueue.Mode = "" + } + // Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256) cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey) if cfg.Totp.EncryptionKey == "" { @@ -1372,6 +1434,14 @@ func setDefaults() { viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30) viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15) // TLS指纹伪装配置(默认关闭,需要账号级别单独启用) + // 用户消息串行队列默认值 + viper.SetDefault("gateway.user_message_queue.enabled", false) + viper.SetDefault("gateway.user_message_queue.lock_ttl_ms", 120000) + viper.SetDefault("gateway.user_message_queue.wait_timeout_ms", 30000) + viper.SetDefault("gateway.user_message_queue.min_delay_ms", 200) + viper.SetDefault("gateway.user_message_queue.max_delay_ms", 2000) + viper.SetDefault("gateway.user_message_queue.cleanup_interval_seconds", 60) + viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("concurrency.ping_interval", 10) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index f8298067..1c34f537 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -216,6 +216,10 @@ func AccountFromServiceShallow(a *service.Account) *Account { buffer := a.GetRPMStickyBuffer() out.RPMStickyBuffer = &buffer } + // 用户消息队列模式 + if mode := a.GetUserMsgQueueMode(); mode != "" { + out.UserMsgQueueMode = &mode + } // TLS指纹伪装开关 if a.IsTLSFingerprintEnabled() { enabled := true diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index b5c0640f..e9235797 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -155,9 +155,10 @@ type Account struct { // RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效) // 从 extra 字段提取,方便前端显示和编辑 - BaseRPM *int `json:"base_rpm,omitempty"` - RPMStrategy *string `json:"rpm_strategy,omitempty"` - RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"` + BaseRPM *int `json:"base_rpm,omitempty"` + RPMStrategy *string `json:"rpm_strategy,omitempty"` + RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"` + UserMsgQueueMode *string `json:"user_msg_queue_mode,omitempty"` // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效) // 从 extra 字段提取,方便前端显示和编辑 diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 2bd59f32..8d39c767 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -45,6 +45,7 @@ type GatewayHandler struct { usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper + userMsgQueueHelper *UserMsgQueueHelper maxAccountSwitches int maxAccountSwitchesGemini int cfg *config.Config @@ -63,6 +64,7 @@ func NewGatewayHandler( apiKeyService *service.APIKeyService, usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, + userMsgQueueService *service.UserMessageQueueService, cfg *config.Config, settingService *service.SettingService, ) *GatewayHandler { @@ -78,6 +80,13 @@ func NewGatewayHandler( maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini } } + + // 初始化用户消息串行队列 helper + var umqHelper *UserMsgQueueHelper + if userMsgQueueService != nil && cfg != nil { + umqHelper = NewUserMsgQueueHelper(userMsgQueueService, SSEPingFormatClaude, pingInterval) + } + return &GatewayHandler{ gatewayService: gatewayService, geminiCompatService: geminiCompatService, @@ -89,6 +98,7 @@ func NewGatewayHandler( usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), + userMsgQueueHelper: umqHelper, maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, cfg: cfg, @@ -566,6 +576,58 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + // ===== 用户消息串行队列 START ===== + var queueRelease func() + umqMode := h.getUserMsgQueueMode(account, parsedReq) + + switch umqMode { + case config.UMQModeSerialize: + // 串行模式:获取锁 + RPM 延迟 + 释放(当前行为不变) + baseRPM := account.GetBaseRPM() + release, qErr := h.userMsgQueueHelper.AcquireWithWait( + c, account.ID, baseRPM, reqStream, &streamStarted, + h.cfg.Gateway.UserMessageQueue.WaitTimeout(), + reqLog, + ) + if qErr != nil { + // fail-open: 记录 warn,不阻止请求 + reqLog.Warn("gateway.umq_acquire_failed", + zap.Int64("account_id", account.ID), + zap.Error(qErr), + ) + } else { + queueRelease = release + } + + case config.UMQModeThrottle: + // 软性限速:仅施加 RPM 自适应延迟,不阻塞并发 + baseRPM := account.GetBaseRPM() + if tErr := h.userMsgQueueHelper.ThrottleWithPing( + c, account.ID, baseRPM, reqStream, &streamStarted, + h.cfg.Gateway.UserMessageQueue.WaitTimeout(), + reqLog, + ); tErr != nil { + reqLog.Warn("gateway.umq_throttle_failed", + zap.Int64("account_id", account.ID), + zap.Error(tErr), + ) + } + + default: + if umqMode != "" { + reqLog.Warn("gateway.umq_unknown_mode", + zap.String("mode", umqMode), + zap.Int64("account_id", account.ID), + ) + } + } + + // 用 wrapReleaseOnDone 确保 context 取消时自动释放(仅 serialize 模式有 queueRelease) + queueRelease = wrapReleaseOnDone(c.Request.Context(), queueRelease) + // 注入回调到 ParsedRequest:使用外层 wrapper 以便提前清理 AfterFunc + parsedReq.OnUpstreamAccepted = queueRelease + // ===== 用户消息串行队列 END ===== + // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() @@ -577,6 +639,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } else { result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) } + + // 兜底释放串行锁(正常情况已通过回调提前释放) + if queueRelease != nil { + queueRelease() + } + // 清理回调引用,防止 failover 重试时旧回调被错误调用 + parsedReq.OnUpstreamAccepted = nil + if accountReleaseFunc != nil { accountReleaseFunc() } @@ -1431,3 +1501,24 @@ func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { }() task(ctx) } + +// getUserMsgQueueMode 获取当前请求的 UMQ 模式 +// 返回 "serialize" | "throttle" | "" +func (h *GatewayHandler) getUserMsgQueueMode(account *service.Account, parsed *service.ParsedRequest) string { + if h.userMsgQueueHelper == nil { + return "" + } + // 仅适用于 Anthropic OAuth/SetupToken 账号 + if !account.IsAnthropicOAuthOrSetupToken() { + return "" + } + if !service.IsRealUserMessage(parsed) { + return "" + } + // 账号级模式优先,fallback 到全局配置 + mode := account.GetUserMsgQueueMode() + if mode == "" { + mode = h.cfg.Gateway.UserMessageQueue.GetEffectiveMode() + } + return mode +} diff --git a/backend/internal/handler/user_msg_queue_helper.go b/backend/internal/handler/user_msg_queue_helper.go new file mode 100644 index 00000000..50449b13 --- /dev/null +++ b/backend/internal/handler/user_msg_queue_helper.go @@ -0,0 +1,237 @@ +package handler + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// UserMsgQueueHelper 用户消息串行队列 Handler 层辅助 +// 复用 ConcurrencyHelper 的退避 + SSE ping 模式 +type UserMsgQueueHelper struct { + queueService *service.UserMessageQueueService + pingFormat SSEPingFormat + pingInterval time.Duration +} + +// NewUserMsgQueueHelper 创建用户消息串行队列辅助 +func NewUserMsgQueueHelper( + queueService *service.UserMessageQueueService, + pingFormat SSEPingFormat, + pingInterval time.Duration, +) *UserMsgQueueHelper { + if pingInterval <= 0 { + pingInterval = defaultPingInterval + } + return &UserMsgQueueHelper{ + queueService: queueService, + pingFormat: pingFormat, + pingInterval: pingInterval, + } +} + +// AcquireWithWait 等待获取串行锁,流式请求期间发送 SSE ping +// 返回的 releaseFunc 内部使用 sync.Once,确保只执行一次释放 +func (h *UserMsgQueueHelper) AcquireWithWait( + c *gin.Context, + accountID int64, + baseRPM int, + isStream bool, + streamStarted *bool, + timeout time.Duration, + reqLog *zap.Logger, +) (releaseFunc func(), err error) { + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + // 先尝试立即获取 + result, err := h.queueService.TryAcquire(ctx, accountID) + if err != nil { + return nil, err // fail-open 已在 service 层处理 + } + + if result.Acquired { + // 获取成功,执行 RPM 自适应延迟 + if err := h.queueService.EnforceDelay(ctx, accountID, baseRPM); err != nil { + if ctx.Err() != nil { + // 延迟期间 context 取消,释放锁 + bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second) + _ = h.queueService.Release(bgCtx, accountID, result.RequestID) + bgCancel() + return nil, ctx.Err() + } + } + reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID)) + return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil + } + + // 需要等待:指数退避轮询 + return h.waitForLockWithPing(c, ctx, accountID, baseRPM, isStream, streamStarted, reqLog) +} + +// waitForLockWithPing 等待获取锁,流式请求期间发送 SSE ping +func (h *UserMsgQueueHelper) waitForLockWithPing( + c *gin.Context, + ctx context.Context, + accountID int64, + baseRPM int, + isStream bool, + streamStarted *bool, + reqLog *zap.Logger, +) (func(), error) { + needPing := isStream && h.pingFormat != "" + + var flusher http.Flusher + if needPing { + var ok bool + flusher, ok = c.Writer.(http.Flusher) + if !ok { + needPing = false + } + } + + var pingCh <-chan time.Time + if needPing { + pingTicker := time.NewTicker(h.pingInterval) + defer pingTicker.Stop() + pingCh = pingTicker.C + } + + backoff := initialBackoff + timer := time.NewTimer(backoff) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("umq wait timeout for account %d", accountID) + + case <-pingCh: + if !*streamStarted { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + *streamStarted = true + } + if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil { + return nil, err + } + flusher.Flush() + + case <-timer.C: + result, err := h.queueService.TryAcquire(ctx, accountID) + if err != nil { + return nil, err + } + if result.Acquired { + // 获取成功,执行 RPM 自适应延迟 + if delayErr := h.queueService.EnforceDelay(ctx, accountID, baseRPM); delayErr != nil { + if ctx.Err() != nil { + bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second) + _ = h.queueService.Release(bgCtx, accountID, result.RequestID) + bgCancel() + return nil, ctx.Err() + } + } + reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID)) + return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil + } + backoff = nextBackoff(backoff) + timer.Reset(backoff) + } + } +} + +// makeReleaseFunc 创建锁释放函数(使用 sync.Once 确保只执行一次) +func (h *UserMsgQueueHelper) makeReleaseFunc(accountID int64, requestID string, reqLog *zap.Logger) func() { + var once sync.Once + return func() { + once.Do(func() { + bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer bgCancel() + if err := h.queueService.Release(bgCtx, accountID, requestID); err != nil { + reqLog.Warn("gateway.umq_release_failed", + zap.Int64("account_id", accountID), + zap.Error(err), + ) + } else { + reqLog.Debug("gateway.umq_lock_released", zap.Int64("account_id", accountID)) + } + }) + } +} + +// ThrottleWithPing 软性限速模式:施加 RPM 自适应延迟,流式期间发送 SSE ping +// 不获取串行锁,不阻塞并发。返回后即可转发请求。 +func (h *UserMsgQueueHelper) ThrottleWithPing( + c *gin.Context, + accountID int64, + baseRPM int, + isStream bool, + streamStarted *bool, + timeout time.Duration, + reqLog *zap.Logger, +) error { + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + delay := h.queueService.CalculateRPMAwareDelay(ctx, accountID, baseRPM) + if delay <= 0 { + return nil + } + + reqLog.Debug("gateway.umq_throttle_delay", + zap.Int64("account_id", accountID), + zap.Duration("delay", delay), + ) + + // 延迟期间发送 SSE ping(复用 waitForLockWithPing 的 ping 逻辑) + needPing := isStream && h.pingFormat != "" + var flusher http.Flusher + if needPing { + flusher, _ = c.Writer.(http.Flusher) + if flusher == nil { + needPing = false + } + } + + var pingCh <-chan time.Time + if needPing { + pingTicker := time.NewTicker(h.pingInterval) + defer pingTicker.Stop() + pingCh = pingTicker.C + } + + timer := time.NewTimer(delay) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-pingCh: + // SSE ping 逻辑(与 waitForLockWithPing 一致) + if !*streamStarted { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + *streamStarted = true + } + if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil { + return err + } + flusher.Flush() + case <-timer.C: + return nil + } + } +} diff --git a/backend/internal/repository/ops_repo_latency_histogram_buckets.go b/backend/internal/repository/ops_repo_latency_histogram_buckets.go index cd5bed37..e56903f1 100644 --- a/backend/internal/repository/ops_repo_latency_histogram_buckets.go +++ b/backend/internal/repository/ops_repo_latency_histogram_buckets.go @@ -35,12 +35,12 @@ func latencyHistogramRangeCaseExpr(column string) string { if b.upperMs <= 0 { continue } - _, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label)) + fmt.Fprintf(&sb, "\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label) } // Default bucket. last := latencyHistogramBuckets[len(latencyHistogramBuckets)-1] - _, _ = sb.WriteString(fmt.Sprintf("\tELSE '%s'\n", last.label)) + fmt.Fprintf(&sb, "\tELSE '%s'\n", last.label) _, _ = sb.WriteString("END") return sb.String() } @@ -54,11 +54,11 @@ func latencyHistogramRangeOrderCaseExpr(column string) string { if b.upperMs <= 0 { continue } - _, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN %d\n", column, b.upperMs, order)) + fmt.Fprintf(&sb, "\tWHEN %s < %d THEN %d\n", column, b.upperMs, order) order++ } - _, _ = sb.WriteString(fmt.Sprintf("\tELSE %d\n", order)) + fmt.Fprintf(&sb, "\tELSE %d\n", order) _, _ = sb.WriteString("END") return sb.String() } diff --git a/backend/internal/repository/user_msg_queue_cache.go b/backend/internal/repository/user_msg_queue_cache.go new file mode 100644 index 00000000..bb3ee698 --- /dev/null +++ b/backend/internal/repository/user_msg_queue_cache.go @@ -0,0 +1,186 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +// Redis Key 模式(使用 hash tag 确保 Redis Cluster 下同一 accountID 的 key 落入同一 slot) +// 格式: umq:{accountID}:lock / umq:{accountID}:last +const ( + umqKeyPrefix = "umq:" + umqLockSuffix = ":lock" // STRING (requestID), PX lockTtlMs + umqLastSuffix = ":last" // STRING (毫秒时间戳), EX 60s +) + +// Lua 脚本:原子获取串行锁(SET NX PX + 重入安全) +var acquireLockScript = redis.NewScript(` +local cur = redis.call('GET', KEYS[1]) +if cur == ARGV[1] then + redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[2])) + return 1 +end +if cur ~= false then return 0 end +redis.call('SET', KEYS[1], ARGV[1], 'PX', tonumber(ARGV[2])) +return 1 +`) + +// Lua 脚本:原子释放锁 + 记录完成时间(使用 Redis TIME 避免时钟偏差) +var releaseLockScript = redis.NewScript(` +local cur = redis.call('GET', KEYS[1]) +if cur == ARGV[1] then + redis.call('DEL', KEYS[1]) + local t = redis.call('TIME') + local ms = tonumber(t[1])*1000 + math.floor(tonumber(t[2])/1000) + redis.call('SET', KEYS[2], ms, 'EX', 60) + return 1 +end +return 0 +`) + +// Lua 脚本:原子清理孤儿锁(仅在 PTTL == -1 时删除,避免 TOCTOU 竞态误删合法锁) +var forceReleaseLockScript = redis.NewScript(` +local pttl = redis.call('PTTL', KEYS[1]) +if pttl == -1 then + redis.call('DEL', KEYS[1]) + return 1 +end +return 0 +`) + +type userMsgQueueCache struct { + rdb *redis.Client +} + +// NewUserMsgQueueCache 创建用户消息队列缓存 +func NewUserMsgQueueCache(rdb *redis.Client) service.UserMsgQueueCache { + return &userMsgQueueCache{rdb: rdb} +} + +func umqLockKey(accountID int64) string { + // 格式: umq:{123}:lock — 花括号确保 Redis Cluster hash tag 生效 + return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLockSuffix +} + +func umqLastKey(accountID int64) string { + // 格式: umq:{123}:last — 与 lockKey 同一 hash slot + return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLastSuffix +} + +// umqScanPattern 用于 SCAN 扫描锁 key +func umqScanPattern() string { + return umqKeyPrefix + "{*}" + umqLockSuffix +} + +// AcquireLock 尝试获取账号级串行锁 +func (c *userMsgQueueCache) AcquireLock(ctx context.Context, accountID int64, requestID string, lockTtlMs int) (bool, error) { + key := umqLockKey(accountID) + result, err := acquireLockScript.Run(ctx, c.rdb, []string{key}, requestID, lockTtlMs).Int() + if err != nil { + return false, fmt.Errorf("umq acquire lock: %w", err) + } + return result == 1, nil +} + +// ReleaseLock 释放锁并记录完成时间 +func (c *userMsgQueueCache) ReleaseLock(ctx context.Context, accountID int64, requestID string) (bool, error) { + lockKey := umqLockKey(accountID) + lastKey := umqLastKey(accountID) + result, err := releaseLockScript.Run(ctx, c.rdb, []string{lockKey, lastKey}, requestID).Int() + if err != nil { + return false, fmt.Errorf("umq release lock: %w", err) + } + return result == 1, nil +} + +// GetLastCompletedMs 获取上次完成时间(毫秒时间戳) +func (c *userMsgQueueCache) GetLastCompletedMs(ctx context.Context, accountID int64) (int64, error) { + key := umqLastKey(accountID) + val, err := c.rdb.Get(ctx, key).Result() + if errors.Is(err, redis.Nil) { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("umq get last completed: %w", err) + } + ms, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return 0, fmt.Errorf("umq parse last completed: %w", err) + } + return ms, nil +} + +// ForceReleaseLock 原子清理孤儿锁(仅在 PTTL == -1 时删除,防止 TOCTOU 竞态误删合法锁) +func (c *userMsgQueueCache) ForceReleaseLock(ctx context.Context, accountID int64) error { + key := umqLockKey(accountID) + _, err := forceReleaseLockScript.Run(ctx, c.rdb, []string{key}).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return fmt.Errorf("umq force release lock: %w", err) + } + return nil +} + +// ScanLockKeys 扫描所有锁 key,仅返回 PTTL == -1(无过期时间)的孤儿锁 accountID 列表 +// 正常的锁都有 PX 过期时间,PTTL == -1 表示异常状态(如 Redis 故障恢复后丢失 TTL) +func (c *userMsgQueueCache) ScanLockKeys(ctx context.Context, maxCount int) ([]int64, error) { + var accountIDs []int64 + var cursor uint64 + pattern := umqScanPattern() + + for { + keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, 100).Result() + if err != nil { + return nil, fmt.Errorf("umq scan lock keys: %w", err) + } + for _, key := range keys { + // 检查 PTTL:只清理 PTTL == -1(无过期时间)的异常锁 + pttl, err := c.rdb.PTTL(ctx, key).Result() + if err != nil { + continue + } + // PTTL 返回值:-2 = key 不存在,-1 = 无过期时间,>0 = 剩余毫秒 + // go-redis 对哨兵值 -1/-2 不乘精度系数,直接返回 time.Duration(-1)/-2 + // 只删除 -1(无过期时间的异常锁),跳过正常持有的锁 + if pttl != time.Duration(-1) { + continue + } + + // 从 key 中提取 accountID: umq:{123}:lock → 提取 {} 内的数字 + openBrace := strings.IndexByte(key, '{') + closeBrace := strings.IndexByte(key, '}') + if openBrace < 0 || closeBrace <= openBrace+1 { + continue + } + idStr := key[openBrace+1 : closeBrace] + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + continue + } + accountIDs = append(accountIDs, id) + if len(accountIDs) >= maxCount { + return accountIDs, nil + } + } + cursor = nextCursor + if cursor == 0 { + break + } + } + return accountIDs, nil +} + +// GetCurrentTimeMs 通过 Redis TIME 命令获取当前服务器时间(毫秒),确保与锁记录的时间源一致 +func (c *userMsgQueueCache) GetCurrentTimeMs(ctx context.Context) (int64, error) { + t, err := c.rdb.Time(ctx).Result() + if err != nil { + return 0, fmt.Errorf("umq get redis time: %w", err) + } + return t.UnixMilli(), nil +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index ee796d98..2e35e0a0 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -80,6 +80,7 @@ var ProviderSet = wire.NewSet( ProvideConcurrencyCache, ProvideSessionLimitCache, NewRPMCache, + NewUserMsgQueueCache, NewDashboardCache, NewEmailCache, NewIdentityCache, diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index c76c817e..81e91aeb 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/domain" ) @@ -1032,6 +1033,26 @@ func (a *Account) IsTLSFingerprintEnabled() bool { return false } +// GetUserMsgQueueMode 获取用户消息队列模式 +// "serialize" = 串行队列, "throttle" = 软性限速, "" = 未设置(使用全局配置) +func (a *Account) GetUserMsgQueueMode() string { + if a.Extra == nil { + return "" + } + // 优先读取新字段 user_msg_queue_mode(白名单校验,非法值视为未设置) + if mode, ok := a.Extra["user_msg_queue_mode"].(string); ok && mode != "" { + if mode == config.UMQModeSerialize || mode == config.UMQModeThrottle { + return mode + } + return "" // 非法值 fallback 到全局配置 + } + // 向后兼容: user_msg_queue_enabled: true → "serialize" + if enabled, ok := a.Extra["user_msg_queue_enabled"].(bool); ok && enabled { + return config.UMQModeSerialize + } + return "" +} + // IsSessionIDMaskingEnabled 检查是否启用会话ID伪装 // 仅适用于 Anthropic OAuth/SetupToken 类型账号 // 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID, diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index f8096a0e..b546fe85 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -61,6 +61,10 @@ type ParsedRequest struct { ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) MaxTokens int // max_tokens 值(用于探测请求拦截) SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变) + + // OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁) + // 流式请求在收到 2xx 响应头后调用,避免持锁等流完成 + OnUpstreamAccepted func() } // ParseGatewayRequest 解析网关请求体并返回结构化结果。 diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 3323f868..48c69881 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -4305,6 +4305,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 处理正常响应 + + // 触发上游接受回调(提前释放串行锁,不等流完成) + if parsed.OnUpstreamAccepted != nil { + parsed.OnUpstreamAccepted() + } + var usage *ClaudeUsage var firstTokenMs *int var clientDisconnect bool diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 64871b9a..03556627 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -994,7 +994,7 @@ func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string { } } // singleflight: 同一时刻只有一个 goroutine 查询 DB,其余复用结果 - result, _, _ := minVersionSF.Do("min_version", func() (any, error) { + result, err, _ := minVersionSF.Do("min_version", func() (any, error) { // 二次检查,避免排队的 goroutine 重复查询 if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok { if time.Now().UnixNano() < cached.expiresAt { @@ -1020,10 +1020,14 @@ func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string { }) return value, nil }) - if s, ok := result.(string); ok { - return s + if err != nil { + return "" } - return "" + ver, ok := result.(string) + if !ok { + return "" + } + return ver } // SetStreamTimeoutSettings 设置流超时处理配置 diff --git a/backend/internal/service/user_msg_queue_service.go b/backend/internal/service/user_msg_queue_service.go new file mode 100644 index 00000000..a0ce95a8 --- /dev/null +++ b/backend/internal/service/user_msg_queue_service.go @@ -0,0 +1,318 @@ +package service + +import ( + "context" + cryptorand "crypto/rand" + "encoding/hex" + "fmt" + "math" + "math/rand/v2" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// UserMsgQueueCache 用户消息串行队列 Redis 缓存接口 +type UserMsgQueueCache interface { + // AcquireLock 尝试获取账号级串行锁 + AcquireLock(ctx context.Context, accountID int64, requestID string, lockTtlMs int) (acquired bool, err error) + // ReleaseLock 释放锁并记录完成时间 + ReleaseLock(ctx context.Context, accountID int64, requestID string) (released bool, err error) + // GetLastCompletedMs 获取上次完成时间(毫秒时间戳,Redis TIME 源) + GetLastCompletedMs(ctx context.Context, accountID int64) (int64, error) + // GetCurrentTimeMs 获取 Redis 服务器当前时间(毫秒),与 ReleaseLock 记录的时间源一致 + GetCurrentTimeMs(ctx context.Context) (int64, error) + // ForceReleaseLock 强制释放锁(孤儿锁清理) + ForceReleaseLock(ctx context.Context, accountID int64) error + // ScanLockKeys 扫描 PTTL == -1 的孤儿锁 key,返回 accountID 列表 + ScanLockKeys(ctx context.Context, maxCount int) ([]int64, error) +} + +// QueueLockResult 锁获取结果 +type QueueLockResult struct { + Acquired bool + RequestID string +} + +// UserMessageQueueService 用户消息串行队列服务 +// 对真实用户消息实施账号级串行化 + RPM 自适应延迟 +type UserMessageQueueService struct { + cache UserMsgQueueCache + rpmCache RPMCache + cfg *config.UserMessageQueueConfig + stopCh chan struct{} // graceful shutdown + stopOnce sync.Once // 确保 Stop() 并发安全 +} + +// NewUserMessageQueueService 创建用户消息串行队列服务 +func NewUserMessageQueueService(cache UserMsgQueueCache, rpmCache RPMCache, cfg *config.UserMessageQueueConfig) *UserMessageQueueService { + return &UserMessageQueueService{ + cache: cache, + rpmCache: rpmCache, + cfg: cfg, + stopCh: make(chan struct{}), + } +} + +// IsRealUserMessage 检测是否为真实用户消息(非 tool_result) +// 与 claude-relay-service 的检测逻辑一致: +// 1. messages 非空 +// 2. 最后一条消息 role == "user" +// 3. 最后一条消息 content(如果是数组)中不含 type:"tool_result" / "tool_use_result" +func IsRealUserMessage(parsed *ParsedRequest) bool { + if parsed == nil || len(parsed.Messages) == 0 { + return false + } + + lastMsg := parsed.Messages[len(parsed.Messages)-1] + msgMap, ok := lastMsg.(map[string]any) + if !ok { + return false + } + + role, _ := msgMap["role"].(string) + if role != "user" { + return false + } + + // 检查 content 是否包含 tool_result 类型 + content, ok := msgMap["content"] + if !ok { + return true // 没有 content 字段,视为普通用户消息 + } + + contentArr, ok := content.([]any) + if !ok { + return true // content 不是数组(可能是 string),视为普通用户消息 + } + + for _, item := range contentArr { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType == "tool_result" || itemType == "tool_use_result" { + return false + } + } + return true +} + +// TryAcquire 尝试立即获取串行锁 +func (s *UserMessageQueueService) TryAcquire(ctx context.Context, accountID int64) (*QueueLockResult, error) { + if s.cache == nil { + return &QueueLockResult{Acquired: true}, nil // fail-open + } + + requestID := generateUMQRequestID() + lockTTL := s.cfg.LockTTLMs + if lockTTL <= 0 { + lockTTL = 120000 + } + + acquired, err := s.cache.AcquireLock(ctx, accountID, requestID, lockTTL) + if err != nil { + logger.LegacyPrintf("service.umq", "AcquireLock failed for account %d: %v", accountID, err) + return &QueueLockResult{Acquired: true}, nil // fail-open + } + + return &QueueLockResult{ + Acquired: acquired, + RequestID: requestID, + }, nil +} + +// Release 释放串行锁 +func (s *UserMessageQueueService) Release(ctx context.Context, accountID int64, requestID string) error { + if s.cache == nil || requestID == "" { + return nil + } + released, err := s.cache.ReleaseLock(ctx, accountID, requestID) + if err != nil { + logger.LegacyPrintf("service.umq", "ReleaseLock failed for account %d: %v", accountID, err) + return err + } + if !released { + logger.LegacyPrintf("service.umq", "ReleaseLock no-op for account %d (requestID mismatch or expired)", accountID) + } + return nil +} + +// EnforceDelay 根据 RPM 负载执行自适应延迟 +// 使用 Redis TIME 确保与 releaseLockScript 记录的时间源一致 +func (s *UserMessageQueueService) EnforceDelay(ctx context.Context, accountID int64, baseRPM int) error { + if s.cache == nil { + return nil + } + + // 先检查历史记录:没有历史则无需延迟,避免不必要的 RPM 查询 + lastMs, err := s.cache.GetLastCompletedMs(ctx, accountID) + if err != nil { + logger.LegacyPrintf("service.umq", "GetLastCompletedMs failed for account %d: %v", accountID, err) + return nil // fail-open + } + if lastMs == 0 { + return nil // 没有历史记录,无需延迟 + } + + delay := s.CalculateRPMAwareDelay(ctx, accountID, baseRPM) + if delay <= 0 { + return nil + } + + // 获取 Redis 当前时间(与 lastMs 同源,避免时钟偏差) + nowMs, err := s.cache.GetCurrentTimeMs(ctx) + if err != nil { + logger.LegacyPrintf("service.umq", "GetCurrentTimeMs failed: %v", err) + return nil // fail-open + } + + elapsed := time.Duration(nowMs-lastMs) * time.Millisecond + if elapsed < 0 { + // 时钟异常(Redis 故障转移等),fail-open + return nil + } + remaining := delay - elapsed + if remaining <= 0 { + return nil + } + + // 执行延迟 + timer := time.NewTimer(remaining) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +// CalculateRPMAwareDelay 根据当前 RPM 负载计算自适应延迟 +// ratio = currentRPM / baseRPM +// ratio < 0.5 → MinDelay +// 0.5 ≤ ratio < 0.8 → 线性插值 MinDelay..MaxDelay +// ratio ≥ 0.8 → MaxDelay +// 返回值包含 ±15% 随机抖动(anti-detection + 避免惊群效应) +func (s *UserMessageQueueService) CalculateRPMAwareDelay(ctx context.Context, accountID int64, baseRPM int) time.Duration { + minDelay := time.Duration(s.cfg.MinDelayMs) * time.Millisecond + maxDelay := time.Duration(s.cfg.MaxDelayMs) * time.Millisecond + + if minDelay <= 0 { + minDelay = 200 * time.Millisecond + } + if maxDelay <= 0 { + maxDelay = 2000 * time.Millisecond + } + // 防止配置错误:minDelay > maxDelay 时交换 + if minDelay > maxDelay { + minDelay, maxDelay = maxDelay, minDelay + } + + var baseDelay time.Duration + + if baseRPM <= 0 || s.rpmCache == nil { + baseDelay = minDelay + } else { + currentRPM, err := s.rpmCache.GetRPM(ctx, accountID) + if err != nil { + logger.LegacyPrintf("service.umq", "GetRPM failed for account %d: %v", accountID, err) + baseDelay = minDelay // fail-open + } else { + ratio := float64(currentRPM) / float64(baseRPM) + if ratio < 0.5 { + baseDelay = minDelay + } else if ratio >= 0.8 { + baseDelay = maxDelay + } else { + // 线性插值: 0.5 → minDelay, 0.8 → maxDelay + t := (ratio - 0.5) / 0.3 + interpolated := float64(minDelay) + t*(float64(maxDelay)-float64(minDelay)) + baseDelay = time.Duration(math.Round(interpolated)) + } + } + } + + // ±15% 随机抖动 + return applyJitter(baseDelay, 0.15) +} + +// StartCleanupWorker 启动孤儿锁清理 worker +// 定期 SCAN umq:*:lock 并清理 PTTL == -1 的异常锁(PTTL 检查在 cache.ScanLockKeys 内完成) +func (s *UserMessageQueueService) StartCleanupWorker(interval time.Duration) { + if s == nil || s.cache == nil || interval <= 0 { + return + } + + runCleanup := func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + accountIDs, err := s.cache.ScanLockKeys(ctx, 1000) + if err != nil { + logger.LegacyPrintf("service.umq", "Cleanup scan failed: %v", err) + return + } + + cleaned := 0 + for _, accountID := range accountIDs { + cleanCtx, cleanCancel := context.WithTimeout(context.Background(), 2*time.Second) + if err := s.cache.ForceReleaseLock(cleanCtx, accountID); err != nil { + logger.LegacyPrintf("service.umq", "Cleanup force release failed for account %d: %v", accountID, err) + } else { + cleaned++ + } + cleanCancel() + } + + if cleaned > 0 { + logger.LegacyPrintf("service.umq", "Cleanup completed: released %d orphaned locks", cleaned) + } + } + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-s.stopCh: + return + case <-ticker.C: + runCleanup() + } + } + }() +} + +// Stop 停止后台 cleanup worker +func (s *UserMessageQueueService) Stop() { + if s != nil && s.stopCh != nil { + s.stopOnce.Do(func() { + close(s.stopCh) + }) + } +} + +// applyJitter 对延迟值施加 ±jitterPct 的随机抖动 +// 使用 math/rand/v2(Go 1.22+ 自动使用 crypto/rand 种子),与 nextBackoff 一致 +// 例如 applyJitter(200ms, 0.15) 返回 170ms ~ 230ms +func applyJitter(d time.Duration, jitterPct float64) time.Duration { + if d <= 0 || jitterPct <= 0 { + return d + } + // [-jitterPct, +jitterPct] + jitter := (rand.Float64()*2 - 1) * jitterPct + return time.Duration(float64(d) * (1 + jitter)) +} + +// generateUMQRequestID 生成唯一请求 ID(与 generateRequestID 一致的 fallback 模式) +func generateUMQRequestID() string { + b := make([]byte, 16) + if _, err := cryptorand.Read(b); err != nil { + return fmt.Sprintf("%x", time.Now().UnixNano()) + } + return hex.EncodeToString(b) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index b0eccb71..c7185190 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -110,6 +110,15 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi return svc } +// ProvideUserMessageQueueService 创建用户消息串行队列服务并启动清理 worker +func ProvideUserMessageQueueService(cache UserMsgQueueCache, rpmCache RPMCache, cfg *config.Config) *UserMessageQueueService { + svc := NewUserMessageQueueService(cache, rpmCache, &cfg.Gateway.UserMessageQueue) + if cfg.Gateway.UserMessageQueue.CleanupIntervalSeconds > 0 { + svc.StartCleanupWorker(time.Duration(cfg.Gateway.UserMessageQueue.CleanupIntervalSeconds) * time.Second) + } + return svc +} + // ProvideSchedulerSnapshotService creates and starts SchedulerSnapshotService. func ProvideSchedulerSnapshotService( cache SchedulerCache, @@ -348,6 +357,7 @@ var ProviderSet = wire.NewSet( NewSubscriptionService, wire.Bind(new(DefaultSubscriptionAssigner), new(*SubscriptionService)), ProvideConcurrencyService, + ProvideUserMessageQueueService, NewUsageRecordWorkerPool, ProvideSchedulerSnapshotService, NewIdentityService, diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 30c3d739..1c83e658 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -686,6 +686,27 @@ />

{{ t('admin.accounts.quotaControl.rpmLimit.stickyBufferHint') }}

+ + + + + +
+ +

+ {{ t('admin.accounts.quotaControl.rpmLimit.userMsgQueueHint') }} +

+
+
@@ -876,6 +897,12 @@ const rpmLimitEnabled = ref(false) const bulkBaseRpm = ref(null) const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') const bulkRpmStickyBuffer = ref(null) +const userMsgQueueMode = ref(null) +const umqModeOptions = computed(() => [ + { value: '', label: t('admin.accounts.quotaControl.rpmLimit.umqModeOff') }, + { value: 'throttle', label: t('admin.accounts.quotaControl.rpmLimit.umqModeThrottle') }, + { value: 'serialize', label: t('admin.accounts.quotaControl.rpmLimit.umqModeSerialize') }, +]) // All models list (combined Anthropic + OpenAI + Gemini) const allModels = [ @@ -1249,6 +1276,14 @@ const buildUpdatePayload = (): Record | null => { updates.extra = extra } + // UMQ mode(独立于 RPM 保存) + if (userMsgQueueMode.value !== null) { + if (!updates.extra) updates.extra = {} + const umqExtra = updates.extra as Record + umqExtra.user_msg_queue_mode = userMsgQueueMode.value // '' = 清除账号级覆盖 + umqExtra.user_msg_queue_enabled = false // 清理旧字段(JSONB merge) + } + return Object.keys(updates).length > 0 ? updates : null } @@ -1309,7 +1344,8 @@ const handleSubmit = async () => { enableRateMultiplier.value || enableStatus.value || enableGroups.value || - enableRpmLimit.value + enableRpmLimit.value || + userMsgQueueMode.value !== null if (!hasAnyFieldEnabled) { appStore.showError(t('admin.accounts.bulkEdit.noFieldsSelected')) @@ -1414,6 +1450,11 @@ watch( rateMultiplier.value = 1 status.value = 'active' groupIds.value = [] + rpmLimitEnabled.value = false + bulkBaseRpm.value = null + bulkRpmStrategy.value = 'tiered' + bulkRpmStickyBuffer.value = null + userMsgQueueMode.value = null // Reset mixed channel warning state showMixedChannelWarning.value = false diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 97a6fbce..75f04081 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1625,6 +1625,27 @@ />

{{ t('admin.accounts.quotaControl.rpmLimit.stickyBufferHint') }}

+ + + + +
+ +

+ {{ t('admin.accounts.quotaControl.rpmLimit.userMsgQueueHint') }} +

+
+ +
@@ -2489,6 +2510,12 @@ const rpmLimitEnabled = ref(false) const baseRpm = ref(null) const rpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') const rpmStickyBuffer = ref(null) +const userMsgQueueMode = ref('') +const umqModeOptions = computed(() => [ + { value: '', label: t('admin.accounts.quotaControl.rpmLimit.umqModeOff') }, + { value: 'throttle', label: t('admin.accounts.quotaControl.rpmLimit.umqModeThrottle') }, + { value: 'serialize', label: t('admin.accounts.quotaControl.rpmLimit.umqModeSerialize') }, +]) const tlsFingerprintEnabled = ref(false) const sessionIdMaskingEnabled = ref(false) const cacheTTLOverrideEnabled = ref(false) @@ -3117,6 +3144,7 @@ const resetForm = () => { baseRpm.value = null rpmStrategy.value = 'tiered' rpmStickyBuffer.value = null + userMsgQueueMode.value = '' tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false cacheTTLOverrideEnabled.value = false @@ -4035,6 +4063,11 @@ const handleAnthropicExchange = async (authCode: string) => { } } + // UMQ mode(独立于 RPM) + if (userMsgQueueMode.value) { + extra.user_msg_queue_mode = userMsgQueueMode.value + } + // Add TLS fingerprint settings if (tlsFingerprintEnabled.value) { extra.enable_tls_fingerprint = true @@ -4142,6 +4175,11 @@ const handleCookieAuth = async (sessionKey: string) => { } } + // UMQ mode(独立于 RPM) + if (userMsgQueueMode.value) { + extra.user_msg_queue_mode = userMsgQueueMode.value + } + // Add TLS fingerprint settings if (tlsFingerprintEnabled.value) { extra.enable_tls_fingerprint = true diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 184eff98..24166a5c 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1035,6 +1035,27 @@ />

{{ t('admin.accounts.quotaControl.rpmLimit.stickyBufferHint') }}

+ + + + +
+ +

+ {{ t('admin.accounts.quotaControl.rpmLimit.userMsgQueueHint') }} +

+
+ +
@@ -1347,6 +1368,12 @@ const rpmLimitEnabled = ref(false) const baseRpm = ref(null) const rpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') const rpmStickyBuffer = ref(null) +const userMsgQueueMode = ref('') +const umqModeOptions = computed(() => [ + { value: '', label: t('admin.accounts.quotaControl.rpmLimit.umqModeOff') }, + { value: 'throttle', label: t('admin.accounts.quotaControl.rpmLimit.umqModeThrottle') }, + { value: 'serialize', label: t('admin.accounts.quotaControl.rpmLimit.umqModeSerialize') }, +]) const tlsFingerprintEnabled = ref(false) const sessionIdMaskingEnabled = ref(false) const cacheTTLOverrideEnabled = ref(false) @@ -1810,6 +1837,7 @@ function loadQuotaControlSettings(account: Account) { baseRpm.value = null rpmStrategy.value = 'tiered' rpmStickyBuffer.value = null + userMsgQueueMode.value = '' tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false cacheTTLOverrideEnabled.value = false @@ -1841,6 +1869,9 @@ function loadQuotaControlSettings(account: Account) { rpmStickyBuffer.value = account.rpm_sticky_buffer ?? null } + // UMQ mode(独立于 RPM 加载,防止编辑无 RPM 账号时丢失已有配置) + userMsgQueueMode.value = account.user_msg_queue_mode ?? '' + // Load TLS fingerprint setting if (account.enable_tls_fingerprint === true) { tlsFingerprintEnabled.value = true @@ -2166,6 +2197,14 @@ const handleSubmit = async () => { delete newExtra.rpm_sticky_buffer } + // UMQ mode(独立于 RPM 保存) + if (userMsgQueueMode.value) { + newExtra.user_msg_queue_mode = userMsgQueueMode.value + } else { + delete newExtra.user_msg_queue_mode + } + delete newExtra.user_msg_queue_enabled // 清理旧字段 + // TLS fingerprint setting if (tlsFingerprintEnabled.value) { newExtra.enable_tls_fingerprint = true diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 01b7919a..fe51c331 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1864,7 +1864,12 @@ export default { strategyHint: 'Tiered: gradually restrict when exceeded; Sticky Exempt: existing sessions unrestricted', stickyBuffer: 'Sticky Buffer', stickyBufferPlaceholder: 'Default: 20% of base RPM', - stickyBufferHint: 'Extra requests allowed for sticky sessions after exceeding base RPM. Leave empty to use default (20% of base RPM, min 1)' + stickyBufferHint: 'Extra requests allowed for sticky sessions after exceeding base RPM. Leave empty to use default (20% of base RPM, min 1)', + userMsgQueue: 'User Message Rate Control', + userMsgQueueHint: 'Rate-limit user messages to avoid triggering upstream RPM limits', + umqModeOff: 'Off', + umqModeThrottle: 'Throttle', + umqModeSerialize: 'Serialize', }, tlsFingerprint: { label: 'TLS Fingerprint Simulation', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 3411d310..156349af 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2007,7 +2007,12 @@ export default { strategyHint: '三区模型: 超限后逐步限制; 粘性豁免: 已有会话不受限', stickyBuffer: '粘性缓冲区', stickyBufferPlaceholder: '默认: base RPM 的 20%', - stickyBufferHint: '超过 base RPM 后,粘性会话额外允许的请求数。为空则使用默认值(base RPM 的 20%,最小为 1)' + stickyBufferHint: '超过 base RPM 后,粘性会话额外允许的请求数。为空则使用默认值(base RPM 的 20%,最小为 1)', + userMsgQueue: '用户消息限速', + userMsgQueueHint: '对用户消息施加发送限制,避免触发上游 RPM 限制', + umqModeOff: '关闭', + umqModeThrottle: '软性限速', + umqModeSerialize: '串行队列', }, tlsFingerprint: { label: 'TLS 指纹模拟', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index ccdde8ae..f8c73bbd 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -665,6 +665,7 @@ export interface Account { base_rpm?: number | null rpm_strategy?: string | null rpm_sticky_buffer?: number | null + user_msg_queue_mode?: string | null // "serialize" | "throttle" | null // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效) enable_tls_fingerprint?: boolean | null From 067810fa9888e2fa226718f54e86cdf712d04267 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 2 Mar 2026 19:37:40 +0800 Subject: [PATCH 004/286] feat: custom menu pages with iframe embedding and CSP injection Add configurable custom menu items that appear in sidebar, each rendering an iframe-embedded external page. Includes shared URL builder with src_host/src_url tracking, CSP frame-src multi-origin deduplication, admin settings UI, and i18n support. chore: bump version to 0.1.87.19 Co-Authored-By: Claude Opus 4.6 --- .../internal/handler/admin/setting_handler.go | 62 +++++++ backend/internal/handler/dto/settings.go | 12 ++ backend/internal/handler/setting_handler.go | 17 ++ .../server/middleware/security_headers.go | 19 +- backend/internal/server/router.go | 91 +++++++++- backend/internal/service/domain_constants.go | 5 +- backend/internal/service/setting_service.go | 5 + backend/internal/service/settings_view.go | 2 + frontend/src/api/admin/settings.ts | 3 + frontend/src/components/layout/AppSidebar.vue | 74 +++++++- frontend/src/i18n/locales/en.ts | 29 +++ frontend/src/i18n/locales/zh.ts | 29 +++ frontend/src/router/index.ts | 11 ++ frontend/src/stores/app.ts | 1 + frontend/src/types/index.ts | 10 ++ frontend/src/utils/embedded-url.ts | 46 +++++ frontend/src/views/admin/SettingsView.vue | 164 +++++++++++++++++ frontend/src/views/user/CustomPageView.vue | 166 ++++++++++++++++++ .../views/user/PurchaseSubscriptionView.vue | 37 +--- tmp_api_admin_orders/[id]/cancel/route.ts | 25 +++ tmp_api_admin_orders/[id]/retry/route.ts | 25 +++ tmp_api_admin_orders/[id]/route.ts | 31 ++++ tmp_api_admin_orders/route.ts | 60 +++++++ tmp_api_orders/[id]/cancel/route.ts | 37 ++++ tmp_api_orders/[id]/route.ts | 50 ++++++ tmp_api_orders/my/route.ts | 46 +++++ tmp_api_orders/route.ts | 68 +++++++ 27 files changed, 1071 insertions(+), 54 deletions(-) create mode 100644 frontend/src/utils/embedded-url.ts create mode 100644 frontend/src/views/user/CustomPageView.vue create mode 100644 tmp_api_admin_orders/[id]/cancel/route.ts create mode 100644 tmp_api_admin_orders/[id]/retry/route.ts create mode 100644 tmp_api_admin_orders/[id]/route.ts create mode 100644 tmp_api_admin_orders/route.ts create mode 100644 tmp_api_orders/[id]/cancel/route.ts create mode 100644 tmp_api_orders/[id]/route.ts create mode 100644 tmp_api_orders/my/route.ts create mode 100644 tmp_api_orders/route.ts diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index e7da042c..eec403dc 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1,6 +1,9 @@ package admin import ( + "crypto/rand" + "encoding/hex" + "encoding/json" "fmt" "log" "net/http" @@ -20,6 +23,27 @@ import ( // semverPattern 预编译 semver 格式校验正则 var semverPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`) +// generateMenuItemID generates a short random hex ID for a custom menu item. +func generateMenuItemID() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + +// parseCustomMenuItems parses a JSON string into a slice of CustomMenuItem. +// Returns empty slice on empty/invalid input. +func parseCustomMenuItems(raw string) []dto.CustomMenuItem { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return []dto.CustomMenuItem{} + } + var items []dto.CustomMenuItem + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return []dto.CustomMenuItem{} + } + return items +} + // SettingHandler 系统设置处理器 type SettingHandler struct { settingService *service.SettingService @@ -92,6 +116,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, SoraClientEnabled: settings.SoraClientEnabled, + CustomMenuItems: parseCustomMenuItems(settings.CustomMenuItems), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, DefaultSubscriptions: defaultSubscriptions, @@ -152,6 +177,7 @@ type UpdateSettingsRequest struct { PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` SoraClientEnabled bool `json:"sora_client_enabled"` + CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` // 默认配置 DefaultConcurrency int `json:"default_concurrency"` @@ -299,6 +325,40 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + // 自定义菜单项验证 + customMenuJSON := previousSettings.CustomMenuItems + if req.CustomMenuItems != nil { + items := *req.CustomMenuItems + for i, item := range items { + if strings.TrimSpace(item.Label) == "" { + response.BadRequest(c, "Custom menu item label is required") + return + } + if strings.TrimSpace(item.URL) == "" { + response.BadRequest(c, "Custom menu item URL is required") + return + } + if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil { + response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL") + return + } + if item.Visibility != "user" && item.Visibility != "admin" { + response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'") + return + } + // Auto-generate ID if missing + if strings.TrimSpace(item.ID) == "" { + items[i].ID = generateMenuItemID() + } + } + menuBytes, err := json.Marshal(items) + if err != nil { + response.BadRequest(c, "Failed to serialize custom menu items") + return + } + customMenuJSON = string(menuBytes) + } + // Ops metrics collector interval validation (seconds). if req.OpsMetricsIntervalSeconds != nil { v := *req.OpsMetricsIntervalSeconds @@ -358,6 +418,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PurchaseSubscriptionEnabled: purchaseEnabled, PurchaseSubscriptionURL: purchaseURL, SoraClientEnabled: req.SoraClientEnabled, + CustomMenuItems: customMenuJSON, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, DefaultSubscriptions: defaultSubscriptions, @@ -449,6 +510,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, SoraClientEnabled: updatedSettings.SoraClientEnabled, + CustomMenuItems: parseCustomMenuItems(updatedSettings.CustomMenuItems), DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, DefaultSubscriptions: updatedDefaultSubscriptions, diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index e9086010..a7d5da22 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -1,5 +1,15 @@ package dto +// CustomMenuItem represents a user-configured custom menu entry. +type CustomMenuItem struct { + ID string `json:"id"` + Label string `json:"label"` + IconSVG string `json:"icon_svg"` + URL string `json:"url"` + Visibility string `json:"visibility"` // "user" or "admin" + SortOrder int `json:"sort_order"` +} + // SystemSettings represents the admin settings API response payload. type SystemSettings struct { RegistrationEnabled bool `json:"registration_enabled"` @@ -38,6 +48,7 @@ type SystemSettings struct { PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"` SoraClientEnabled bool `json:"sora_client_enabled"` + CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` @@ -87,6 +98,7 @@ type PublicSettings struct { HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"` Version string `json:"version"` diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 2141a9ee..1b8e33a8 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -1,6 +1,9 @@ package handler import ( + "encoding/json" + "strings" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -50,8 +53,22 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + CustomMenuItems: parsePublicCustomMenuItems(settings.CustomMenuItems), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, SoraClientEnabled: settings.SoraClientEnabled, Version: h.version, }) } + +// parsePublicCustomMenuItems parses a JSON string into a slice of CustomMenuItem. +func parsePublicCustomMenuItems(raw string) []dto.CustomMenuItem { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return []dto.CustomMenuItem{} + } + var items []dto.CustomMenuItem + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return []dto.CustomMenuItem{} + } + return items +} diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index f061db90..d9ec951e 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -41,7 +41,9 @@ func GetNonceFromContext(c *gin.Context) string { } // SecurityHeaders sets baseline security headers for all responses. -func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { +// getFrameSrcOrigins is an optional function that returns extra origins to inject into frame-src; +// pass nil to disable dynamic frame-src injection. +func SecurityHeaders(cfg config.CSPConfig, getFrameSrcOrigins func() []string) gin.HandlerFunc { policy := strings.TrimSpace(cfg.Policy) if policy == "" { policy = config.DefaultCSPPolicy @@ -51,6 +53,15 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { policy = enhanceCSPPolicy(policy) return func(c *gin.Context) { + finalPolicy := policy + if getFrameSrcOrigins != nil { + for _, origin := range getFrameSrcOrigins() { + if origin != "" { + finalPolicy = addToDirective(finalPolicy, "frame-src", origin) + } + } + } + c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Frame-Options", "DENY") c.Header("Referrer-Policy", "strict-origin-when-cross-origin") @@ -65,12 +76,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { if err != nil { // crypto/rand 失败时降级为无 nonce 的 CSP 策略 log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err) - finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'unsafe-inline'") - c.Header("Content-Security-Policy", finalPolicy) + c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'unsafe-inline'")) } else { c.Set(CSPNonceKey, nonce) - finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'") - c.Header("Content-Security-Policy", finalPolicy) + c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'nonce-"+nonce+"'")) } } c.Next() diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index 07b51f23..c44a4608 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -1,7 +1,13 @@ package server import ( + "context" + "encoding/json" "log" + "net/url" + "strings" + "sync/atomic" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" @@ -14,6 +20,25 @@ import ( "github.com/redis/go-redis/v9" ) +// extractOrigin returns the scheme+host origin from rawURL, or "" on error. +// Only http and https schemes are accepted; other values (e.g. "//host/path") return "". +func extractOrigin(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "" + } + u, err := url.Parse(rawURL) + if err != nil || u.Host == "" { + return "" + } + if u.Scheme != "http" && u.Scheme != "https" { + return "" + } + return u.Scheme + "://" + u.Host +} + +const paymentOriginFetchTimeout = 5 * time.Second + // SetupRouter 配置路由器中间件和路由 func SetupRouter( r *gin.Engine, @@ -28,11 +53,65 @@ func SetupRouter( cfg *config.Config, redisClient *redis.Client, ) *gin.Engine { + // 缓存 iframe 页面的 origin 列表,用于动态注入 CSP frame-src + // 包含 purchase_subscription_url 和所有 custom_menu_items 的 origin(去重) + var cachedFrameOrigins atomic.Pointer[[]string] + emptyOrigins := []string{} + cachedFrameOrigins.Store(&emptyOrigins) + + refreshFrameOrigins := func() { + ctx, cancel := context.WithTimeout(context.Background(), paymentOriginFetchTimeout) + defer cancel() + settings, err := settingService.GetPublicSettings(ctx) + if err != nil { + // 获取失败时保留已有缓存,避免 frame-src 被意外清空 + return + } + + seen := make(map[string]struct{}) + var origins []string + + // purchase subscription URL + if settings.PurchaseSubscriptionEnabled { + if origin := extractOrigin(settings.PurchaseSubscriptionURL); origin != "" { + if _, ok := seen[origin]; !ok { + seen[origin] = struct{}{} + origins = append(origins, origin) + } + } + } + + // custom menu items + if raw := strings.TrimSpace(settings.CustomMenuItems); raw != "" && raw != "[]" { + var items []struct { + URL string `json:"url"` + } + if err := json.Unmarshal([]byte(raw), &items); err == nil { + for _, item := range items { + if origin := extractOrigin(item.URL); origin != "" { + if _, ok := seen[origin]; !ok { + seen[origin] = struct{}{} + origins = append(origins, origin) + } + } + } + } + } + + cachedFrameOrigins.Store(&origins) + } + refreshFrameOrigins() // 启动时初始化 + // 应用中间件 r.Use(middleware2.RequestLogger()) r.Use(middleware2.Logger()) r.Use(middleware2.CORS(cfg.CORS)) - r.Use(middleware2.SecurityHeaders(cfg.Security.CSP)) + r.Use(middleware2.SecurityHeaders(cfg.Security.CSP, func() []string { + if p := cachedFrameOrigins.Load(); p != nil { + return *p + } + return nil + })) // Serve embedded frontend with settings injection if available if web.HasEmbeddedFrontend() { @@ -40,11 +119,17 @@ func SetupRouter( if err != nil { log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err) r.Use(web.ServeEmbeddedFrontend()) + settingService.SetOnUpdateCallback(refreshFrameOrigins) } else { - // Register cache invalidation callback - settingService.SetOnUpdateCallback(frontendServer.InvalidateCache) + // Register combined callback: invalidate HTML cache + refresh frame origins + settingService.SetOnUpdateCallback(func() { + frontendServer.InvalidateCache() + refreshFrameOrigins() + }) r.Use(frontendServer.Middleware()) } + } else { + settingService.SetOnUpdateCallback(refreshFrameOrigins) } // 注册路由 diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index b304bc9f..cf61e3d1 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -113,8 +113,9 @@ const ( SettingKeyDocURL = "doc_url" // 文档链接 SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src) SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮 - SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示“购买订阅”页面入口 - SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src) + SettingKeyPurchaseSubscriptionEnabled = “purchase_subscription_enabled” // 是否展示”购买订阅”页面入口 + SettingKeyPurchaseSubscriptionURL = “purchase_subscription_url” // “购买订阅”页面 URL(作为 iframe src) + SettingKeyCustomMenuItems = “custom_menu_items” // 自定义菜单项(JSON 数组) // 默认配置 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 64871b9a..04f49273 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -124,6 +124,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyPurchaseSubscriptionEnabled, SettingKeyPurchaseSubscriptionURL, SettingKeySoraClientEnabled, + SettingKeyCustomMenuItems, SettingKeyLinuxDoConnectEnabled, } @@ -163,6 +164,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", + CustomMenuItems: settings[SettingKeyCustomMenuItems], LinuxDoOAuthEnabled: linuxDoEnabled, }, nil } @@ -293,6 +295,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled) updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) + updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) @@ -509,6 +512,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyPurchaseSubscriptionEnabled: "false", SettingKeyPurchaseSubscriptionURL: "", SettingKeySoraClientEnabled: "false", + SettingKeyCustomMenuItems: "[]", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultSubscriptions: "[]", @@ -567,6 +571,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", + CustomMenuItems: settings[SettingKeyCustomMenuItems], } // 解析整数类型 diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 5a441ea1..9f0de600 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -40,6 +40,7 @@ type SystemSettings struct { PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string SoraClientEnabled bool + CustomMenuItems string // JSON array of custom menu items DefaultConcurrency int DefaultBalance float64 @@ -92,6 +93,7 @@ type PublicSettings struct { PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string SoraClientEnabled bool + CustomMenuItems string // JSON array of custom menu items LinuxDoOAuthEnabled bool Version string diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index c1b767ba..52855a04 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -4,6 +4,7 @@ */ import { apiClient } from '../client' +import type { CustomMenuItem } from '@/types' export interface DefaultSubscriptionSetting { group_id: number @@ -38,6 +39,7 @@ export interface SystemSettings { purchase_subscription_enabled: boolean purchase_subscription_url: string sora_client_enabled: boolean + custom_menu_items: CustomMenuItem[] // SMTP settings smtp_host: string smtp_port: number @@ -99,6 +101,7 @@ export interface UpdateSettingsRequest { purchase_subscription_enabled?: boolean purchase_subscription_url?: string sora_client_enabled?: boolean + custom_menu_items?: CustomMenuItem[] smtp_host?: string smtp_port?: number smtp_username?: string diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index b356e3e5..5b5db67e 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -47,7 +47,8 @@ " @click="handleMenuItemClick(item.path)" > - + + {{ item.label }} @@ -71,7 +72,8 @@ :data-tour="item.path === '/keys' ? 'sidebar-my-keys' : undefined" @click="handleMenuItemClick(item.path)" > - + + {{ item.label }} @@ -92,7 +94,8 @@ :data-tour="item.path === '/keys' ? 'sidebar-my-keys' : undefined" @click="handleMenuItemClick(item.path)" > - + + {{ item.label }} @@ -150,6 +153,14 @@ import { useI18n } from 'vue-i18n' import { useAdminSettingsStore, useAppStore, useAuthStore, useOnboardingStore } from '@/stores' import VersionBadge from '@/components/common/VersionBadge.vue' +interface NavItem { + path: string + label: string + icon: unknown + iconSvg?: string + hideInSimpleMode?: boolean +} + const { t } = useI18n() const route = useRoute() @@ -496,8 +507,8 @@ const ChevronDoubleRightIcon = { } // User navigation items (for regular users) -const userNavItems = computed(() => { - const items = [ +const userNavItems = computed((): NavItem[] => { + const items: NavItem[] = [ { path: '/dashboard', label: t('nav.dashboard'), icon: DashboardIcon }, { path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon }, { path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true }, @@ -515,6 +526,13 @@ const userNavItems = computed(() => { } ] : []), + ...customMenuItemsForUser.value.map((item): NavItem => ({ + path: `/custom/${item.id}`, + label: item.label, + icon: null, + iconSvg: item.icon_svg, + hideInSimpleMode: true, + })), { path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true }, { path: '/profile', label: t('nav.profile'), icon: UserIcon } ] @@ -522,8 +540,8 @@ const userNavItems = computed(() => { }) // Personal navigation items (for admin's "My Account" section, without Dashboard) -const personalNavItems = computed(() => { - const items = [ +const personalNavItems = computed((): NavItem[] => { + const items: NavItem[] = [ { path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon }, { path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true }, { path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true }, @@ -540,15 +558,37 @@ const personalNavItems = computed(() => { } ] : []), + ...customMenuItemsForUser.value.map((item): NavItem => ({ + path: `/custom/${item.id}`, + label: item.label, + icon: null, + iconSvg: item.icon_svg, + hideInSimpleMode: true, + })), { path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true }, { path: '/profile', label: t('nav.profile'), icon: UserIcon } ] return authStore.isSimpleMode ? items.filter(item => !item.hideInSimpleMode) : items }) +// Custom menu items filtered by visibility +const customMenuItemsForUser = computed(() => { + const items = appStore.cachedPublicSettings?.custom_menu_items ?? [] + return items + .filter((item) => item.visibility === 'user') + .sort((a, b) => a.sort_order - b.sort_order) +}) + +const customMenuItemsForAdmin = computed(() => { + const items = appStore.cachedPublicSettings?.custom_menu_items ?? [] + return items + .filter((item) => item.visibility === 'admin') + .sort((a, b) => a.sort_order - b.sort_order) +}) + // Admin navigation items -const adminNavItems = computed(() => { - const baseItems = [ +const adminNavItems = computed((): NavItem[] => { + const baseItems: NavItem[] = [ { path: '/admin/dashboard', label: t('nav.dashboard'), icon: DashboardIcon }, ...(adminSettingsStore.opsMonitoringEnabled ? [{ path: '/admin/ops', label: t('nav.ops'), icon: ChartIcon }] @@ -567,6 +607,10 @@ const adminNavItems = computed(() => { // 简单模式下,在系统设置前插入 API密钥 if (authStore.isSimpleMode) { const filtered = baseItems.filter(item => !item.hideInSimpleMode) + // Add admin custom menu items + for (const cm of customMenuItemsForAdmin.value) { + filtered.push({ path: `/custom/${cm.id}`, label: cm.label, icon: null, iconSvg: cm.icon_svg }) + } filtered.push({ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon }) filtered.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon }) filtered.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon }) @@ -574,6 +618,10 @@ const adminNavItems = computed(() => { } baseItems.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon }) + // Add admin custom menu items before settings + for (const cm of customMenuItemsForAdmin.value) { + baseItems.push({ path: `/custom/${cm.id}`, label: cm.label, icon: null, iconSvg: cm.icon_svg }) + } baseItems.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon }) return baseItems }) @@ -654,4 +702,12 @@ onMounted(() => { .fade-leave-to { opacity: 0; } + +/* Custom SVG icon in sidebar: inherit color, constrain size */ +.sidebar-svg-icon :deep(svg) { + width: 1.25rem; + height: 1.25rem; + stroke: currentColor; + fill: none; +} diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 01b7919a..42cf9765 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -3625,6 +3625,25 @@ export default { enabled: 'Enable Sora Client', enabledHint: 'When enabled, the Sora entry will be shown in the sidebar for users to access Sora features' }, + customMenu: { + title: 'Custom Menu Pages', + description: 'Add custom iframe pages to the sidebar navigation. Each page can be visible to regular users or administrators.', + itemLabel: 'Menu Item #{n}', + name: 'Menu Name', + namePlaceholder: 'e.g. Help Center', + url: 'Page URL', + urlPlaceholder: 'https://example.com/page', + iconSvg: 'SVG Icon', + iconSvgPlaceholder: '...', + iconPreview: 'Icon Preview', + visibility: 'Visible To', + visibilityUser: 'Regular Users', + visibilityAdmin: 'Administrators', + add: 'Add Menu Item', + remove: 'Remove', + moveUp: 'Move Up', + moveDown: 'Move Down', + }, smtp: { title: 'SMTP Settings', description: 'Configure email sending for verification codes', @@ -3913,6 +3932,16 @@ export default { 'The administrator enabled the entry but has not configured a recharge/subscription URL. Please contact admin.' }, + // Custom Page (iframe embed) + customPage: { + title: 'Custom Page', + openInNewTab: 'Open in new tab', + notFoundTitle: 'Page not found', + notFoundDesc: 'This custom page does not exist or has been removed.', + notConfiguredTitle: 'Page URL not configured', + notConfiguredDesc: 'The URL for this custom page has not been properly configured.', + }, + // Announcements Page announcements: { title: 'Announcements', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 3411d310..a0632fd9 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -3795,6 +3795,25 @@ export default { enabled: '启用 Sora 客户端', enabledHint: '开启后,侧边栏将显示 Sora 入口,用户可访问 Sora 功能' }, + customMenu: { + title: '自定义菜单页面', + description: '添加自定义 iframe 页面到侧边栏导航。每个页面可以设置为普通用户或管理员可见。', + itemLabel: '菜单项 #{n}', + name: '菜单名称', + namePlaceholder: '如:帮助中心', + url: '页面 URL', + urlPlaceholder: 'https://example.com/page', + iconSvg: 'SVG 图标', + iconSvgPlaceholder: '...', + iconPreview: '图标预览', + visibility: '可见角色', + visibilityUser: '普通用户', + visibilityAdmin: '管理员', + add: '添加菜单项', + remove: '删除', + moveUp: '上移', + moveDown: '下移', + }, smtp: { title: 'SMTP 设置', description: '配置用于发送验证码的邮件服务', @@ -4081,6 +4100,16 @@ export default { notConfiguredDesc: '管理员已开启入口,但尚未配置充值/订阅链接,请联系管理员。' }, + // Custom Page (iframe embed) + customPage: { + title: '自定义页面', + openInNewTab: '新窗口打开', + notFoundTitle: '页面不存在', + notFoundDesc: '该自定义页面不存在或已被删除。', + notConfiguredTitle: '页面链接未配置', + notConfiguredDesc: '该自定义页面的 URL 未正确配置。', + }, + // Announcements Page announcements: { title: '公告', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index cb81d160..142828cb 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -203,6 +203,17 @@ const routes: RouteRecordRaw[] = [ descriptionKey: 'sora.description' } }, + { + path: '/custom/:id', + name: 'CustomPage', + component: () => import('@/views/user/CustomPageView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: false, + title: 'Custom Page', + titleKey: 'customPage.title', + } + }, // ==================== Admin Routes ==================== { diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index 42a42272..37439a4c 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -327,6 +327,7 @@ export const useAppStore = defineStore('app', () => { hide_ccs_import_button: false, purchase_subscription_enabled: false, purchase_subscription_url: '', + custom_menu_items: [], linuxdo_oauth_enabled: false, sora_client_enabled: false, version: siteVersion.value diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index ccdde8ae..7f2f5f51 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -75,6 +75,15 @@ export interface SendVerifyCodeResponse { countdown: number } +export interface CustomMenuItem { + id: string + label: string + icon_svg: string + url: string + visibility: 'user' | 'admin' + sort_order: number +} + export interface PublicSettings { registration_enabled: boolean email_verify_enabled: boolean @@ -93,6 +102,7 @@ export interface PublicSettings { hide_ccs_import_button: boolean purchase_subscription_enabled: boolean purchase_subscription_url: string + custom_menu_items: CustomMenuItem[] linuxdo_oauth_enabled: boolean sora_client_enabled: boolean version: string diff --git a/frontend/src/utils/embedded-url.ts b/frontend/src/utils/embedded-url.ts new file mode 100644 index 00000000..9319ee07 --- /dev/null +++ b/frontend/src/utils/embedded-url.ts @@ -0,0 +1,46 @@ +/** + * Shared URL builder for iframe-embedded pages. + * Used by PurchaseSubscriptionView and CustomPageView to build consistent URLs + * with user_id, token, theme, ui_mode, src_host, and src parameters. + */ + +const EMBEDDED_USER_ID_QUERY_KEY = 'user_id' +const EMBEDDED_AUTH_TOKEN_QUERY_KEY = 'token' +const EMBEDDED_THEME_QUERY_KEY = 'theme' +const EMBEDDED_UI_MODE_QUERY_KEY = 'ui_mode' +const EMBEDDED_UI_MODE_VALUE = 'embedded' +const EMBEDDED_SRC_HOST_QUERY_KEY = 'src_host' +const EMBEDDED_SRC_QUERY_KEY = 'src_url' + +export function buildEmbeddedUrl( + baseUrl: string, + userId?: number, + authToken?: string | null, + theme: 'light' | 'dark' = 'light', +): string { + if (!baseUrl) return baseUrl + try { + const url = new URL(baseUrl) + if (userId) { + url.searchParams.set(EMBEDDED_USER_ID_QUERY_KEY, String(userId)) + } + if (authToken) { + url.searchParams.set(EMBEDDED_AUTH_TOKEN_QUERY_KEY, authToken) + } + url.searchParams.set(EMBEDDED_THEME_QUERY_KEY, theme) + url.searchParams.set(EMBEDDED_UI_MODE_QUERY_KEY, EMBEDDED_UI_MODE_VALUE) + // Source tracking: let the embedded page know where it's being loaded from + if (typeof window !== 'undefined') { + url.searchParams.set(EMBEDDED_SRC_HOST_QUERY_KEY, window.location.origin) + url.searchParams.set(EMBEDDED_SRC_QUERY_KEY, window.location.href) + } + return url.toString() + } catch { + return baseUrl + } +} + +export function detectTheme(): 'light' | 'dark' { + if (typeof document === 'undefined') return 'light' + return document.documentElement.classList.contains('dark') ? 'dark' : 'light' +} diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 39e1a6b5..02f7f449 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -1160,6 +1160,135 @@ + +
+
+

+ {{ t('admin.settings.customMenu.title') }} +

+

+ {{ t('admin.settings.customMenu.description') }} +

+
+
+ +
+
+ + {{ t('admin.settings.customMenu.itemLabel', { n: index + 1 }) }} + +
+ + + + + + +
+
+ +
+ +
+ + +
+ + +
+ + +
+ + +
+ + +
+ + +
+ +
+ + +
+ +
+
+
+
+
+ + + +
+
+
@@ -1332,6 +1461,7 @@ const form = reactive({ purchase_subscription_enabled: false, purchase_subscription_url: '', sora_client_enabled: false, + custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>, smtp_host: '', smtp_port: 587, smtp_username: '', @@ -1396,6 +1526,39 @@ async function setAndCopyLinuxdoRedirectUrl() { await copyToClipboard(url, t('admin.settings.linuxdo.redirectUrlSetAndCopied')) } +// Custom menu item management +function addMenuItem() { + form.custom_menu_items.push({ + id: '', + label: '', + icon_svg: '', + url: '', + visibility: 'user', + sort_order: form.custom_menu_items.length, + }) +} + +function removeMenuItem(index: number) { + form.custom_menu_items.splice(index, 1) + // Re-index sort_order + form.custom_menu_items.forEach((item, i) => { + item.sort_order = i + }) +} + +function moveMenuItem(index: number, direction: -1 | 1) { + const targetIndex = index + direction + if (targetIndex < 0 || targetIndex >= form.custom_menu_items.length) return + const items = form.custom_menu_items + const temp = items[index] + items[index] = items[targetIndex] + items[targetIndex] = temp + // Re-index sort_order + items.forEach((item, i) => { + item.sort_order = i + }) +} + function handleLogoUpload(event: Event) { const input = event.target as HTMLInputElement const file = input.files?.[0] @@ -1534,6 +1697,7 @@ async function saveSettings() { purchase_subscription_enabled: form.purchase_subscription_enabled, purchase_subscription_url: form.purchase_subscription_url, sora_client_enabled: form.sora_client_enabled, + custom_menu_items: form.custom_menu_items, smtp_host: form.smtp_host, smtp_port: form.smtp_port, smtp_username: form.smtp_username, diff --git a/frontend/src/views/user/CustomPageView.vue b/frontend/src/views/user/CustomPageView.vue new file mode 100644 index 00000000..45e61e17 --- /dev/null +++ b/frontend/src/views/user/CustomPageView.vue @@ -0,0 +1,166 @@ + + + + + diff --git a/frontend/src/views/user/PurchaseSubscriptionView.vue b/frontend/src/views/user/PurchaseSubscriptionView.vue index fdcd0d34..d6d356f5 100644 --- a/frontend/src/views/user/PurchaseSubscriptionView.vue +++ b/frontend/src/views/user/PurchaseSubscriptionView.vue @@ -74,17 +74,12 @@ import { useAppStore } from '@/stores' import { useAuthStore } from '@/stores/auth' import AppLayout from '@/components/layout/AppLayout.vue' import Icon from '@/components/icons/Icon.vue' +import { buildEmbeddedUrl, detectTheme } from '@/utils/embedded-url' const { t } = useI18n() const appStore = useAppStore() const authStore = useAuthStore() -const PURCHASE_USER_ID_QUERY_KEY = 'user_id' -const PURCHASE_AUTH_TOKEN_QUERY_KEY = 'token' -const PURCHASE_THEME_QUERY_KEY = 'theme' -const PURCHASE_UI_MODE_QUERY_KEY = 'ui_mode' -const PURCHASE_UI_MODE_EMBEDDED = 'embedded' - const loading = ref(false) const purchaseTheme = ref<'light' | 'dark'>('light') let themeObserver: MutationObserver | null = null @@ -93,37 +88,9 @@ const purchaseEnabled = computed(() => { return appStore.cachedPublicSettings?.purchase_subscription_enabled ?? false }) -function detectTheme(): 'light' | 'dark' { - if (typeof document === 'undefined') return 'light' - return document.documentElement.classList.contains('dark') ? 'dark' : 'light' -} - -function buildPurchaseUrl( - baseUrl: string, - userId?: number, - authToken?: string | null, - theme: 'light' | 'dark' = 'light', -): string { - if (!baseUrl) return baseUrl - try { - const url = new URL(baseUrl) - if (userId) { - url.searchParams.set(PURCHASE_USER_ID_QUERY_KEY, String(userId)) - } - if (authToken) { - url.searchParams.set(PURCHASE_AUTH_TOKEN_QUERY_KEY, authToken) - } - url.searchParams.set(PURCHASE_THEME_QUERY_KEY, theme) - url.searchParams.set(PURCHASE_UI_MODE_QUERY_KEY, PURCHASE_UI_MODE_EMBEDDED) - return url.toString() - } catch { - return baseUrl - } -} - const purchaseUrl = computed(() => { const baseUrl = (appStore.cachedPublicSettings?.purchase_subscription_url || '').trim() - return buildPurchaseUrl(baseUrl, authStore.user?.id, authStore.token, purchaseTheme.value) + return buildEmbeddedUrl(baseUrl, authStore.user?.id, authStore.token, purchaseTheme.value) }) const isValidUrl = computed(() => { diff --git a/tmp_api_admin_orders/[id]/cancel/route.ts b/tmp_api_admin_orders/[id]/cancel/route.ts new file mode 100644 index 00000000..0857b4e0 --- /dev/null +++ b/tmp_api_admin_orders/[id]/cancel/route.ts @@ -0,0 +1,25 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { verifyAdminToken, unauthorizedResponse } from '@/lib/admin-auth'; +import { adminCancelOrder, OrderError } from '@/lib/order/service'; + +export async function POST( + request: NextRequest, + { params }: { params: Promise<{ id: string }> }, +) { + if (!verifyAdminToken(request)) return unauthorizedResponse(); + + try { + const { id } = await params; + await adminCancelOrder(id); + return NextResponse.json({ success: true }); + } catch (error) { + if (error instanceof OrderError) { + return NextResponse.json( + { error: error.message, code: error.code }, + { status: error.statusCode }, + ); + } + console.error('Admin cancel order error:', error); + return NextResponse.json({ error: '取消订单失败' }, { status: 500 }); + } +} diff --git a/tmp_api_admin_orders/[id]/retry/route.ts b/tmp_api_admin_orders/[id]/retry/route.ts new file mode 100644 index 00000000..07a3c0d0 --- /dev/null +++ b/tmp_api_admin_orders/[id]/retry/route.ts @@ -0,0 +1,25 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { verifyAdminToken, unauthorizedResponse } from '@/lib/admin-auth'; +import { retryRecharge, OrderError } from '@/lib/order/service'; + +export async function POST( + request: NextRequest, + { params }: { params: Promise<{ id: string }> }, +) { + if (!verifyAdminToken(request)) return unauthorizedResponse(); + + try { + const { id } = await params; + await retryRecharge(id); + return NextResponse.json({ success: true }); + } catch (error) { + if (error instanceof OrderError) { + return NextResponse.json( + { error: error.message, code: error.code }, + { status: error.statusCode }, + ); + } + console.error('Retry recharge error:', error); + return NextResponse.json({ error: '重试充值失败' }, { status: 500 }); + } +} diff --git a/tmp_api_admin_orders/[id]/route.ts b/tmp_api_admin_orders/[id]/route.ts new file mode 100644 index 00000000..941ed839 --- /dev/null +++ b/tmp_api_admin_orders/[id]/route.ts @@ -0,0 +1,31 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { prisma } from '@/lib/db'; +import { verifyAdminToken, unauthorizedResponse } from '@/lib/admin-auth'; + +export async function GET( + request: NextRequest, + { params }: { params: Promise<{ id: string }> }, +) { + if (!verifyAdminToken(request)) return unauthorizedResponse(); + + const { id } = await params; + + const order = await prisma.order.findUnique({ + where: { id }, + include: { + auditLogs: { + orderBy: { createdAt: 'desc' }, + }, + }, + }); + + if (!order) { + return NextResponse.json({ error: '订单不存在' }, { status: 404 }); + } + + return NextResponse.json({ + ...order, + amount: Number(order.amount), + refundAmount: order.refundAmount ? Number(order.refundAmount) : null, + }); +} diff --git a/tmp_api_admin_orders/route.ts b/tmp_api_admin_orders/route.ts new file mode 100644 index 00000000..110560bf --- /dev/null +++ b/tmp_api_admin_orders/route.ts @@ -0,0 +1,60 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { prisma } from '@/lib/db'; +import { verifyAdminToken, unauthorizedResponse } from '@/lib/admin-auth'; +import { Prisma } from '@prisma/client'; + +export async function GET(request: NextRequest) { + if (!verifyAdminToken(request)) return unauthorizedResponse(); + + const searchParams = request.nextUrl.searchParams; + const page = Math.max(1, Number(searchParams.get('page') || '1')); + const pageSize = Math.min(100, Math.max(1, Number(searchParams.get('page_size') || '20'))); + const status = searchParams.get('status'); + const userId = searchParams.get('user_id'); + const dateFrom = searchParams.get('date_from'); + const dateTo = searchParams.get('date_to'); + + const where: Prisma.OrderWhereInput = {}; + if (status) where.status = status as any; + if (userId) where.userId = Number(userId); + if (dateFrom || dateTo) { + where.createdAt = {}; + if (dateFrom) where.createdAt.gte = new Date(dateFrom); + if (dateTo) where.createdAt.lte = new Date(dateTo); + } + + const [orders, total] = await Promise.all([ + prisma.order.findMany({ + where, + orderBy: { createdAt: 'desc' }, + skip: (page - 1) * pageSize, + take: pageSize, + select: { + id: true, + userId: true, + userName: true, + userEmail: true, + amount: true, + status: true, + paymentType: true, + createdAt: true, + paidAt: true, + completedAt: true, + failedReason: true, + expiresAt: true, + }, + }), + prisma.order.count({ where }), + ]); + + return NextResponse.json({ + orders: orders.map(o => ({ + ...o, + amount: Number(o.amount), + })), + total, + page, + page_size: pageSize, + total_pages: Math.ceil(total / pageSize), + }); +} diff --git a/tmp_api_orders/[id]/cancel/route.ts b/tmp_api_orders/[id]/cancel/route.ts new file mode 100644 index 00000000..4e0b0dc6 --- /dev/null +++ b/tmp_api_orders/[id]/cancel/route.ts @@ -0,0 +1,37 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { z } from 'zod'; +import { cancelOrder, OrderError } from '@/lib/order/service'; + +const cancelSchema = z.object({ + user_id: z.number().int().positive(), +}); + +export async function POST( + request: NextRequest, + { params }: { params: Promise<{ id: string }> }, +) { + try { + const { id } = await params; + const body = await request.json(); + const parsed = cancelSchema.safeParse(body); + + if (!parsed.success) { + return NextResponse.json( + { error: '参数错误', details: parsed.error.flatten().fieldErrors }, + { status: 400 }, + ); + } + + await cancelOrder(id, parsed.data.user_id); + return NextResponse.json({ success: true }); + } catch (error) { + if (error instanceof OrderError) { + return NextResponse.json( + { error: error.message, code: error.code }, + { status: error.statusCode }, + ); + } + console.error('Cancel order error:', error); + return NextResponse.json({ error: '取消订单失败' }, { status: 500 }); + } +} diff --git a/tmp_api_orders/[id]/route.ts b/tmp_api_orders/[id]/route.ts new file mode 100644 index 00000000..08448607 --- /dev/null +++ b/tmp_api_orders/[id]/route.ts @@ -0,0 +1,50 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { prisma } from '@/lib/db'; + +export async function GET( + request: NextRequest, + { params }: { params: Promise<{ id: string }> }, +) { + const { id } = await params; + + const order = await prisma.order.findUnique({ + where: { id }, + select: { + id: true, + userId: true, + userName: true, + amount: true, + status: true, + paymentType: true, + payUrl: true, + qrCode: true, + qrCodeImg: true, + expiresAt: true, + paidAt: true, + completedAt: true, + failedReason: true, + createdAt: true, + }, + }); + + if (!order) { + return NextResponse.json({ error: '订单不存在' }, { status: 404 }); + } + + return NextResponse.json({ + order_id: order.id, + user_id: order.userId, + user_name: order.userName, + amount: Number(order.amount), + status: order.status, + payment_type: order.paymentType, + pay_url: order.payUrl, + qr_code: order.qrCode, + qr_code_img: order.qrCodeImg, + expires_at: order.expiresAt, + paid_at: order.paidAt, + completed_at: order.completedAt, + failed_reason: order.failedReason, + created_at: order.createdAt, + }); +} diff --git a/tmp_api_orders/my/route.ts b/tmp_api_orders/my/route.ts new file mode 100644 index 00000000..43ca2f0a --- /dev/null +++ b/tmp_api_orders/my/route.ts @@ -0,0 +1,46 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { prisma } from '@/lib/db'; +import { getCurrentUserByToken } from '@/lib/sub2api/client'; + +export async function GET(request: NextRequest) { + const token = request.nextUrl.searchParams.get('token')?.trim(); + if (!token) { + return NextResponse.json({ error: 'token is required' }, { status: 400 }); + } + + try { + const user = await getCurrentUserByToken(token); + const orders = await prisma.order.findMany({ + where: { userId: user.id }, + orderBy: { createdAt: 'desc' }, + take: 20, + select: { + id: true, + amount: true, + status: true, + paymentType: true, + createdAt: true, + }, + }); + + return NextResponse.json({ + user: { + id: user.id, + username: user.username, + email: user.email, + displayName: user.username || user.email || `用户 #${user.id}`, + balance: user.balance, + }, + orders: orders.map((item) => ({ + id: item.id, + amount: Number(item.amount), + status: item.status, + paymentType: item.paymentType, + createdAt: item.createdAt, + })), + }); + } catch (error) { + console.error('Get my orders error:', error); + return NextResponse.json({ error: 'unauthorized' }, { status: 401 }); + } +} diff --git a/tmp_api_orders/route.ts b/tmp_api_orders/route.ts new file mode 100644 index 00000000..0fd93aa4 --- /dev/null +++ b/tmp_api_orders/route.ts @@ -0,0 +1,68 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { z } from 'zod'; +import { createOrder, OrderError } from '@/lib/order/service'; +import { getEnv } from '@/lib/config'; + +const createOrderSchema = z.object({ + user_id: z.number().int().positive(), + amount: z.number().positive(), + payment_type: z.enum(['alipay', 'wxpay']), +}); + +export async function POST(request: NextRequest) { + try { + const env = getEnv(); + const body = await request.json(); + const parsed = createOrderSchema.safeParse(body); + + if (!parsed.success) { + return NextResponse.json( + { error: '参数错误', details: parsed.error.flatten().fieldErrors }, + { status: 400 }, + ); + } + + const { user_id, amount, payment_type } = parsed.data; + + // Validate amount range + if (amount < env.MIN_RECHARGE_AMOUNT || amount > env.MAX_RECHARGE_AMOUNT) { + return NextResponse.json( + { error: `充值金额需在 ${env.MIN_RECHARGE_AMOUNT} - ${env.MAX_RECHARGE_AMOUNT} 之间` }, + { status: 400 }, + ); + } + + // Validate payment type is enabled + if (!env.ENABLED_PAYMENT_TYPES.includes(payment_type)) { + return NextResponse.json( + { error: `不支持的支付方式: ${payment_type}` }, + { status: 400 }, + ); + } + + const clientIp = request.headers.get('x-forwarded-for')?.split(',')[0]?.trim() + || request.headers.get('x-real-ip') + || '127.0.0.1'; + + const result = await createOrder({ + userId: user_id, + amount, + paymentType: payment_type, + clientIp, + }); + + return NextResponse.json(result); + } catch (error) { + if (error instanceof OrderError) { + return NextResponse.json( + { error: error.message, code: error.code }, + { status: error.statusCode }, + ); + } + console.error('Create order error:', error); + return NextResponse.json( + { error: '创建订单失败,请稍后重试' }, + { status: 500 }, + ); + } +} From a50d5d351b81c0caad53f5d15358e1cc12d0cd2d Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 2 Mar 2026 19:44:25 +0800 Subject: [PATCH 005/286] fix: replace curly quotes with straight quotes in domain_constants.go Co-Authored-By: Claude Opus 4.6 --- backend/internal/service/domain_constants.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index cf61e3d1..df213002 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -113,9 +113,9 @@ const ( SettingKeyDocURL = "doc_url" // 文档链接 SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src) SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮 - SettingKeyPurchaseSubscriptionEnabled = “purchase_subscription_enabled” // 是否展示”购买订阅”页面入口 - SettingKeyPurchaseSubscriptionURL = “purchase_subscription_url” // “购买订阅”页面 URL(作为 iframe src) - SettingKeyCustomMenuItems = “custom_menu_items” // 自定义菜单项(JSON 数组) + SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口 + SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src) + SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组) // 默认配置 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 From 1f955249965d66f084cbbe3f94b1a8b31ead0c10 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 3 Mar 2026 06:20:10 +0800 Subject: [PATCH 006/286] feat: ImageUpload component, custom page title, sidebar menu order --- .../src/components/common/ImageUpload.vue | 141 ++++++++++++++++++ frontend/src/components/layout/AppHeader.vue | 7 + frontend/src/components/layout/AppSidebar.vue | 22 ++- frontend/src/i18n/locales/en.ts | 2 + frontend/src/i18n/locales/zh.ts | 2 + frontend/src/router/index.ts | 15 +- frontend/src/views/admin/SettingsView.vue | 130 +++------------- tmp_api_admin_orders/[id]/cancel/route.ts | 25 ---- tmp_api_admin_orders/[id]/retry/route.ts | 25 ---- tmp_api_admin_orders/[id]/route.ts | 31 ---- tmp_api_admin_orders/route.ts | 60 -------- tmp_api_orders/[id]/cancel/route.ts | 37 ----- tmp_api_orders/[id]/route.ts | 50 ------- tmp_api_orders/my/route.ts | 46 ------ tmp_api_orders/route.ts | 68 --------- 15 files changed, 193 insertions(+), 468 deletions(-) create mode 100644 frontend/src/components/common/ImageUpload.vue delete mode 100644 tmp_api_admin_orders/[id]/cancel/route.ts delete mode 100644 tmp_api_admin_orders/[id]/retry/route.ts delete mode 100644 tmp_api_admin_orders/[id]/route.ts delete mode 100644 tmp_api_admin_orders/route.ts delete mode 100644 tmp_api_orders/[id]/cancel/route.ts delete mode 100644 tmp_api_orders/[id]/route.ts delete mode 100644 tmp_api_orders/my/route.ts delete mode 100644 tmp_api_orders/route.ts diff --git a/frontend/src/components/common/ImageUpload.vue b/frontend/src/components/common/ImageUpload.vue new file mode 100644 index 00000000..b77ab64e --- /dev/null +++ b/frontend/src/components/common/ImageUpload.vue @@ -0,0 +1,141 @@ + + + diff --git a/frontend/src/components/layout/AppHeader.vue b/frontend/src/components/layout/AppHeader.vue index a6b4030f..ffc7c5e2 100644 --- a/frontend/src/components/layout/AppHeader.vue +++ b/frontend/src/components/layout/AppHeader.vue @@ -254,6 +254,13 @@ const displayName = computed(() => { }) const pageTitle = computed(() => { + // For custom pages, use the menu item's label instead of generic "自定义页面" + if (route.name === 'CustomPage') { + const id = route.params.id as string + const items = appStore.cachedPublicSettings?.custom_menu_items ?? [] + const menuItem = items.find((item) => item.id === id) + if (menuItem?.label) return menuItem.label + } const titleKey = route.meta.titleKey as string if (titleKey) { return t(titleKey) diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index 5b5db67e..40b8c8de 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -526,15 +526,14 @@ const userNavItems = computed((): NavItem[] => { } ] : []), + { path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true }, + { path: '/profile', label: t('nav.profile'), icon: UserIcon }, ...customMenuItemsForUser.value.map((item): NavItem => ({ path: `/custom/${item.id}`, label: item.label, icon: null, iconSvg: item.icon_svg, - hideInSimpleMode: true, })), - { path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true }, - { path: '/profile', label: t('nav.profile'), icon: UserIcon } ] return authStore.isSimpleMode ? items.filter(item => !item.hideInSimpleMode) : items }) @@ -558,15 +557,14 @@ const personalNavItems = computed((): NavItem[] => { } ] : []), + { path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true }, + { path: '/profile', label: t('nav.profile'), icon: UserIcon }, ...customMenuItemsForUser.value.map((item): NavItem => ({ path: `/custom/${item.id}`, label: item.label, icon: null, iconSvg: item.icon_svg, - hideInSimpleMode: true, })), - { path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true }, - { path: '/profile', label: t('nav.profile'), icon: UserIcon } ] return authStore.isSimpleMode ? items.filter(item => !item.hideInSimpleMode) : items }) @@ -607,22 +605,22 @@ const adminNavItems = computed((): NavItem[] => { // 简单模式下,在系统设置前插入 API密钥 if (authStore.isSimpleMode) { const filtered = baseItems.filter(item => !item.hideInSimpleMode) - // Add admin custom menu items - for (const cm of customMenuItemsForAdmin.value) { - filtered.push({ path: `/custom/${cm.id}`, label: cm.label, icon: null, iconSvg: cm.icon_svg }) - } filtered.push({ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon }) filtered.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon }) filtered.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon }) + // Add admin custom menu items after settings + for (const cm of customMenuItemsForAdmin.value) { + filtered.push({ path: `/custom/${cm.id}`, label: cm.label, icon: null, iconSvg: cm.icon_svg }) + } return filtered } baseItems.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon }) - // Add admin custom menu items before settings + baseItems.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon }) + // Add admin custom menu items after settings for (const cm of customMenuItemsForAdmin.value) { baseItems.push({ path: `/custom/${cm.id}`, label: cm.label, icon: null, iconSvg: cm.icon_svg }) } - baseItems.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon }) return baseItems }) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 42cf9765..7357c3f1 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -3636,6 +3636,8 @@ export default { iconSvg: 'SVG Icon', iconSvgPlaceholder: '...', iconPreview: 'Icon Preview', + uploadSvg: 'Upload SVG', + removeSvg: 'Remove', visibility: 'Visible To', visibilityUser: 'Regular Users', visibilityAdmin: 'Administrators', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index a0632fd9..9f2fb639 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -3806,6 +3806,8 @@ export default { iconSvg: 'SVG 图标', iconSvgPlaceholder: '...', iconPreview: '图标预览', + uploadSvg: '上传 SVG', + removeSvg: '清除', visibility: '可见角色', visibilityUser: '普通用户', visibilityAdmin: '管理员', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 142828cb..08f492d4 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -428,7 +428,20 @@ router.beforeEach((to, _from, next) => { // Set page title const appStore = useAppStore() - document.title = resolveDocumentTitle(to.meta.title, appStore.siteName, to.meta.titleKey as string) + // For custom pages, use menu item label as document title + if (to.name === 'CustomPage') { + const id = to.params.id as string + const items = appStore.cachedPublicSettings?.custom_menu_items ?? [] + const menuItem = items.find((item) => item.id === id) + if (menuItem?.label) { + const siteName = appStore.siteName || 'Sub2API' + document.title = `${menuItem.label} - ${siteName}` + } else { + document.title = resolveDocumentTitle(to.meta.title, appStore.siteName, to.meta.titleKey as string) + } + } else { + document.title = resolveDocumentTitle(to.meta.title, appStore.siteName, to.meta.titleKey as string) + } // Check if route requires authentication const requiresAuth = to.meta.requiresAuth !== false // Default to true diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 02f7f449..3a42a5b7 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -832,64 +832,14 @@ -
- -
-
- Site Logo - - - -
-
- -
-
- - -
-

- {{ t('admin.settings.site.logoHint') }} -

-

{{ logoError }}

-
-
+
@@ -1257,22 +1207,14 @@ -
- - -
- -
-
+
@@ -1390,6 +1332,7 @@ import Select from '@/components/common/Select.vue' import GroupBadge from '@/components/common/GroupBadge.vue' import GroupOptionItem from '@/components/common/GroupOptionItem.vue' import Toggle from '@/components/common/Toggle.vue' +import ImageUpload from '@/components/common/ImageUpload.vue' import { useClipboard } from '@/composables/useClipboard' import { useAppStore } from '@/stores' @@ -1402,7 +1345,6 @@ const saving = ref(false) const testingSmtp = ref(false) const sendingTestEmail = ref(false) const testEmailAddress = ref('') -const logoError = ref('') // Admin API Key 状态 const adminApiKeyLoading = ref(true) @@ -1559,44 +1501,6 @@ function moveMenuItem(index: number, direction: -1 | 1) { }) } -function handleLogoUpload(event: Event) { - const input = event.target as HTMLInputElement - const file = input.files?.[0] - logoError.value = '' - - if (!file) return - - // Check file size (300KB = 307200 bytes) - const maxSize = 300 * 1024 - if (file.size > maxSize) { - logoError.value = t('admin.settings.site.logoSizeError', { - size: (file.size / 1024).toFixed(1) - }) - input.value = '' - return - } - - // Check file type - if (!file.type.startsWith('image/')) { - logoError.value = t('admin.settings.site.logoTypeError') - input.value = '' - return - } - - // Convert to base64 - const reader = new FileReader() - reader.onload = (e) => { - form.site_logo = e.target?.result as string - } - reader.onerror = () => { - logoError.value = t('admin.settings.site.logoReadError') - } - reader.readAsDataURL(file) - - // Reset input - input.value = '' -} - async function loadSettings() { loading.value = true try { diff --git a/tmp_api_admin_orders/[id]/cancel/route.ts b/tmp_api_admin_orders/[id]/cancel/route.ts deleted file mode 100644 index 0857b4e0..00000000 --- a/tmp_api_admin_orders/[id]/cancel/route.ts +++ /dev/null @@ -1,25 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { verifyAdminToken, unauthorizedResponse } from '@/lib/admin-auth'; -import { adminCancelOrder, OrderError } from '@/lib/order/service'; - -export async function POST( - request: NextRequest, - { params }: { params: Promise<{ id: string }> }, -) { - if (!verifyAdminToken(request)) return unauthorizedResponse(); - - try { - const { id } = await params; - await adminCancelOrder(id); - return NextResponse.json({ success: true }); - } catch (error) { - if (error instanceof OrderError) { - return NextResponse.json( - { error: error.message, code: error.code }, - { status: error.statusCode }, - ); - } - console.error('Admin cancel order error:', error); - return NextResponse.json({ error: '取消订单失败' }, { status: 500 }); - } -} diff --git a/tmp_api_admin_orders/[id]/retry/route.ts b/tmp_api_admin_orders/[id]/retry/route.ts deleted file mode 100644 index 07a3c0d0..00000000 --- a/tmp_api_admin_orders/[id]/retry/route.ts +++ /dev/null @@ -1,25 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { verifyAdminToken, unauthorizedResponse } from '@/lib/admin-auth'; -import { retryRecharge, OrderError } from '@/lib/order/service'; - -export async function POST( - request: NextRequest, - { params }: { params: Promise<{ id: string }> }, -) { - if (!verifyAdminToken(request)) return unauthorizedResponse(); - - try { - const { id } = await params; - await retryRecharge(id); - return NextResponse.json({ success: true }); - } catch (error) { - if (error instanceof OrderError) { - return NextResponse.json( - { error: error.message, code: error.code }, - { status: error.statusCode }, - ); - } - console.error('Retry recharge error:', error); - return NextResponse.json({ error: '重试充值失败' }, { status: 500 }); - } -} diff --git a/tmp_api_admin_orders/[id]/route.ts b/tmp_api_admin_orders/[id]/route.ts deleted file mode 100644 index 941ed839..00000000 --- a/tmp_api_admin_orders/[id]/route.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { prisma } from '@/lib/db'; -import { verifyAdminToken, unauthorizedResponse } from '@/lib/admin-auth'; - -export async function GET( - request: NextRequest, - { params }: { params: Promise<{ id: string }> }, -) { - if (!verifyAdminToken(request)) return unauthorizedResponse(); - - const { id } = await params; - - const order = await prisma.order.findUnique({ - where: { id }, - include: { - auditLogs: { - orderBy: { createdAt: 'desc' }, - }, - }, - }); - - if (!order) { - return NextResponse.json({ error: '订单不存在' }, { status: 404 }); - } - - return NextResponse.json({ - ...order, - amount: Number(order.amount), - refundAmount: order.refundAmount ? Number(order.refundAmount) : null, - }); -} diff --git a/tmp_api_admin_orders/route.ts b/tmp_api_admin_orders/route.ts deleted file mode 100644 index 110560bf..00000000 --- a/tmp_api_admin_orders/route.ts +++ /dev/null @@ -1,60 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { prisma } from '@/lib/db'; -import { verifyAdminToken, unauthorizedResponse } from '@/lib/admin-auth'; -import { Prisma } from '@prisma/client'; - -export async function GET(request: NextRequest) { - if (!verifyAdminToken(request)) return unauthorizedResponse(); - - const searchParams = request.nextUrl.searchParams; - const page = Math.max(1, Number(searchParams.get('page') || '1')); - const pageSize = Math.min(100, Math.max(1, Number(searchParams.get('page_size') || '20'))); - const status = searchParams.get('status'); - const userId = searchParams.get('user_id'); - const dateFrom = searchParams.get('date_from'); - const dateTo = searchParams.get('date_to'); - - const where: Prisma.OrderWhereInput = {}; - if (status) where.status = status as any; - if (userId) where.userId = Number(userId); - if (dateFrom || dateTo) { - where.createdAt = {}; - if (dateFrom) where.createdAt.gte = new Date(dateFrom); - if (dateTo) where.createdAt.lte = new Date(dateTo); - } - - const [orders, total] = await Promise.all([ - prisma.order.findMany({ - where, - orderBy: { createdAt: 'desc' }, - skip: (page - 1) * pageSize, - take: pageSize, - select: { - id: true, - userId: true, - userName: true, - userEmail: true, - amount: true, - status: true, - paymentType: true, - createdAt: true, - paidAt: true, - completedAt: true, - failedReason: true, - expiresAt: true, - }, - }), - prisma.order.count({ where }), - ]); - - return NextResponse.json({ - orders: orders.map(o => ({ - ...o, - amount: Number(o.amount), - })), - total, - page, - page_size: pageSize, - total_pages: Math.ceil(total / pageSize), - }); -} diff --git a/tmp_api_orders/[id]/cancel/route.ts b/tmp_api_orders/[id]/cancel/route.ts deleted file mode 100644 index 4e0b0dc6..00000000 --- a/tmp_api_orders/[id]/cancel/route.ts +++ /dev/null @@ -1,37 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { z } from 'zod'; -import { cancelOrder, OrderError } from '@/lib/order/service'; - -const cancelSchema = z.object({ - user_id: z.number().int().positive(), -}); - -export async function POST( - request: NextRequest, - { params }: { params: Promise<{ id: string }> }, -) { - try { - const { id } = await params; - const body = await request.json(); - const parsed = cancelSchema.safeParse(body); - - if (!parsed.success) { - return NextResponse.json( - { error: '参数错误', details: parsed.error.flatten().fieldErrors }, - { status: 400 }, - ); - } - - await cancelOrder(id, parsed.data.user_id); - return NextResponse.json({ success: true }); - } catch (error) { - if (error instanceof OrderError) { - return NextResponse.json( - { error: error.message, code: error.code }, - { status: error.statusCode }, - ); - } - console.error('Cancel order error:', error); - return NextResponse.json({ error: '取消订单失败' }, { status: 500 }); - } -} diff --git a/tmp_api_orders/[id]/route.ts b/tmp_api_orders/[id]/route.ts deleted file mode 100644 index 08448607..00000000 --- a/tmp_api_orders/[id]/route.ts +++ /dev/null @@ -1,50 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { prisma } from '@/lib/db'; - -export async function GET( - request: NextRequest, - { params }: { params: Promise<{ id: string }> }, -) { - const { id } = await params; - - const order = await prisma.order.findUnique({ - where: { id }, - select: { - id: true, - userId: true, - userName: true, - amount: true, - status: true, - paymentType: true, - payUrl: true, - qrCode: true, - qrCodeImg: true, - expiresAt: true, - paidAt: true, - completedAt: true, - failedReason: true, - createdAt: true, - }, - }); - - if (!order) { - return NextResponse.json({ error: '订单不存在' }, { status: 404 }); - } - - return NextResponse.json({ - order_id: order.id, - user_id: order.userId, - user_name: order.userName, - amount: Number(order.amount), - status: order.status, - payment_type: order.paymentType, - pay_url: order.payUrl, - qr_code: order.qrCode, - qr_code_img: order.qrCodeImg, - expires_at: order.expiresAt, - paid_at: order.paidAt, - completed_at: order.completedAt, - failed_reason: order.failedReason, - created_at: order.createdAt, - }); -} diff --git a/tmp_api_orders/my/route.ts b/tmp_api_orders/my/route.ts deleted file mode 100644 index 43ca2f0a..00000000 --- a/tmp_api_orders/my/route.ts +++ /dev/null @@ -1,46 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { prisma } from '@/lib/db'; -import { getCurrentUserByToken } from '@/lib/sub2api/client'; - -export async function GET(request: NextRequest) { - const token = request.nextUrl.searchParams.get('token')?.trim(); - if (!token) { - return NextResponse.json({ error: 'token is required' }, { status: 400 }); - } - - try { - const user = await getCurrentUserByToken(token); - const orders = await prisma.order.findMany({ - where: { userId: user.id }, - orderBy: { createdAt: 'desc' }, - take: 20, - select: { - id: true, - amount: true, - status: true, - paymentType: true, - createdAt: true, - }, - }); - - return NextResponse.json({ - user: { - id: user.id, - username: user.username, - email: user.email, - displayName: user.username || user.email || `用户 #${user.id}`, - balance: user.balance, - }, - orders: orders.map((item) => ({ - id: item.id, - amount: Number(item.amount), - status: item.status, - paymentType: item.paymentType, - createdAt: item.createdAt, - })), - }); - } catch (error) { - console.error('Get my orders error:', error); - return NextResponse.json({ error: 'unauthorized' }, { status: 401 }); - } -} diff --git a/tmp_api_orders/route.ts b/tmp_api_orders/route.ts deleted file mode 100644 index 0fd93aa4..00000000 --- a/tmp_api_orders/route.ts +++ /dev/null @@ -1,68 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { z } from 'zod'; -import { createOrder, OrderError } from '@/lib/order/service'; -import { getEnv } from '@/lib/config'; - -const createOrderSchema = z.object({ - user_id: z.number().int().positive(), - amount: z.number().positive(), - payment_type: z.enum(['alipay', 'wxpay']), -}); - -export async function POST(request: NextRequest) { - try { - const env = getEnv(); - const body = await request.json(); - const parsed = createOrderSchema.safeParse(body); - - if (!parsed.success) { - return NextResponse.json( - { error: '参数错误', details: parsed.error.flatten().fieldErrors }, - { status: 400 }, - ); - } - - const { user_id, amount, payment_type } = parsed.data; - - // Validate amount range - if (amount < env.MIN_RECHARGE_AMOUNT || amount > env.MAX_RECHARGE_AMOUNT) { - return NextResponse.json( - { error: `充值金额需在 ${env.MIN_RECHARGE_AMOUNT} - ${env.MAX_RECHARGE_AMOUNT} 之间` }, - { status: 400 }, - ); - } - - // Validate payment type is enabled - if (!env.ENABLED_PAYMENT_TYPES.includes(payment_type)) { - return NextResponse.json( - { error: `不支持的支付方式: ${payment_type}` }, - { status: 400 }, - ); - } - - const clientIp = request.headers.get('x-forwarded-for')?.split(',')[0]?.trim() - || request.headers.get('x-real-ip') - || '127.0.0.1'; - - const result = await createOrder({ - userId: user_id, - amount, - paymentType: payment_type, - clientIp, - }); - - return NextResponse.json(result); - } catch (error) { - if (error instanceof OrderError) { - return NextResponse.json( - { error: error.message, code: error.code }, - { status: error.statusCode }, - ); - } - console.error('Create order error:', error); - return NextResponse.json( - { error: '创建订单失败,请稍后重试' }, - { status: 500 }, - ); - } -} From e4f8799323da4f6dde79be6a55f2a4fd41c262c3 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 3 Mar 2026 06:21:23 +0800 Subject: [PATCH 007/286] fix: include custom_menu_items in GetPublicSettingsForInjection --- backend/internal/service/setting_service.go | 27 ++++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 04f49273..63a873d1 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -210,12 +210,13 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ContactInfo string `json:"contact_info,omitempty"` DocURL string `json:"doc_url,omitempty"` HomeContent string `json:"home_content,omitempty"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` - SoraClientEnabled bool `json:"sora_client_enabled"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - Version string `json:"version,omitempty"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` + SoraClientEnabled bool `json:"sora_client_enabled"` + CustomMenuItems json.RawMessage `json:"custom_menu_items"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + Version string `json:"version,omitempty"` }{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, @@ -236,11 +237,25 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, SoraClientEnabled: settings.SoraClientEnabled, + CustomMenuItems: sanitizeCustomMenuItemsJSON(settings.CustomMenuItems), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, Version: s.version, }, nil } +// sanitizeCustomMenuItemsJSON validates a raw JSON string and returns it as json.RawMessage. +// Returns "[]" if the input is empty or invalid JSON. +func sanitizeCustomMenuItemsJSON(raw string) json.RawMessage { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return json.RawMessage("[]") + } + if json.Valid([]byte(raw)) { + return json.RawMessage(raw) + } + return json.RawMessage("[]") +} + // UpdateSettings 更新系统设置 func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { From bf6fe5e9626c2f32ef5d12cc86003f5c9585c1fd Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 3 Mar 2026 02:18:19 +0800 Subject: [PATCH 008/286] fix: custom menu security hardening and code quality improvements - Add admin menu permission check in CustomPageView (visibility + role) - Sanitize SVG content with DOMPurify before v-html rendering (XSS prevention) - Decouple router.go from dto package using anonymous struct - Consolidate duplicate parseCustomMenuItems into dto.ParseCustomMenuItems - Enhance menu item validation (count, length, ID uniqueness limits) - Add audit logging for purchase_subscription and custom_menu_items changes - Update API contract test to include custom_menu_items field Co-Authored-By: Claude Opus 4.6 --- .../internal/handler/admin/setting_handler.go | 85 +++++++++++++------ backend/internal/handler/dto/settings.go | 83 +++++++++++------- backend/internal/handler/setting_handler.go | 18 +--- backend/internal/server/api_contract_test.go | 3 +- .../src/components/common/ImageUpload.vue | 7 +- frontend/src/components/layout/AppSidebar.vue | 7 +- frontend/src/utils/sanitize.ts | 6 ++ frontend/src/views/user/CustomPageView.vue | 6 +- 8 files changed, 133 insertions(+), 82 deletions(-) create mode 100644 frontend/src/utils/sanitize.ts diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index eec403dc..26cd3128 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -30,20 +30,6 @@ func generateMenuItemID() string { return hex.EncodeToString(b) } -// parseCustomMenuItems parses a JSON string into a slice of CustomMenuItem. -// Returns empty slice on empty/invalid input. -func parseCustomMenuItems(raw string) []dto.CustomMenuItem { - raw = strings.TrimSpace(raw) - if raw == "" || raw == "[]" { - return []dto.CustomMenuItem{} - } - var items []dto.CustomMenuItem - if err := json.Unmarshal([]byte(raw), &items); err != nil { - return []dto.CustomMenuItem{} - } - return items -} - // SettingHandler 系统设置处理器 type SettingHandler struct { settingService *service.SettingService @@ -116,7 +102,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, SoraClientEnabled: settings.SoraClientEnabled, - CustomMenuItems: parseCustomMenuItems(settings.CustomMenuItems), + CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, DefaultSubscriptions: defaultSubscriptions, @@ -166,17 +152,17 @@ type UpdateSettingsRequest struct { LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` // OEM设置 - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` - SoraClientEnabled bool `json:"sora_client_enabled"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` + SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` // 默认配置 @@ -326,18 +312,38 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } // 自定义菜单项验证 + const ( + maxCustomMenuItems = 20 + maxMenuItemLabelLen = 50 + maxMenuItemURLLen = 2048 + maxMenuItemIconSVGLen = 10 * 1024 // 10KB + maxMenuItemIDLen = 32 + ) + customMenuJSON := previousSettings.CustomMenuItems if req.CustomMenuItems != nil { items := *req.CustomMenuItems + if len(items) > maxCustomMenuItems { + response.BadRequest(c, "Too many custom menu items (max 20)") + return + } for i, item := range items { if strings.TrimSpace(item.Label) == "" { response.BadRequest(c, "Custom menu item label is required") return } + if len(item.Label) > maxMenuItemLabelLen { + response.BadRequest(c, "Custom menu item label is too long (max 50 characters)") + return + } if strings.TrimSpace(item.URL) == "" { response.BadRequest(c, "Custom menu item URL is required") return } + if len(item.URL) > maxMenuItemURLLen { + response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)") + return + } if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil { response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL") return @@ -346,11 +352,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'") return } + if len(item.IconSVG) > maxMenuItemIconSVGLen { + response.BadRequest(c, "Custom menu item icon SVG is too large (max 10KB)") + return + } // Auto-generate ID if missing if strings.TrimSpace(item.ID) == "" { items[i].ID = generateMenuItemID() + } else if len(item.ID) > maxMenuItemIDLen { + response.BadRequest(c, "Custom menu item ID is too long (max 32 characters)") + return } } + // ID uniqueness check + seen := make(map[string]struct{}, len(items)) + for _, item := range items { + if _, exists := seen[item.ID]; exists { + response.BadRequest(c, "Duplicate custom menu item ID: "+item.ID) + return + } + seen[item.ID] = struct{}{} + } menuBytes, err := json.Marshal(items) if err != nil { response.BadRequest(c, "Failed to serialize custom menu items") @@ -510,7 +532,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, SoraClientEnabled: updatedSettings.SoraClientEnabled, - CustomMenuItems: parseCustomMenuItems(updatedSettings.CustomMenuItems), + CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, DefaultSubscriptions: updatedDefaultSubscriptions, @@ -674,6 +696,15 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion { changed = append(changed, "min_claude_code_version") } + if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled { + changed = append(changed, "purchase_subscription_enabled") + } + if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL { + changed = append(changed, "purchase_subscription_url") + } + if before.CustomMenuItems != after.CustomMenuItems { + changed = append(changed, "custom_menu_items") + } return changed } diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index a7d5da22..f3c21be5 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -1,5 +1,10 @@ package dto +import ( + "encoding/json" + "strings" +) + // CustomMenuItem represents a user-configured custom menu entry. type CustomMenuItem struct { ID string `json:"id"` @@ -37,17 +42,17 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url"` - SoraClientEnabled bool `json:"sora_client_enabled"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` DefaultConcurrency int `json:"default_concurrency"` @@ -80,28 +85,28 @@ type DefaultSubscriptionSetting struct { } type PublicSettings struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - SoraClientEnabled bool `json:"sora_client_enabled"` - Version string `json:"version"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + SoraClientEnabled bool `json:"sora_client_enabled"` + Version string `json:"version"` } // SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段) @@ -150,3 +155,17 @@ type StreamTimeoutSettings struct { ThresholdCount int `json:"threshold_count"` ThresholdWindowMinutes int `json:"threshold_window_minutes"` } + +// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem. +// Returns empty slice on empty/invalid input. +func ParseCustomMenuItems(raw string) []CustomMenuItem { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return []CustomMenuItem{} + } + var items []CustomMenuItem + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return []CustomMenuItem{} + } + return items +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 1b8e33a8..40277a48 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -1,9 +1,6 @@ package handler import ( - "encoding/json" - "strings" - "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -53,22 +50,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - CustomMenuItems: parsePublicCustomMenuItems(settings.CustomMenuItems), + CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, SoraClientEnabled: settings.SoraClientEnabled, Version: h.version, }) } - -// parsePublicCustomMenuItems parses a JSON string into a slice of CustomMenuItem. -func parsePublicCustomMenuItems(raw string) []dto.CustomMenuItem { - raw = strings.TrimSpace(raw) - if raw == "" || raw == "[]" { - return []dto.CustomMenuItem{} - } - var items []dto.CustomMenuItem - if err := json.Unmarshal([]byte(raw), &items); err != nil { - return []dto.CustomMenuItem{} - } - return items -} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index a8845d9b..f15a2074 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -513,7 +513,8 @@ func TestAPIContracts(t *testing.T) { "hide_ccs_import_button": false, "purchase_subscription_enabled": false, "purchase_subscription_url": "", - "min_claude_code_version": "" + "min_claude_code_version": "", + "custom_menu_items": [] } }`, }, diff --git a/frontend/src/components/common/ImageUpload.vue b/frontend/src/components/common/ImageUpload.vue index b77ab64e..6ef84079 100644 --- a/frontend/src/components/common/ImageUpload.vue +++ b/frontend/src/components/common/ImageUpload.vue @@ -11,7 +11,7 @@ v-if="mode === 'svg' && modelValue" class="text-gray-600 dark:text-gray-300 [&>svg]:h-full [&>svg]:w-full" :class="innerSizeClass" - v-html="modelValue" + v-html="sanitizedValue" > import { ref, computed } from 'vue' import Icon from '@/components/icons/Icon.vue' +import { sanitizeSvg } from '@/utils/sanitize' const props = withDefaults(defineProps<{ modelValue: string @@ -97,6 +98,10 @@ const error = ref('') const acceptTypes = computed(() => props.mode === 'svg' ? '.svg' : 'image/*') +const sanitizedValue = computed(() => + props.mode === 'svg' ? sanitizeSvg(props.modelValue ?? '') : '' +) + const previewSizeClass = computed(() => props.size === 'sm' ? 'h-14 w-14' : 'h-20 w-20') const innerSizeClass = computed(() => props.size === 'sm' ? 'h-7 w-7' : 'h-12 w-12') const placeholderSizeClass = computed(() => props.size === 'sm' ? 'h-5 w-5' : 'h-8 w-8') diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index 40b8c8de..dcfc60bb 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -47,7 +47,7 @@ " @click="handleMenuItemClick(item.path)" > - + {{ item.label }} @@ -72,7 +72,7 @@ :data-tour="item.path === '/keys' ? 'sidebar-my-keys' : undefined" @click="handleMenuItemClick(item.path)" > - + {{ item.label }} @@ -94,7 +94,7 @@ :data-tour="item.path === '/keys' ? 'sidebar-my-keys' : undefined" @click="handleMenuItemClick(item.path)" > - + {{ item.label }} @@ -152,6 +152,7 @@ import { useRoute } from 'vue-router' import { useI18n } from 'vue-i18n' import { useAdminSettingsStore, useAppStore, useAuthStore, useOnboardingStore } from '@/stores' import VersionBadge from '@/components/common/VersionBadge.vue' +import { sanitizeSvg } from '@/utils/sanitize' interface NavItem { path: string diff --git a/frontend/src/utils/sanitize.ts b/frontend/src/utils/sanitize.ts new file mode 100644 index 00000000..a61a52e1 --- /dev/null +++ b/frontend/src/utils/sanitize.ts @@ -0,0 +1,6 @@ +import DOMPurify from 'dompurify' + +export function sanitizeSvg(svg: string): string { + if (!svg) return '' + return DOMPurify.sanitize(svg, { USE_PROFILES: { svg: true, svgFilters: true } }) +} diff --git a/frontend/src/views/user/CustomPageView.vue b/frontend/src/views/user/CustomPageView.vue index 45e61e17..ed1c11d7 100644 --- a/frontend/src/views/user/CustomPageView.vue +++ b/frontend/src/views/user/CustomPageView.vue @@ -87,7 +87,11 @@ const menuItemId = computed(() => route.params.id as string) const menuItem = computed(() => { const items = appStore.cachedPublicSettings?.custom_menu_items ?? [] - return items.find((item) => item.id === menuItemId.value) ?? null + const found = items.find((item) => item.id === menuItemId.value) ?? null + if (found && found.visibility === 'admin' && !authStore.isAdmin) { + return null + } + return found }) const embeddedUrl = computed(() => { From 50a8116ae9223a15b49f13a66a9b5072050c7109 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 3 Mar 2026 06:37:50 +0800 Subject: [PATCH 009/286] fix: update SecurityHeaders call sites to match new signature --- backend/cmd/server/main.go | 2 +- .../middleware/security_headers_test.go | 22 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 63095209..46edcb69 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -100,7 +100,7 @@ func runSetupServer() { r := gin.New() r.Use(middleware.Recovery()) r.Use(middleware.CORS(config.CORSConfig{})) - r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy})) + r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}, nil)) // Register setup routes setup.RegisterRoutes(r) diff --git a/backend/internal/server/middleware/security_headers_test.go b/backend/internal/server/middleware/security_headers_test.go index 5a779825..031385d0 100644 --- a/backend/internal/server/middleware/security_headers_test.go +++ b/backend/internal/server/middleware/security_headers_test.go @@ -84,7 +84,7 @@ func TestGetNonceFromContext(t *testing.T) { func TestSecurityHeaders(t *testing.T) { t.Run("sets_basic_security_headers", func(t *testing.T) { cfg := config.CSPConfig{Enabled: false} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -99,7 +99,7 @@ func TestSecurityHeaders(t *testing.T) { t.Run("csp_disabled_no_csp_header", func(t *testing.T) { cfg := config.CSPConfig{Enabled: false} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -115,7 +115,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "default-src 'self'", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -136,7 +136,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -156,7 +156,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "script-src 'self' __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -180,7 +180,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -199,7 +199,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: " \t\n ", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -217,7 +217,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -235,7 +235,7 @@ func TestSecurityHeaders(t *testing.T) { t.Run("calls_next_handler", func(t *testing.T) { cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) nextCalled := false router := gin.New() @@ -258,7 +258,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "script-src __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) nonces := make(map[string]bool) for i := 0; i < 10; i++ { @@ -376,7 +376,7 @@ func BenchmarkSecurityHeadersMiddleware(b *testing.B) { Enabled: true, Policy: "script-src 'self' __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) b.ResetTimer() for i := 0; i < b.N; i++ { From 7541e243bc5706d282b93dc0a179bead3659b2be Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 3 Mar 2026 06:38:04 +0800 Subject: [PATCH 010/286] style: fix gofmt alignment in setting_service.go --- backend/internal/service/setting_service.go | 30 ++++++++++----------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 63a873d1..2311e150 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -195,21 +195,21 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any // Return a struct that matches the frontend's expected format return &struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo,omitempty"` - SiteSubtitle string `json:"site_subtitle,omitempty"` - APIBaseURL string `json:"api_base_url,omitempty"` - ContactInfo string `json:"contact_info,omitempty"` - DocURL string `json:"doc_url,omitempty"` - HomeContent string `json:"home_content,omitempty"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo,omitempty"` + SiteSubtitle string `json:"site_subtitle,omitempty"` + APIBaseURL string `json:"api_base_url,omitempty"` + ContactInfo string `json:"contact_info,omitempty"` + DocURL string `json:"doc_url,omitempty"` + HomeContent string `json:"home_content,omitempty"` HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` From e97c376681b32ea526e04686c8a5c3e8904298d8 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 3 Mar 2026 07:05:01 +0800 Subject: [PATCH 011/286] fix: security hardening and architectural improvements for custom menu 1. (Critical) Filter admin-only menu items from public API responses - both GetPublicSettings handler and GetPublicSettingsForInjection now exclude visibility=admin items, preventing unauthorized access to admin menu URLs. 2. (Medium) Validate JSON array structure in sanitizeCustomMenuItemsJSON - use json.Unmarshal into []json.RawMessage instead of json.Valid to reject non-array JSON values that would cause frontend runtime errors. 3. (Medium) Decouple router from business JSON parsing - move origin extraction logic from router.go to SettingService.GetFrameSrcOrigins, eliminating direct JSON parsing of custom_menu_items in the routing layer. 4. (Low) Restrict custom menu item ID charset to [a-zA-Z0-9_-] via regex validation, preventing route-breaking characters like / ? # or spaces. 5. (Low) Handle crypto/rand error in generateMenuItemID - return error instead of silently ignoring, preventing potential duplicate IDs. Co-Authored-By: Claude Opus 4.6 --- .../internal/handler/admin/setting_handler.go | 21 ++- backend/internal/handler/dto/settings.go | 12 ++ backend/internal/handler/setting_handler.go | 2 +- backend/internal/server/router.go | 58 +-------- backend/internal/service/setting_service.go | 123 +++++++++++++++++- 5 files changed, 150 insertions(+), 66 deletions(-) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 26cd3128..e32c142f 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -23,11 +23,16 @@ import ( // semverPattern 预编译 semver 格式校验正则 var semverPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`) +// menuItemIDPattern validates custom menu item IDs: alphanumeric, hyphens, underscores only. +var menuItemIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + // generateMenuItemID generates a short random hex ID for a custom menu item. -func generateMenuItemID() string { +func generateMenuItemID() (string, error) { b := make([]byte, 8) - _, _ = rand.Read(b) - return hex.EncodeToString(b) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate menu item ID: %w", err) + } + return hex.EncodeToString(b), nil } // SettingHandler 系统设置处理器 @@ -358,10 +363,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } // Auto-generate ID if missing if strings.TrimSpace(item.ID) == "" { - items[i].ID = generateMenuItemID() + id, err := generateMenuItemID() + if err != nil { + response.Error(c, http.StatusInternalServerError, "Failed to generate menu item ID") + return + } + items[i].ID = id } else if len(item.ID) > maxMenuItemIDLen { response.BadRequest(c, "Custom menu item ID is too long (max 32 characters)") return + } else if !menuItemIDPattern.MatchString(item.ID) { + response.BadRequest(c, "Custom menu item ID contains invalid characters (only a-z, A-Z, 0-9, - and _ are allowed)") + return } } // ID uniqueness check diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index f3c21be5..beb03e67 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -169,3 +169,15 @@ func ParseCustomMenuItems(raw string) []CustomMenuItem { } return items } + +// ParseUserVisibleMenuItems parses custom menu items and filters out admin-only entries. +func ParseUserVisibleMenuItems(raw string) []CustomMenuItem { + items := ParseCustomMenuItems(raw) + filtered := make([]CustomMenuItem, 0, len(items)) + for _, item := range items { + if item.Visibility != "admin" { + filtered = append(filtered, item) + } + } + return filtered +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 40277a48..a48eaf31 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -50,7 +50,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), + CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, SoraClientEnabled: settings.SoraClientEnabled, Version: h.version, diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index c44a4608..430edcf8 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -2,10 +2,7 @@ package server import ( "context" - "encoding/json" "log" - "net/url" - "strings" "sync/atomic" "time" @@ -20,24 +17,7 @@ import ( "github.com/redis/go-redis/v9" ) -// extractOrigin returns the scheme+host origin from rawURL, or "" on error. -// Only http and https schemes are accepted; other values (e.g. "//host/path") return "". -func extractOrigin(rawURL string) string { - rawURL = strings.TrimSpace(rawURL) - if rawURL == "" { - return "" - } - u, err := url.Parse(rawURL) - if err != nil || u.Host == "" { - return "" - } - if u.Scheme != "http" && u.Scheme != "https" { - return "" - } - return u.Scheme + "://" + u.Host -} - -const paymentOriginFetchTimeout = 5 * time.Second +const frameSrcRefreshTimeout = 5 * time.Second // SetupRouter 配置路由器中间件和路由 func SetupRouter( @@ -54,50 +34,18 @@ func SetupRouter( redisClient *redis.Client, ) *gin.Engine { // 缓存 iframe 页面的 origin 列表,用于动态注入 CSP frame-src - // 包含 purchase_subscription_url 和所有 custom_menu_items 的 origin(去重) var cachedFrameOrigins atomic.Pointer[[]string] emptyOrigins := []string{} cachedFrameOrigins.Store(&emptyOrigins) refreshFrameOrigins := func() { - ctx, cancel := context.WithTimeout(context.Background(), paymentOriginFetchTimeout) + ctx, cancel := context.WithTimeout(context.Background(), frameSrcRefreshTimeout) defer cancel() - settings, err := settingService.GetPublicSettings(ctx) + origins, err := settingService.GetFrameSrcOrigins(ctx) if err != nil { // 获取失败时保留已有缓存,避免 frame-src 被意外清空 return } - - seen := make(map[string]struct{}) - var origins []string - - // purchase subscription URL - if settings.PurchaseSubscriptionEnabled { - if origin := extractOrigin(settings.PurchaseSubscriptionURL); origin != "" { - if _, ok := seen[origin]; !ok { - seen[origin] = struct{}{} - origins = append(origins, origin) - } - } - } - - // custom menu items - if raw := strings.TrimSpace(settings.CustomMenuItems); raw != "" && raw != "[]" { - var items []struct { - URL string `json:"url"` - } - if err := json.Unmarshal([]byte(raw), &items); err == nil { - for _, item := range items { - if origin := extractOrigin(item.URL); origin != "" { - if _, ok := seen[origin]; !ok { - seen[origin] = struct{}{} - origins = append(origins, origin) - } - } - } - } - } - cachedFrameOrigins.Store(&origins) } refreshFrameOrigins() // 启动时初始化 diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 2311e150..a2bb06a4 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "log/slog" + "net/url" "strconv" "strings" "sync/atomic" @@ -237,23 +238,133 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, SoraClientEnabled: settings.SoraClientEnabled, - CustomMenuItems: sanitizeCustomMenuItemsJSON(settings.CustomMenuItems), + CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, Version: s.version, }, nil } -// sanitizeCustomMenuItemsJSON validates a raw JSON string and returns it as json.RawMessage. -// Returns "[]" if the input is empty or invalid JSON. +// sanitizeCustomMenuItemsJSON validates a raw JSON string is a valid JSON array +// and returns it as json.RawMessage. Returns "[]" if the input is empty, not a +// valid JSON array, or is a non-array JSON value (e.g. object, string). func sanitizeCustomMenuItemsJSON(raw string) json.RawMessage { raw = strings.TrimSpace(raw) if raw == "" || raw == "[]" { return json.RawMessage("[]") } - if json.Valid([]byte(raw)) { - return json.RawMessage(raw) + // Verify it's actually a JSON array, not an object or other type + var arr []json.RawMessage + if err := json.Unmarshal([]byte(raw), &arr); err != nil { + return json.RawMessage("[]") } - return json.RawMessage("[]") + return json.RawMessage(raw) +} + +// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON +// array string, returning only items with visibility != "admin". +func filterUserVisibleMenuItems(raw string) json.RawMessage { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return json.RawMessage("[]") + } + var items []struct { + Visibility string `json:"visibility"` + } + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return json.RawMessage("[]") + } + + // Parse full items to preserve all fields + var fullItems []json.RawMessage + if err := json.Unmarshal([]byte(raw), &fullItems); err != nil { + return json.RawMessage("[]") + } + + var filtered []json.RawMessage + for i, item := range items { + if item.Visibility != "admin" { + filtered = append(filtered, fullItems[i]) + } + } + if len(filtered) == 0 { + return json.RawMessage("[]") + } + result, err := json.Marshal(filtered) + if err != nil { + return json.RawMessage("[]") + } + return result +} + +// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url +// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection. +func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) { + settings, err := s.GetPublicSettings(ctx) + if err != nil { + return nil, err + } + + seen := make(map[string]struct{}) + var origins []string + + addOrigin := func(rawURL string) { + if origin := extractOriginFromURL(rawURL); origin != "" { + if _, ok := seen[origin]; !ok { + seen[origin] = struct{}{} + origins = append(origins, origin) + } + } + } + + // purchase subscription URL + if settings.PurchaseSubscriptionEnabled { + addOrigin(settings.PurchaseSubscriptionURL) + } + + // all custom menu items (including admin-only, since CSP must allow all iframes) + for _, item := range parseCustomMenuItemURLs(settings.CustomMenuItems) { + addOrigin(item) + } + + return origins, nil +} + +// extractOriginFromURL returns the scheme+host origin from rawURL. +// Only http and https schemes are accepted. +func extractOriginFromURL(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "" + } + u, err := url.Parse(rawURL) + if err != nil || u.Host == "" { + return "" + } + if u.Scheme != "http" && u.Scheme != "https" { + return "" + } + return u.Scheme + "://" + u.Host +} + +// parseCustomMenuItemURLs extracts URLs from a raw JSON array of custom menu items. +func parseCustomMenuItemURLs(raw string) []string { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return nil + } + var items []struct { + URL string `json:"url"` + } + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil + } + urls := make([]string, 0, len(items)) + for _, item := range items { + if item.URL != "" { + urls = append(urls, item.URL) + } + } + return urls } // UpdateSettings 更新系统设置 From 451a85111882ead466382ce8a3b25c8678e8422f Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 3 Mar 2026 07:13:08 +0800 Subject: [PATCH 012/286] fix: remove unused sanitizeCustomMenuItemsJSON function Replaced by filterUserVisibleMenuItems which includes both array validation and admin-item filtering. Co-Authored-By: Claude Opus 4.6 --- backend/internal/service/setting_service.go | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index a2bb06a4..3809c9d0 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -244,22 +244,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any }, nil } -// sanitizeCustomMenuItemsJSON validates a raw JSON string is a valid JSON array -// and returns it as json.RawMessage. Returns "[]" if the input is empty, not a -// valid JSON array, or is a non-array JSON value (e.g. object, string). -func sanitizeCustomMenuItemsJSON(raw string) json.RawMessage { - raw = strings.TrimSpace(raw) - if raw == "" || raw == "[]" { - return json.RawMessage("[]") - } - // Verify it's actually a JSON array, not an object or other type - var arr []json.RawMessage - if err := json.Unmarshal([]byte(raw), &arr); err != nil { - return json.RawMessage("[]") - } - return json.RawMessage(raw) -} - // filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON // array string, returning only items with visibility != "admin". func filterUserVisibleMenuItems(raw string) json.RawMessage { From 5ba71cd2f1a6ab7d08f85b03ba68243b76e648e1 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 3 Mar 2026 10:45:35 +0800 Subject: [PATCH 013/286] fix(frontend): admin custom menu items not showing in sidebar The public settings API filters out menu items with visibility='admin', so customMenuItemsForAdmin was always empty when reading from cachedPublicSettings. Fix by loading custom menu items from the admin settings API (via adminSettingsStore) which returns all items unfiltered. Changes: - adminSettings store: store custom_menu_items from admin settings API - AppSidebar: read admin menu items from adminSettingsStore instead of cachedPublicSettings - CustomPageView: merge public and admin menu items so admin users can access admin-only custom pages --- frontend/src/components/layout/AppSidebar.vue | 2 +- frontend/src/stores/adminSettings.ts | 7 ++++ frontend/src/views/user/CustomPageView.vue | 33 ++++++++++++++----- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index dcfc60bb..f6e31f36 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -579,7 +579,7 @@ const customMenuItemsForUser = computed(() => { }) const customMenuItemsForAdmin = computed(() => { - const items = appStore.cachedPublicSettings?.custom_menu_items ?? [] + const items = adminSettingsStore.customMenuItems ?? [] return items .filter((item) => item.visibility === 'admin') .sort((a, b) => a.sort_order - b.sort_order) diff --git a/frontend/src/stores/adminSettings.ts b/frontend/src/stores/adminSettings.ts index 460cc92b..3696b560 100644 --- a/frontend/src/stores/adminSettings.ts +++ b/frontend/src/stores/adminSettings.ts @@ -1,6 +1,7 @@ import { defineStore } from 'pinia' import { ref } from 'vue' import { adminAPI } from '@/api' +import type { CustomMenuItem } from '@/types' export const useAdminSettingsStore = defineStore('adminSettings', () => { const loaded = ref(false) @@ -43,6 +44,9 @@ export const useAdminSettingsStore = defineStore('adminSettings', () => { } } + // Custom menu items (all items including admin-only, loaded from admin settings API) + const customMenuItems = ref([]) + // Default open, but honor cached value to reduce UI flicker on first paint. const opsMonitoringEnabled = ref(readCachedBool('ops_monitoring_enabled_cached', true)) const opsRealtimeMonitoringEnabled = ref(readCachedBool('ops_realtime_monitoring_enabled_cached', true)) @@ -64,6 +68,8 @@ export const useAdminSettingsStore = defineStore('adminSettings', () => { opsQueryModeDefault.value = settings.ops_query_mode_default || 'auto' writeCachedString('ops_query_mode_default_cached', opsQueryModeDefault.value) + customMenuItems.value = settings.custom_menu_items ?? [] + loaded.value = true } catch (err) { // Keep cached/default value: do not "flip" the UI based on a transient fetch failure. @@ -122,6 +128,7 @@ export const useAdminSettingsStore = defineStore('adminSettings', () => { opsMonitoringEnabled, opsRealtimeMonitoringEnabled, opsQueryModeDefault, + customMenuItems, fetch, setOpsMonitoringEnabledLocal, setOpsRealtimeMonitoringEnabledLocal, diff --git a/frontend/src/views/user/CustomPageView.vue b/frontend/src/views/user/CustomPageView.vue index ed1c11d7..daea29f1 100644 --- a/frontend/src/views/user/CustomPageView.vue +++ b/frontend/src/views/user/CustomPageView.vue @@ -70,6 +70,7 @@ import { useRoute } from 'vue-router' import { useI18n } from 'vue-i18n' import { useAppStore } from '@/stores' import { useAuthStore } from '@/stores/auth' +import { useAdminSettingsStore } from '@/stores/adminSettings' import AppLayout from '@/components/layout/AppLayout.vue' import Icon from '@/components/icons/Icon.vue' import { buildEmbeddedUrl, detectTheme } from '@/utils/embedded-url' @@ -78,6 +79,7 @@ const { t } = useI18n() const route = useRoute() const appStore = useAppStore() const authStore = useAuthStore() +const adminSettingsStore = useAdminSettingsStore() const loading = ref(false) const pageTheme = ref<'light' | 'dark'>('light') @@ -86,8 +88,15 @@ let themeObserver: MutationObserver | null = null const menuItemId = computed(() => route.params.id as string) const menuItem = computed(() => { - const items = appStore.cachedPublicSettings?.custom_menu_items ?? [] - const found = items.find((item) => item.id === menuItemId.value) ?? null + const publicItems = appStore.cachedPublicSettings?.custom_menu_items ?? [] + const adminItems = authStore.isAdmin ? (adminSettingsStore.customMenuItems ?? []) : [] + const allItems = [...publicItems] + for (const item of adminItems) { + if (!allItems.some((existing) => existing.id === item.id)) { + allItems.push(item) + } + } + const found = allItems.find((item) => item.id === menuItemId.value) ?? null if (found && found.visibility === 'admin' && !authStore.isAdmin) { return null } @@ -122,12 +131,20 @@ onMounted(async () => { }) } - if (appStore.publicSettingsLoaded) return - loading.value = true - try { - await appStore.fetchPublicSettings() - } finally { - loading.value = false + const promises: Promise[] = [] + if (!appStore.publicSettingsLoaded) { + promises.push(appStore.fetchPublicSettings()) + } + if (authStore.isAdmin) { + promises.push(adminSettingsStore.fetch()) + } + if (promises.length > 0) { + loading.value = true + try { + await Promise.all(promises) + } finally { + loading.value = false + } } }) From 99f1e3ff3591e3c7977dc27fce99c7a018f94c60 Mon Sep 17 00:00:00 2001 From: ius Date: Tue, 3 Mar 2026 11:01:22 +0800 Subject: [PATCH 014/286] fix(migrations): avoid startup outage from 061 full-table backfill --- .../internal/repository/migrations_runner.go | 6 +++ .../migrations_runner_checksum_test.go | 9 ++++ .../061_add_usage_log_request_type.sql | 50 ++++++++++++++++--- 3 files changed, 58 insertions(+), 7 deletions(-) diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index a60ba294..017d578c 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -66,6 +66,12 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {}, }, }, + "061_add_usage_log_request_type.sql": { + fileChecksum: "97bdd9a32d921986f74a0231ab90735567a9234fb7062f4d9d1baf108ba59769", + acceptedDBChecksum: map[string]struct{}{ + "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {}, + }, + }, } // ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。 diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go index 54f5b0ec..67cad963 100644 --- a/backend/internal/repository/migrations_runner_checksum_test.go +++ b/backend/internal/repository/migrations_runner_checksum_test.go @@ -25,6 +25,15 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { require.False(t, ok) }) + t.Run("061历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "061_add_usage_log_request_type.sql", + "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", + "97bdd9a32d921986f74a0231ab90735567a9234fb7062f4d9d1baf108ba59769", + ) + require.True(t, ok) + }) + t.Run("非白名单迁移不兼容", func(t *testing.T) { ok := isMigrationChecksumCompatible( "001_init.sql", diff --git a/backend/migrations/061_add_usage_log_request_type.sql b/backend/migrations/061_add_usage_log_request_type.sql index 68a33d51..d2a9f446 100644 --- a/backend/migrations/061_add_usage_log_request_type.sql +++ b/backend/migrations/061_add_usage_log_request_type.sql @@ -19,11 +19,47 @@ $$; CREATE INDEX IF NOT EXISTS idx_usage_logs_request_type_created_at ON usage_logs (request_type, created_at); --- Backfill from legacy fields. openai_ws_mode has higher priority than stream. -UPDATE usage_logs -SET request_type = CASE - WHEN openai_ws_mode = TRUE THEN 3 - WHEN stream = TRUE THEN 2 - ELSE 1 +-- Backfill from legacy fields in bounded batches. +-- Why bounded: +-- 1) Full-table UPDATE on large usage_logs can block startup for a long time. +-- 2) request_type=0 rows remain query-compatible via legacy fallback logic +-- (stream/openai_ws_mode) in repository filters. +-- 3) Subsequent writes will use explicit request_type and gradually dilute +-- historical unknown rows. +-- +-- openai_ws_mode has higher priority than stream. +DO $$ +DECLARE + v_rows INTEGER := 0; + v_total_rows INTEGER := 0; + v_batch_size INTEGER := 5000; + v_started_at TIMESTAMPTZ := clock_timestamp(); + v_max_duration INTERVAL := INTERVAL '8 seconds'; +BEGIN + LOOP + WITH batch AS ( + SELECT id + FROM usage_logs + WHERE request_type = 0 + ORDER BY id + LIMIT v_batch_size + ) + UPDATE usage_logs ul + SET request_type = CASE + WHEN ul.openai_ws_mode = TRUE THEN 3 + WHEN ul.stream = TRUE THEN 2 + ELSE 1 + END + FROM batch + WHERE ul.id = batch.id; + + GET DIAGNOSTICS v_rows = ROW_COUNT; + EXIT WHEN v_rows = 0; + + v_total_rows := v_total_rows + v_rows; + EXIT WHEN clock_timestamp() - v_started_at >= v_max_duration; + END LOOP; + + RAISE NOTICE 'usage_logs.request_type startup backfill rows=%', v_total_rows; END -WHERE request_type = 0; +$$; From 7be8f4dc6e0d63c299efcf100cf6e966ac9c1d56 Mon Sep 17 00:00:00 2001 From: xvhuan <50939413+xvhuan@users.noreply.github.com> Date: Tue, 3 Mar 2026 11:49:24 +0800 Subject: [PATCH 015/286] perf(admin-dashboard): accelerate trend load with pre-aggregation and async user trend --- backend/internal/repository/usage_log_repo.go | 79 +++++++++++++++++++ frontend/src/views/admin/DashboardView.vue | 36 +++++++-- 2 files changed, 110 insertions(+), 5 deletions(-) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index d30cc7dd..ff40e97d 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -1655,6 +1655,13 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe // GetUsageTrendWithFilters returns usage trend data with optional filters func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { + if shouldUsePreaggregatedTrend(granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) { + aggregated, aggregatedErr := r.getUsageTrendFromAggregates(ctx, startTime, endTime, granularity) + if aggregatedErr == nil && len(aggregated) > 0 { + return aggregated, nil + } + } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` @@ -1719,6 +1726,78 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start return results, nil } +func shouldUsePreaggregatedTrend(granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) bool { + if granularity != "day" && granularity != "hour" { + return false + } + return userID == 0 && + apiKeyID == 0 && + accountID == 0 && + groupID == 0 && + model == "" && + requestType == nil && + stream == nil && + billingType == nil +} + +func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { + dateFormat := safeDateFormat(granularity) + query := "" + args := []any{startTime, endTime} + + switch granularity { + case "hour": + query = fmt.Sprintf(` + SELECT + TO_CHAR(bucket_start, '%s') as date, + total_requests as requests, + input_tokens, + output_tokens, + (cache_creation_tokens + cache_read_tokens) as cache_tokens, + (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens, + total_cost as cost, + actual_cost + FROM usage_dashboard_hourly + WHERE bucket_start >= $1 AND bucket_start < $2 + ORDER BY bucket_start ASC + `, dateFormat) + case "day": + query = fmt.Sprintf(` + SELECT + TO_CHAR(bucket_date::timestamp, '%s') as date, + total_requests as requests, + input_tokens, + output_tokens, + (cache_creation_tokens + cache_read_tokens) as cache_tokens, + (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens, + total_cost as cost, + actual_cost + FROM usage_dashboard_daily + WHERE bucket_date >= $1::date AND bucket_date < $2::date + ORDER BY bucket_date ASC + `, dateFormat) + default: + return nil, nil + } + + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results, err = scanTrendRows(rows) + if err != nil { + return nil, err + } + return results, nil +} + // GetModelStatsWithFilters returns model statistics with optional filters func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" diff --git a/frontend/src/views/admin/DashboardView.vue b/frontend/src/views/admin/DashboardView.vue index d4f1fbb0..ac5c44af 100644 --- a/frontend/src/views/admin/DashboardView.vue +++ b/frontend/src/views/admin/DashboardView.vue @@ -246,7 +246,10 @@ {{ t('admin.dashboard.recentUsage') }} (Top 12)
- +
+ +
+
(null) const loading = ref(false) const chartsLoading = ref(false) +const userTrendLoading = ref(false) // Chart data const trendData = ref([]) const modelStats = ref([]) const userTrend = ref([]) +let chartLoadSeq = 0 // Helper function to format date in local timezone const formatLocalDate = (date: Date): string => { @@ -531,7 +536,9 @@ const loadDashboardStats = async () => { } const loadChartData = async () => { + const currentSeq = ++chartLoadSeq chartsLoading.value = true + userTrendLoading.value = true try { const params = { start_date: startDate.value, @@ -539,20 +546,39 @@ const loadChartData = async () => { granularity: granularity.value } - const [trendResponse, modelResponse, userResponse] = await Promise.all([ + const [trendResponse, modelResponse] = await Promise.all([ adminAPI.dashboard.getUsageTrend(params), - adminAPI.dashboard.getModelStats({ start_date: startDate.value, end_date: endDate.value }), - adminAPI.dashboard.getUserUsageTrend({ ...params, limit: 12 }) + adminAPI.dashboard.getModelStats({ start_date: startDate.value, end_date: endDate.value }) ]) + if (currentSeq !== chartLoadSeq) return trendData.value = trendResponse.trend || [] modelStats.value = modelResponse.models || [] - userTrend.value = userResponse.trend || [] } catch (error) { + if (currentSeq !== chartLoadSeq) return console.error('Error loading chart data:', error) } finally { + if (currentSeq !== chartLoadSeq) return chartsLoading.value = false } + + try { + const params = { + start_date: startDate.value, + end_date: endDate.value, + granularity: granularity.value, + limit: 12 + } + const userResponse = await adminAPI.dashboard.getUserUsageTrend(params) + if (currentSeq !== chartLoadSeq) return + userTrend.value = userResponse.trend || [] + } catch (error) { + if (currentSeq !== chartLoadSeq) return + console.error('Error loading user trend:', error) + } finally { + if (currentSeq !== chartLoadSeq) return + userTrendLoading.value = false + } } onMounted(() => { From 530a16291cf50942cc8254721d72dbb90ca4d422 Mon Sep 17 00:00:00 2001 From: QTom Date: Tue, 3 Mar 2026 13:10:26 +0800 Subject: [PATCH 016/286] =?UTF-8?q?fix(gateway):=20=E5=88=86=E7=BB=84?= =?UTF-8?q?=E9=9A=94=E7=A6=BB=20=E2=80=94=20=E7=A6=81=E6=AD=A2=E6=9C=AA?= =?UTF-8?q?=E5=88=86=E7=BB=84=E8=B4=A6=E5=8F=B7=E8=A2=AB=E8=B7=A8=E7=BB=84?= =?UTF-8?q?=E8=B0=83=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 当 API Key 无分组时,调度仅从未分组账号池中选取。 修复 isAccountInGroup 在 groupID==nil 时的逻辑, 同时补全 scheduler_snapshot_service 和 gemini_compat_service 中的 SimpleMode 保护,确保分组隔离在所有调度路径生效。 新增 ListSchedulableUngroupedByPlatform/s 方法, 使用 Ent 的 Not(HasAccountGroups()) 谓词实现未分组账号隔离。 新增 17 个单元和端到端隔离测试,覆盖所有分支和边界条件。 --- .../handler/sora_client_handler_test.go | 6 + .../handler/sora_gateway_handler_test.go | 6 + backend/internal/repository/account_repo.go | 45 +++ backend/internal/server/api_contract_test.go | 8 + backend/internal/service/account_service.go | 2 + .../service/account_service_delete_test.go | 8 + .../service/gateway_group_isolation_test.go | 363 ++++++++++++++++++ .../service/gateway_multiplatform_test.go | 6 + backend/internal/service/gateway_service.go | 15 +- .../service/gemini_messages_compat_service.go | 5 +- .../service/gemini_multiplatform_test.go | 6 + .../service/openai_gateway_service.go | 2 +- .../service/openai_gateway_service_test.go | 4 + .../service/scheduler_snapshot_service.go | 9 +- 14 files changed, 475 insertions(+), 10 deletions(-) create mode 100644 backend/internal/service/gateway_group_isolation_test.go diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index 5df7fa0a..c2284ce2 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2089,6 +2089,12 @@ func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) { return r.accounts, nil } +func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) { + return r.accounts, nil +} func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error { return nil } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 355cdb7a..59ac34b1 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -182,6 +182,12 @@ func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platfo func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { return r.ListSchedulableByPlatforms(ctx, platforms) } +func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return r.ListSchedulableByPlatform(ctx, platform) +} +func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + return r.ListSchedulableByPlatforms(ctx, platforms) +} func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 4aa74928..0669cbbd 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -829,6 +829,51 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat return r.accountsToService(ctx, accounts) } +func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + now := time.Now() + accounts, err := r.client.Account.Query(). + Where( + dbaccount.PlatformEQ(platform), + dbaccount.StatusEQ(service.StatusActive), + dbaccount.SchedulableEQ(true), + dbaccount.Not(dbaccount.HasAccountGroups()), + tempUnschedulablePredicate(), + notExpiredPredicate(now), + dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), + dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), + ). + Order(dbent.Asc(dbaccount.FieldPriority)). + All(ctx) + if err != nil { + return nil, err + } + return r.accountsToService(ctx, accounts) +} + +func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + if len(platforms) == 0 { + return nil, nil + } + now := time.Now() + accounts, err := r.client.Account.Query(). + Where( + dbaccount.PlatformIn(platforms...), + dbaccount.StatusEQ(service.StatusActive), + dbaccount.SchedulableEQ(true), + dbaccount.Not(dbaccount.HasAccountGroups()), + tempUnschedulablePredicate(), + notExpiredPredicate(now), + dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), + dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), + ). + Order(dbent.Asc(dbaccount.FieldPriority)). + All(ctx) + if err != nil { + return nil, err + } + return r.accountsToService(ctx, accounts) +} + func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { if len(platforms) == 0 { return nil, nil diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index f15a2074..446ee20d 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -1027,6 +1027,14 @@ func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte return nil, errors.New("not implemented") } +func (s *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return errors.New("not implemented") } diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index a3707184..18a70c5c 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -54,6 +54,8 @@ type AccountRepository interface { ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) + ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) + ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index a466b68a..768cf7b7 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -147,6 +147,14 @@ func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte panic("unexpected ListSchedulableByGroupIDAndPlatforms call") } +func (s *accountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + panic("unexpected ListSchedulableUngroupedByPlatform call") +} + +func (s *accountRepoStub) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + panic("unexpected ListSchedulableUngroupedByPlatforms call") +} + func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { panic("unexpected SetRateLimited call") } diff --git a/backend/internal/service/gateway_group_isolation_test.go b/backend/internal/service/gateway_group_isolation_test.go new file mode 100644 index 00000000..00508f0e --- /dev/null +++ b/backend/internal/service/gateway_group_isolation_test.go @@ -0,0 +1,363 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// ============================================================================ +// Part 1: isAccountInGroup 单元测试 +// ============================================================================ + +func TestIsAccountInGroup(t *testing.T) { + svc := &GatewayService{} + groupID100 := int64(100) + groupID200 := int64(200) + + tests := []struct { + name string + account *Account + groupID *int64 + expected bool + }{ + // groupID == nil(无分组 API Key) + { + "nil_groupID_ungrouped_account_nil_groups", + &Account{ID: 1, AccountGroups: nil}, + nil, true, + }, + { + "nil_groupID_ungrouped_account_empty_slice", + &Account{ID: 2, AccountGroups: []AccountGroup{}}, + nil, true, + }, + { + "nil_groupID_grouped_account_single", + &Account{ID: 3, AccountGroups: []AccountGroup{{GroupID: 100}}}, + nil, false, + }, + { + "nil_groupID_grouped_account_multiple", + &Account{ID: 4, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}}, + nil, false, + }, + // groupID != nil(有分组 API Key) + { + "with_groupID_account_in_group", + &Account{ID: 5, AccountGroups: []AccountGroup{{GroupID: 100}}}, + &groupID100, true, + }, + { + "with_groupID_account_not_in_group", + &Account{ID: 6, AccountGroups: []AccountGroup{{GroupID: 200}}}, + &groupID100, false, + }, + { + "with_groupID_ungrouped_account", + &Account{ID: 7, AccountGroups: nil}, + &groupID100, false, + }, + { + "with_groupID_multi_group_account_match_one", + &Account{ID: 8, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}}, + &groupID200, true, + }, + { + "with_groupID_multi_group_account_no_match", + &Account{ID: 9, AccountGroups: []AccountGroup{{GroupID: 300}, {GroupID: 400}}}, + &groupID100, false, + }, + // 防御性边界 + { + "nil_account_nil_groupID", + nil, + nil, false, + }, + { + "nil_account_with_groupID", + nil, + &groupID100, false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.isAccountInGroup(tt.account, tt.groupID) + require.Equal(t, tt.expected, got, "isAccountInGroup 结果不符预期") + }) + } +} + +// ============================================================================ +// Part 2: 分组隔离端到端调度测试 +// ============================================================================ + +// groupAwareMockAccountRepo 嵌入 mockAccountRepoForPlatform,覆写分组隔离相关方法。 +// allAccounts 存储所有账号,分组查询方法按 AccountGroups 字段进行真实过滤。 +type groupAwareMockAccountRepo struct { + *mockAccountRepoForPlatform + allAccounts []Account +} + +// ListSchedulableUngroupedByPlatform 仅返回未分组账号(AccountGroups 为空) +func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + var result []Account + for _, acc := range m.allAccounts { + if acc.Platform == platform && acc.IsSchedulable() && len(acc.AccountGroups) == 0 { + result = append(result, acc) + } + } + return result, nil +} + +// ListSchedulableUngroupedByPlatforms 仅返回未分组账号(多平台版本) +func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + platformSet := make(map[string]bool, len(platforms)) + for _, p := range platforms { + platformSet[p] = true + } + var result []Account + for _, acc := range m.allAccounts { + if platformSet[acc.Platform] && acc.IsSchedulable() && len(acc.AccountGroups) == 0 { + result = append(result, acc) + } + } + return result, nil +} + +// ListSchedulableByGroupIDAndPlatform 返回属于指定分组的账号 +func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { + var result []Account + for _, acc := range m.allAccounts { + if acc.Platform == platform && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) { + result = append(result, acc) + } + } + return result, nil +} + +// ListSchedulableByGroupIDAndPlatforms 返回属于指定分组的账号(多平台版本) +func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { + platformSet := make(map[string]bool, len(platforms)) + for _, p := range platforms { + platformSet[p] = true + } + var result []Account + for _, acc := range m.allAccounts { + if platformSet[acc.Platform] && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) { + result = append(result, acc) + } + } + return result, nil +} + +// accountBelongsToGroup 检查账号是否属于指定分组 +func accountBelongsToGroup(acc Account, groupID int64) bool { + for _, ag := range acc.AccountGroups { + if ag.GroupID == groupID { + return true + } + } + return false +} + +// Verify interface implementation +var _ AccountRepository = (*groupAwareMockAccountRepo)(nil) + +// newGroupAwareMockRepo 创建分组感知的 mock repo +func newGroupAwareMockRepo(accounts []Account) *groupAwareMockAccountRepo { + byID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + byID[accounts[i].ID] = &accounts[i] + } + return &groupAwareMockAccountRepo{ + mockAccountRepoForPlatform: &mockAccountRepoForPlatform{ + accounts: accounts, + accountsByID: byID, + }, + allAccounts: accounts, + } +} + +func TestGroupIsolation_UngroupedKey_ShouldNotScheduleGroupedAccounts(t *testing.T) { + // 场景:无分组 API Key(groupID=nil),池中只有已分组账号 → 应返回错误 + ctx := context.Background() + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 200}}}, + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.Error(t, err, "无分组 Key 不应调度到已分组账号") + require.Nil(t, acc) +} + +func TestGroupIsolation_GroupedKey_ShouldNotScheduleUngroupedAccounts(t *testing.T) { + // 场景:有分组 API Key(groupID=100),池中只有未分组账号 → 应返回错误 + ctx := context.Background() + groupID := int64(100) + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{}}, + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI) + require.Error(t, err, "有分组 Key 不应调度到未分组账号") + require.Nil(t, acc) +} + +func TestGroupIsolation_UngroupedKey_ShouldOnlyScheduleUngroupedAccounts(t *testing.T) { + // 场景:无分组 API Key(groupID=nil),池中有未分组和已分组账号 → 应只选中未分组的 + ctx := context.Background() + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组,不应被选中 + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, // 未分组,应被选中 + {ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 200}}}, // 已分组,不应被选中 + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "应成功调度未分组账号") + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "应选中未分组的账号 ID=2") +} + +func TestGroupIsolation_GroupedKey_ShouldOnlyScheduleMatchingGroupAccounts(t *testing.T) { + // 场景:有分组 API Key(groupID=100),池中有未分组和多个分组账号 → 应只选中分组 100 内的 + ctx := context.Background() + groupID := int64(100) + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, // 未分组,不应被选中 + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 200}}}, // 属于分组 200,不应被选中 + {ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, // 属于分组 100,应被选中 + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "应成功调度分组内账号") + require.NotNil(t, acc) + require.Equal(t, int64(3), acc.ID, "应选中分组 100 内的账号 ID=3") +} + +// ============================================================================ +// Part 3: SimpleMode 旁路测试 +// ============================================================================ + +func TestGroupIsolation_SimpleMode_SkipsGroupIsolation(t *testing.T) { + // SimpleMode 应跳过分组隔离,使用 ListSchedulableByPlatform 返回所有账号。 + // 测试非 useMixed 路径(platform=openai,不会触发 mixed 调度逻辑)。 + ctx := context.Background() + + // 混合未分组和已分组账号,SimpleMode 下应全部可调度 + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组 + {ID: 2, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, // 未分组 + } + + // 使用基础 mock(ListSchedulableByPlatform 返回所有匹配平台的账号,不做分组过滤) + byID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + byID[accounts[i].ID] = &accounts[i] + } + repo := &mockAccountRepoForPlatform{ + accounts: accounts, + accountsByID: byID, + } + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: &config.Config{RunMode: config.RunModeSimple}, + } + + // groupID=nil 时,SimpleMode 应使用 ListSchedulableByPlatform(不过滤分组) + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "SimpleMode 应跳过分组隔离直接返回账号") + require.NotNil(t, acc) + // 应选择优先级最高的账号(Priority=1, ID=2),即使它未分组 + require.Equal(t, int64(2), acc.ID, "SimpleMode 应按优先级选择,不考虑分组") +} + +func TestGroupIsolation_SimpleMode_GroupedAccountAlsoSchedulable(t *testing.T) { + // SimpleMode + groupID=nil 时,已分组账号也应该可被调度 + ctx := context.Background() + + // 只有已分组账号,在 standard 模式下 groupID=nil 会报错,但 simple 模式应正常 + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, + } + + byID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + byID[accounts[i].ID] = &accounts[i] + } + repo := &mockAccountRepoForPlatform{ + accounts: accounts, + accountsByID: byID, + } + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: &config.Config{RunMode: config.RunModeSimple}, + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "SimpleMode 下已分组账号也应可调度") + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID, "SimpleMode 应能调度已分组账号") +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 067a0e08..1cb3c61e 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -147,6 +147,12 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByPlatforms(ctx context.Cont func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { return m.ListSchedulableByPlatforms(ctx, platforms) } +func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return m.ListSchedulableByPlatform(ctx, platform) +} +func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + return m.ListSchedulableByPlatforms(ctx, platforms) +} func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 48c69881..fa9a3cb1 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1782,8 +1782,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i var err error if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) - } else { + } else if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms) } if err != nil { slog.Debug("account_scheduling_list_failed", @@ -1824,7 +1826,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) // 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询 } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, platform) } if err != nil { slog.Debug("account_scheduling_list_failed", @@ -1964,14 +1966,15 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte } // isAccountInGroup checks if the account belongs to the specified group. -// Returns true if groupID is nil (no group restriction) or account belongs to the group. +// When groupID is nil, returns true only for ungrouped accounts (no group assignments). func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { - if groupID == nil { - return true // 无分组限制 - } if account == nil { return false } + if groupID == nil { + // 无分组的 API Key 只能使用未分组的账号 + return len(account.AccountGroups) == 0 + } for _, ag := range account.AccountGroups { if ag.GroupID == *groupID { return true diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 1c38b6c2..a003f636 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -431,7 +431,10 @@ func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Co if groupID != nil { return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms) } - return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms) + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms) + } + return s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, queryPlatforms) } func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) { diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 86bc9476..9476e984 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -138,6 +138,12 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont } return m.ListSchedulableByPlatforms(ctx, platforms) } +func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return m.ListSchedulableByPlatform(ctx, platform) +} +func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + return m.ListSchedulableByPlatforms(ctx, platforms) +} func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f624d92a..02db384f 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1343,7 +1343,7 @@ func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, grou } else if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, PlatformOpenAI) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 89443b69..4f5f7f3c 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -57,6 +57,10 @@ func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, pl return result, nil } +func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return r.ListSchedulableByPlatform(ctx, platform) +} + type stubConcurrencyCache struct { ConcurrencyCache loadBatchErr error diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 9f8fa14a..4c9540f1 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -605,8 +605,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke var err error if groupID > 0 { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms) - } else { + } else if s.isRunModeSimple() { accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms) } if err != nil { return nil, err @@ -624,7 +626,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke if groupID > 0 { return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform) } - return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform) + if s.isRunModeSimple() { + return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform) + } + return s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, bucket.Platform) } func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket { From a80ec5d8bb3da941438dddeb21fed73b45814c57 Mon Sep 17 00:00:00 2001 From: shaw Date: Tue, 3 Mar 2026 15:01:10 +0800 Subject: [PATCH 017/286] =?UTF-8?q?feat:=20apikey=E6=94=AF=E6=8C=815h/1d/7?= =?UTF-8?q?d=E9=80=9F=E7=8E=87=E6=8E=A7=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/cmd/server/wire_gen.go | 5 +- backend/cmd/server/wire_gen_test.go | 2 +- backend/ent/apikey.go | 112 ++- backend/ent/apikey/apikey.go | 84 ++ backend/ent/apikey/where.go | 435 ++++++++++ backend/ent/apikey_create.go | 744 ++++++++++++++++++ backend/ent/apikey_update.go | 480 +++++++++++ backend/ent/migrate/schema.go | 17 +- backend/ent/mutation.go | 743 ++++++++++++++++- backend/ent/runtime/runtime.go | 24 + backend/ent/schema/api_key.go | 41 + backend/go.mod | 2 - backend/internal/handler/api_key_handler.go | 32 +- backend/internal/handler/dto/mappers.go | 41 +- backend/internal/handler/dto/types.go | 11 + backend/internal/handler/gateway_handler.go | 12 + backend/internal/repository/api_key_repo.go | 128 ++- backend/internal/repository/billing_cache.go | 105 ++- .../middleware/api_key_auth_google_test.go | 9 + backend/internal/service/api_key.go | 16 + .../internal/service/api_key_auth_cache.go | 5 + .../service/api_key_auth_cache_impl.go | 6 + backend/internal/service/api_key_service.go | 103 ++- .../internal/service/billing_cache_service.go | 185 ++++- .../service/billing_cache_service_test.go | 20 +- backend/internal/service/billing_service.go | 16 + backend/internal/service/gateway_service.go | 19 +- .../service/openai_gateway_service.go | 8 + frontend/src/api/keys.ts | 13 +- frontend/src/i18n/locales/en.ts | 13 + frontend/src/i18n/locales/zh.ts | 13 + frontend/src/types/index.ts | 16 + frontend/src/views/user/KeysView.vue | 338 +++++++- 33 files changed, 3715 insertions(+), 83 deletions(-) diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 2e9afc26..cbeb9a69 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -58,11 +58,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { promoCodeRepository := repository.NewPromoCodeRepository(client) billingCache := repository.NewBillingCache(redisClient) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) - billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) - apiKeyRepository := repository.NewAPIKeyRepository(client) + apiKeyRepository := repository.NewAPIKeyRepository(client, db) + billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig) userGroupRateRepository := repository.NewUserGroupRateRepository(db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) + apiKeyService.SetRateLimitCacheInvalidator(billingCache) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go index 9fb9888d..bd2e7f90 100644 --- a/backend/cmd/server/wire_gen_test.go +++ b/backend/cmd/server/wire_gen_test.go @@ -42,7 +42,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second) pricingSvc := service.NewPricingService(cfg, nil) emailQueueSvc := service.NewEmailQueueService(nil, 1) - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg) schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg) opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil) diff --git a/backend/ent/apikey.go b/backend/ent/apikey.go index 760851c8..9ee660c2 100644 --- a/backend/ent/apikey.go +++ b/backend/ent/apikey.go @@ -48,6 +48,24 @@ type APIKey struct { QuotaUsed float64 `json:"quota_used,omitempty"` // Expiration time for this API key (null = never expires) ExpiresAt *time.Time `json:"expires_at,omitempty"` + // Rate limit in USD per 5 hours (0 = unlimited) + RateLimit5h float64 `json:"rate_limit_5h,omitempty"` + // Rate limit in USD per day (0 = unlimited) + RateLimit1d float64 `json:"rate_limit_1d,omitempty"` + // Rate limit in USD per 7 days (0 = unlimited) + RateLimit7d float64 `json:"rate_limit_7d,omitempty"` + // Used amount in USD for the current 5h window + Usage5h float64 `json:"usage_5h,omitempty"` + // Used amount in USD for the current 1d window + Usage1d float64 `json:"usage_1d,omitempty"` + // Used amount in USD for the current 7d window + Usage7d float64 `json:"usage_7d,omitempty"` + // Start time of the current 5h rate limit window + Window5hStart *time.Time `json:"window_5h_start,omitempty"` + // Start time of the current 1d rate limit window + Window1dStart *time.Time `json:"window_1d_start,omitempty"` + // Start time of the current 7d rate limit window + Window7dStart *time.Time `json:"window_7d_start,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the APIKeyQuery when eager-loading is set. Edges APIKeyEdges `json:"edges"` @@ -105,13 +123,13 @@ func (*APIKey) scanValues(columns []string) ([]any, error) { switch columns[i] { case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist: values[i] = new([]byte) - case apikey.FieldQuota, apikey.FieldQuotaUsed: + case apikey.FieldQuota, apikey.FieldQuotaUsed, apikey.FieldRateLimit5h, apikey.FieldRateLimit1d, apikey.FieldRateLimit7d, apikey.FieldUsage5h, apikey.FieldUsage1d, apikey.FieldUsage7d: values[i] = new(sql.NullFloat64) case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID: values[i] = new(sql.NullInt64) case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: values[i] = new(sql.NullString) - case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldLastUsedAt, apikey.FieldExpiresAt: + case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldLastUsedAt, apikey.FieldExpiresAt, apikey.FieldWindow5hStart, apikey.FieldWindow1dStart, apikey.FieldWindow7dStart: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -226,6 +244,63 @@ func (_m *APIKey) assignValues(columns []string, values []any) error { _m.ExpiresAt = new(time.Time) *_m.ExpiresAt = value.Time } + case apikey.FieldRateLimit5h: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field rate_limit_5h", values[i]) + } else if value.Valid { + _m.RateLimit5h = value.Float64 + } + case apikey.FieldRateLimit1d: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field rate_limit_1d", values[i]) + } else if value.Valid { + _m.RateLimit1d = value.Float64 + } + case apikey.FieldRateLimit7d: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field rate_limit_7d", values[i]) + } else if value.Valid { + _m.RateLimit7d = value.Float64 + } + case apikey.FieldUsage5h: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field usage_5h", values[i]) + } else if value.Valid { + _m.Usage5h = value.Float64 + } + case apikey.FieldUsage1d: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field usage_1d", values[i]) + } else if value.Valid { + _m.Usage1d = value.Float64 + } + case apikey.FieldUsage7d: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field usage_7d", values[i]) + } else if value.Valid { + _m.Usage7d = value.Float64 + } + case apikey.FieldWindow5hStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field window_5h_start", values[i]) + } else if value.Valid { + _m.Window5hStart = new(time.Time) + *_m.Window5hStart = value.Time + } + case apikey.FieldWindow1dStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field window_1d_start", values[i]) + } else if value.Valid { + _m.Window1dStart = new(time.Time) + *_m.Window1dStart = value.Time + } + case apikey.FieldWindow7dStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field window_7d_start", values[i]) + } else if value.Valid { + _m.Window7dStart = new(time.Time) + *_m.Window7dStart = value.Time + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -326,6 +401,39 @@ func (_m *APIKey) String() string { builder.WriteString("expires_at=") builder.WriteString(v.Format(time.ANSIC)) } + builder.WriteString(", ") + builder.WriteString("rate_limit_5h=") + builder.WriteString(fmt.Sprintf("%v", _m.RateLimit5h)) + builder.WriteString(", ") + builder.WriteString("rate_limit_1d=") + builder.WriteString(fmt.Sprintf("%v", _m.RateLimit1d)) + builder.WriteString(", ") + builder.WriteString("rate_limit_7d=") + builder.WriteString(fmt.Sprintf("%v", _m.RateLimit7d)) + builder.WriteString(", ") + builder.WriteString("usage_5h=") + builder.WriteString(fmt.Sprintf("%v", _m.Usage5h)) + builder.WriteString(", ") + builder.WriteString("usage_1d=") + builder.WriteString(fmt.Sprintf("%v", _m.Usage1d)) + builder.WriteString(", ") + builder.WriteString("usage_7d=") + builder.WriteString(fmt.Sprintf("%v", _m.Usage7d)) + builder.WriteString(", ") + if v := _m.Window5hStart; v != nil { + builder.WriteString("window_5h_start=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.Window1dStart; v != nil { + builder.WriteString("window_1d_start=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.Window7dStart; v != nil { + builder.WriteString("window_7d_start=") + builder.WriteString(v.Format(time.ANSIC)) + } builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/apikey/apikey.go b/backend/ent/apikey/apikey.go index 6abea56b..d398a027 100644 --- a/backend/ent/apikey/apikey.go +++ b/backend/ent/apikey/apikey.go @@ -43,6 +43,24 @@ const ( FieldQuotaUsed = "quota_used" // FieldExpiresAt holds the string denoting the expires_at field in the database. FieldExpiresAt = "expires_at" + // FieldRateLimit5h holds the string denoting the rate_limit_5h field in the database. + FieldRateLimit5h = "rate_limit_5h" + // FieldRateLimit1d holds the string denoting the rate_limit_1d field in the database. + FieldRateLimit1d = "rate_limit_1d" + // FieldRateLimit7d holds the string denoting the rate_limit_7d field in the database. + FieldRateLimit7d = "rate_limit_7d" + // FieldUsage5h holds the string denoting the usage_5h field in the database. + FieldUsage5h = "usage_5h" + // FieldUsage1d holds the string denoting the usage_1d field in the database. + FieldUsage1d = "usage_1d" + // FieldUsage7d holds the string denoting the usage_7d field in the database. + FieldUsage7d = "usage_7d" + // FieldWindow5hStart holds the string denoting the window_5h_start field in the database. + FieldWindow5hStart = "window_5h_start" + // FieldWindow1dStart holds the string denoting the window_1d_start field in the database. + FieldWindow1dStart = "window_1d_start" + // FieldWindow7dStart holds the string denoting the window_7d_start field in the database. + FieldWindow7dStart = "window_7d_start" // EdgeUser holds the string denoting the user edge name in mutations. EdgeUser = "user" // EdgeGroup holds the string denoting the group edge name in mutations. @@ -91,6 +109,15 @@ var Columns = []string{ FieldQuota, FieldQuotaUsed, FieldExpiresAt, + FieldRateLimit5h, + FieldRateLimit1d, + FieldRateLimit7d, + FieldUsage5h, + FieldUsage1d, + FieldUsage7d, + FieldWindow5hStart, + FieldWindow1dStart, + FieldWindow7dStart, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -129,6 +156,18 @@ var ( DefaultQuota float64 // DefaultQuotaUsed holds the default value on creation for the "quota_used" field. DefaultQuotaUsed float64 + // DefaultRateLimit5h holds the default value on creation for the "rate_limit_5h" field. + DefaultRateLimit5h float64 + // DefaultRateLimit1d holds the default value on creation for the "rate_limit_1d" field. + DefaultRateLimit1d float64 + // DefaultRateLimit7d holds the default value on creation for the "rate_limit_7d" field. + DefaultRateLimit7d float64 + // DefaultUsage5h holds the default value on creation for the "usage_5h" field. + DefaultUsage5h float64 + // DefaultUsage1d holds the default value on creation for the "usage_1d" field. + DefaultUsage1d float64 + // DefaultUsage7d holds the default value on creation for the "usage_7d" field. + DefaultUsage7d float64 ) // OrderOption defines the ordering options for the APIKey queries. @@ -199,6 +238,51 @@ func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() } +// ByRateLimit5h orders the results by the rate_limit_5h field. +func ByRateLimit5h(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRateLimit5h, opts...).ToFunc() +} + +// ByRateLimit1d orders the results by the rate_limit_1d field. +func ByRateLimit1d(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRateLimit1d, opts...).ToFunc() +} + +// ByRateLimit7d orders the results by the rate_limit_7d field. +func ByRateLimit7d(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRateLimit7d, opts...).ToFunc() +} + +// ByUsage5h orders the results by the usage_5h field. +func ByUsage5h(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUsage5h, opts...).ToFunc() +} + +// ByUsage1d orders the results by the usage_1d field. +func ByUsage1d(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUsage1d, opts...).ToFunc() +} + +// ByUsage7d orders the results by the usage_7d field. +func ByUsage7d(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUsage7d, opts...).ToFunc() +} + +// ByWindow5hStart orders the results by the window_5h_start field. +func ByWindow5hStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWindow5hStart, opts...).ToFunc() +} + +// ByWindow1dStart orders the results by the window_1d_start field. +func ByWindow1dStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWindow1dStart, opts...).ToFunc() +} + +// ByWindow7dStart orders the results by the window_7d_start field. +func ByWindow7dStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWindow7dStart, opts...).ToFunc() +} + // ByUserField orders the results by user field. func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/apikey/where.go b/backend/ent/apikey/where.go index c1900ee1..edd2652b 100644 --- a/backend/ent/apikey/where.go +++ b/backend/ent/apikey/where.go @@ -115,6 +115,51 @@ func ExpiresAt(v time.Time) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v)) } +// RateLimit5h applies equality check predicate on the "rate_limit_5h" field. It's identical to RateLimit5hEQ. +func RateLimit5h(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit5h, v)) +} + +// RateLimit1d applies equality check predicate on the "rate_limit_1d" field. It's identical to RateLimit1dEQ. +func RateLimit1d(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit1d, v)) +} + +// RateLimit7d applies equality check predicate on the "rate_limit_7d" field. It's identical to RateLimit7dEQ. +func RateLimit7d(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit7d, v)) +} + +// Usage5h applies equality check predicate on the "usage_5h" field. It's identical to Usage5hEQ. +func Usage5h(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage5h, v)) +} + +// Usage1d applies equality check predicate on the "usage_1d" field. It's identical to Usage1dEQ. +func Usage1d(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage1d, v)) +} + +// Usage7d applies equality check predicate on the "usage_7d" field. It's identical to Usage7dEQ. +func Usage7d(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage7d, v)) +} + +// Window5hStart applies equality check predicate on the "window_5h_start" field. It's identical to Window5hStartEQ. +func Window5hStart(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow5hStart, v)) +} + +// Window1dStart applies equality check predicate on the "window_1d_start" field. It's identical to Window1dStartEQ. +func Window1dStart(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow1dStart, v)) +} + +// Window7dStart applies equality check predicate on the "window_7d_start" field. It's identical to Window7dStartEQ. +func Window7dStart(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow7dStart, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldCreatedAt, v)) @@ -690,6 +735,396 @@ func ExpiresAtNotNil() predicate.APIKey { return predicate.APIKey(sql.FieldNotNull(FieldExpiresAt)) } +// RateLimit5hEQ applies the EQ predicate on the "rate_limit_5h" field. +func RateLimit5hEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit5h, v)) +} + +// RateLimit5hNEQ applies the NEQ predicate on the "rate_limit_5h" field. +func RateLimit5hNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldRateLimit5h, v)) +} + +// RateLimit5hIn applies the In predicate on the "rate_limit_5h" field. +func RateLimit5hIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldRateLimit5h, vs...)) +} + +// RateLimit5hNotIn applies the NotIn predicate on the "rate_limit_5h" field. +func RateLimit5hNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldRateLimit5h, vs...)) +} + +// RateLimit5hGT applies the GT predicate on the "rate_limit_5h" field. +func RateLimit5hGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldRateLimit5h, v)) +} + +// RateLimit5hGTE applies the GTE predicate on the "rate_limit_5h" field. +func RateLimit5hGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldRateLimit5h, v)) +} + +// RateLimit5hLT applies the LT predicate on the "rate_limit_5h" field. +func RateLimit5hLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldRateLimit5h, v)) +} + +// RateLimit5hLTE applies the LTE predicate on the "rate_limit_5h" field. +func RateLimit5hLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldRateLimit5h, v)) +} + +// RateLimit1dEQ applies the EQ predicate on the "rate_limit_1d" field. +func RateLimit1dEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit1d, v)) +} + +// RateLimit1dNEQ applies the NEQ predicate on the "rate_limit_1d" field. +func RateLimit1dNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldRateLimit1d, v)) +} + +// RateLimit1dIn applies the In predicate on the "rate_limit_1d" field. +func RateLimit1dIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldRateLimit1d, vs...)) +} + +// RateLimit1dNotIn applies the NotIn predicate on the "rate_limit_1d" field. +func RateLimit1dNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldRateLimit1d, vs...)) +} + +// RateLimit1dGT applies the GT predicate on the "rate_limit_1d" field. +func RateLimit1dGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldRateLimit1d, v)) +} + +// RateLimit1dGTE applies the GTE predicate on the "rate_limit_1d" field. +func RateLimit1dGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldRateLimit1d, v)) +} + +// RateLimit1dLT applies the LT predicate on the "rate_limit_1d" field. +func RateLimit1dLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldRateLimit1d, v)) +} + +// RateLimit1dLTE applies the LTE predicate on the "rate_limit_1d" field. +func RateLimit1dLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldRateLimit1d, v)) +} + +// RateLimit7dEQ applies the EQ predicate on the "rate_limit_7d" field. +func RateLimit7dEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit7d, v)) +} + +// RateLimit7dNEQ applies the NEQ predicate on the "rate_limit_7d" field. +func RateLimit7dNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldRateLimit7d, v)) +} + +// RateLimit7dIn applies the In predicate on the "rate_limit_7d" field. +func RateLimit7dIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldRateLimit7d, vs...)) +} + +// RateLimit7dNotIn applies the NotIn predicate on the "rate_limit_7d" field. +func RateLimit7dNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldRateLimit7d, vs...)) +} + +// RateLimit7dGT applies the GT predicate on the "rate_limit_7d" field. +func RateLimit7dGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldRateLimit7d, v)) +} + +// RateLimit7dGTE applies the GTE predicate on the "rate_limit_7d" field. +func RateLimit7dGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldRateLimit7d, v)) +} + +// RateLimit7dLT applies the LT predicate on the "rate_limit_7d" field. +func RateLimit7dLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldRateLimit7d, v)) +} + +// RateLimit7dLTE applies the LTE predicate on the "rate_limit_7d" field. +func RateLimit7dLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldRateLimit7d, v)) +} + +// Usage5hEQ applies the EQ predicate on the "usage_5h" field. +func Usage5hEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage5h, v)) +} + +// Usage5hNEQ applies the NEQ predicate on the "usage_5h" field. +func Usage5hNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldUsage5h, v)) +} + +// Usage5hIn applies the In predicate on the "usage_5h" field. +func Usage5hIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldUsage5h, vs...)) +} + +// Usage5hNotIn applies the NotIn predicate on the "usage_5h" field. +func Usage5hNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldUsage5h, vs...)) +} + +// Usage5hGT applies the GT predicate on the "usage_5h" field. +func Usage5hGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldUsage5h, v)) +} + +// Usage5hGTE applies the GTE predicate on the "usage_5h" field. +func Usage5hGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldUsage5h, v)) +} + +// Usage5hLT applies the LT predicate on the "usage_5h" field. +func Usage5hLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldUsage5h, v)) +} + +// Usage5hLTE applies the LTE predicate on the "usage_5h" field. +func Usage5hLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldUsage5h, v)) +} + +// Usage1dEQ applies the EQ predicate on the "usage_1d" field. +func Usage1dEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage1d, v)) +} + +// Usage1dNEQ applies the NEQ predicate on the "usage_1d" field. +func Usage1dNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldUsage1d, v)) +} + +// Usage1dIn applies the In predicate on the "usage_1d" field. +func Usage1dIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldUsage1d, vs...)) +} + +// Usage1dNotIn applies the NotIn predicate on the "usage_1d" field. +func Usage1dNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldUsage1d, vs...)) +} + +// Usage1dGT applies the GT predicate on the "usage_1d" field. +func Usage1dGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldUsage1d, v)) +} + +// Usage1dGTE applies the GTE predicate on the "usage_1d" field. +func Usage1dGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldUsage1d, v)) +} + +// Usage1dLT applies the LT predicate on the "usage_1d" field. +func Usage1dLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldUsage1d, v)) +} + +// Usage1dLTE applies the LTE predicate on the "usage_1d" field. +func Usage1dLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldUsage1d, v)) +} + +// Usage7dEQ applies the EQ predicate on the "usage_7d" field. +func Usage7dEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage7d, v)) +} + +// Usage7dNEQ applies the NEQ predicate on the "usage_7d" field. +func Usage7dNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldUsage7d, v)) +} + +// Usage7dIn applies the In predicate on the "usage_7d" field. +func Usage7dIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldUsage7d, vs...)) +} + +// Usage7dNotIn applies the NotIn predicate on the "usage_7d" field. +func Usage7dNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldUsage7d, vs...)) +} + +// Usage7dGT applies the GT predicate on the "usage_7d" field. +func Usage7dGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldUsage7d, v)) +} + +// Usage7dGTE applies the GTE predicate on the "usage_7d" field. +func Usage7dGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldUsage7d, v)) +} + +// Usage7dLT applies the LT predicate on the "usage_7d" field. +func Usage7dLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldUsage7d, v)) +} + +// Usage7dLTE applies the LTE predicate on the "usage_7d" field. +func Usage7dLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldUsage7d, v)) +} + +// Window5hStartEQ applies the EQ predicate on the "window_5h_start" field. +func Window5hStartEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow5hStart, v)) +} + +// Window5hStartNEQ applies the NEQ predicate on the "window_5h_start" field. +func Window5hStartNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldWindow5hStart, v)) +} + +// Window5hStartIn applies the In predicate on the "window_5h_start" field. +func Window5hStartIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldWindow5hStart, vs...)) +} + +// Window5hStartNotIn applies the NotIn predicate on the "window_5h_start" field. +func Window5hStartNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldWindow5hStart, vs...)) +} + +// Window5hStartGT applies the GT predicate on the "window_5h_start" field. +func Window5hStartGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldWindow5hStart, v)) +} + +// Window5hStartGTE applies the GTE predicate on the "window_5h_start" field. +func Window5hStartGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldWindow5hStart, v)) +} + +// Window5hStartLT applies the LT predicate on the "window_5h_start" field. +func Window5hStartLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldWindow5hStart, v)) +} + +// Window5hStartLTE applies the LTE predicate on the "window_5h_start" field. +func Window5hStartLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldWindow5hStart, v)) +} + +// Window5hStartIsNil applies the IsNil predicate on the "window_5h_start" field. +func Window5hStartIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldWindow5hStart)) +} + +// Window5hStartNotNil applies the NotNil predicate on the "window_5h_start" field. +func Window5hStartNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldWindow5hStart)) +} + +// Window1dStartEQ applies the EQ predicate on the "window_1d_start" field. +func Window1dStartEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow1dStart, v)) +} + +// Window1dStartNEQ applies the NEQ predicate on the "window_1d_start" field. +func Window1dStartNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldWindow1dStart, v)) +} + +// Window1dStartIn applies the In predicate on the "window_1d_start" field. +func Window1dStartIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldWindow1dStart, vs...)) +} + +// Window1dStartNotIn applies the NotIn predicate on the "window_1d_start" field. +func Window1dStartNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldWindow1dStart, vs...)) +} + +// Window1dStartGT applies the GT predicate on the "window_1d_start" field. +func Window1dStartGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldWindow1dStart, v)) +} + +// Window1dStartGTE applies the GTE predicate on the "window_1d_start" field. +func Window1dStartGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldWindow1dStart, v)) +} + +// Window1dStartLT applies the LT predicate on the "window_1d_start" field. +func Window1dStartLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldWindow1dStart, v)) +} + +// Window1dStartLTE applies the LTE predicate on the "window_1d_start" field. +func Window1dStartLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldWindow1dStart, v)) +} + +// Window1dStartIsNil applies the IsNil predicate on the "window_1d_start" field. +func Window1dStartIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldWindow1dStart)) +} + +// Window1dStartNotNil applies the NotNil predicate on the "window_1d_start" field. +func Window1dStartNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldWindow1dStart)) +} + +// Window7dStartEQ applies the EQ predicate on the "window_7d_start" field. +func Window7dStartEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow7dStart, v)) +} + +// Window7dStartNEQ applies the NEQ predicate on the "window_7d_start" field. +func Window7dStartNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldWindow7dStart, v)) +} + +// Window7dStartIn applies the In predicate on the "window_7d_start" field. +func Window7dStartIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldWindow7dStart, vs...)) +} + +// Window7dStartNotIn applies the NotIn predicate on the "window_7d_start" field. +func Window7dStartNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldWindow7dStart, vs...)) +} + +// Window7dStartGT applies the GT predicate on the "window_7d_start" field. +func Window7dStartGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldWindow7dStart, v)) +} + +// Window7dStartGTE applies the GTE predicate on the "window_7d_start" field. +func Window7dStartGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldWindow7dStart, v)) +} + +// Window7dStartLT applies the LT predicate on the "window_7d_start" field. +func Window7dStartLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldWindow7dStart, v)) +} + +// Window7dStartLTE applies the LTE predicate on the "window_7d_start" field. +func Window7dStartLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldWindow7dStart, v)) +} + +// Window7dStartIsNil applies the IsNil predicate on the "window_7d_start" field. +func Window7dStartIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldWindow7dStart)) +} + +// Window7dStartNotNil applies the NotNil predicate on the "window_7d_start" field. +func Window7dStartNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldWindow7dStart)) +} + // HasUser applies the HasEdge predicate on the "user" edge. func HasUser() predicate.APIKey { return predicate.APIKey(func(s *sql.Selector) { diff --git a/backend/ent/apikey_create.go b/backend/ent/apikey_create.go index bc506585..4ec8aeaa 100644 --- a/backend/ent/apikey_create.go +++ b/backend/ent/apikey_create.go @@ -181,6 +181,132 @@ func (_c *APIKeyCreate) SetNillableExpiresAt(v *time.Time) *APIKeyCreate { return _c } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (_c *APIKeyCreate) SetRateLimit5h(v float64) *APIKeyCreate { + _c.mutation.SetRateLimit5h(v) + return _c +} + +// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableRateLimit5h(v *float64) *APIKeyCreate { + if v != nil { + _c.SetRateLimit5h(*v) + } + return _c +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (_c *APIKeyCreate) SetRateLimit1d(v float64) *APIKeyCreate { + _c.mutation.SetRateLimit1d(v) + return _c +} + +// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableRateLimit1d(v *float64) *APIKeyCreate { + if v != nil { + _c.SetRateLimit1d(*v) + } + return _c +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (_c *APIKeyCreate) SetRateLimit7d(v float64) *APIKeyCreate { + _c.mutation.SetRateLimit7d(v) + return _c +} + +// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableRateLimit7d(v *float64) *APIKeyCreate { + if v != nil { + _c.SetRateLimit7d(*v) + } + return _c +} + +// SetUsage5h sets the "usage_5h" field. +func (_c *APIKeyCreate) SetUsage5h(v float64) *APIKeyCreate { + _c.mutation.SetUsage5h(v) + return _c +} + +// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableUsage5h(v *float64) *APIKeyCreate { + if v != nil { + _c.SetUsage5h(*v) + } + return _c +} + +// SetUsage1d sets the "usage_1d" field. +func (_c *APIKeyCreate) SetUsage1d(v float64) *APIKeyCreate { + _c.mutation.SetUsage1d(v) + return _c +} + +// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableUsage1d(v *float64) *APIKeyCreate { + if v != nil { + _c.SetUsage1d(*v) + } + return _c +} + +// SetUsage7d sets the "usage_7d" field. +func (_c *APIKeyCreate) SetUsage7d(v float64) *APIKeyCreate { + _c.mutation.SetUsage7d(v) + return _c +} + +// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableUsage7d(v *float64) *APIKeyCreate { + if v != nil { + _c.SetUsage7d(*v) + } + return _c +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (_c *APIKeyCreate) SetWindow5hStart(v time.Time) *APIKeyCreate { + _c.mutation.SetWindow5hStart(v) + return _c +} + +// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableWindow5hStart(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetWindow5hStart(*v) + } + return _c +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (_c *APIKeyCreate) SetWindow1dStart(v time.Time) *APIKeyCreate { + _c.mutation.SetWindow1dStart(v) + return _c +} + +// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableWindow1dStart(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetWindow1dStart(*v) + } + return _c +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (_c *APIKeyCreate) SetWindow7dStart(v time.Time) *APIKeyCreate { + _c.mutation.SetWindow7dStart(v) + return _c +} + +// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableWindow7dStart(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetWindow7dStart(*v) + } + return _c +} + // SetUser sets the "user" edge to the User entity. func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate { return _c.SetUserID(v.ID) @@ -269,6 +395,30 @@ func (_c *APIKeyCreate) defaults() error { v := apikey.DefaultQuotaUsed _c.mutation.SetQuotaUsed(v) } + if _, ok := _c.mutation.RateLimit5h(); !ok { + v := apikey.DefaultRateLimit5h + _c.mutation.SetRateLimit5h(v) + } + if _, ok := _c.mutation.RateLimit1d(); !ok { + v := apikey.DefaultRateLimit1d + _c.mutation.SetRateLimit1d(v) + } + if _, ok := _c.mutation.RateLimit7d(); !ok { + v := apikey.DefaultRateLimit7d + _c.mutation.SetRateLimit7d(v) + } + if _, ok := _c.mutation.Usage5h(); !ok { + v := apikey.DefaultUsage5h + _c.mutation.SetUsage5h(v) + } + if _, ok := _c.mutation.Usage1d(); !ok { + v := apikey.DefaultUsage1d + _c.mutation.SetUsage1d(v) + } + if _, ok := _c.mutation.Usage7d(); !ok { + v := apikey.DefaultUsage7d + _c.mutation.SetUsage7d(v) + } return nil } @@ -313,6 +463,24 @@ func (_c *APIKeyCreate) check() error { if _, ok := _c.mutation.QuotaUsed(); !ok { return &ValidationError{Name: "quota_used", err: errors.New(`ent: missing required field "APIKey.quota_used"`)} } + if _, ok := _c.mutation.RateLimit5h(); !ok { + return &ValidationError{Name: "rate_limit_5h", err: errors.New(`ent: missing required field "APIKey.rate_limit_5h"`)} + } + if _, ok := _c.mutation.RateLimit1d(); !ok { + return &ValidationError{Name: "rate_limit_1d", err: errors.New(`ent: missing required field "APIKey.rate_limit_1d"`)} + } + if _, ok := _c.mutation.RateLimit7d(); !ok { + return &ValidationError{Name: "rate_limit_7d", err: errors.New(`ent: missing required field "APIKey.rate_limit_7d"`)} + } + if _, ok := _c.mutation.Usage5h(); !ok { + return &ValidationError{Name: "usage_5h", err: errors.New(`ent: missing required field "APIKey.usage_5h"`)} + } + if _, ok := _c.mutation.Usage1d(); !ok { + return &ValidationError{Name: "usage_1d", err: errors.New(`ent: missing required field "APIKey.usage_1d"`)} + } + if _, ok := _c.mutation.Usage7d(); !ok { + return &ValidationError{Name: "usage_7d", err: errors.New(`ent: missing required field "APIKey.usage_7d"`)} + } if len(_c.mutation.UserIDs()) == 0 { return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "APIKey.user"`)} } @@ -391,6 +559,42 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) { _spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value) _node.ExpiresAt = &value } + if value, ok := _c.mutation.RateLimit5h(); ok { + _spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + _node.RateLimit5h = value + } + if value, ok := _c.mutation.RateLimit1d(); ok { + _spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + _node.RateLimit1d = value + } + if value, ok := _c.mutation.RateLimit7d(); ok { + _spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + _node.RateLimit7d = value + } + if value, ok := _c.mutation.Usage5h(); ok { + _spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value) + _node.Usage5h = value + } + if value, ok := _c.mutation.Usage1d(); ok { + _spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value) + _node.Usage1d = value + } + if value, ok := _c.mutation.Usage7d(); ok { + _spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value) + _node.Usage7d = value + } + if value, ok := _c.mutation.Window5hStart(); ok { + _spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value) + _node.Window5hStart = &value + } + if value, ok := _c.mutation.Window1dStart(); ok { + _spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value) + _node.Window1dStart = &value + } + if value, ok := _c.mutation.Window7dStart(); ok { + _spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value) + _node.Window7dStart = &value + } if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -697,6 +901,168 @@ func (u *APIKeyUpsert) ClearExpiresAt() *APIKeyUpsert { return u } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (u *APIKeyUpsert) SetRateLimit5h(v float64) *APIKeyUpsert { + u.Set(apikey.FieldRateLimit5h, v) + return u +} + +// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateRateLimit5h() *APIKeyUpsert { + u.SetExcluded(apikey.FieldRateLimit5h) + return u +} + +// AddRateLimit5h adds v to the "rate_limit_5h" field. +func (u *APIKeyUpsert) AddRateLimit5h(v float64) *APIKeyUpsert { + u.Add(apikey.FieldRateLimit5h, v) + return u +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (u *APIKeyUpsert) SetRateLimit1d(v float64) *APIKeyUpsert { + u.Set(apikey.FieldRateLimit1d, v) + return u +} + +// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateRateLimit1d() *APIKeyUpsert { + u.SetExcluded(apikey.FieldRateLimit1d) + return u +} + +// AddRateLimit1d adds v to the "rate_limit_1d" field. +func (u *APIKeyUpsert) AddRateLimit1d(v float64) *APIKeyUpsert { + u.Add(apikey.FieldRateLimit1d, v) + return u +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (u *APIKeyUpsert) SetRateLimit7d(v float64) *APIKeyUpsert { + u.Set(apikey.FieldRateLimit7d, v) + return u +} + +// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateRateLimit7d() *APIKeyUpsert { + u.SetExcluded(apikey.FieldRateLimit7d) + return u +} + +// AddRateLimit7d adds v to the "rate_limit_7d" field. +func (u *APIKeyUpsert) AddRateLimit7d(v float64) *APIKeyUpsert { + u.Add(apikey.FieldRateLimit7d, v) + return u +} + +// SetUsage5h sets the "usage_5h" field. +func (u *APIKeyUpsert) SetUsage5h(v float64) *APIKeyUpsert { + u.Set(apikey.FieldUsage5h, v) + return u +} + +// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateUsage5h() *APIKeyUpsert { + u.SetExcluded(apikey.FieldUsage5h) + return u +} + +// AddUsage5h adds v to the "usage_5h" field. +func (u *APIKeyUpsert) AddUsage5h(v float64) *APIKeyUpsert { + u.Add(apikey.FieldUsage5h, v) + return u +} + +// SetUsage1d sets the "usage_1d" field. +func (u *APIKeyUpsert) SetUsage1d(v float64) *APIKeyUpsert { + u.Set(apikey.FieldUsage1d, v) + return u +} + +// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateUsage1d() *APIKeyUpsert { + u.SetExcluded(apikey.FieldUsage1d) + return u +} + +// AddUsage1d adds v to the "usage_1d" field. +func (u *APIKeyUpsert) AddUsage1d(v float64) *APIKeyUpsert { + u.Add(apikey.FieldUsage1d, v) + return u +} + +// SetUsage7d sets the "usage_7d" field. +func (u *APIKeyUpsert) SetUsage7d(v float64) *APIKeyUpsert { + u.Set(apikey.FieldUsage7d, v) + return u +} + +// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateUsage7d() *APIKeyUpsert { + u.SetExcluded(apikey.FieldUsage7d) + return u +} + +// AddUsage7d adds v to the "usage_7d" field. +func (u *APIKeyUpsert) AddUsage7d(v float64) *APIKeyUpsert { + u.Add(apikey.FieldUsage7d, v) + return u +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (u *APIKeyUpsert) SetWindow5hStart(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldWindow5hStart, v) + return u +} + +// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateWindow5hStart() *APIKeyUpsert { + u.SetExcluded(apikey.FieldWindow5hStart) + return u +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (u *APIKeyUpsert) ClearWindow5hStart() *APIKeyUpsert { + u.SetNull(apikey.FieldWindow5hStart) + return u +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (u *APIKeyUpsert) SetWindow1dStart(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldWindow1dStart, v) + return u +} + +// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateWindow1dStart() *APIKeyUpsert { + u.SetExcluded(apikey.FieldWindow1dStart) + return u +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (u *APIKeyUpsert) ClearWindow1dStart() *APIKeyUpsert { + u.SetNull(apikey.FieldWindow1dStart) + return u +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (u *APIKeyUpsert) SetWindow7dStart(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldWindow7dStart, v) + return u +} + +// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateWindow7dStart() *APIKeyUpsert { + u.SetExcluded(apikey.FieldWindow7dStart) + return u +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (u *APIKeyUpsert) ClearWindow7dStart() *APIKeyUpsert { + u.SetNull(apikey.FieldWindow7dStart) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -980,6 +1346,195 @@ func (u *APIKeyUpsertOne) ClearExpiresAt() *APIKeyUpsertOne { }) } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (u *APIKeyUpsertOne) SetRateLimit5h(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit5h(v) + }) +} + +// AddRateLimit5h adds v to the "rate_limit_5h" field. +func (u *APIKeyUpsertOne) AddRateLimit5h(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit5h(v) + }) +} + +// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateRateLimit5h() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit5h() + }) +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (u *APIKeyUpsertOne) SetRateLimit1d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit1d(v) + }) +} + +// AddRateLimit1d adds v to the "rate_limit_1d" field. +func (u *APIKeyUpsertOne) AddRateLimit1d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit1d(v) + }) +} + +// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateRateLimit1d() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit1d() + }) +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (u *APIKeyUpsertOne) SetRateLimit7d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit7d(v) + }) +} + +// AddRateLimit7d adds v to the "rate_limit_7d" field. +func (u *APIKeyUpsertOne) AddRateLimit7d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit7d(v) + }) +} + +// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateRateLimit7d() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit7d() + }) +} + +// SetUsage5h sets the "usage_5h" field. +func (u *APIKeyUpsertOne) SetUsage5h(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage5h(v) + }) +} + +// AddUsage5h adds v to the "usage_5h" field. +func (u *APIKeyUpsertOne) AddUsage5h(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage5h(v) + }) +} + +// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateUsage5h() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage5h() + }) +} + +// SetUsage1d sets the "usage_1d" field. +func (u *APIKeyUpsertOne) SetUsage1d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage1d(v) + }) +} + +// AddUsage1d adds v to the "usage_1d" field. +func (u *APIKeyUpsertOne) AddUsage1d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage1d(v) + }) +} + +// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateUsage1d() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage1d() + }) +} + +// SetUsage7d sets the "usage_7d" field. +func (u *APIKeyUpsertOne) SetUsage7d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage7d(v) + }) +} + +// AddUsage7d adds v to the "usage_7d" field. +func (u *APIKeyUpsertOne) AddUsage7d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage7d(v) + }) +} + +// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateUsage7d() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage7d() + }) +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (u *APIKeyUpsertOne) SetWindow5hStart(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow5hStart(v) + }) +} + +// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateWindow5hStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow5hStart() + }) +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (u *APIKeyUpsertOne) ClearWindow5hStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow5hStart() + }) +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (u *APIKeyUpsertOne) SetWindow1dStart(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow1dStart(v) + }) +} + +// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateWindow1dStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow1dStart() + }) +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (u *APIKeyUpsertOne) ClearWindow1dStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow1dStart() + }) +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (u *APIKeyUpsertOne) SetWindow7dStart(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow7dStart(v) + }) +} + +// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateWindow7dStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow7dStart() + }) +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (u *APIKeyUpsertOne) ClearWindow7dStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow7dStart() + }) +} + // Exec executes the query. func (u *APIKeyUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1429,6 +1984,195 @@ func (u *APIKeyUpsertBulk) ClearExpiresAt() *APIKeyUpsertBulk { }) } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (u *APIKeyUpsertBulk) SetRateLimit5h(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit5h(v) + }) +} + +// AddRateLimit5h adds v to the "rate_limit_5h" field. +func (u *APIKeyUpsertBulk) AddRateLimit5h(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit5h(v) + }) +} + +// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateRateLimit5h() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit5h() + }) +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (u *APIKeyUpsertBulk) SetRateLimit1d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit1d(v) + }) +} + +// AddRateLimit1d adds v to the "rate_limit_1d" field. +func (u *APIKeyUpsertBulk) AddRateLimit1d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit1d(v) + }) +} + +// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateRateLimit1d() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit1d() + }) +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (u *APIKeyUpsertBulk) SetRateLimit7d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit7d(v) + }) +} + +// AddRateLimit7d adds v to the "rate_limit_7d" field. +func (u *APIKeyUpsertBulk) AddRateLimit7d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit7d(v) + }) +} + +// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateRateLimit7d() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit7d() + }) +} + +// SetUsage5h sets the "usage_5h" field. +func (u *APIKeyUpsertBulk) SetUsage5h(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage5h(v) + }) +} + +// AddUsage5h adds v to the "usage_5h" field. +func (u *APIKeyUpsertBulk) AddUsage5h(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage5h(v) + }) +} + +// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateUsage5h() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage5h() + }) +} + +// SetUsage1d sets the "usage_1d" field. +func (u *APIKeyUpsertBulk) SetUsage1d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage1d(v) + }) +} + +// AddUsage1d adds v to the "usage_1d" field. +func (u *APIKeyUpsertBulk) AddUsage1d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage1d(v) + }) +} + +// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateUsage1d() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage1d() + }) +} + +// SetUsage7d sets the "usage_7d" field. +func (u *APIKeyUpsertBulk) SetUsage7d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage7d(v) + }) +} + +// AddUsage7d adds v to the "usage_7d" field. +func (u *APIKeyUpsertBulk) AddUsage7d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage7d(v) + }) +} + +// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateUsage7d() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage7d() + }) +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (u *APIKeyUpsertBulk) SetWindow5hStart(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow5hStart(v) + }) +} + +// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateWindow5hStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow5hStart() + }) +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (u *APIKeyUpsertBulk) ClearWindow5hStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow5hStart() + }) +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (u *APIKeyUpsertBulk) SetWindow1dStart(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow1dStart(v) + }) +} + +// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateWindow1dStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow1dStart() + }) +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (u *APIKeyUpsertBulk) ClearWindow1dStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow1dStart() + }) +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (u *APIKeyUpsertBulk) SetWindow7dStart(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow7dStart(v) + }) +} + +// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateWindow7dStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow7dStart() + }) +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (u *APIKeyUpsertBulk) ClearWindow7dStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow7dStart() + }) +} + // Exec executes the query. func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/apikey_update.go b/backend/ent/apikey_update.go index 6ca01854..db341e4c 100644 --- a/backend/ent/apikey_update.go +++ b/backend/ent/apikey_update.go @@ -252,6 +252,192 @@ func (_u *APIKeyUpdate) ClearExpiresAt() *APIKeyUpdate { return _u } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (_u *APIKeyUpdate) SetRateLimit5h(v float64) *APIKeyUpdate { + _u.mutation.ResetRateLimit5h() + _u.mutation.SetRateLimit5h(v) + return _u +} + +// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableRateLimit5h(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetRateLimit5h(*v) + } + return _u +} + +// AddRateLimit5h adds value to the "rate_limit_5h" field. +func (_u *APIKeyUpdate) AddRateLimit5h(v float64) *APIKeyUpdate { + _u.mutation.AddRateLimit5h(v) + return _u +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (_u *APIKeyUpdate) SetRateLimit1d(v float64) *APIKeyUpdate { + _u.mutation.ResetRateLimit1d() + _u.mutation.SetRateLimit1d(v) + return _u +} + +// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableRateLimit1d(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetRateLimit1d(*v) + } + return _u +} + +// AddRateLimit1d adds value to the "rate_limit_1d" field. +func (_u *APIKeyUpdate) AddRateLimit1d(v float64) *APIKeyUpdate { + _u.mutation.AddRateLimit1d(v) + return _u +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (_u *APIKeyUpdate) SetRateLimit7d(v float64) *APIKeyUpdate { + _u.mutation.ResetRateLimit7d() + _u.mutation.SetRateLimit7d(v) + return _u +} + +// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableRateLimit7d(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetRateLimit7d(*v) + } + return _u +} + +// AddRateLimit7d adds value to the "rate_limit_7d" field. +func (_u *APIKeyUpdate) AddRateLimit7d(v float64) *APIKeyUpdate { + _u.mutation.AddRateLimit7d(v) + return _u +} + +// SetUsage5h sets the "usage_5h" field. +func (_u *APIKeyUpdate) SetUsage5h(v float64) *APIKeyUpdate { + _u.mutation.ResetUsage5h() + _u.mutation.SetUsage5h(v) + return _u +} + +// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableUsage5h(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetUsage5h(*v) + } + return _u +} + +// AddUsage5h adds value to the "usage_5h" field. +func (_u *APIKeyUpdate) AddUsage5h(v float64) *APIKeyUpdate { + _u.mutation.AddUsage5h(v) + return _u +} + +// SetUsage1d sets the "usage_1d" field. +func (_u *APIKeyUpdate) SetUsage1d(v float64) *APIKeyUpdate { + _u.mutation.ResetUsage1d() + _u.mutation.SetUsage1d(v) + return _u +} + +// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableUsage1d(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetUsage1d(*v) + } + return _u +} + +// AddUsage1d adds value to the "usage_1d" field. +func (_u *APIKeyUpdate) AddUsage1d(v float64) *APIKeyUpdate { + _u.mutation.AddUsage1d(v) + return _u +} + +// SetUsage7d sets the "usage_7d" field. +func (_u *APIKeyUpdate) SetUsage7d(v float64) *APIKeyUpdate { + _u.mutation.ResetUsage7d() + _u.mutation.SetUsage7d(v) + return _u +} + +// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableUsage7d(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetUsage7d(*v) + } + return _u +} + +// AddUsage7d adds value to the "usage_7d" field. +func (_u *APIKeyUpdate) AddUsage7d(v float64) *APIKeyUpdate { + _u.mutation.AddUsage7d(v) + return _u +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (_u *APIKeyUpdate) SetWindow5hStart(v time.Time) *APIKeyUpdate { + _u.mutation.SetWindow5hStart(v) + return _u +} + +// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableWindow5hStart(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetWindow5hStart(*v) + } + return _u +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (_u *APIKeyUpdate) ClearWindow5hStart() *APIKeyUpdate { + _u.mutation.ClearWindow5hStart() + return _u +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (_u *APIKeyUpdate) SetWindow1dStart(v time.Time) *APIKeyUpdate { + _u.mutation.SetWindow1dStart(v) + return _u +} + +// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableWindow1dStart(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetWindow1dStart(*v) + } + return _u +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (_u *APIKeyUpdate) ClearWindow1dStart() *APIKeyUpdate { + _u.mutation.ClearWindow1dStart() + return _u +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (_u *APIKeyUpdate) SetWindow7dStart(v time.Time) *APIKeyUpdate { + _u.mutation.SetWindow7dStart(v) + return _u +} + +// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableWindow7dStart(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetWindow7dStart(*v) + } + return _u +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (_u *APIKeyUpdate) ClearWindow7dStart() *APIKeyUpdate { + _u.mutation.ClearWindow7dStart() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate { return _u.SetUserID(v.ID) @@ -456,6 +642,60 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ExpiresAtCleared() { _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) } + if value, ok := _u.mutation.RateLimit5h(); ok { + _spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit5h(); ok { + _spec.AddField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateLimit1d(); ok { + _spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit1d(); ok { + _spec.AddField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateLimit7d(); ok { + _spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit7d(); ok { + _spec.AddField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage5h(); ok { + _spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage5h(); ok { + _spec.AddField(apikey.FieldUsage5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage1d(); ok { + _spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage1d(); ok { + _spec.AddField(apikey.FieldUsage1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage7d(); ok { + _spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage7d(); ok { + _spec.AddField(apikey.FieldUsage7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Window5hStart(); ok { + _spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value) + } + if _u.mutation.Window5hStartCleared() { + _spec.ClearField(apikey.FieldWindow5hStart, field.TypeTime) + } + if value, ok := _u.mutation.Window1dStart(); ok { + _spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value) + } + if _u.mutation.Window1dStartCleared() { + _spec.ClearField(apikey.FieldWindow1dStart, field.TypeTime) + } + if value, ok := _u.mutation.Window7dStart(); ok { + _spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value) + } + if _u.mutation.Window7dStartCleared() { + _spec.ClearField(apikey.FieldWindow7dStart, field.TypeTime) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -799,6 +1039,192 @@ func (_u *APIKeyUpdateOne) ClearExpiresAt() *APIKeyUpdateOne { return _u } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (_u *APIKeyUpdateOne) SetRateLimit5h(v float64) *APIKeyUpdateOne { + _u.mutation.ResetRateLimit5h() + _u.mutation.SetRateLimit5h(v) + return _u +} + +// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableRateLimit5h(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetRateLimit5h(*v) + } + return _u +} + +// AddRateLimit5h adds value to the "rate_limit_5h" field. +func (_u *APIKeyUpdateOne) AddRateLimit5h(v float64) *APIKeyUpdateOne { + _u.mutation.AddRateLimit5h(v) + return _u +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (_u *APIKeyUpdateOne) SetRateLimit1d(v float64) *APIKeyUpdateOne { + _u.mutation.ResetRateLimit1d() + _u.mutation.SetRateLimit1d(v) + return _u +} + +// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableRateLimit1d(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetRateLimit1d(*v) + } + return _u +} + +// AddRateLimit1d adds value to the "rate_limit_1d" field. +func (_u *APIKeyUpdateOne) AddRateLimit1d(v float64) *APIKeyUpdateOne { + _u.mutation.AddRateLimit1d(v) + return _u +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (_u *APIKeyUpdateOne) SetRateLimit7d(v float64) *APIKeyUpdateOne { + _u.mutation.ResetRateLimit7d() + _u.mutation.SetRateLimit7d(v) + return _u +} + +// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableRateLimit7d(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetRateLimit7d(*v) + } + return _u +} + +// AddRateLimit7d adds value to the "rate_limit_7d" field. +func (_u *APIKeyUpdateOne) AddRateLimit7d(v float64) *APIKeyUpdateOne { + _u.mutation.AddRateLimit7d(v) + return _u +} + +// SetUsage5h sets the "usage_5h" field. +func (_u *APIKeyUpdateOne) SetUsage5h(v float64) *APIKeyUpdateOne { + _u.mutation.ResetUsage5h() + _u.mutation.SetUsage5h(v) + return _u +} + +// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableUsage5h(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetUsage5h(*v) + } + return _u +} + +// AddUsage5h adds value to the "usage_5h" field. +func (_u *APIKeyUpdateOne) AddUsage5h(v float64) *APIKeyUpdateOne { + _u.mutation.AddUsage5h(v) + return _u +} + +// SetUsage1d sets the "usage_1d" field. +func (_u *APIKeyUpdateOne) SetUsage1d(v float64) *APIKeyUpdateOne { + _u.mutation.ResetUsage1d() + _u.mutation.SetUsage1d(v) + return _u +} + +// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableUsage1d(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetUsage1d(*v) + } + return _u +} + +// AddUsage1d adds value to the "usage_1d" field. +func (_u *APIKeyUpdateOne) AddUsage1d(v float64) *APIKeyUpdateOne { + _u.mutation.AddUsage1d(v) + return _u +} + +// SetUsage7d sets the "usage_7d" field. +func (_u *APIKeyUpdateOne) SetUsage7d(v float64) *APIKeyUpdateOne { + _u.mutation.ResetUsage7d() + _u.mutation.SetUsage7d(v) + return _u +} + +// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableUsage7d(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetUsage7d(*v) + } + return _u +} + +// AddUsage7d adds value to the "usage_7d" field. +func (_u *APIKeyUpdateOne) AddUsage7d(v float64) *APIKeyUpdateOne { + _u.mutation.AddUsage7d(v) + return _u +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (_u *APIKeyUpdateOne) SetWindow5hStart(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetWindow5hStart(v) + return _u +} + +// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableWindow5hStart(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetWindow5hStart(*v) + } + return _u +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (_u *APIKeyUpdateOne) ClearWindow5hStart() *APIKeyUpdateOne { + _u.mutation.ClearWindow5hStart() + return _u +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (_u *APIKeyUpdateOne) SetWindow1dStart(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetWindow1dStart(v) + return _u +} + +// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableWindow1dStart(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetWindow1dStart(*v) + } + return _u +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (_u *APIKeyUpdateOne) ClearWindow1dStart() *APIKeyUpdateOne { + _u.mutation.ClearWindow1dStart() + return _u +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (_u *APIKeyUpdateOne) SetWindow7dStart(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetWindow7dStart(v) + return _u +} + +// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableWindow7dStart(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetWindow7dStart(*v) + } + return _u +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (_u *APIKeyUpdateOne) ClearWindow7dStart() *APIKeyUpdateOne { + _u.mutation.ClearWindow7dStart() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne { return _u.SetUserID(v.ID) @@ -1033,6 +1459,60 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro if _u.mutation.ExpiresAtCleared() { _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) } + if value, ok := _u.mutation.RateLimit5h(); ok { + _spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit5h(); ok { + _spec.AddField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateLimit1d(); ok { + _spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit1d(); ok { + _spec.AddField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateLimit7d(); ok { + _spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit7d(); ok { + _spec.AddField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage5h(); ok { + _spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage5h(); ok { + _spec.AddField(apikey.FieldUsage5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage1d(); ok { + _spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage1d(); ok { + _spec.AddField(apikey.FieldUsage1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage7d(); ok { + _spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage7d(); ok { + _spec.AddField(apikey.FieldUsage7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Window5hStart(); ok { + _spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value) + } + if _u.mutation.Window5hStartCleared() { + _spec.ClearField(apikey.FieldWindow5hStart, field.TypeTime) + } + if value, ok := _u.mutation.Window1dStart(); ok { + _spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value) + } + if _u.mutation.Window1dStartCleared() { + _spec.ClearField(apikey.FieldWindow1dStart, field.TypeTime) + } + if value, ok := _u.mutation.Window7dStart(); ok { + _spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value) + } + if _u.mutation.Window7dStartCleared() { + _spec.ClearField(apikey.FieldWindow7dStart, field.TypeTime) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 769dddce..85e94072 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -24,6 +24,15 @@ var ( {Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "quota_used", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "expires_at", Type: field.TypeTime, Nullable: true}, + {Name: "rate_limit_5h", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "rate_limit_1d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "rate_limit_7d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "usage_5h", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "usage_1d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "usage_7d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "window_5h_start", Type: field.TypeTime, Nullable: true}, + {Name: "window_1d_start", Type: field.TypeTime, Nullable: true}, + {Name: "window_7d_start", Type: field.TypeTime, Nullable: true}, {Name: "group_id", Type: field.TypeInt64, Nullable: true}, {Name: "user_id", Type: field.TypeInt64}, } @@ -35,13 +44,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "api_keys_groups_api_keys", - Columns: []*schema.Column{APIKeysColumns[13]}, + Columns: []*schema.Column{APIKeysColumns[22]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "api_keys_users_api_keys", - Columns: []*schema.Column{APIKeysColumns[14]}, + Columns: []*schema.Column{APIKeysColumns[23]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -50,12 +59,12 @@ var ( { Name: "apikey_user_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[14]}, + Columns: []*schema.Column{APIKeysColumns[23]}, }, { Name: "apikey_group_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[13]}, + Columns: []*schema.Column{APIKeysColumns[22]}, }, { Name: "apikey_status", diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 823cd389..85e2ea71 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -91,6 +91,21 @@ type APIKeyMutation struct { quota_used *float64 addquota_used *float64 expires_at *time.Time + rate_limit_5h *float64 + addrate_limit_5h *float64 + rate_limit_1d *float64 + addrate_limit_1d *float64 + rate_limit_7d *float64 + addrate_limit_7d *float64 + usage_5h *float64 + addusage_5h *float64 + usage_1d *float64 + addusage_1d *float64 + usage_7d *float64 + addusage_7d *float64 + window_5h_start *time.Time + window_1d_start *time.Time + window_7d_start *time.Time clearedFields map[string]struct{} user *int64 cleareduser bool @@ -856,6 +871,489 @@ func (m *APIKeyMutation) ResetExpiresAt() { delete(m.clearedFields, apikey.FieldExpiresAt) } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (m *APIKeyMutation) SetRateLimit5h(f float64) { + m.rate_limit_5h = &f + m.addrate_limit_5h = nil +} + +// RateLimit5h returns the value of the "rate_limit_5h" field in the mutation. +func (m *APIKeyMutation) RateLimit5h() (r float64, exists bool) { + v := m.rate_limit_5h + if v == nil { + return + } + return *v, true +} + +// OldRateLimit5h returns the old "rate_limit_5h" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldRateLimit5h(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateLimit5h is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateLimit5h requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateLimit5h: %w", err) + } + return oldValue.RateLimit5h, nil +} + +// AddRateLimit5h adds f to the "rate_limit_5h" field. +func (m *APIKeyMutation) AddRateLimit5h(f float64) { + if m.addrate_limit_5h != nil { + *m.addrate_limit_5h += f + } else { + m.addrate_limit_5h = &f + } +} + +// AddedRateLimit5h returns the value that was added to the "rate_limit_5h" field in this mutation. +func (m *APIKeyMutation) AddedRateLimit5h() (r float64, exists bool) { + v := m.addrate_limit_5h + if v == nil { + return + } + return *v, true +} + +// ResetRateLimit5h resets all changes to the "rate_limit_5h" field. +func (m *APIKeyMutation) ResetRateLimit5h() { + m.rate_limit_5h = nil + m.addrate_limit_5h = nil +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (m *APIKeyMutation) SetRateLimit1d(f float64) { + m.rate_limit_1d = &f + m.addrate_limit_1d = nil +} + +// RateLimit1d returns the value of the "rate_limit_1d" field in the mutation. +func (m *APIKeyMutation) RateLimit1d() (r float64, exists bool) { + v := m.rate_limit_1d + if v == nil { + return + } + return *v, true +} + +// OldRateLimit1d returns the old "rate_limit_1d" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldRateLimit1d(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateLimit1d is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateLimit1d requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateLimit1d: %w", err) + } + return oldValue.RateLimit1d, nil +} + +// AddRateLimit1d adds f to the "rate_limit_1d" field. +func (m *APIKeyMutation) AddRateLimit1d(f float64) { + if m.addrate_limit_1d != nil { + *m.addrate_limit_1d += f + } else { + m.addrate_limit_1d = &f + } +} + +// AddedRateLimit1d returns the value that was added to the "rate_limit_1d" field in this mutation. +func (m *APIKeyMutation) AddedRateLimit1d() (r float64, exists bool) { + v := m.addrate_limit_1d + if v == nil { + return + } + return *v, true +} + +// ResetRateLimit1d resets all changes to the "rate_limit_1d" field. +func (m *APIKeyMutation) ResetRateLimit1d() { + m.rate_limit_1d = nil + m.addrate_limit_1d = nil +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (m *APIKeyMutation) SetRateLimit7d(f float64) { + m.rate_limit_7d = &f + m.addrate_limit_7d = nil +} + +// RateLimit7d returns the value of the "rate_limit_7d" field in the mutation. +func (m *APIKeyMutation) RateLimit7d() (r float64, exists bool) { + v := m.rate_limit_7d + if v == nil { + return + } + return *v, true +} + +// OldRateLimit7d returns the old "rate_limit_7d" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldRateLimit7d(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateLimit7d is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateLimit7d requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateLimit7d: %w", err) + } + return oldValue.RateLimit7d, nil +} + +// AddRateLimit7d adds f to the "rate_limit_7d" field. +func (m *APIKeyMutation) AddRateLimit7d(f float64) { + if m.addrate_limit_7d != nil { + *m.addrate_limit_7d += f + } else { + m.addrate_limit_7d = &f + } +} + +// AddedRateLimit7d returns the value that was added to the "rate_limit_7d" field in this mutation. +func (m *APIKeyMutation) AddedRateLimit7d() (r float64, exists bool) { + v := m.addrate_limit_7d + if v == nil { + return + } + return *v, true +} + +// ResetRateLimit7d resets all changes to the "rate_limit_7d" field. +func (m *APIKeyMutation) ResetRateLimit7d() { + m.rate_limit_7d = nil + m.addrate_limit_7d = nil +} + +// SetUsage5h sets the "usage_5h" field. +func (m *APIKeyMutation) SetUsage5h(f float64) { + m.usage_5h = &f + m.addusage_5h = nil +} + +// Usage5h returns the value of the "usage_5h" field in the mutation. +func (m *APIKeyMutation) Usage5h() (r float64, exists bool) { + v := m.usage_5h + if v == nil { + return + } + return *v, true +} + +// OldUsage5h returns the old "usage_5h" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldUsage5h(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUsage5h is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUsage5h requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUsage5h: %w", err) + } + return oldValue.Usage5h, nil +} + +// AddUsage5h adds f to the "usage_5h" field. +func (m *APIKeyMutation) AddUsage5h(f float64) { + if m.addusage_5h != nil { + *m.addusage_5h += f + } else { + m.addusage_5h = &f + } +} + +// AddedUsage5h returns the value that was added to the "usage_5h" field in this mutation. +func (m *APIKeyMutation) AddedUsage5h() (r float64, exists bool) { + v := m.addusage_5h + if v == nil { + return + } + return *v, true +} + +// ResetUsage5h resets all changes to the "usage_5h" field. +func (m *APIKeyMutation) ResetUsage5h() { + m.usage_5h = nil + m.addusage_5h = nil +} + +// SetUsage1d sets the "usage_1d" field. +func (m *APIKeyMutation) SetUsage1d(f float64) { + m.usage_1d = &f + m.addusage_1d = nil +} + +// Usage1d returns the value of the "usage_1d" field in the mutation. +func (m *APIKeyMutation) Usage1d() (r float64, exists bool) { + v := m.usage_1d + if v == nil { + return + } + return *v, true +} + +// OldUsage1d returns the old "usage_1d" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldUsage1d(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUsage1d is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUsage1d requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUsage1d: %w", err) + } + return oldValue.Usage1d, nil +} + +// AddUsage1d adds f to the "usage_1d" field. +func (m *APIKeyMutation) AddUsage1d(f float64) { + if m.addusage_1d != nil { + *m.addusage_1d += f + } else { + m.addusage_1d = &f + } +} + +// AddedUsage1d returns the value that was added to the "usage_1d" field in this mutation. +func (m *APIKeyMutation) AddedUsage1d() (r float64, exists bool) { + v := m.addusage_1d + if v == nil { + return + } + return *v, true +} + +// ResetUsage1d resets all changes to the "usage_1d" field. +func (m *APIKeyMutation) ResetUsage1d() { + m.usage_1d = nil + m.addusage_1d = nil +} + +// SetUsage7d sets the "usage_7d" field. +func (m *APIKeyMutation) SetUsage7d(f float64) { + m.usage_7d = &f + m.addusage_7d = nil +} + +// Usage7d returns the value of the "usage_7d" field in the mutation. +func (m *APIKeyMutation) Usage7d() (r float64, exists bool) { + v := m.usage_7d + if v == nil { + return + } + return *v, true +} + +// OldUsage7d returns the old "usage_7d" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldUsage7d(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUsage7d is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUsage7d requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUsage7d: %w", err) + } + return oldValue.Usage7d, nil +} + +// AddUsage7d adds f to the "usage_7d" field. +func (m *APIKeyMutation) AddUsage7d(f float64) { + if m.addusage_7d != nil { + *m.addusage_7d += f + } else { + m.addusage_7d = &f + } +} + +// AddedUsage7d returns the value that was added to the "usage_7d" field in this mutation. +func (m *APIKeyMutation) AddedUsage7d() (r float64, exists bool) { + v := m.addusage_7d + if v == nil { + return + } + return *v, true +} + +// ResetUsage7d resets all changes to the "usage_7d" field. +func (m *APIKeyMutation) ResetUsage7d() { + m.usage_7d = nil + m.addusage_7d = nil +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (m *APIKeyMutation) SetWindow5hStart(t time.Time) { + m.window_5h_start = &t +} + +// Window5hStart returns the value of the "window_5h_start" field in the mutation. +func (m *APIKeyMutation) Window5hStart() (r time.Time, exists bool) { + v := m.window_5h_start + if v == nil { + return + } + return *v, true +} + +// OldWindow5hStart returns the old "window_5h_start" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldWindow5hStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWindow5hStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWindow5hStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWindow5hStart: %w", err) + } + return oldValue.Window5hStart, nil +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (m *APIKeyMutation) ClearWindow5hStart() { + m.window_5h_start = nil + m.clearedFields[apikey.FieldWindow5hStart] = struct{}{} +} + +// Window5hStartCleared returns if the "window_5h_start" field was cleared in this mutation. +func (m *APIKeyMutation) Window5hStartCleared() bool { + _, ok := m.clearedFields[apikey.FieldWindow5hStart] + return ok +} + +// ResetWindow5hStart resets all changes to the "window_5h_start" field. +func (m *APIKeyMutation) ResetWindow5hStart() { + m.window_5h_start = nil + delete(m.clearedFields, apikey.FieldWindow5hStart) +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (m *APIKeyMutation) SetWindow1dStart(t time.Time) { + m.window_1d_start = &t +} + +// Window1dStart returns the value of the "window_1d_start" field in the mutation. +func (m *APIKeyMutation) Window1dStart() (r time.Time, exists bool) { + v := m.window_1d_start + if v == nil { + return + } + return *v, true +} + +// OldWindow1dStart returns the old "window_1d_start" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldWindow1dStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWindow1dStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWindow1dStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWindow1dStart: %w", err) + } + return oldValue.Window1dStart, nil +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (m *APIKeyMutation) ClearWindow1dStart() { + m.window_1d_start = nil + m.clearedFields[apikey.FieldWindow1dStart] = struct{}{} +} + +// Window1dStartCleared returns if the "window_1d_start" field was cleared in this mutation. +func (m *APIKeyMutation) Window1dStartCleared() bool { + _, ok := m.clearedFields[apikey.FieldWindow1dStart] + return ok +} + +// ResetWindow1dStart resets all changes to the "window_1d_start" field. +func (m *APIKeyMutation) ResetWindow1dStart() { + m.window_1d_start = nil + delete(m.clearedFields, apikey.FieldWindow1dStart) +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (m *APIKeyMutation) SetWindow7dStart(t time.Time) { + m.window_7d_start = &t +} + +// Window7dStart returns the value of the "window_7d_start" field in the mutation. +func (m *APIKeyMutation) Window7dStart() (r time.Time, exists bool) { + v := m.window_7d_start + if v == nil { + return + } + return *v, true +} + +// OldWindow7dStart returns the old "window_7d_start" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldWindow7dStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWindow7dStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWindow7dStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWindow7dStart: %w", err) + } + return oldValue.Window7dStart, nil +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (m *APIKeyMutation) ClearWindow7dStart() { + m.window_7d_start = nil + m.clearedFields[apikey.FieldWindow7dStart] = struct{}{} +} + +// Window7dStartCleared returns if the "window_7d_start" field was cleared in this mutation. +func (m *APIKeyMutation) Window7dStartCleared() bool { + _, ok := m.clearedFields[apikey.FieldWindow7dStart] + return ok +} + +// ResetWindow7dStart resets all changes to the "window_7d_start" field. +func (m *APIKeyMutation) ResetWindow7dStart() { + m.window_7d_start = nil + delete(m.clearedFields, apikey.FieldWindow7dStart) +} + // ClearUser clears the "user" edge to the User entity. func (m *APIKeyMutation) ClearUser() { m.cleareduser = true @@ -998,7 +1496,7 @@ func (m *APIKeyMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *APIKeyMutation) Fields() []string { - fields := make([]string, 0, 14) + fields := make([]string, 0, 23) if m.created_at != nil { fields = append(fields, apikey.FieldCreatedAt) } @@ -1041,6 +1539,33 @@ func (m *APIKeyMutation) Fields() []string { if m.expires_at != nil { fields = append(fields, apikey.FieldExpiresAt) } + if m.rate_limit_5h != nil { + fields = append(fields, apikey.FieldRateLimit5h) + } + if m.rate_limit_1d != nil { + fields = append(fields, apikey.FieldRateLimit1d) + } + if m.rate_limit_7d != nil { + fields = append(fields, apikey.FieldRateLimit7d) + } + if m.usage_5h != nil { + fields = append(fields, apikey.FieldUsage5h) + } + if m.usage_1d != nil { + fields = append(fields, apikey.FieldUsage1d) + } + if m.usage_7d != nil { + fields = append(fields, apikey.FieldUsage7d) + } + if m.window_5h_start != nil { + fields = append(fields, apikey.FieldWindow5hStart) + } + if m.window_1d_start != nil { + fields = append(fields, apikey.FieldWindow1dStart) + } + if m.window_7d_start != nil { + fields = append(fields, apikey.FieldWindow7dStart) + } return fields } @@ -1077,6 +1602,24 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) { return m.QuotaUsed() case apikey.FieldExpiresAt: return m.ExpiresAt() + case apikey.FieldRateLimit5h: + return m.RateLimit5h() + case apikey.FieldRateLimit1d: + return m.RateLimit1d() + case apikey.FieldRateLimit7d: + return m.RateLimit7d() + case apikey.FieldUsage5h: + return m.Usage5h() + case apikey.FieldUsage1d: + return m.Usage1d() + case apikey.FieldUsage7d: + return m.Usage7d() + case apikey.FieldWindow5hStart: + return m.Window5hStart() + case apikey.FieldWindow1dStart: + return m.Window1dStart() + case apikey.FieldWindow7dStart: + return m.Window7dStart() } return nil, false } @@ -1114,6 +1657,24 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldQuotaUsed(ctx) case apikey.FieldExpiresAt: return m.OldExpiresAt(ctx) + case apikey.FieldRateLimit5h: + return m.OldRateLimit5h(ctx) + case apikey.FieldRateLimit1d: + return m.OldRateLimit1d(ctx) + case apikey.FieldRateLimit7d: + return m.OldRateLimit7d(ctx) + case apikey.FieldUsage5h: + return m.OldUsage5h(ctx) + case apikey.FieldUsage1d: + return m.OldUsage1d(ctx) + case apikey.FieldUsage7d: + return m.OldUsage7d(ctx) + case apikey.FieldWindow5hStart: + return m.OldWindow5hStart(ctx) + case apikey.FieldWindow1dStart: + return m.OldWindow1dStart(ctx) + case apikey.FieldWindow7dStart: + return m.OldWindow7dStart(ctx) } return nil, fmt.Errorf("unknown APIKey field %s", name) } @@ -1221,6 +1782,69 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { } m.SetExpiresAt(v) return nil + case apikey.FieldRateLimit5h: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRateLimit5h(v) + return nil + case apikey.FieldRateLimit1d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRateLimit1d(v) + return nil + case apikey.FieldRateLimit7d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRateLimit7d(v) + return nil + case apikey.FieldUsage5h: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUsage5h(v) + return nil + case apikey.FieldUsage1d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUsage1d(v) + return nil + case apikey.FieldUsage7d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUsage7d(v) + return nil + case apikey.FieldWindow5hStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWindow5hStart(v) + return nil + case apikey.FieldWindow1dStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWindow1dStart(v) + return nil + case apikey.FieldWindow7dStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWindow7dStart(v) + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -1235,6 +1859,24 @@ func (m *APIKeyMutation) AddedFields() []string { if m.addquota_used != nil { fields = append(fields, apikey.FieldQuotaUsed) } + if m.addrate_limit_5h != nil { + fields = append(fields, apikey.FieldRateLimit5h) + } + if m.addrate_limit_1d != nil { + fields = append(fields, apikey.FieldRateLimit1d) + } + if m.addrate_limit_7d != nil { + fields = append(fields, apikey.FieldRateLimit7d) + } + if m.addusage_5h != nil { + fields = append(fields, apikey.FieldUsage5h) + } + if m.addusage_1d != nil { + fields = append(fields, apikey.FieldUsage1d) + } + if m.addusage_7d != nil { + fields = append(fields, apikey.FieldUsage7d) + } return fields } @@ -1247,6 +1889,18 @@ func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) { return m.AddedQuota() case apikey.FieldQuotaUsed: return m.AddedQuotaUsed() + case apikey.FieldRateLimit5h: + return m.AddedRateLimit5h() + case apikey.FieldRateLimit1d: + return m.AddedRateLimit1d() + case apikey.FieldRateLimit7d: + return m.AddedRateLimit7d() + case apikey.FieldUsage5h: + return m.AddedUsage5h() + case apikey.FieldUsage1d: + return m.AddedUsage1d() + case apikey.FieldUsage7d: + return m.AddedUsage7d() } return nil, false } @@ -1270,6 +1924,48 @@ func (m *APIKeyMutation) AddField(name string, value ent.Value) error { } m.AddQuotaUsed(v) return nil + case apikey.FieldRateLimit5h: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRateLimit5h(v) + return nil + case apikey.FieldRateLimit1d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRateLimit1d(v) + return nil + case apikey.FieldRateLimit7d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRateLimit7d(v) + return nil + case apikey.FieldUsage5h: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddUsage5h(v) + return nil + case apikey.FieldUsage1d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddUsage1d(v) + return nil + case apikey.FieldUsage7d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddUsage7d(v) + return nil } return fmt.Errorf("unknown APIKey numeric field %s", name) } @@ -1296,6 +1992,15 @@ func (m *APIKeyMutation) ClearedFields() []string { if m.FieldCleared(apikey.FieldExpiresAt) { fields = append(fields, apikey.FieldExpiresAt) } + if m.FieldCleared(apikey.FieldWindow5hStart) { + fields = append(fields, apikey.FieldWindow5hStart) + } + if m.FieldCleared(apikey.FieldWindow1dStart) { + fields = append(fields, apikey.FieldWindow1dStart) + } + if m.FieldCleared(apikey.FieldWindow7dStart) { + fields = append(fields, apikey.FieldWindow7dStart) + } return fields } @@ -1328,6 +2033,15 @@ func (m *APIKeyMutation) ClearField(name string) error { case apikey.FieldExpiresAt: m.ClearExpiresAt() return nil + case apikey.FieldWindow5hStart: + m.ClearWindow5hStart() + return nil + case apikey.FieldWindow1dStart: + m.ClearWindow1dStart() + return nil + case apikey.FieldWindow7dStart: + m.ClearWindow7dStart() + return nil } return fmt.Errorf("unknown APIKey nullable field %s", name) } @@ -1378,6 +2092,33 @@ func (m *APIKeyMutation) ResetField(name string) error { case apikey.FieldExpiresAt: m.ResetExpiresAt() return nil + case apikey.FieldRateLimit5h: + m.ResetRateLimit5h() + return nil + case apikey.FieldRateLimit1d: + m.ResetRateLimit1d() + return nil + case apikey.FieldRateLimit7d: + m.ResetRateLimit7d() + return nil + case apikey.FieldUsage5h: + m.ResetUsage5h() + return nil + case apikey.FieldUsage1d: + m.ResetUsage1d() + return nil + case apikey.FieldUsage7d: + m.ResetUsage7d() + return nil + case apikey.FieldWindow5hStart: + m.ResetWindow5hStart() + return nil + case apikey.FieldWindow1dStart: + m.ResetWindow1dStart() + return nil + case apikey.FieldWindow7dStart: + m.ResetWindow7dStart() + return nil } return fmt.Errorf("unknown APIKey field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 65531aae..2c7467f6 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -102,6 +102,30 @@ func init() { apikeyDescQuotaUsed := apikeyFields[9].Descriptor() // apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field. apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64) + // apikeyDescRateLimit5h is the schema descriptor for rate_limit_5h field. + apikeyDescRateLimit5h := apikeyFields[11].Descriptor() + // apikey.DefaultRateLimit5h holds the default value on creation for the rate_limit_5h field. + apikey.DefaultRateLimit5h = apikeyDescRateLimit5h.Default.(float64) + // apikeyDescRateLimit1d is the schema descriptor for rate_limit_1d field. + apikeyDescRateLimit1d := apikeyFields[12].Descriptor() + // apikey.DefaultRateLimit1d holds the default value on creation for the rate_limit_1d field. + apikey.DefaultRateLimit1d = apikeyDescRateLimit1d.Default.(float64) + // apikeyDescRateLimit7d is the schema descriptor for rate_limit_7d field. + apikeyDescRateLimit7d := apikeyFields[13].Descriptor() + // apikey.DefaultRateLimit7d holds the default value on creation for the rate_limit_7d field. + apikey.DefaultRateLimit7d = apikeyDescRateLimit7d.Default.(float64) + // apikeyDescUsage5h is the schema descriptor for usage_5h field. + apikeyDescUsage5h := apikeyFields[14].Descriptor() + // apikey.DefaultUsage5h holds the default value on creation for the usage_5h field. + apikey.DefaultUsage5h = apikeyDescUsage5h.Default.(float64) + // apikeyDescUsage1d is the schema descriptor for usage_1d field. + apikeyDescUsage1d := apikeyFields[15].Descriptor() + // apikey.DefaultUsage1d holds the default value on creation for the usage_1d field. + apikey.DefaultUsage1d = apikeyDescUsage1d.Default.(float64) + // apikeyDescUsage7d is the schema descriptor for usage_7d field. + apikeyDescUsage7d := apikeyFields[16].Descriptor() + // apikey.DefaultUsage7d holds the default value on creation for the usage_7d field. + apikey.DefaultUsage7d = apikeyDescUsage7d.Default.(float64) accountMixin := schema.Account{}.Mixin() accountMixinHooks1 := accountMixin[1].Hooks() account.Hooks[0] = accountMixinHooks1[0] diff --git a/backend/ent/schema/api_key.go b/backend/ent/schema/api_key.go index c1ac7ac3..5db51270 100644 --- a/backend/ent/schema/api_key.go +++ b/backend/ent/schema/api_key.go @@ -74,6 +74,47 @@ func (APIKey) Fields() []ent.Field { Optional(). Nillable(). Comment("Expiration time for this API key (null = never expires)"), + + // ========== Rate limit fields ========== + // Rate limit configuration (0 = unlimited) + field.Float("rate_limit_5h"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Rate limit in USD per 5 hours (0 = unlimited)"), + field.Float("rate_limit_1d"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Rate limit in USD per day (0 = unlimited)"), + field.Float("rate_limit_7d"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Rate limit in USD per 7 days (0 = unlimited)"), + // Rate limit usage tracking + field.Float("usage_5h"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Used amount in USD for the current 5h window"), + field.Float("usage_1d"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Used amount in USD for the current 1d window"), + field.Float("usage_7d"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Used amount in USD for the current 7d window"), + // Window start times + field.Time("window_5h_start"). + Optional(). + Nillable(). + Comment("Start time of the current 5h rate limit window"), + field.Time("window_1d_start"). + Optional(). + Nillable(). + Comment("Start time of the current 1d rate limit window"), + field.Time("window_7d_start"). + Optional(). + Nillable(). + Comment("Start time of the current 7d rate limit window"), } } diff --git a/backend/go.mod b/backend/go.mod index a34c9fff..ab76258a 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -180,8 +180,6 @@ require ( golang.org/x/text v0.34.0 // indirect golang.org/x/tools v0.41.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect - google.golang.org/grpc v1.75.1 // indirect - google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.67.6 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index 61762744..8db3ea2c 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -36,6 +36,11 @@ type CreateAPIKeyRequest struct { IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 Quota *float64 `json:"quota"` // 配额限制 (USD) ExpiresInDays *int `json:"expires_in_days"` // 过期天数 + + // Rate limit fields (0 = unlimited) + RateLimit5h *float64 `json:"rate_limit_5h"` + RateLimit1d *float64 `json:"rate_limit_1d"` + RateLimit7d *float64 `json:"rate_limit_7d"` } // UpdateAPIKeyRequest represents the update API key request payload @@ -48,6 +53,12 @@ type UpdateAPIKeyRequest struct { Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制 ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601) ResetQuota *bool `json:"reset_quota"` // 重置已用配额 + + // Rate limit fields (nil = no change, 0 = unlimited) + RateLimit5h *float64 `json:"rate_limit_5h"` + RateLimit1d *float64 `json:"rate_limit_1d"` + RateLimit7d *float64 `json:"rate_limit_7d"` + ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // 重置限速用量 } // List handles listing user's API keys with pagination @@ -131,6 +142,15 @@ func (h *APIKeyHandler) Create(c *gin.Context) { if req.Quota != nil { svcReq.Quota = *req.Quota } + if req.RateLimit5h != nil { + svcReq.RateLimit5h = *req.RateLimit5h + } + if req.RateLimit1d != nil { + svcReq.RateLimit1d = *req.RateLimit1d + } + if req.RateLimit7d != nil { + svcReq.RateLimit7d = *req.RateLimit7d + } executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq) @@ -163,10 +183,14 @@ func (h *APIKeyHandler) Update(c *gin.Context) { } svcReq := service.UpdateAPIKeyRequest{ - IPWhitelist: req.IPWhitelist, - IPBlacklist: req.IPBlacklist, - Quota: req.Quota, - ResetQuota: req.ResetQuota, + IPWhitelist: req.IPWhitelist, + IPBlacklist: req.IPBlacklist, + Quota: req.Quota, + ResetQuota: req.ResetQuota, + RateLimit5h: req.RateLimit5h, + RateLimit1d: req.RateLimit1d, + RateLimit7d: req.RateLimit7d, + ResetRateLimitUsage: req.ResetRateLimitUsage, } if req.Name != "" { svcReq.Name = &req.Name diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 1c34f537..fe2a1d77 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -72,22 +72,31 @@ func APIKeyFromService(k *service.APIKey) *APIKey { return nil } return &APIKey{ - ID: k.ID, - UserID: k.UserID, - Key: k.Key, - Name: k.Name, - GroupID: k.GroupID, - Status: k.Status, - IPWhitelist: k.IPWhitelist, - IPBlacklist: k.IPBlacklist, - LastUsedAt: k.LastUsedAt, - Quota: k.Quota, - QuotaUsed: k.QuotaUsed, - ExpiresAt: k.ExpiresAt, - CreatedAt: k.CreatedAt, - UpdatedAt: k.UpdatedAt, - User: UserFromServiceShallow(k.User), - Group: GroupFromServiceShallow(k.Group), + ID: k.ID, + UserID: k.UserID, + Key: k.Key, + Name: k.Name, + GroupID: k.GroupID, + Status: k.Status, + IPWhitelist: k.IPWhitelist, + IPBlacklist: k.IPBlacklist, + LastUsedAt: k.LastUsedAt, + Quota: k.Quota, + QuotaUsed: k.QuotaUsed, + ExpiresAt: k.ExpiresAt, + CreatedAt: k.CreatedAt, + UpdatedAt: k.UpdatedAt, + RateLimit5h: k.RateLimit5h, + RateLimit1d: k.RateLimit1d, + RateLimit7d: k.RateLimit7d, + Usage5h: k.Usage5h, + Usage1d: k.Usage1d, + Usage7d: k.Usage7d, + Window5hStart: k.Window5hStart, + Window1dStart: k.Window1dStart, + Window7dStart: k.Window7dStart, + User: UserFromServiceShallow(k.User), + Group: GroupFromServiceShallow(k.Group), } } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index e9235797..920615f7 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -47,6 +47,17 @@ type APIKey struct { CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` + // Rate limit fields + RateLimit5h float64 `json:"rate_limit_5h"` + RateLimit1d float64 `json:"rate_limit_1d"` + RateLimit7d float64 `json:"rate_limit_7d"` + Usage5h float64 `json:"usage_5h"` + Usage1d float64 `json:"usage_1d"` + Usage7d float64 `json:"usage_7d"` + Window5hStart *time.Time `json:"window_5h_start"` + Window1dStart *time.Time `json:"window_1d_start"` + Window7dStart *time.Time `json:"window_7d_start"` + User *User `json:"user,omitempty"` Group *Group `json:"group,omitempty"` } diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 8d39c767..c47e66df 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -1445,6 +1445,18 @@ func billingErrorDetails(err error) (status int, code, message string) { } return http.StatusServiceUnavailable, "billing_service_error", msg } + if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) { + msg := pkgerrors.Message(err) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg + } + if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) { + msg := pkgerrors.Message(err) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg + } + if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) { + msg := pkgerrors.Message(err) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg + } msg := pkgerrors.Message(err) if msg == "" { logger.L().With( diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index b9ce60a5..94de4f45 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -2,6 +2,7 @@ package repository import ( "context" + "database/sql" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -16,10 +17,11 @@ import ( type apiKeyRepository struct { client *dbent.Client + sql sqlExecutor } -func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository { - return &apiKeyRepository{client: client} +func NewAPIKeyRepository(client *dbent.Client, sqlDB *sql.DB) service.APIKeyRepository { + return &apiKeyRepository{client: client, sql: sqlDB} } func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery { @@ -37,7 +39,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro SetNillableLastUsedAt(key.LastUsedAt). SetQuota(key.Quota). SetQuotaUsed(key.QuotaUsed). - SetNillableExpiresAt(key.ExpiresAt) + SetNillableExpiresAt(key.ExpiresAt). + SetRateLimit5h(key.RateLimit5h). + SetRateLimit1d(key.RateLimit1d). + SetRateLimit7d(key.RateLimit7d) if len(key.IPWhitelist) > 0 { builder.SetIPWhitelist(key.IPWhitelist) @@ -118,6 +123,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se apikey.FieldQuota, apikey.FieldQuotaUsed, apikey.FieldExpiresAt, + apikey.FieldRateLimit5h, + apikey.FieldRateLimit1d, + apikey.FieldRateLimit7d, ). WithUser(func(q *dbent.UserQuery) { q.Select( @@ -179,6 +187,12 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro SetStatus(key.Status). SetQuota(key.Quota). SetQuotaUsed(key.QuotaUsed). + SetRateLimit5h(key.RateLimit5h). + SetRateLimit1d(key.RateLimit1d). + SetRateLimit7d(key.RateLimit7d). + SetUsage5h(key.Usage5h). + SetUsage1d(key.Usage1d). + SetUsage7d(key.Usage7d). SetUpdatedAt(now) if key.GroupID != nil { builder.SetGroupID(*key.GroupID) @@ -193,6 +207,23 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro builder.ClearExpiresAt() } + // Rate limit window start times + if key.Window5hStart != nil { + builder.SetWindow5hStart(*key.Window5hStart) + } else { + builder.ClearWindow5hStart() + } + if key.Window1dStart != nil { + builder.SetWindow1dStart(*key.Window1dStart) + } else { + builder.ClearWindow1dStart() + } + if key.Window7dStart != nil { + builder.SetWindow7dStart(*key.Window7dStart) + } else { + builder.ClearWindow7dStart() + } + // IP 限制字段 if len(key.IPWhitelist) > 0 { builder.SetIPWhitelist(key.IPWhitelist) @@ -412,25 +443,88 @@ func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt return nil } +// IncrementRateLimitUsage atomically increments all rate limit usage counters and initializes +// window start times via COALESCE if not already set. +func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + _, err := r.sql.ExecContext(ctx, ` + UPDATE api_keys SET + usage_5h = usage_5h + $1, + usage_1d = usage_1d + $1, + usage_7d = usage_7d + $1, + window_5h_start = COALESCE(window_5h_start, NOW()), + window_1d_start = COALESCE(window_1d_start, NOW()), + window_7d_start = COALESCE(window_7d_start, NOW()), + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL`, + cost, id) + return err +} + +// ResetRateLimitWindows resets expired rate limit windows atomically. +func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64) error { + _, err := r.sql.ExecContext(ctx, ` + UPDATE api_keys SET + usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END, + window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END, + usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END, + window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END, + usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END, + window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END, + updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL`, + id) + return err +} + +// GetRateLimitData returns the current rate limit usage and window start times for an API key. +func (r *apiKeyRepository) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) { + rows, err := r.sql.QueryContext(ctx, ` + SELECT usage_5h, usage_1d, usage_7d, window_5h_start, window_1d_start, window_7d_start + FROM api_keys + WHERE id = $1 AND deleted_at IS NULL`, + id) + if err != nil { + return nil, err + } + defer rows.Close() + if !rows.Next() { + return nil, service.ErrAPIKeyNotFound + } + data := &service.APIKeyRateLimitData{} + if err := rows.Scan(&data.Usage5h, &data.Usage1d, &data.Usage7d, &data.Window5hStart, &data.Window1dStart, &data.Window7dStart); err != nil { + return nil, err + } + return data, rows.Err() +} + func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { if m == nil { return nil } out := &service.APIKey{ - ID: m.ID, - UserID: m.UserID, - Key: m.Key, - Name: m.Name, - Status: m.Status, - IPWhitelist: m.IPWhitelist, - IPBlacklist: m.IPBlacklist, - LastUsedAt: m.LastUsedAt, - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, - GroupID: m.GroupID, - Quota: m.Quota, - QuotaUsed: m.QuotaUsed, - ExpiresAt: m.ExpiresAt, + ID: m.ID, + UserID: m.UserID, + Key: m.Key, + Name: m.Name, + Status: m.Status, + IPWhitelist: m.IPWhitelist, + IPBlacklist: m.IPBlacklist, + LastUsedAt: m.LastUsedAt, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + GroupID: m.GroupID, + Quota: m.Quota, + QuotaUsed: m.QuotaUsed, + ExpiresAt: m.ExpiresAt, + RateLimit5h: m.RateLimit5h, + RateLimit1d: m.RateLimit1d, + RateLimit7d: m.RateLimit7d, + Usage5h: m.Usage5h, + Usage1d: m.Usage1d, + Usage7d: m.Usage7d, + Window5hStart: m.Window5hStart, + Window1dStart: m.Window1dStart, + Window7dStart: m.Window7dStart, } if m.Edges.User != nil { out.User = userEntityToService(m.Edges.User) diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index e753e1b8..8a00237b 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -14,10 +14,12 @@ import ( ) const ( - billingBalanceKeyPrefix = "billing:balance:" - billingSubKeyPrefix = "billing:sub:" - billingCacheTTL = 5 * time.Minute - billingCacheJitter = 30 * time.Second + billingBalanceKeyPrefix = "billing:balance:" + billingSubKeyPrefix = "billing:sub:" + billingRateLimitKeyPrefix = "apikey:rate:" + billingCacheTTL = 5 * time.Minute + billingCacheJitter = 30 * time.Second + rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window ) // jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩 @@ -49,6 +51,20 @@ const ( subFieldVersion = "version" ) +// billingRateLimitKey generates the Redis key for API key rate limit cache. +func billingRateLimitKey(keyID int64) string { + return fmt.Sprintf("%s%d", billingRateLimitKeyPrefix, keyID) +} + +const ( + rateLimitFieldUsage5h = "usage_5h" + rateLimitFieldUsage1d = "usage_1d" + rateLimitFieldUsage7d = "usage_7d" + rateLimitFieldWindow5h = "window_5h" + rateLimitFieldWindow1d = "window_1d" + rateLimitFieldWindow7d = "window_7d" +) + var ( deductBalanceScript = redis.NewScript(` local current = redis.call('GET', KEYS[1]) @@ -73,6 +89,21 @@ var ( redis.call('EXPIRE', KEYS[1], ARGV[2]) return 1 `) + + // updateRateLimitUsageScript atomically increments all three rate limit usage counters. + // Returns 0 if the key doesn't exist (cache miss), 1 on success. + updateRateLimitUsageScript = redis.NewScript(` + local exists = redis.call('EXISTS', KEYS[1]) + if exists == 0 then + return 0 + end + local cost = tonumber(ARGV[1]) + redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost) + redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost) + redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost) + redis.call('EXPIRE', KEYS[1], ARGV[2]) + return 1 + `) ) type billingCache struct { @@ -195,3 +226,69 @@ func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, key := billingSubKey(userID, groupID) return c.rdb.Del(ctx, key).Err() } + +func (c *billingCache) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*service.APIKeyRateLimitCacheData, error) { + key := billingRateLimitKey(keyID) + result, err := c.rdb.HGetAll(ctx, key).Result() + if err != nil { + return nil, err + } + if len(result) == 0 { + return nil, redis.Nil + } + data := &service.APIKeyRateLimitCacheData{} + if v, ok := result[rateLimitFieldUsage5h]; ok { + data.Usage5h, _ = strconv.ParseFloat(v, 64) + } + if v, ok := result[rateLimitFieldUsage1d]; ok { + data.Usage1d, _ = strconv.ParseFloat(v, 64) + } + if v, ok := result[rateLimitFieldUsage7d]; ok { + data.Usage7d, _ = strconv.ParseFloat(v, 64) + } + if v, ok := result[rateLimitFieldWindow5h]; ok { + data.Window5h, _ = strconv.ParseInt(v, 10, 64) + } + if v, ok := result[rateLimitFieldWindow1d]; ok { + data.Window1d, _ = strconv.ParseInt(v, 10, 64) + } + if v, ok := result[rateLimitFieldWindow7d]; ok { + data.Window7d, _ = strconv.ParseInt(v, 10, 64) + } + return data, nil +} + +func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *service.APIKeyRateLimitCacheData) error { + if data == nil { + return nil + } + key := billingRateLimitKey(keyID) + fields := map[string]any{ + rateLimitFieldUsage5h: data.Usage5h, + rateLimitFieldUsage1d: data.Usage1d, + rateLimitFieldUsage7d: data.Usage7d, + rateLimitFieldWindow5h: data.Window5h, + rateLimitFieldWindow1d: data.Window1d, + rateLimitFieldWindow7d: data.Window7d, + } + pipe := c.rdb.Pipeline() + pipe.HSet(ctx, key, fields) + pipe.Expire(ctx, key, rateLimitCacheTTL) + _, err := pipe.Exec(ctx) + return err +} + +func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error { + key := billingRateLimitKey(keyID) + _, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result() + if err != nil && !errors.Is(err, redis.Nil) { + log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err) + return err + } + return nil +} + +func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { + key := billingRateLimitKey(keyID) + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 2124c86c..9f8bb2d5 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -95,6 +95,15 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim } return nil } +func (f fakeAPIKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + return nil +} +func (f fakeAPIKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error { + return nil +} +func (f fakeAPIKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) { + return &service.APIKeyRateLimitData{}, nil +} func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { return errors.New("not implemented") diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index 07523597..255e7679 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -36,12 +36,28 @@ type APIKey struct { Quota float64 // Quota limit in USD (0 = unlimited) QuotaUsed float64 // Used quota amount ExpiresAt *time.Time // Expiration time (nil = never expires) + + // Rate limit fields + RateLimit5h float64 // Rate limit in USD per 5h (0 = unlimited) + RateLimit1d float64 // Rate limit in USD per 1d (0 = unlimited) + RateLimit7d float64 // Rate limit in USD per 7d (0 = unlimited) + Usage5h float64 // Used amount in current 5h window + Usage1d float64 // Used amount in current 1d window + Usage7d float64 // Used amount in current 7d window + Window5hStart *time.Time // Start of current 5h window + Window1dStart *time.Time // Start of current 1d window + Window7dStart *time.Time // Start of current 7d window } func (k *APIKey) IsActive() bool { return k.Status == StatusActive } +// HasRateLimits returns true if any rate limit window is configured +func (k *APIKey) HasRateLimits() bool { + return k.RateLimit5h > 0 || k.RateLimit1d > 0 || k.RateLimit7d > 0 +} + // IsExpired checks if the API key has expired func (k *APIKey) IsExpired() bool { if k.ExpiresAt == nil { diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 4240be23..83933f42 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -19,6 +19,11 @@ type APIKeyAuthSnapshot struct { // Expiration field for API Key expiration feature ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires) + + // Rate limit configuration (only limits, not usage - usage read from Redis at check time) + RateLimit5h float64 `json:"rate_limit_5h"` + RateLimit1d float64 `json:"rate_limit_1d"` + RateLimit7d float64 `json:"rate_limit_7d"` } // APIKeyAuthUserSnapshot 用户快照 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 30eb8d74..0ca694af 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -209,6 +209,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { Quota: apiKey.Quota, QuotaUsed: apiKey.QuotaUsed, ExpiresAt: apiKey.ExpiresAt, + RateLimit5h: apiKey.RateLimit5h, + RateLimit1d: apiKey.RateLimit1d, + RateLimit7d: apiKey.RateLimit7d, User: APIKeyAuthUserSnapshot{ ID: apiKey.User.ID, Status: apiKey.User.Status, @@ -262,6 +265,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho Quota: snapshot.Quota, QuotaUsed: snapshot.QuotaUsed, ExpiresAt: snapshot.ExpiresAt, + RateLimit5h: snapshot.RateLimit5h, + RateLimit1d: snapshot.RateLimit1d, + RateLimit7d: snapshot.RateLimit7d, User: &User{ ID: snapshot.User.ID, Status: snapshot.User.Status, diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 0d073077..aaa2403f 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -30,6 +30,11 @@ var ( ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期") // ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted") ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完") + + // Rate limit errors + ErrAPIKeyRateLimit5hExceeded = infraerrors.TooManyRequests("API_KEY_RATE_5H_EXCEEDED", "api key 5小时限额已用完") + ErrAPIKeyRateLimit1dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_1D_EXCEEDED", "api key 日限额已用完") + ErrAPIKeyRateLimit7dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_7D_EXCEEDED", "api key 7天限额已用完") ) const ( @@ -64,6 +69,21 @@ type APIKeyRepository interface { // Quota methods IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error + + // Rate limit methods + IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error + ResetRateLimitWindows(ctx context.Context, id int64) error + GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) +} + +// APIKeyRateLimitData holds rate limit usage and window state for an API key. +type APIKeyRateLimitData struct { + Usage5h float64 + Usage1d float64 + Usage7d float64 + Window5hStart *time.Time + Window1dStart *time.Time + Window7dStart *time.Time } // APIKeyCache defines cache operations for API key service @@ -102,6 +122,11 @@ type CreateAPIKeyRequest struct { // Quota fields Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires) + + // Rate limit fields (0 = unlimited) + RateLimit5h float64 `json:"rate_limit_5h"` + RateLimit1d float64 `json:"rate_limit_1d"` + RateLimit7d float64 `json:"rate_limit_7d"` } // UpdateAPIKeyRequest 更新API Key请求 @@ -117,22 +142,34 @@ type UpdateAPIKeyRequest struct { ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change) ClearExpiration bool `json:"-"` // Clear expiration (internal use) ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0 + + // Rate limit fields (nil = no change, 0 = unlimited) + RateLimit5h *float64 `json:"rate_limit_5h"` + RateLimit1d *float64 `json:"rate_limit_1d"` + RateLimit7d *float64 `json:"rate_limit_7d"` + ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // Reset all usage counters to 0 } // APIKeyService API Key服务 +// RateLimitCacheInvalidator invalidates rate limit cache entries on manual reset. +type RateLimitCacheInvalidator interface { + InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error +} + type APIKeyService struct { - apiKeyRepo APIKeyRepository - userRepo UserRepository - groupRepo GroupRepository - userSubRepo UserSubscriptionRepository - userGroupRateRepo UserGroupRateRepository - cache APIKeyCache - cfg *config.Config - authCacheL1 *ristretto.Cache - authCfg apiKeyAuthCacheConfig - authGroup singleflight.Group - lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time) - lastUsedTouchSF singleflight.Group + apiKeyRepo APIKeyRepository + userRepo UserRepository + groupRepo GroupRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache APIKeyCache + rateLimitCacheInvalid RateLimitCacheInvalidator // optional: invalidate Redis rate limit cache + cfg *config.Config + authCacheL1 *ristretto.Cache + authCfg apiKeyAuthCacheConfig + authGroup singleflight.Group + lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time) + lastUsedTouchSF singleflight.Group } // NewAPIKeyService 创建API Key服务实例 @@ -158,6 +195,12 @@ func NewAPIKeyService( return svc } +// SetRateLimitCacheInvalidator sets the optional rate limit cache invalidator. +// Called after construction (e.g. in wire) to avoid circular dependencies. +func (s *APIKeyService) SetRateLimitCacheInvalidator(inv RateLimitCacheInvalidator) { + s.rateLimitCacheInvalid = inv +} + func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) { if apiKey == nil { return @@ -327,6 +370,9 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK IPBlacklist: req.IPBlacklist, Quota: req.Quota, QuotaUsed: 0, + RateLimit5h: req.RateLimit5h, + RateLimit1d: req.RateLimit1d, + RateLimit7d: req.RateLimit7d, } // Set expiration time if specified @@ -519,6 +565,26 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req apiKey.IPWhitelist = req.IPWhitelist apiKey.IPBlacklist = req.IPBlacklist + // Update rate limit configuration + if req.RateLimit5h != nil { + apiKey.RateLimit5h = *req.RateLimit5h + } + if req.RateLimit1d != nil { + apiKey.RateLimit1d = *req.RateLimit1d + } + if req.RateLimit7d != nil { + apiKey.RateLimit7d = *req.RateLimit7d + } + resetRateLimit := req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage + if resetRateLimit { + apiKey.Usage5h = 0 + apiKey.Usage1d = 0 + apiKey.Usage7d = 0 + apiKey.Window5hStart = nil + apiKey.Window1dStart = nil + apiKey.Window7dStart = nil + } + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { return nil, fmt.Errorf("update api key: %w", err) } @@ -526,6 +592,11 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req s.InvalidateAuthCacheByKey(ctx, apiKey.Key) s.compileAPIKeyIPRules(apiKey) + // Invalidate Redis rate limit cache so reset takes effect immediately + if resetRateLimit && s.rateLimitCacheInvalid != nil { + _ = s.rateLimitCacheInvalid.InvalidateAPIKeyRateLimit(ctx, apiKey.ID) + } + return apiKey, nil } @@ -746,3 +817,11 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos return nil } + +// UpdateRateLimitUsage atomically increments rate limit usage counters in the DB. +func (s *APIKeyService) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error { + if cost <= 0 { + return nil + } + return s.apiKeyRepo.IncrementRateLimitUsage(ctx, apiKeyID, cost) +} diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 1a76f5f6..e6c82cf1 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -40,6 +40,7 @@ const ( cacheWriteSetSubscription cacheWriteUpdateSubscriptionUsage cacheWriteDeductBalance + cacheWriteUpdateRateLimitUsage ) // 异步缓存写入工作池配置 @@ -68,19 +69,26 @@ type cacheWriteTask struct { kind cacheWriteKind userID int64 groupID int64 + apiKeyID int64 balance float64 amount float64 subscriptionData *subscriptionCacheData } +// apiKeyRateLimitLoader defines the interface for loading rate limit data from DB. +type apiKeyRateLimitLoader interface { + GetRateLimitData(ctx context.Context, keyID int64) (*APIKeyRateLimitData, error) +} + // BillingCacheService 计费缓存服务 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 type BillingCacheService struct { - cache BillingCache - userRepo UserRepository - subRepo UserSubscriptionRepository - cfg *config.Config - circuitBreaker *billingCircuitBreaker + cache BillingCache + userRepo UserRepository + subRepo UserSubscriptionRepository + apiKeyRateLimitLoader apiKeyRateLimitLoader + cfg *config.Config + circuitBreaker *billingCircuitBreaker cacheWriteChan chan cacheWriteTask cacheWriteWg sync.WaitGroup @@ -96,12 +104,13 @@ type BillingCacheService struct { } // NewBillingCacheService 创建计费缓存服务 -func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService { +func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService { svc := &BillingCacheService{ - cache: cache, - userRepo: userRepo, - subRepo: subRepo, - cfg: cfg, + cache: cache, + userRepo: userRepo, + subRepo: subRepo, + apiKeyRateLimitLoader: apiKeyRepo, + cfg: cfg, } svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker) svc.startCacheWriteWorkers() @@ -188,6 +197,12 @@ func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) { logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err) } } + case cacheWriteUpdateRateLimitUsage: + if s.cache != nil { + if err := s.cache.UpdateAPIKeyRateLimitUsage(ctx, task.apiKeyID, task.amount); err != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: update rate limit usage cache failed for api key %d: %v", task.apiKeyID, err) + } + } } cancel() } @@ -204,6 +219,8 @@ func cacheWriteKindName(kind cacheWriteKind) string { return "update_subscription_usage" case cacheWriteDeductBalance: return "deduct_balance" + case cacheWriteUpdateRateLimitUsage: + return "update_rate_limit_usage" default: return "unknown" } @@ -476,6 +493,137 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID return nil } +// ============================================ +// API Key 限速缓存方法 +// ============================================ + +// checkAPIKeyRateLimits checks rate limit windows for an API key. +// It loads usage from Redis cache (falling back to DB on cache miss), +// resets expired windows in-memory and triggers async DB reset, +// and returns an error if any window limit is exceeded. +func (s *BillingCacheService) checkAPIKeyRateLimits(ctx context.Context, apiKey *APIKey) error { + if s.cache == nil { + // No cache: fall back to reading from DB directly + if s.apiKeyRateLimitLoader == nil { + return nil + } + data, err := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID) + if err != nil { + return nil // Don't block requests on DB errors + } + return s.evaluateRateLimits(ctx, apiKey, data.Usage5h, data.Usage1d, data.Usage7d, + data.Window5hStart, data.Window1dStart, data.Window7dStart) + } + + cacheData, err := s.cache.GetAPIKeyRateLimit(ctx, apiKey.ID) + if err != nil { + // Cache miss: load from DB and populate cache + if s.apiKeyRateLimitLoader == nil { + return nil + } + dbData, dbErr := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID) + if dbErr != nil { + return nil // Don't block requests on DB errors + } + // Build cache entry from DB data + cacheEntry := &APIKeyRateLimitCacheData{ + Usage5h: dbData.Usage5h, + Usage1d: dbData.Usage1d, + Usage7d: dbData.Usage7d, + } + if dbData.Window5hStart != nil { + cacheEntry.Window5h = dbData.Window5hStart.Unix() + } + if dbData.Window1dStart != nil { + cacheEntry.Window1d = dbData.Window1dStart.Unix() + } + if dbData.Window7dStart != nil { + cacheEntry.Window7d = dbData.Window7dStart.Unix() + } + _ = s.cache.SetAPIKeyRateLimit(ctx, apiKey.ID, cacheEntry) + cacheData = cacheEntry + } + + var w5h, w1d, w7d *time.Time + if cacheData.Window5h > 0 { + t := time.Unix(cacheData.Window5h, 0) + w5h = &t + } + if cacheData.Window1d > 0 { + t := time.Unix(cacheData.Window1d, 0) + w1d = &t + } + if cacheData.Window7d > 0 { + t := time.Unix(cacheData.Window7d, 0) + w7d = &t + } + return s.evaluateRateLimits(ctx, apiKey, cacheData.Usage5h, cacheData.Usage1d, cacheData.Usage7d, w5h, w1d, w7d) +} + +// evaluateRateLimits checks usage against limits, triggering async resets for expired windows. +func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *APIKey, usage5h, usage1d, usage7d float64, w5h, w1d, w7d *time.Time) error { + needsReset := false + + // Reset expired windows in-memory for check purposes + if w5h != nil && time.Since(*w5h) >= 5*time.Hour { + usage5h = 0 + needsReset = true + } + if w1d != nil && time.Since(*w1d) >= 24*time.Hour { + usage1d = 0 + needsReset = true + } + if w7d != nil && time.Since(*w7d) >= 7*24*time.Hour { + usage7d = 0 + needsReset = true + } + + // Trigger async DB reset if any window expired + if needsReset { + keyID := apiKey.ID + go func() { + resetCtx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) + defer cancel() + if s.apiKeyRateLimitLoader != nil { + // Use the repo directly - reset then reload cache + if loader, ok := s.apiKeyRateLimitLoader.(interface { + ResetRateLimitWindows(ctx context.Context, id int64) error + }); ok { + _ = loader.ResetRateLimitWindows(resetCtx, keyID) + } + } + // Invalidate cache so next request loads fresh data + if s.cache != nil { + _ = s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID) + } + }() + } + + // Check limits + if apiKey.RateLimit5h > 0 && usage5h >= apiKey.RateLimit5h { + return ErrAPIKeyRateLimit5hExceeded + } + if apiKey.RateLimit1d > 0 && usage1d >= apiKey.RateLimit1d { + return ErrAPIKeyRateLimit1dExceeded + } + if apiKey.RateLimit7d > 0 && usage7d >= apiKey.RateLimit7d { + return ErrAPIKeyRateLimit7dExceeded + } + return nil +} + +// QueueUpdateAPIKeyRateLimitUsage asynchronously updates rate limit usage in the cache. +func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, cost float64) { + if s.cache == nil { + return + } + s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteUpdateRateLimitUsage, + apiKeyID: apiKeyID, + amount: cost, + }) +} + // ============================================ // 统一检查方法 // ============================================ @@ -496,10 +644,23 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil if isSubscriptionMode { - return s.checkSubscriptionEligibility(ctx, user.ID, group, subscription) + if err := s.checkSubscriptionEligibility(ctx, user.ID, group, subscription); err != nil { + return err + } + } else { + if err := s.checkBalanceEligibility(ctx, user.ID); err != nil { + return err + } } - return s.checkBalanceEligibility(ctx, user.ID) + // Check API Key rate limits (applies to both billing modes) + if apiKey != nil && apiKey.HasRateLimits() { + if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil { + return err + } + } + + return nil } // checkBalanceEligibility 检查余额模式资格 diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go index 4e5f50e2..7d7045e2 100644 --- a/backend/internal/service/billing_cache_service_test.go +++ b/backend/internal/service/billing_cache_service_test.go @@ -52,9 +52,25 @@ func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context return nil } +func (b *billingCacheWorkerStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) { + return nil, errors.New("not implemented") +} + +func (b *billingCacheWorkerStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error { + return nil +} + +func (b *billingCacheWorkerStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error { + return nil +} + +func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { + return nil +} + func TestBillingCacheServiceQueueHighLoad(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) t.Cleanup(svc.Stop) start := time.Now() @@ -76,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) svc.Stop() enqueued := svc.enqueueCacheWrite(cacheWriteTask{ diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 6abd1e53..5d67c808 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -10,6 +10,16 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" ) +// APIKeyRateLimitCacheData holds rate limit usage data cached in Redis. +type APIKeyRateLimitCacheData struct { + Usage5h float64 `json:"usage_5h"` + Usage1d float64 `json:"usage_1d"` + Usage7d float64 `json:"usage_7d"` + Window5h int64 `json:"window_5h"` // unix timestamp, 0 = not started + Window1d int64 `json:"window_1d"` + Window7d int64 `json:"window_7d"` +} + // BillingCache defines cache operations for billing service type BillingCache interface { // Balance operations @@ -23,6 +33,12 @@ type BillingCache interface { SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error + + // API Key rate limit operations + GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) + SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error + UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error + InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error } // ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 48c69881..62ffdd4b 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -6361,9 +6361,10 @@ type RecordUsageInput struct { APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 } -// APIKeyQuotaUpdater defines the interface for updating API Key quota +// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage type APIKeyQuotaUpdater interface { UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error + UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error } // RecordUsage 记录使用量并扣费(或更新订阅用量) @@ -6557,6 +6558,14 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } } + // Update API Key rate limit usage + if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { + if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { + logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err) + } + s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) + } + // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) @@ -6746,6 +6755,14 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } } + // Update API Key rate limit usage + if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { + if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { + logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err) + } + s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) + } + // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f624d92a..41ce278f 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -3492,6 +3492,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec } } + // Update API Key rate limit usage + if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { + if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { + logger.LegacyPrintf("service.openai_gateway", "Update API key rate limit usage failed: %v", err) + } + s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) + } + // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) diff --git a/frontend/src/api/keys.ts b/frontend/src/api/keys.ts index c5943789..6a03e6aa 100644 --- a/frontend/src/api/keys.ts +++ b/frontend/src/api/keys.ts @@ -46,6 +46,7 @@ export async function getById(id: number): Promise { * @param ipBlacklist - Optional IP blacklist * @param quota - Optional quota limit in USD (0 = unlimited) * @param expiresInDays - Optional days until expiry (undefined = never expires) + * @param rateLimitData - Optional rate limit fields * @returns Created API key */ export async function create( @@ -55,7 +56,8 @@ export async function create( ipWhitelist?: string[], ipBlacklist?: string[], quota?: number, - expiresInDays?: number + expiresInDays?: number, + rateLimitData?: { rate_limit_5h?: number; rate_limit_1d?: number; rate_limit_7d?: number } ): Promise { const payload: CreateApiKeyRequest = { name } if (groupId !== undefined) { @@ -76,6 +78,15 @@ export async function create( if (expiresInDays !== undefined && expiresInDays > 0) { payload.expires_in_days = expiresInDays } + if (rateLimitData?.rate_limit_5h && rateLimitData.rate_limit_5h > 0) { + payload.rate_limit_5h = rateLimitData.rate_limit_5h + } + if (rateLimitData?.rate_limit_1d && rateLimitData.rate_limit_1d > 0) { + payload.rate_limit_1d = rateLimitData.rate_limit_1d + } + if (rateLimitData?.rate_limit_7d && rateLimitData.rate_limit_7d > 0) { + payload.rate_limit_7d = rateLimitData.rate_limit_7d + } const { data } = await apiClient.post('/keys', payload) return data diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index eedbb142..41edeb6a 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -560,6 +560,19 @@ export default { resetQuotaConfirmMessage: 'Are you sure you want to reset the used quota (${used}) for key "{name}" to 0? This action cannot be undone.', quotaResetSuccess: 'Quota reset successfully', failedToResetQuota: 'Failed to reset quota', + rateLimitColumn: 'Rate Limit', + rateLimitSection: 'Rate Limit', + resetUsage: 'Reset', + rateLimit5h: '5-Hour Limit (USD)', + rateLimit1d: 'Daily Limit (USD)', + rateLimit7d: '7-Day Limit (USD)', + rateLimitHint: 'Set the maximum spending for this key within each time window. 0 = unlimited.', + rateLimitUsage: 'Rate Limit Usage', + resetRateLimitUsage: 'Reset Rate Limit Usage', + resetRateLimitTitle: 'Confirm Reset Rate Limit', + resetRateLimitConfirmMessage: 'Are you sure you want to reset the rate limit usage for key "{name}"? All time window usage will be reset to zero. This action cannot be undone.', + rateLimitResetSuccess: 'Rate limit usage reset successfully', + failedToResetRateLimit: 'Failed to reset rate limit usage', expiration: 'Expiration', expiresInDays: '{days} days', extendDays: '+{days} days', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index afd051f4..397ecbb2 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -566,6 +566,19 @@ export default { resetQuotaConfirmMessage: '确定要将密钥 "{name}" 的已用额度(${used})重置为 0 吗?此操作不可撤销。', quotaResetSuccess: '额度重置成功', failedToResetQuota: '重置额度失败', + rateLimitColumn: '速率限制', + rateLimitSection: '速率限制', + resetUsage: '重置', + rateLimit5h: '5小时限额 (USD)', + rateLimit1d: '日限额 (USD)', + rateLimit7d: '7天限额 (USD)', + rateLimitHint: '设置此密钥在指定时间窗口内的最大消费额。0 = 无限制。', + rateLimitUsage: '速率限制用量', + resetRateLimitUsage: '重置速率限制用量', + resetRateLimitTitle: '确认重置速率限制', + resetRateLimitConfirmMessage: '确定要重置密钥 "{name}" 的速率限制用量吗?所有时间窗口的已用额度将归零。此操作不可撤销。', + rateLimitResetSuccess: '速率限制已重置', + failedToResetRateLimit: '重置速率限制失败', expiration: '密钥有效期', expiresInDays: '{days} 天', extendDays: '+{days} 天', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index ed6430d5..6e5aa302 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -421,6 +421,15 @@ export interface ApiKey { created_at: string updated_at: string group?: Group + rate_limit_5h: number + rate_limit_1d: number + rate_limit_7d: number + usage_5h: number + usage_1d: number + usage_7d: number + window_5h_start: string | null + window_1d_start: string | null + window_7d_start: string | null } export interface CreateApiKeyRequest { @@ -431,6 +440,9 @@ export interface CreateApiKeyRequest { ip_blacklist?: string[] quota?: number // Quota limit in USD (0 = unlimited) expires_in_days?: number // Days until expiry (null = never expires) + rate_limit_5h?: number + rate_limit_1d?: number + rate_limit_7d?: number } export interface UpdateApiKeyRequest { @@ -442,6 +454,10 @@ export interface UpdateApiKeyRequest { quota?: number // Quota limit in USD (null = no change, 0 = unlimited) expires_at?: string | null // Expiration time (null = no change) reset_quota?: boolean // Reset quota_used to 0 + rate_limit_5h?: number + rate_limit_1d?: number + rate_limit_7d?: number + reset_rate_limit_usage?: boolean } export interface CreateGroupRequest { diff --git a/frontend/src/views/user/KeysView.vue b/frontend/src/views/user/KeysView.vue index 6beb993b..3f599844 100644 --- a/frontend/src/views/user/KeysView.vue +++ b/frontend/src/views/user/KeysView.vue @@ -137,6 +137,97 @@
+ + + + @@ -96,6 +85,7 @@ import { computed } from 'vue' import { useI18n } from 'vue-i18n' import type { Account } from '@/types' +import QuotaBadge from './QuotaBadge.vue' const props = defineProps<{ account: Account @@ -304,46 +294,17 @@ const rpmTooltip = computed(() => { } }) -// 是否显示配额限制(仅 apikey 类型且设置了 quota_limit) -const showQuotaLimit = computed(() => { - return ( - props.account.type === 'apikey' && - props.account.quota_limit !== undefined && - props.account.quota_limit !== null && - props.account.quota_limit > 0 - ) +// 是否显示各维度配额(仅 apikey 类型) +const showDailyQuota = computed(() => { + return props.account.type === 'apikey' && (props.account.quota_daily_limit ?? 0) > 0 }) -// 当前已用配额 -const currentQuotaUsed = computed(() => props.account.quota_used ?? 0) - -// 配额状态样式 -const quotaClass = computed(() => { - if (!showQuotaLimit.value) return '' - - const used = currentQuotaUsed.value - const limit = props.account.quota_limit || 0 - - if (used >= limit) { - return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400' - } - if (used >= limit * 0.8) { - return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400' - } - return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400' +const showWeeklyQuota = computed(() => { + return props.account.type === 'apikey' && (props.account.quota_weekly_limit ?? 0) > 0 }) -// 配额提示文字 -const quotaTooltip = computed(() => { - if (!showQuotaLimit.value) return '' - - const used = currentQuotaUsed.value - const limit = props.account.quota_limit || 0 - - if (used >= limit) { - return t('admin.accounts.capacity.quota.exceeded') - } - return t('admin.accounts.capacity.quota.normal') +const showTotalQuota = computed(() => { + return props.account.type === 'apikey' && (props.account.quota_limit ?? 0) > 0 }) // 格式化费用显示 diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 14064078..835ec853 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1228,7 +1228,22 @@
- +
+
+

{{ t('admin.accounts.quotaLimit') }}

+

+ {{ t('admin.accounts.quotaLimitHint') }} +

+
+ +
('oauth') // For oauth-based: 'oauth' or 'setup- const apiKeyBaseUrl = ref('https://api.anthropic.com') const apiKeyValue = ref('') const editQuotaLimit = ref(null) +const editQuotaDailyLimit = ref(null) +const editQuotaWeeklyLimit = ref(null) const modelMappings = ref([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) @@ -3272,6 +3289,8 @@ const resetForm = () => { apiKeyBaseUrl.value = 'https://api.anthropic.com' apiKeyValue.value = '' editQuotaLimit.value = null + editQuotaDailyLimit.value = null + editQuotaWeeklyLimit.value = null modelMappings.value = [] modelRestrictionMode.value = 'whitelist' allowedModels.value = [...claudeModels] // Default fill related models @@ -3686,10 +3705,22 @@ const createAccountAndFinish = async ( if (!applyTempUnschedConfig(credentials)) { return } - // Inject quota_limit for apikey accounts + // Inject quota limits for apikey accounts let finalExtra = extra - if (type === 'apikey' && editQuotaLimit.value != null && editQuotaLimit.value > 0) { - finalExtra = { ...(extra || {}), quota_limit: editQuotaLimit.value } + if (type === 'apikey') { + const quotaExtra: Record = { ...(extra || {}) } + if (editQuotaLimit.value != null && editQuotaLimit.value > 0) { + quotaExtra.quota_limit = editQuotaLimit.value + } + if (editQuotaDailyLimit.value != null && editQuotaDailyLimit.value > 0) { + quotaExtra.quota_daily_limit = editQuotaDailyLimit.value + } + if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) { + quotaExtra.quota_weekly_limit = editQuotaWeeklyLimit.value + } + if (Object.keys(quotaExtra).length > 0) { + finalExtra = quotaExtra + } } await doCreateAccount({ name: form.name, diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 148f95b6..200f3c3c 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -904,7 +904,22 @@
- +
+
+

{{ t('admin.accounts.quotaLimit') }}

+

+ {{ t('admin.accounts.quotaLimitHint') }} +

+
+ +
(OPENAI_WS_MODE_OF const codexCLIOnlyEnabled = ref(false) const anthropicPassthroughEnabled = ref(false) const editQuotaLimit = ref(null) +const editQuotaDailyLimit = ref(null) +const editQuotaWeeklyLimit = ref(null) const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复 @@ -1704,8 +1721,14 @@ watch( if (newAccount.type === 'apikey') { const quotaVal = extra?.quota_limit as number | undefined editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null + const dailyVal = extra?.quota_daily_limit as number | undefined + editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null + const weeklyVal = extra?.quota_weekly_limit as number | undefined + editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null } else { editQuotaLimit.value = null + editQuotaDailyLimit.value = null + editQuotaWeeklyLimit.value = null } // Load antigravity model mapping (Antigravity 只支持映射模式) @@ -2525,6 +2548,16 @@ const handleSubmit = async () => { } else { delete newExtra.quota_limit } + if (editQuotaDailyLimit.value != null && editQuotaDailyLimit.value > 0) { + newExtra.quota_daily_limit = editQuotaDailyLimit.value + } else { + delete newExtra.quota_daily_limit + } + if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) { + newExtra.quota_weekly_limit = editQuotaWeeklyLimit.value + } else { + delete newExtra.quota_weekly_limit + } updatePayload.extra = newExtra } diff --git a/frontend/src/components/account/QuotaBadge.vue b/frontend/src/components/account/QuotaBadge.vue new file mode 100644 index 00000000..7cf0f59d --- /dev/null +++ b/frontend/src/components/account/QuotaBadge.vue @@ -0,0 +1,49 @@ + + + diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index 1be73a25..505118ba 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -1,50 +1,59 @@ diff --git a/frontend/src/components/admin/account/AccountActionMenu.vue b/frontend/src/components/admin/account/AccountActionMenu.vue index 02596b9f..683a2092 100644 --- a/frontend/src/components/admin/account/AccountActionMenu.vue +++ b/frontend/src/components/admin/account/AccountActionMenu.vue @@ -76,10 +76,11 @@ const isRateLimited = computed(() => { }) const isOverloaded = computed(() => props.account?.overload_until && new Date(props.account.overload_until) > new Date()) const hasQuotaLimit = computed(() => { - return props.account?.type === 'apikey' && - props.account?.quota_limit !== undefined && - props.account?.quota_limit !== null && - props.account?.quota_limit > 0 + return props.account?.type === 'apikey' && ( + (props.account?.quota_limit ?? 0) > 0 || + (props.account?.quota_daily_limit ?? 0) > 0 || + (props.account?.quota_weekly_limit ?? 0) > 0 + ) }) const handleKeydown = (event: KeyboardEvent) => { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 1efff120..36dc790c 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1794,11 +1794,17 @@ export default { resetQuota: 'Reset Quota', quotaLimit: 'Quota Limit', quotaLimitPlaceholder: '0 means unlimited', - quotaLimitHint: 'Set max spending limit (USD). Account will be paused when reached. Changing limit won\'t reset usage.', + quotaLimitHint: 'Set daily/weekly/total spending limits (USD). Account will be paused when any limit is reached. Changing limits won\'t reset usage.', quotaLimitToggle: 'Enable Quota Limit', quotaLimitToggleHint: 'When enabled, account will be paused when usage reaches the set limit', - quotaLimitAmount: 'Limit Amount', - quotaLimitAmountHint: 'Maximum spending limit (USD). Account will be auto-paused when reached. Changing limit won\'t reset usage.', + quotaDailyLimit: 'Daily Limit', + quotaDailyLimitHint: 'Automatically resets every 24 hours from first usage.', + quotaWeeklyLimit: 'Weekly Limit', + quotaWeeklyLimitHint: 'Automatically resets every 7 days from first usage.', + quotaTotalLimit: 'Total Limit', + quotaTotalLimitHint: 'Cumulative spending limit. Does not auto-reset — use "Reset Quota" to clear.', + quotaLimitAmount: 'Total Limit', + quotaLimitAmountHint: 'Cumulative spending limit. Does not auto-reset.', testConnection: 'Test Connection', reAuthorize: 'Re-Authorize', refreshToken: 'Refresh Token', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index b2c38928..017b2cea 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1801,11 +1801,17 @@ export default { resetQuota: '重置配额', quotaLimit: '配额限制', quotaLimitPlaceholder: '0 表示不限制', - quotaLimitHint: '设置最大使用额度(美元),达到后账号暂停调度。修改限额不会重置已用额度。', + quotaLimitHint: '设置日/周/总使用额度(美元),任一维度达到限额后账号暂停调度。修改限额不会重置已用额度。', quotaLimitToggle: '启用配额限制', quotaLimitToggleHint: '开启后,当账号用量达到设定额度时自动暂停调度', - quotaLimitAmount: '限额金额', - quotaLimitAmountHint: '账号最大可用额度(美元),达到后自动暂停。修改限额不会重置已用额度。', + quotaDailyLimit: '日限额', + quotaDailyLimitHint: '从首次使用起每 24 小时自动重置。', + quotaWeeklyLimit: '周限额', + quotaWeeklyLimitHint: '从首次使用起每 7 天自动重置。', + quotaTotalLimit: '总限额', + quotaTotalLimitHint: '累计消费上限,不会自动重置 — 使用「重置配额」手动清零。', + quotaLimitAmount: '总限额', + quotaLimitAmountHint: '累计消费上限,不会自动重置。', testConnection: '测试连接', reAuthorize: '重新授权', refreshToken: '刷新令牌', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 2d8a2487..46665742 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -719,6 +719,10 @@ export interface Account { // API Key 账号配额限制 quota_limit?: number | null quota_used?: number | null + quota_daily_limit?: number | null + quota_daily_used?: number | null + quota_weekly_limit?: number | null + quota_weekly_used?: number | null // 运行时状态(仅当启用对应限制时返回) current_window_cost?: number | null // 当前窗口费用 From d22e62ac8a451de016519df52072da469460963c Mon Sep 17 00:00:00 2001 From: erio Date: Sat, 7 Mar 2026 19:28:22 +0800 Subject: [PATCH 125/286] fix(test): add allow_messages_dispatch to group API contract test The recent upstream commit added allow_messages_dispatch to the Group DTO but did not update the API contract test expectation. Co-Authored-By: Claude Opus 4.6 --- backend/internal/server/api_contract_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index aafbbe21..32126791 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -212,6 +212,7 @@ func TestAPIContracts(t *testing.T) { "claude_code_only": false, "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, + "allow_messages_dispatch": false, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } From 0debe0a80c9ae53d848bc04c0b46e76473fbe11b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A5=9E=E4=B9=90?= <6682635@qq.com> Date: Sat, 7 Mar 2026 20:02:58 +0800 Subject: [PATCH 126/286] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20OpenAI=20WS?= =?UTF-8?q?=20=E7=94=A8=E9=87=8F=E7=AA=97=E5=8F=A3=E5=88=B7=E6=96=B0?= =?UTF-8?q?=E4=B8=8E=E9=99=90=E9=A2=9D=E7=BA=A0=E5=81=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../internal/service/account_test_service.go | 10 ++ .../account_test_service_openai_test.go | 66 +++++++ .../internal/service/account_usage_service.go | 64 ++++++- .../service/account_usage_service_test.go | 68 +++++++ backend/internal/service/ratelimit_service.go | 18 ++ .../service/ratelimit_service_openai_test.go | 46 +++++ .../components/account/AccountUsageCell.vue | 68 ++++++- .../components/account/UsageProgressBar.vue | 60 +------ .../__tests__/AccountUsageCell.spec.ts | 167 ++++++++++++++++++ .../__tests__/accountUsageRefresh.spec.ts | 39 ++++ frontend/src/utils/accountUsageRefresh.ts | 28 +++ frontend/src/views/admin/AccountsView.vue | 4 +- 12 files changed, 568 insertions(+), 70 deletions(-) create mode 100644 backend/internal/service/account_test_service_openai_test.go create mode 100644 backend/internal/service/account_usage_service_test.go create mode 100644 frontend/src/utils/__tests__/accountUsageRefresh.spec.ts create mode 100644 frontend/src/utils/accountUsageRefresh.ts diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 9557e175..5b22c645 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -408,6 +408,16 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) + if isOAuth && s.accountRepo != nil { + if updates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(updates) > 0 { + _ = s.accountRepo.UpdateExtra(ctx, account.ID, updates) + mergeAccountExtra(account, updates) + } + if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil { + _ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt) + account.RateLimitResetAt = resetAt + } + } return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) } diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go new file mode 100644 index 00000000..61a755a7 --- /dev/null +++ b/backend/internal/service/account_test_service_openai_test.go @@ -0,0 +1,66 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type openAIAccountTestRepo struct { + mockAccountRepoForGemini + updatedExtra map[string]any + rateLimitedID int64 + rateLimitedAt *time.Time +} + +func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + r.updatedExtra = updates + return nil +} + +func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error { + r.rateLimitedID = id + r.rateLimitedAt = &resetAt + return nil +} + +func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newSoraTestContext() + + resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`) + resp.Header.Set("x-codex-primary-used-percent", "100") + resp.Header.Set("x-codex-primary-reset-after-seconds", "604800") + resp.Header.Set("x-codex-primary-window-minutes", "10080") + resp.Header.Set("x-codex-secondary-used-percent", "100") + resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000") + resp.Header.Set("x-codex-secondary-window-minutes", "300") + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 88, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + require.Error(t, err) + require.NotEmpty(t, repo.updatedExtra) + require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"]) + require.Equal(t, int64(88), repo.rateLimitedID) + require.NotNil(t, repo.rateLimitedAt) + require.NotNil(t, account.RateLimitResetAt) + if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil { + require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second) + } +} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index b0a4900d..9bb3aa0b 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -367,7 +367,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou usage.SevenDay = progress } - if (usage.FiveHour == nil || usage.SevenDay == nil) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) { + if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) { if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 { mergeAccountExtra(account, updates) if usage.UpdatedAt == nil { @@ -409,6 +409,40 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou return usage, nil } +func shouldRefreshOpenAICodexSnapshot(account *Account, usage *UsageInfo, now time.Time) bool { + if account == nil { + return false + } + if usage == nil { + return true + } + if usage.FiveHour == nil || usage.SevenDay == nil { + return true + } + if account.IsRateLimited() { + return true + } + return isOpenAICodexSnapshotStale(account, now) +} + +func isOpenAICodexSnapshotStale(account *Account, now time.Time) bool { + if account == nil || !account.IsOpenAIOAuth() || !account.IsOpenAIResponsesWebSocketV2Enabled() { + return false + } + if account.Extra == nil { + return true + } + raw, ok := account.Extra["codex_usage_updated_at"] + if !ok { + return true + } + ts, err := parseTime(fmt.Sprint(raw)) + if err != nil { + return true + } + return now.Sub(ts) >= openAIProbeCacheTTL +} + func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool { if s == nil || s.cache == nil || accountID <= 0 { return true @@ -478,20 +512,34 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode) + updates, err := extractOpenAICodexProbeUpdates(resp) + if err != nil { + return nil, err + } + if len(updates) > 0 { + go func(accountID int64, updates map[string]any) { + updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer updateCancel() + _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + }(account.ID, updates) + return updates, nil + } + return nil, nil +} + +func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) { + if resp == nil { + return nil, nil } if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) if len(updates) > 0 { - go func(accountID int64, updates map[string]any) { - updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer updateCancel() - _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) - }(account.ID, updates) return updates, nil } } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode) + } return nil, nil } diff --git a/backend/internal/service/account_usage_service_test.go b/backend/internal/service/account_usage_service_test.go new file mode 100644 index 00000000..974d9029 --- /dev/null +++ b/backend/internal/service/account_usage_service_test.go @@ -0,0 +1,68 @@ +package service + +import ( + "net/http" + "testing" + "time" +) + +func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) { + t.Parallel() + + rateLimitedUntil := time.Now().Add(5 * time.Minute) + now := time.Now() + usage := &UsageInfo{ + FiveHour: &UsageProgress{Utilization: 0}, + SevenDay: &UsageProgress{Utilization: 0}, + } + + if !shouldRefreshOpenAICodexSnapshot(&Account{RateLimitResetAt: &rateLimitedUntil}, usage, now) { + t.Fatal("expected rate-limited account to force codex snapshot refresh") + } + + if shouldRefreshOpenAICodexSnapshot(&Account{}, usage, now) { + t.Fatal("expected complete non-rate-limited usage to skip codex snapshot refresh") + } + + if !shouldRefreshOpenAICodexSnapshot(&Account{}, &UsageInfo{FiveHour: nil, SevenDay: &UsageProgress{}}, now) { + t.Fatal("expected missing 5h snapshot to require refresh") + } + + staleAt := now.Add(-(openAIProbeCacheTTL + time.Minute)).Format(time.RFC3339) + if !shouldRefreshOpenAICodexSnapshot(&Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "codex_usage_updated_at": staleAt, + }, + }, usage, now) { + t.Fatal("expected stale ws snapshot to trigger refresh") + } +} + +func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) { + t.Parallel() + + headers := make(http.Header) + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "604800") + headers.Set("x-codex-primary-window-minutes", "10080") + headers.Set("x-codex-secondary-used-percent", "100") + headers.Set("x-codex-secondary-reset-after-seconds", "18000") + headers.Set("x-codex-secondary-window-minutes", "300") + + updates, err := extractOpenAICodexProbeUpdates(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers}) + if err != nil { + t.Fatalf("extractOpenAICodexProbeUpdates() error = %v", err) + } + if len(updates) == 0 { + t.Fatal("expected codex probe updates from 429 headers") + } + if got := updates["codex_5h_used_percent"]; got != 100.0 { + t.Fatalf("codex_5h_used_percent = %v, want 100", got) + } + if got := updates["codex_7d_used_percent"]; got != 100.0 { + t.Fatalf("codex_7d_used_percent = %v, want 100", got) + } +} diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index f8f3154b..60ad99d0 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -615,6 +615,7 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) { // 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded) if account.Platform == PlatformOpenAI { + s.persistOpenAICodexSnapshot(ctx, account, headers) if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) @@ -878,6 +879,23 @@ func pickSooner(a, b *time.Time) *time.Time { } } +func (s *RateLimitService) persistOpenAICodexSnapshot(ctx context.Context, account *Account, headers http.Header) { + if s == nil || s.accountRepo == nil || account == nil || headers == nil { + return + } + snapshot := ParseCodexRateLimitHeaders(headers) + if snapshot == nil { + return + } + updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) + if len(updates) == 0 { + return + } + if err := s.accountRepo.UpdateExtra(ctx, account.ID, updates); err != nil { + slog.Warn("openai_codex_snapshot_persist_failed", "account_id", account.ID, "error", err) + } +} + // parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳 // OpenAI 的 usage_limit_reached 错误格式: // diff --git a/backend/internal/service/ratelimit_service_openai_test.go b/backend/internal/service/ratelimit_service_openai_test.go index 00902068..94b9a170 100644 --- a/backend/internal/service/ratelimit_service_openai_test.go +++ b/backend/internal/service/ratelimit_service_openai_test.go @@ -1,6 +1,7 @@ package service import ( + "context" "net/http" "testing" "time" @@ -141,6 +142,51 @@ func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) { } } +type openAI429SnapshotRepo struct { + mockAccountRepoForGemini + rateLimitedID int64 + updatedExtra map[string]any +} + +func (r *openAI429SnapshotRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error { + r.rateLimitedID = id + return nil +} + +func (r *openAI429SnapshotRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + r.updatedExtra = updates + return nil +} + +func TestHandle429_OpenAIPersistsCodexSnapshotImmediately(t *testing.T) { + repo := &openAI429SnapshotRepo{} + svc := NewRateLimitService(repo, nil, nil, nil, nil) + account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + + headers := http.Header{} + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "604800") + headers.Set("x-codex-primary-window-minutes", "10080") + headers.Set("x-codex-secondary-used-percent", "100") + headers.Set("x-codex-secondary-reset-after-seconds", "18000") + headers.Set("x-codex-secondary-window-minutes", "300") + + svc.handle429(context.Background(), account, headers, nil) + + if repo.rateLimitedID != account.ID { + t.Fatalf("rateLimitedID = %d, want %d", repo.rateLimitedID, account.ID) + } + if len(repo.updatedExtra) == 0 { + t.Fatal("expected codex snapshot to be persisted on 429") + } + if got := repo.updatedExtra["codex_5h_used_percent"]; got != 100.0 { + t.Fatalf("codex_5h_used_percent = %v, want 100", got) + } + if got := repo.updatedExtra["codex_7d_used_percent"]; got != 100.0 { + t.Fatalf("codex_7d_used_percent = %v, want 100", got) + } +} + func TestNormalizedCodexLimits(t *testing.T) { // Test the Normalize() method directly pUsed := 100.0 diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index 20b4b629..44c8e209 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -69,9 +69,39 @@
-
- + +
@@ -285,6 +287,7 @@ import { useAppStore } from '@/stores/app' import { useAuthStore } from '@/stores/auth' import { adminAPI } from '@/api/admin' import { useTableLoader } from '@/composables/useTableLoader' +import { useSwipeSelect } from '@/composables/useSwipeSelect' import AppLayout from '@/components/layout/AppLayout.vue' import TablePageLayout from '@/components/layout/TablePageLayout.vue' import DataTable from '@/components/common/DataTable.vue' @@ -319,6 +322,12 @@ const authStore = useAuthStore() const proxies = ref([]) const groups = ref([]) const selIds = ref([]) +const accountTableRef = ref(null) +useSwipeSelect(accountTableRef, { + isSelected: (id) => selIds.value.includes(id), + select: (id) => { if (!selIds.value.includes(id)) selIds.value.push(id) }, + deselect: (id) => { selIds.value = selIds.value.filter(x => x !== id) } +}) const selPlatforms = computed(() => { const platforms = new Set( accounts.value diff --git a/frontend/src/views/admin/ProxiesView.vue b/frontend/src/views/admin/ProxiesView.vue index 147b3205..c26aa233 100644 --- a/frontend/src/views/admin/ProxiesView.vue +++ b/frontend/src/views/admin/ProxiesView.vue @@ -88,6 +88,7 @@ @@ -154,7 +158,7 @@ const props = defineProps<{ const emit = defineEmits<{ close: [] - reset: [] + reset: [account: Account] }>() const { t } = useI18n() @@ -225,12 +229,12 @@ const handleReset = async () => { if (!props.account) return resetting.value = true try { - await adminAPI.accounts.resetTempUnschedulable(props.account.id) - appStore.showSuccess(t('admin.accounts.tempUnschedulable.resetSuccess')) - emit('reset') + const updated = await adminAPI.accounts.recoverState(props.account.id) + appStore.showSuccess(t('admin.accounts.recoverStateSuccess')) + emit('reset', updated) handleClose() } catch (error: any) { - appStore.showError(error?.message || t('admin.accounts.tempUnschedulable.resetFailed')) + appStore.showError(error?.message || t('admin.accounts.recoverStateFailed')) } finally { resetting.value = false } diff --git a/frontend/src/components/admin/account/AccountActionMenu.vue b/frontend/src/components/admin/account/AccountActionMenu.vue index 02596b9f..2765b9e2 100644 --- a/frontend/src/components/admin/account/AccountActionMenu.vue +++ b/frontend/src/components/admin/account/AccountActionMenu.vue @@ -32,14 +32,10 @@ {{ t('admin.accounts.refreshToken') }} -
- - + +
+

+ + {{ t('admin.accounts.poolModeInfo') }} +

+
+
+ + +

+ {{ + t('admin.accounts.poolModeRetryCountHint', { + default: DEFAULT_POOL_MODE_RETRY_COUNT, + max: MAX_POOL_MODE_RETRY_COUNT + }) + }} +

+
+ +
@@ -2612,6 +2664,10 @@ const editQuotaLimit = ref(null) const modelMappings = ref([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) +const DEFAULT_POOL_MODE_RETRY_COUNT = 3 +const MAX_POOL_MODE_RETRY_COUNT = 10 +const poolModeEnabled = ref(false) +const poolModeRetryCount = ref(DEFAULT_POOL_MODE_RETRY_COUNT) const customErrorCodesEnabled = ref(false) const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) @@ -3281,6 +3337,8 @@ const resetForm = () => { fetchAntigravityDefaultMappings().then(mappings => { antigravityModelMappings.value = [...mappings] }) + poolModeEnabled.value = false + poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT customErrorCodesEnabled.value = false selectedErrorCodes.value = [] customErrorCodeInput.value = null @@ -3433,6 +3491,20 @@ const handleMixedChannelCancel = () => { clearMixedChannelDialog() } +const normalizePoolModeRetryCount = (value: number) => { + if (!Number.isFinite(value)) { + return DEFAULT_POOL_MODE_RETRY_COUNT + } + const normalized = Math.trunc(value) + if (normalized < 0) { + return 0 + } + if (normalized > MAX_POOL_MODE_RETRY_COUNT) { + return MAX_POOL_MODE_RETRY_COUNT + } + return normalized +} + const handleSubmit = async () => { // For OAuth-based type, handle OAuth flow (goes to step 2) if (isOAuthFlow.value) { @@ -3532,6 +3604,12 @@ const handleSubmit = async () => { } } + // Add pool mode if enabled + if (poolModeEnabled.value) { + credentials.pool_mode = true + credentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) + } + // Add custom error codes if enabled if (customErrorCodesEnabled.value) { credentials.custom_error_codes_enabled = true diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 148f95b6..074677b2 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -251,6 +251,58 @@
+ +
+
+
+ +

+ {{ t('admin.accounts.poolModeHint') }} +

+
+ +
+
+

+ + {{ t('admin.accounts.poolModeInfo') }} +

+
+
+ + +

+ {{ + t('admin.accounts.poolModeRetryCountHint', { + default: DEFAULT_POOL_MODE_RETRY_COUNT, + max: MAX_POOL_MODE_RETRY_COUNT + }) + }} +

+
+
+
@@ -1483,6 +1535,10 @@ const editApiKey = ref('') const modelMappings = ref([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) +const DEFAULT_POOL_MODE_RETRY_COUNT = 3 +const MAX_POOL_MODE_RETRY_COUNT = 10 +const poolModeEnabled = ref(false) +const poolModeRetryCount = ref(DEFAULT_POOL_MODE_RETRY_COUNT) const customErrorCodesEnabled = ref(false) const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) @@ -1641,6 +1697,20 @@ const expiresAtInput = computed({ }) // Watchers +const normalizePoolModeRetryCount = (value: number) => { + if (!Number.isFinite(value)) { + return DEFAULT_POOL_MODE_RETRY_COUNT + } + const normalized = Math.trunc(value) + if (normalized < 0) { + return 0 + } + if (normalized > MAX_POOL_MODE_RETRY_COUNT) { + return MAX_POOL_MODE_RETRY_COUNT + } + return normalized +} + watch( () => props.account, (newAccount) => { @@ -1782,6 +1852,12 @@ watch( allowedModels.value = [] } + // Load pool mode + poolModeEnabled.value = credentials.pool_mode === true + poolModeRetryCount.value = normalizePoolModeRetryCount( + Number(credentials.pool_mode_retry_count ?? DEFAULT_POOL_MODE_RETRY_COUNT) + ) + // Load custom error codes customErrorCodesEnabled.value = credentials.custom_error_codes_enabled === true const existingErrorCodes = credentials.custom_error_codes as number[] | undefined @@ -1828,6 +1904,8 @@ watch( modelMappings.value = [] allowedModels.value = [] } + poolModeEnabled.value = false + poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT customErrorCodesEnabled.value = false selectedErrorCodes.value = [] } @@ -2288,6 +2366,15 @@ const handleSubmit = async () => { newCredentials.model_mapping = currentCredentials.model_mapping } + // Add pool mode if enabled + if (poolModeEnabled.value) { + newCredentials.pool_mode = true + newCredentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) + } else { + delete newCredentials.pool_mode + delete newCredentials.pool_mode_retry_count + } + // Add custom error codes if enabled if (customErrorCodesEnabled.value) { newCredentials.custom_error_codes_enabled = true diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 270d68c5..c5b11146 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1929,6 +1929,13 @@ export default { addModel: 'Add', modelExists: 'Model already exists', modelCount: '{count} models', + poolMode: 'Pool Mode', + poolModeHint: 'Enable when upstream is an account pool; errors won\'t mark local account status', + poolModeInfo: + 'When enabled, upstream 429/403/401 errors will auto-retry without marking the account as rate-limited or errored. Suitable for upstream pointing to another sub2api instance.', + poolModeRetryCount: 'Same-Account Retries', + poolModeRetryCountHint: + 'Only applies in pool mode. Use 0 to disable in-place retry. Default {default}, maximum {max}.', customErrorCodes: 'Custom Error Codes', customErrorCodesHint: 'Only stop scheduling for selected error codes', customErrorCodesWarning: diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 44fa5fbf..54e7210b 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2073,6 +2073,12 @@ export default { addModel: '填入', modelExists: '该模型已存在', modelCount: '{count} 个模型', + poolMode: '池模式', + poolModeHint: '上游为账号池时启用,错误不标记本地账号状态', + poolModeInfo: + '启用后,上游 429/403/401 错误将自动重试而不标记账号限流或错误,适用于上游指向另一个 sub2api 实例的场景。', + poolModeRetryCount: '同账号重试次数', + poolModeRetryCountHint: '仅在池模式下生效。0 表示不原地重试;默认 {default},最大 {max}。', customErrorCodes: '自定义错误码', customErrorCodesHint: '仅对选中的错误码停止调度', customErrorCodesWarning: '仅选中的错误码会停止调度,其他错误将返回 500。', From 785115c62b9c2e4ef903f3608f0e0c132ed67cc1 Mon Sep 17 00:00:00 2001 From: bayma888 Date: Mon, 9 Feb 2026 18:14:50 +0800 Subject: [PATCH 141/286] fix(ui): improve group selector dropdown width and visibility - Increase Select dropdown max-width from 320px to 480px for better content display - Change KeysView group selector from fixed 256px to adaptive 280-480px width - Make group switch icon always visible (60% opacity, 100% on hover) - Allow group description to wrap to 2 lines instead of truncating - Improve user experience for group selection in API keys page --- frontend/src/components/common/GroupOptionItem.vue | 2 +- frontend/src/components/common/Select.vue | 2 +- frontend/src/views/user/KeysView.vue | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/frontend/src/components/common/GroupOptionItem.vue b/frontend/src/components/common/GroupOptionItem.vue index 44750350..8673b9b1 100644 --- a/frontend/src/components/common/GroupOptionItem.vue +++ b/frontend/src/components/common/GroupOptionItem.vue @@ -13,7 +13,7 @@ /> {{ description }} diff --git a/frontend/src/components/common/Select.vue b/frontend/src/components/common/Select.vue index c90d0201..36b5e022 100644 --- a/frontend/src/components/common/Select.vue +++ b/frontend/src/components/common/Select.vue @@ -434,7 +434,7 @@ onUnmounted(() => { diff --git a/frontend/src/components/common/Select.vue b/frontend/src/components/common/Select.vue index 36b5e022..9a81344c 100644 --- a/frontend/src/components/common/Select.vue +++ b/frontend/src/components/common/Select.vue @@ -224,7 +224,13 @@ const filteredOptions = computed(() => { let opts = props.options as any[] if (props.searchable && searchQuery.value) { const query = searchQuery.value.toLowerCase() - opts = opts.filter((opt) => getOptionLabel(opt).toLowerCase().includes(query)) + opts = opts.filter((opt) => { + // Match label + if (getOptionLabel(opt).toLowerCase().includes(query)) return true + // Also match description if present + if (opt.description && String(opt.description).toLowerCase().includes(query)) return true + return false + }) } return opts }) @@ -434,7 +440,7 @@ onUnmounted(() => { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 9fd0c006..c9eae3ab 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1372,7 +1372,11 @@ export default { accounts: 'Accounts', status: 'Status', actions: 'Actions', - billingType: 'Billing Type' + billingType: 'Billing Type', + userName: 'Username', + userEmail: 'Email', + userNotes: 'Notes', + userStatus: 'Status' }, rateAndAccounts: '{rate}x rate · {count} accounts', accountsCount: '{count} accounts', @@ -1411,6 +1415,26 @@ export default { failedToUpdate: 'Failed to update group', failedToDelete: 'Failed to delete group', nameRequired: 'Please enter group name', + rateMultipliers: 'Rate Multipliers', + rateMultipliersTitle: 'Group Rate Multipliers', + addUserRate: 'Add User Rate Multiplier', + searchUserPlaceholder: 'Search user email...', + noRateMultipliers: 'No user rate multipliers configured', + rateUpdated: 'Rate multiplier updated', + rateDeleted: 'Rate multiplier removed', + rateAdded: 'Rate multiplier added', + clearAll: 'Clear All', + confirmClearAll: 'Are you sure you want to clear all rate multiplier settings for this group? This cannot be undone.', + rateCleared: 'All rate multipliers cleared', + batchAdjust: 'Batch Adjust Rates', + multiplierFactor: 'Factor', + applyMultiplier: 'Apply', + rateAdjusted: 'Rates adjusted successfully', + rateSaved: 'Rate multipliers saved', + finalRate: 'Final Rate', + unsavedChanges: 'Unsaved changes', + revertChanges: 'Revert', + userInfo: 'User Info', platforms: { all: 'All Platforms', anthropic: 'Anthropic', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index d139cd34..4a663de1 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1428,7 +1428,11 @@ export default { accounts: '账号数', status: '状态', actions: '操作', - billingType: '计费类型' + billingType: '计费类型', + userName: '用户名', + userEmail: '邮箱', + userNotes: '备注', + userStatus: '状态' }, form: { name: '名称', @@ -1510,6 +1514,26 @@ export default { failedToCreate: '创建分组失败', failedToUpdate: '更新分组失败', nameRequired: '请输入分组名称', + rateMultipliers: '专属倍率', + rateMultipliersTitle: '分组专属倍率管理', + addUserRate: '添加用户专属倍率', + searchUserPlaceholder: '搜索用户邮箱...', + noRateMultipliers: '暂无用户设置了专属倍率', + rateUpdated: '专属倍率已更新', + rateDeleted: '专属倍率已删除', + rateAdded: '专属倍率已添加', + clearAll: '全部清空', + confirmClearAll: '确定要清空该分组所有用户的专属倍率设置吗?此操作不可撤销。', + rateCleared: '已清空所有专属倍率', + batchAdjust: '批量调整倍率', + multiplierFactor: '乘数', + applyMultiplier: '应用', + rateAdjusted: '倍率已批量调整', + rateSaved: '专属倍率已保存', + finalRate: '最终倍率', + unsavedChanges: '有未保存的修改', + revertChanges: '撤销修改', + userInfo: '用户信息', subscription: { title: '订阅设置', type: '计费类型', diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index 01b98c0c..a78762d6 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -181,6 +181,13 @@ {{ t('common.edit') }} + - + + +
+
+ + +
-
+ +
-
+
@@ -77,6 +111,70 @@
+
+ {{ t('admin.dashboard.noDataAvailable') }} +
+ +
+ +
+
+ {{ t('admin.dashboard.failedToLoad') }} +
+
+
+ +
+
+ + + + + + + + + + + + + + + + + +
{{ t('admin.dashboard.spendingRankingUser') }}{{ t('admin.dashboard.spendingRankingRequests') }}{{ t('admin.dashboard.spendingRankingTokens') }}{{ t('admin.dashboard.spendingRankingSpend') }}
+
+ + #{{ index + 1 }} + + + {{ getRankingUserLabel(item) }} + +
+
+ {{ formatNumber(item.requests) }} + + {{ formatTokens(item.tokens) }} + + ${{ formatCost(item.actual_cost) }} +
+
+
diff --git a/frontend/src/views/admin/ops/components/OpsOpenAITokenStatsCard.vue b/frontend/src/views/admin/ops/components/OpsOpenAITokenStatsCard.vue index 5b53555f..7f68594b 100644 --- a/frontend/src/views/admin/ops/components/OpsOpenAITokenStatsCard.vue +++ b/frontend/src/views/admin/ops/components/OpsOpenAITokenStatsCard.vue @@ -208,35 +208,39 @@ function onNextPage() { :description="t('admin.ops.openaiTokenStats.empty')" /> -
- - - - - - - - - - - - - - - - - - - - - - - -
{{ t('admin.ops.openaiTokenStats.table.model') }}{{ t('admin.ops.openaiTokenStats.table.requestCount') }}{{ t('admin.ops.openaiTokenStats.table.avgTokensPerSec') }}{{ t('admin.ops.openaiTokenStats.table.avgFirstTokenMs') }}{{ t('admin.ops.openaiTokenStats.table.totalOutputTokens') }}{{ t('admin.ops.openaiTokenStats.table.avgDurationMs') }}{{ t('admin.ops.openaiTokenStats.table.requestsWithFirstToken') }}
{{ row.model }}{{ formatInt(row.request_count) }}{{ formatRate(row.avg_tokens_per_sec) }}{{ formatRate(row.avg_first_token_ms) }}{{ formatInt(row.total_output_tokens) }}{{ formatInt(row.avg_duration_ms) }}{{ formatInt(row.requests_with_first_token) }}
+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + +
{{ t('admin.ops.openaiTokenStats.table.model') }}{{ t('admin.ops.openaiTokenStats.table.requestCount') }}{{ t('admin.ops.openaiTokenStats.table.avgTokensPerSec') }}{{ t('admin.ops.openaiTokenStats.table.avgFirstTokenMs') }}{{ t('admin.ops.openaiTokenStats.table.totalOutputTokens') }}{{ t('admin.ops.openaiTokenStats.table.avgDurationMs') }}{{ t('admin.ops.openaiTokenStats.table.requestsWithFirstToken') }}
{{ row.model }}{{ formatInt(row.request_count) }}{{ formatRate(row.avg_tokens_per_sec) }}{{ formatRate(row.avg_first_token_ms) }}{{ formatInt(row.total_output_tokens) }}{{ formatInt(row.avg_duration_ms) }}{{ formatInt(row.requests_with_first_token) }}
+
+
{{ t('admin.ops.openaiTokenStats.totalModels', { total }) }}
diff --git a/frontend/src/views/admin/ops/components/OpsSettingsDialog.vue b/frontend/src/views/admin/ops/components/OpsSettingsDialog.vue index 3bec6d0d..9a1d99e4 100644 --- a/frontend/src/views/admin/ops/components/OpsSettingsDialog.vue +++ b/frontend/src/views/admin/ops/components/OpsSettingsDialog.vue @@ -543,6 +543,21 @@ async function saveAllSettings() { />
+ + +
+
{{ t('admin.ops.settings.dashboardCards') }}
+ +
+
+ +

+ {{ t('admin.ops.settings.displayOpenAITokenStatsHint') }} +

+
+ +
+
diff --git a/frontend/src/views/admin/ops/components/__tests__/OpsOpenAITokenStatsCard.spec.ts b/frontend/src/views/admin/ops/components/__tests__/OpsOpenAITokenStatsCard.spec.ts index 3e95f460..5804e176 100644 --- a/frontend/src/views/admin/ops/components/__tests__/OpsOpenAITokenStatsCard.spec.ts +++ b/frontend/src/views/admin/ops/components/__tests__/OpsOpenAITokenStatsCard.spec.ts @@ -196,6 +196,23 @@ describe('OpsOpenAITokenStatsCard', () => { expect(wrapper.find('.empty-state').exists()).toBe(true) }) + it('数据表使用固定高度滚动容器,避免纵向无限增长', async () => { + mockGetOpenAITokenStats.mockResolvedValue(sampleResponse) + + const wrapper = mount(OpsOpenAITokenStatsCard, { + props: { refreshToken: 0 }, + global: { + stubs: { + Select: SelectStub, + EmptyState: EmptyStateStub, + }, + }, + }) + await flushPromises() + + expect(wrapper.find('.max-h-\\[420px\\]').exists()).toBe(true) + }) + it('接口异常时显示错误提示', async () => { mockGetOpenAITokenStats.mockRejectedValue(new Error('加载失败')) From 53ad1645cf2e246204100ad77261303471757ecb Mon Sep 17 00:00:00 2001 From: Rose Ding Date: Fri, 13 Mar 2026 10:38:19 +0800 Subject: [PATCH 220/286] =?UTF-8?q?feat:=20=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E5=AE=9A=E6=97=B6=E5=A4=87=E4=BB=BD=E4=B8=8E=E6=81=A2=E5=A4=8D?= =?UTF-8?q?=EF=BC=88S3=20=E5=85=BC=E5=AE=B9=E5=AD=98=E5=82=A8=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=20Cloudflare=20R2=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增管理员专属的数据库备份与恢复功能: - 全量 PostgreSQL 备份(pg_dump),gzip 压缩后上传到 S3 兼容存储 - 支持手动备份和 cron 定时备份 - 支持从备份恢复(psql --single-transaction) - 备份文件自动过期清理(默认 14 天) - 前端完整管理页面(S3 配置、定时配置、备份列表、恢复/下载/删除) - 内置 Cloudflare R2 配置教程弹窗 - Dockerfile 从 postgres 镜像多阶段复制 pg_dump/psql,确保版本一致 Co-Authored-By: Claude Opus 4.6 --- Dockerfile | 20 +- backend/cmd/server/wire.go | 7 + backend/cmd/server/wire_gen.go | 13 +- .../internal/handler/admin/backup_handler.go | 170 ++++ backend/internal/handler/handler.go | 1 + backend/internal/handler/wire.go | 3 + backend/internal/server/routes/admin.go | 27 + backend/internal/service/backup_service.go | 814 ++++++++++++++++++ backend/internal/service/wire.go | 8 + deploy/docker-compose.dev.yml | 105 +++ frontend/src/api/admin/backup.ts | 114 +++ frontend/src/api/admin/index.ts | 7 +- frontend/src/components/layout/AppSidebar.vue | 17 + frontend/src/i18n/locales/en.ts | 105 +++ frontend/src/i18n/locales/zh.ts | 105 +++ frontend/src/router/index.ts | 12 + frontend/src/views/admin/BackupView.vue | 506 +++++++++++ 17 files changed, 2029 insertions(+), 5 deletions(-) create mode 100644 backend/internal/handler/admin/backup_handler.go create mode 100644 backend/internal/service/backup_service.go create mode 100644 deploy/docker-compose.dev.yml create mode 100644 frontend/src/api/admin/backup.ts create mode 100644 frontend/src/views/admin/BackupView.vue diff --git a/Dockerfile b/Dockerfile index 8517f2fa..8fd48cc2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,6 +9,7 @@ ARG NODE_IMAGE=node:24-alpine ARG GOLANG_IMAGE=golang:1.26.1-alpine ARG ALPINE_IMAGE=alpine:3.21 +ARG POSTGRES_IMAGE=postgres:18-alpine ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn @@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \ ./cmd/server # ----------------------------------------------------------------------------- -# Stage 3: Final Runtime Image +# Stage 3: PostgreSQL Client (version-matched with docker-compose) +# ----------------------------------------------------------------------------- +FROM ${POSTGRES_IMAGE} AS pg-client + +# ----------------------------------------------------------------------------- +# Stage 4: Final Runtime Image # ----------------------------------------------------------------------------- FROM ${ALPINE_IMAGE} @@ -86,8 +92,20 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api" RUN apk add --no-cache \ ca-certificates \ tzdata \ + libpq \ + zstd-libs \ + lz4-libs \ + krb5-libs \ + libldap \ + libedit \ && rm -rf /var/cache/apk/* +# Copy pg_dump and psql from the same postgres image used in docker-compose +# This ensures version consistency between backup tools and the database server +COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump +COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql +COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/ + # Create non-root user RUN addgroup -g 1000 sub2api && \ adduser -u 1000 -G sub2api -s /bin/sh -D sub2api diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 89bdbdca..7fc648ac 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -94,6 +94,7 @@ func provideCleanup( antigravityOAuth *service.AntigravityOAuthService, openAIGateway *service.OpenAIGatewayService, scheduledTestRunner *service.ScheduledTestRunnerService, + backupSvc *service.BackupService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -230,6 +231,12 @@ func provideCleanup( } return nil }}, + {"BackupService", func() error { + if backupSvc != nil { + backupSvc.Stop() + } + return nil + }}, } infraSteps := []cleanupStep{ diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 4d4517d2..139d883a 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -145,6 +145,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) dataManagementService := service.NewDataManagementService() dataManagementHandler := admin.NewDataManagementHandler(dataManagementService) + backupService := service.ProvideBackupService(settingRepository, configConfig) + backupHandler := admin.NewBackupHandler(backupService) oAuthHandler := admin.NewOAuthHandler(oAuthService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) @@ -200,7 +202,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db) scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) @@ -231,7 +233,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService) application := &Application{ Server: httpServer, Cleanup: v, @@ -284,6 +286,7 @@ func provideCleanup( antigravityOAuth *service.AntigravityOAuthService, openAIGateway *service.OpenAIGatewayService, scheduledTestRunner *service.ScheduledTestRunnerService, + backupSvc *service.BackupService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -419,6 +422,12 @@ func provideCleanup( } return nil }}, + {"BackupService", func() error { + if backupSvc != nil { + backupSvc.Stop() + } + return nil + }}, } infraSteps := []cleanupStep{ diff --git a/backend/internal/handler/admin/backup_handler.go b/backend/internal/handler/admin/backup_handler.go new file mode 100644 index 00000000..818928c6 --- /dev/null +++ b/backend/internal/handler/admin/backup_handler.go @@ -0,0 +1,170 @@ +package admin + +import ( + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type BackupHandler struct { + backupService *service.BackupService +} + +func NewBackupHandler(backupService *service.BackupService) *BackupHandler { + return &BackupHandler{backupService: backupService} +} + +// ─── S3 配置 ─── + +func (h *BackupHandler) GetS3Config(c *gin.Context) { + cfg, err := h.backupService.GetS3Config(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *BackupHandler) UpdateS3Config(c *gin.Context) { + var req service.BackupS3Config + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *BackupHandler) TestS3Connection(c *gin.Context) { + var req service.BackupS3Config + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + err := h.backupService.TestS3Connection(c.Request.Context(), req) + if err != nil { + response.Success(c, gin.H{"ok": false, "message": err.Error()}) + return + } + response.Success(c, gin.H{"ok": true, "message": "connection successful"}) +} + +// ─── 定时备份 ─── + +func (h *BackupHandler) GetSchedule(c *gin.Context) { + cfg, err := h.backupService.GetSchedule(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *BackupHandler) UpdateSchedule(c *gin.Context) { + var req service.BackupScheduleConfig + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +// ─── 备份操作 ─── + +type CreateBackupRequest struct { + ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期 +} + +func (h *BackupHandler) CreateBackup(c *gin.Context) { + var req CreateBackupRequest + _ = c.ShouldBindJSON(&req) // 允许空 body + + expireDays := 14 // 默认14天过期 + if req.ExpireDays != nil { + expireDays = *req.ExpireDays + } + + record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, record) +} + +func (h *BackupHandler) ListBackups(c *gin.Context) { + records, err := h.backupService.ListBackups(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + if records == nil { + records = []service.BackupRecord{} + } + response.Success(c, gin.H{"items": records}) +} + +func (h *BackupHandler) GetBackup(c *gin.Context) { + backupID := c.Param("id") + if backupID == "" { + response.BadRequest(c, "backup ID is required") + return + } + record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, record) +} + +func (h *BackupHandler) DeleteBackup(c *gin.Context) { + backupID := c.Param("id") + if backupID == "" { + response.BadRequest(c, "backup ID is required") + return + } + if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +func (h *BackupHandler) GetDownloadURL(c *gin.Context) { + backupID := c.Param("id") + if backupID == "" { + response.BadRequest(c, "backup ID is required") + return + } + url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"url": url}) +} + +// ─── 恢复操作 ─── + +func (h *BackupHandler) RestoreBackup(c *gin.Context) { + backupID := c.Param("id") + if backupID == "" { + response.BadRequest(c, "backup ID is required") + return + } + if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"restored": true}) +} diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 3f1d73ca..89d556cc 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -12,6 +12,7 @@ type AdminHandlers struct { Account *admin.AccountHandler Announcement *admin.AnnouncementHandler DataManagement *admin.DataManagementHandler + Backup *admin.BackupHandler OAuth *admin.OAuthHandler OpenAIOAuth *admin.OpenAIOAuthHandler GeminiOAuth *admin.GeminiOAuthHandler diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index d1e12e03..f3aadcf3 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -15,6 +15,7 @@ func ProvideAdminHandlers( accountHandler *admin.AccountHandler, announcementHandler *admin.AnnouncementHandler, dataManagementHandler *admin.DataManagementHandler, + backupHandler *admin.BackupHandler, oauthHandler *admin.OAuthHandler, openaiOAuthHandler *admin.OpenAIOAuthHandler, geminiOAuthHandler *admin.GeminiOAuthHandler, @@ -39,6 +40,7 @@ func ProvideAdminHandlers( Account: accountHandler, Announcement: announcementHandler, DataManagement: dataManagementHandler, + Backup: backupHandler, OAuth: oauthHandler, OpenAIOAuth: openaiOAuthHandler, GeminiOAuth: geminiOAuthHandler, @@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet( admin.NewAccountHandler, admin.NewAnnouncementHandler, admin.NewDataManagementHandler, + admin.NewBackupHandler, admin.NewOAuthHandler, admin.NewOpenAIOAuthHandler, admin.NewGeminiOAuthHandler, diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 9fdb233b..d52de15e 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -58,6 +58,9 @@ func RegisterAdminRoutes( // 数据管理 registerDataManagementRoutes(admin, h) + // 数据库备份恢复 + registerBackupRoutes(admin, h) + // 运维监控(Ops) registerOpsRoutes(admin, h) @@ -436,6 +439,30 @@ func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { } } +func registerBackupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + backup := admin.Group("/backups") + { + // S3 存储配置 + backup.GET("/s3-config", h.Admin.Backup.GetS3Config) + backup.PUT("/s3-config", h.Admin.Backup.UpdateS3Config) + backup.POST("/s3-config/test", h.Admin.Backup.TestS3Connection) + + // 定时备份配置 + backup.GET("/schedule", h.Admin.Backup.GetSchedule) + backup.PUT("/schedule", h.Admin.Backup.UpdateSchedule) + + // 备份操作 + backup.POST("", h.Admin.Backup.CreateBackup) + backup.GET("", h.Admin.Backup.ListBackups) + backup.GET("/:id", h.Admin.Backup.GetBackup) + backup.DELETE("/:id", h.Admin.Backup.DeleteBackup) + backup.GET("/:id/download-url", h.Admin.Backup.GetDownloadURL) + + // 恢复操作 + backup.POST("/:id/restore", h.Admin.Backup.RestoreBackup) + } +} + func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) { system := admin.Group("/system") { diff --git a/backend/internal/service/backup_service.go b/backend/internal/service/backup_service.go new file mode 100644 index 00000000..8c07f3b8 --- /dev/null +++ b/backend/internal/service/backup_service.go @@ -0,0 +1,814 @@ +package service + +import ( + "bytes" + "compress/gzip" + "context" + "encoding/json" + "fmt" + "io" + "os/exec" + "sort" + "strings" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/google/uuid" + "github.com/robfig/cron/v3" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +const ( + settingKeyBackupS3Config = "backup_s3_config" + settingKeyBackupSchedule = "backup_schedule" + settingKeyBackupRecords = "backup_records" + + maxBackupRecords = 100 +) + +var ( + ErrBackupS3NotConfigured = infraerrors.BadRequest("BACKUP_S3_NOT_CONFIGURED", "backup S3 storage is not configured") + ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found") + ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress") + ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress") +) + +// BackupS3Config S3 兼容存储配置(支持 Cloudflare R2) +type BackupS3Config struct { + Endpoint string `json:"endpoint"` // e.g. https://.r2.cloudflarestorage.com + Region string `json:"region"` // R2 用 "auto" + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key,omitempty"` + Prefix string `json:"prefix"` // S3 key 前缀,如 "backups/" + ForcePathStyle bool `json:"force_path_style"` +} + +// IsConfigured 检查必要字段是否已配置 +func (c *BackupS3Config) IsConfigured() bool { + return c.Bucket != "" && c.AccessKeyID != "" && c.SecretAccessKey != "" +} + +// BackupScheduleConfig 定时备份配置 +type BackupScheduleConfig struct { + Enabled bool `json:"enabled"` + CronExpr string `json:"cron_expr"` // cron 表达式,如 "0 2 * * *" 每天凌晨2点 + RetainDays int `json:"retain_days"` // 备份文件过期天数,默认14,0=不自动清理 + RetainCount int `json:"retain_count"` // 最多保留份数,0=不限制 +} + +// BackupRecord 备份记录 +type BackupRecord struct { + ID string `json:"id"` + Status string `json:"status"` // pending, running, completed, failed + BackupType string `json:"backup_type"` // postgres + FileName string `json:"file_name"` + S3Key string `json:"s3_key"` + SizeBytes int64 `json:"size_bytes"` + TriggeredBy string `json:"triggered_by"` // manual, scheduled + ErrorMsg string `json:"error_message,omitempty"` + StartedAt string `json:"started_at"` + FinishedAt string `json:"finished_at,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` // 过期时间 +} + +// BackupService 数据库备份恢复服务 +type BackupService struct { + settingRepo SettingRepository + dbCfg *config.DatabaseConfig + + mu sync.Mutex + s3Client *s3.Client + s3Cfg *BackupS3Config + backingUp bool + restoring bool + + cronMu sync.Mutex + cronSched *cron.Cron + cronEntryID cron.EntryID +} + +func NewBackupService(settingRepo SettingRepository, cfg *config.Config) *BackupService { + svc := &BackupService{ + settingRepo: settingRepo, + dbCfg: &cfg.Database, + } + return svc +} + +// Start 启动定时备份调度器 +func (s *BackupService) Start() { + s.cronSched = cron.New() + s.cronSched.Start() + + // 加载已有的定时配置 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + schedule, err := s.GetSchedule(ctx) + if err != nil { + logger.LegacyPrintf("service.backup", "[Backup] 加载定时备份配置失败: %v", err) + return + } + if schedule.Enabled && schedule.CronExpr != "" { + if err := s.applyCronSchedule(schedule); err != nil { + logger.LegacyPrintf("service.backup", "[Backup] 应用定时备份配置失败: %v", err) + } + } +} + +// Stop 停止定时备份 +func (s *BackupService) Stop() { + s.cronMu.Lock() + defer s.cronMu.Unlock() + if s.cronSched != nil { + s.cronSched.Stop() + } +} + +// ─── S3 配置管理 ─── + +func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) { + raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config) + if err != nil || raw == "" { + return &BackupS3Config{}, nil + } + var cfg BackupS3Config + if err := json.Unmarshal([]byte(raw), &cfg); err != nil { + return &BackupS3Config{}, nil + } + // 脱敏返回 + cfg.SecretAccessKey = "" + return &cfg, nil +} + +func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) { + // 如果没提供 secret,保留原有值 + if cfg.SecretAccessKey == "" { + old, _ := s.loadS3Config(ctx) + if old != nil { + cfg.SecretAccessKey = old.SecretAccessKey + } + } + + data, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal s3 config: %w", err) + } + if err := s.settingRepo.Set(ctx, settingKeyBackupS3Config, string(data)); err != nil { + return nil, fmt.Errorf("save s3 config: %w", err) + } + + // 清除缓存的 S3 客户端 + s.mu.Lock() + s.s3Client = nil + s.s3Cfg = nil + s.mu.Unlock() + + cfg.SecretAccessKey = "" + return &cfg, nil +} + +func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config) error { + // 如果没提供 secret,用已保存的 + if cfg.SecretAccessKey == "" { + old, _ := s.loadS3Config(ctx) + if old != nil { + cfg.SecretAccessKey = old.SecretAccessKey + } + } + + if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" { + return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required") + } + + client, err := s.buildS3Client(ctx, &cfg) + if err != nil { + return err + } + _, err = client.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: &cfg.Bucket, + }) + if err != nil { + return fmt.Errorf("S3 HeadBucket failed: %w", err) + } + return nil +} + +// ─── 定时备份管理 ─── + +func (s *BackupService) GetSchedule(ctx context.Context) (*BackupScheduleConfig, error) { + raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupSchedule) + if err != nil || raw == "" { + return &BackupScheduleConfig{}, nil + } + var cfg BackupScheduleConfig + if err := json.Unmarshal([]byte(raw), &cfg); err != nil { + return &BackupScheduleConfig{}, nil + } + return &cfg, nil +} + +func (s *BackupService) UpdateSchedule(ctx context.Context, cfg BackupScheduleConfig) (*BackupScheduleConfig, error) { + if cfg.Enabled && cfg.CronExpr == "" { + return nil, infraerrors.BadRequest("INVALID_CRON", "cron expression is required when schedule is enabled") + } + // 验证 cron 表达式 + if cfg.CronExpr != "" { + parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow) + if _, err := parser.Parse(cfg.CronExpr); err != nil { + return nil, infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("invalid cron expression: %v", err)) + } + } + + data, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal schedule config: %w", err) + } + if err := s.settingRepo.Set(ctx, settingKeyBackupSchedule, string(data)); err != nil { + return nil, fmt.Errorf("save schedule config: %w", err) + } + + // 应用或停止定时任务 + if cfg.Enabled { + if err := s.applyCronSchedule(&cfg); err != nil { + return nil, err + } + } else { + s.removeCronSchedule() + } + + return &cfg, nil +} + +func (s *BackupService) applyCronSchedule(cfg *BackupScheduleConfig) error { + s.cronMu.Lock() + defer s.cronMu.Unlock() + + if s.cronSched == nil { + return fmt.Errorf("cron scheduler not initialized") + } + + // 移除旧任务 + if s.cronEntryID != 0 { + s.cronSched.Remove(s.cronEntryID) + s.cronEntryID = 0 + } + + entryID, err := s.cronSched.AddFunc(cfg.CronExpr, func() { + s.runScheduledBackup() + }) + if err != nil { + return infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("failed to schedule: %v", err)) + } + s.cronEntryID = entryID + logger.LegacyPrintf("service.backup", "[Backup] 定时备份已启用: %s", cfg.CronExpr) + return nil +} + +func (s *BackupService) removeCronSchedule() { + s.cronMu.Lock() + defer s.cronMu.Unlock() + if s.cronSched != nil && s.cronEntryID != 0 { + s.cronSched.Remove(s.cronEntryID) + s.cronEntryID = 0 + logger.LegacyPrintf("service.backup", "[Backup] 定时备份已停用") + } +} + +func (s *BackupService) runScheduledBackup() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + defer cancel() + + // 读取定时备份配置中的过期天数 + schedule, _ := s.GetSchedule(ctx) + expireDays := 14 // 默认14天过期 + if schedule != nil && schedule.RetainDays > 0 { + expireDays = schedule.RetainDays + } + + logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays) + record, err := s.CreateBackup(ctx, "scheduled", expireDays) + if err != nil { + logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err) + return + } + logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes) + + // 清理过期备份(复用已加载的 schedule) + if schedule == nil { + return + } + if err := s.cleanupOldBackups(ctx, schedule); err != nil { + logger.LegacyPrintf("service.backup", "[Backup] 清理过期备份失败: %v", err) + } +} + +// ─── 备份/恢复核心 ─── + +// CreateBackup 创建全量数据库备份并上传到 S3 +// expireDays: 备份过期天数,0=永不过期,默认14天 +func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) { + s.mu.Lock() + if s.backingUp { + s.mu.Unlock() + return nil, ErrBackupInProgress + } + s.backingUp = true + s.mu.Unlock() + defer func() { + s.mu.Lock() + s.backingUp = false + s.mu.Unlock() + }() + + s3Cfg, err := s.loadS3Config(ctx) + if err != nil { + return nil, err + } + if s3Cfg == nil || !s3Cfg.IsConfigured() { + return nil, ErrBackupS3NotConfigured + } + + client, err := s.getOrCreateS3Client(ctx, s3Cfg) + if err != nil { + return nil, fmt.Errorf("init S3 client: %w", err) + } + + now := time.Now() + backupID := uuid.New().String()[:8] + fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405")) + s3Key := s.buildS3Key(s3Cfg, fileName) + + var expiresAt string + if expireDays > 0 { + expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339) + } + + record := &BackupRecord{ + ID: backupID, + Status: "running", + BackupType: "postgres", + FileName: fileName, + S3Key: s3Key, + TriggeredBy: triggeredBy, + StartedAt: now.Format(time.RFC3339), + ExpiresAt: expiresAt, + } + + // 执行全量 pg_dump + dumpData, err := s.pgDump(ctx) + if err != nil { + record.Status = "failed" + record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err) + record.FinishedAt = time.Now().Format(time.RFC3339) + _ = s.saveRecord(ctx, record) + return record, fmt.Errorf("pg_dump: %w", err) + } + + // gzip 压缩 + var compressed bytes.Buffer + gzWriter := gzip.NewWriter(&compressed) + if _, err := gzWriter.Write(dumpData); err != nil { + record.Status = "failed" + record.ErrorMsg = fmt.Sprintf("gzip failed: %v", err) + record.FinishedAt = time.Now().Format(time.RFC3339) + _ = s.saveRecord(ctx, record) + return record, fmt.Errorf("gzip: %w", err) + } + if err := gzWriter.Close(); err != nil { + return nil, fmt.Errorf("gzip close: %w", err) + } + + record.SizeBytes = int64(compressed.Len()) + + // 上传到 S3 + contentType := "application/gzip" + _, err = client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: &s3Cfg.Bucket, + Key: &s3Key, + Body: bytes.NewReader(compressed.Bytes()), + ContentType: &contentType, + }) + if err != nil { + record.Status = "failed" + record.ErrorMsg = fmt.Sprintf("S3 upload failed: %v", err) + record.FinishedAt = time.Now().Format(time.RFC3339) + _ = s.saveRecord(ctx, record) + return record, fmt.Errorf("s3 upload: %w", err) + } + + record.Status = "completed" + record.FinishedAt = time.Now().Format(time.RFC3339) + if err := s.saveRecord(ctx, record); err != nil { + logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err) + } + + return record, nil +} + +// RestoreBackup 从 S3 下载备份并恢复到数据库 +func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error { + s.mu.Lock() + if s.restoring { + s.mu.Unlock() + return ErrRestoreInProgress + } + s.restoring = true + s.mu.Unlock() + defer func() { + s.mu.Lock() + s.restoring = false + s.mu.Unlock() + }() + + record, err := s.GetBackupRecord(ctx, backupID) + if err != nil { + return err + } + if record.Status != "completed" { + return infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup") + } + + s3Cfg, err := s.loadS3Config(ctx) + if err != nil { + return err + } + client, err := s.getOrCreateS3Client(ctx, s3Cfg) + if err != nil { + return fmt.Errorf("init S3 client: %w", err) + } + + // 从 S3 下载 + result, err := client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: &s3Cfg.Bucket, + Key: &record.S3Key, + }) + if err != nil { + return fmt.Errorf("S3 download failed: %w", err) + } + defer result.Body.Close() + + // 解压 gzip + gzReader, err := gzip.NewReader(result.Body) + if err != nil { + return fmt.Errorf("gzip reader: %w", err) + } + defer gzReader.Close() + + sqlData, err := io.ReadAll(gzReader) + if err != nil { + return fmt.Errorf("read backup data: %w", err) + } + + // 执行 psql 恢复 + if err := s.pgRestore(ctx, sqlData); err != nil { + return fmt.Errorf("pg restore: %w", err) + } + + return nil +} + +// ─── 备份记录管理 ─── + +func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) { + records, err := s.loadRecords(ctx) + if err != nil { + return nil, err + } + // 倒序返回(最新在前) + sort.Slice(records, func(i, j int) bool { + return records[i].StartedAt > records[j].StartedAt + }) + return records, nil +} + +func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (*BackupRecord, error) { + records, err := s.loadRecords(ctx) + if err != nil { + return nil, err + } + for i := range records { + if records[i].ID == backupID { + return &records[i], nil + } + } + return nil, ErrBackupNotFound +} + +func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error { + records, err := s.loadRecords(ctx) + if err != nil { + return err + } + + var found *BackupRecord + var remaining []BackupRecord + for i := range records { + if records[i].ID == backupID { + found = &records[i] + } else { + remaining = append(remaining, records[i]) + } + } + if found == nil { + return ErrBackupNotFound + } + + // 从 S3 删除 + if found.S3Key != "" && found.Status == "completed" { + s3Cfg, err := s.loadS3Config(ctx) + if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() { + client, err := s.getOrCreateS3Client(ctx, s3Cfg) + if err == nil { + _, _ = client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: &s3Cfg.Bucket, + Key: &found.S3Key, + }) + } + } + } + + return s.saveRecords(ctx, remaining) +} + +// GetBackupDownloadURL 获取备份文件预签名下载 URL +func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID string) (string, error) { + record, err := s.GetBackupRecord(ctx, backupID) + if err != nil { + return "", err + } + if record.Status != "completed" { + return "", infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "backup is not completed") + } + + s3Cfg, err := s.loadS3Config(ctx) + if err != nil { + return "", err + } + client, err := s.getOrCreateS3Client(ctx, s3Cfg) + if err != nil { + return "", err + } + + presignClient := s3.NewPresignClient(client) + result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{ + Bucket: &s3Cfg.Bucket, + Key: &record.S3Key, + }, s3.WithPresignExpires(1*time.Hour)) + if err != nil { + return "", fmt.Errorf("presign url: %w", err) + } + return result.URL, nil +} + +// ─── 内部方法 ─── + +func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) { + raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config) + if err != nil || raw == "" { + return nil, nil + } + var cfg BackupS3Config + if err := json.Unmarshal([]byte(raw), &cfg); err != nil { + return nil, nil + } + return &cfg, nil +} + +func (s *BackupService) buildS3Client(ctx context.Context, cfg *BackupS3Config) (*s3.Client, error) { + region := cfg.Region + if region == "" { + region = "auto" // Cloudflare R2 默认 region + } + + awsCfg, err := awsconfig.LoadDefaultConfig(ctx, + awsconfig.WithRegion(region), + awsconfig.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""), + ), + ) + if err != nil { + return nil, fmt.Errorf("load aws config: %w", err) + } + + client := s3.NewFromConfig(awsCfg, func(o *s3.Options) { + if cfg.Endpoint != "" { + o.BaseEndpoint = &cfg.Endpoint + } + if cfg.ForcePathStyle { + o.UsePathStyle = true + } + o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware) + o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired + }) + return client, nil +} + +func (s *BackupService) getOrCreateS3Client(ctx context.Context, cfg *BackupS3Config) (*s3.Client, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.s3Client != nil && s.s3Cfg != nil { + return s.s3Client, nil + } + + if cfg == nil { + return nil, ErrBackupS3NotConfigured + } + + client, err := s.buildS3Client(ctx, cfg) + if err != nil { + return nil, err + } + s.s3Client = client + s.s3Cfg = cfg + return client, nil +} + +func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string { + prefix := strings.TrimRight(cfg.Prefix, "/") + if prefix == "" { + prefix = "backups" + } + return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName) +} + +func (s *BackupService) pgDump(ctx context.Context) ([]byte, error) { + args := []string{ + "-h", s.dbCfg.Host, + "-p", fmt.Sprintf("%d", s.dbCfg.Port), + "-U", s.dbCfg.User, + "-d", s.dbCfg.DBName, + "--no-owner", + "--no-acl", + "--clean", + "--if-exists", + } + + cmd := exec.CommandContext(ctx, "pg_dump", args...) + if s.dbCfg.Password != "" { + cmd.Env = append(cmd.Environ(), "PGPASSWORD="+s.dbCfg.Password) + } + if s.dbCfg.SSLMode != "" { + cmd.Env = append(cmd.Environ(), "PGSSLMODE="+s.dbCfg.SSLMode) + } + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return nil, fmt.Errorf("%v: %s", err, stderr.String()) + } + return stdout.Bytes(), nil +} + +func (s *BackupService) pgRestore(ctx context.Context, sqlData []byte) error { + args := []string{ + "-h", s.dbCfg.Host, + "-p", fmt.Sprintf("%d", s.dbCfg.Port), + "-U", s.dbCfg.User, + "-d", s.dbCfg.DBName, + "--single-transaction", + } + + cmd := exec.CommandContext(ctx, "psql", args...) + if s.dbCfg.Password != "" { + cmd.Env = append(cmd.Environ(), "PGPASSWORD="+s.dbCfg.Password) + } + if s.dbCfg.SSLMode != "" { + cmd.Env = append(cmd.Environ(), "PGSSLMODE="+s.dbCfg.SSLMode) + } + + cmd.Stdin = bytes.NewReader(sqlData) + var stderr bytes.Buffer + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("%v: %s", err, stderr.String()) + } + return nil +} + +func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) { + raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords) + if err != nil || raw == "" { + return nil, nil + } + var records []BackupRecord + if err := json.Unmarshal([]byte(raw), &records); err != nil { + return nil, nil + } + return records, nil +} + +func (s *BackupService) saveRecords(ctx context.Context, records []BackupRecord) error { + data, err := json.Marshal(records) + if err != nil { + return err + } + return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data)) +} + +func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error { + records, _ := s.loadRecords(ctx) + + // 更新已有记录或追加 + found := false + for i := range records { + if records[i].ID == record.ID { + records[i] = *record + found = true + break + } + } + if !found { + records = append(records, *record) + } + + // 限制记录数量 + if len(records) > maxBackupRecords { + records = records[len(records)-maxBackupRecords:] + } + + return s.saveRecords(ctx, records) +} + +func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error { + if schedule == nil { + return nil + } + + records, err := s.loadRecords(ctx) + if err != nil { + return err + } + + // 按时间倒序 + sort.Slice(records, func(i, j int) bool { + return records[i].StartedAt > records[j].StartedAt + }) + + var toDelete []BackupRecord + var toKeep []BackupRecord + + for i, r := range records { + shouldDelete := false + + // 按保留份数清理 + if schedule.RetainCount > 0 && i >= schedule.RetainCount { + shouldDelete = true + } + + // 按保留天数清理 + if schedule.RetainDays > 0 && r.StartedAt != "" { + startedAt, err := time.Parse(time.RFC3339, r.StartedAt) + if err == nil && time.Since(startedAt) > time.Duration(schedule.RetainDays)*24*time.Hour { + shouldDelete = true + } + } + + if shouldDelete && r.Status == "completed" { + toDelete = append(toDelete, r) + } else { + toKeep = append(toKeep, r) + } + } + + // 删除 S3 上的文件 + for _, r := range toDelete { + if r.S3Key != "" { + _ = s.deleteS3Object(ctx, r.S3Key) + } + } + + if len(toDelete) > 0 { + logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete)) + return s.saveRecords(ctx, toKeep) + } + return nil +} + +func (s *BackupService) deleteS3Object(ctx context.Context, key string) error { + s3Cfg, err := s.loadS3Config(ctx) + if err != nil || s3Cfg == nil { + return nil + } + client, err := s.getOrCreateS3Client(ctx, s3Cfg) + if err != nil { + return err + } + _, err = client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: &s3Cfg.Bucket, + Key: &key, + }) + return err +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 4d0c2271..4ae06731 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -322,6 +322,13 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC return apiKeyService } +// ProvideBackupService creates and starts BackupService +func ProvideBackupService(settingRepo SettingRepository, cfg *config.Config) *BackupService { + svc := NewBackupService(settingRepo, cfg) + svc.Start() + return svc +} + // ProvideSettingService wires SettingService with group reader for default subscription validation. func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService { svc := NewSettingService(settingRepo, cfg) @@ -373,6 +380,7 @@ var ProviderSet = wire.NewSet( NewAccountTestService, ProvideSettingService, NewDataManagementService, + ProvideBackupService, ProvideOpsSystemLogSink, NewOpsService, ProvideOpsMetricsCollector, diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml new file mode 100644 index 00000000..7793e424 --- /dev/null +++ b/deploy/docker-compose.dev.yml @@ -0,0 +1,105 @@ +# ============================================================================= +# Sub2API Docker Compose - Local Development Build +# ============================================================================= +# Build from local source code for testing changes. +# +# Usage: +# cd deploy +# docker compose -f docker-compose.dev.yml up --build +# ============================================================================= + +services: + sub2api: + build: + context: .. + dockerfile: Dockerfile + container_name: sub2api-dev + restart: unless-stopped + ports: + - "${BIND_HOST:-127.0.0.1}:${SERVER_PORT:-8080}:8080" + volumes: + - ./data:/app/data + environment: + - AUTO_SETUP=true + - SERVER_HOST=0.0.0.0 + - SERVER_PORT=8080 + - SERVER_MODE=debug + - RUN_MODE=${RUN_MODE:-standard} + - DATABASE_HOST=postgres + - DATABASE_PORT=5432 + - DATABASE_USER=${POSTGRES_USER:-sub2api} + - DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} + - DATABASE_DBNAME=${POSTGRES_DB:-sub2api} + - DATABASE_SSLMODE=disable + - REDIS_HOST=redis + - REDIS_PORT=6379 + - REDIS_PASSWORD=${REDIS_PASSWORD:-} + - REDIS_DB=${REDIS_DB:-0} + - ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local} + - ADMIN_PASSWORD=${ADMIN_PASSWORD:-} + - JWT_SECRET=${JWT_SECRET:-} + - TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-} + - TZ=${TZ:-Asia/Shanghai} + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + networks: + - sub2api-network + healthcheck: + test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + + postgres: + image: postgres:18-alpine + container_name: sub2api-postgres-dev + restart: unless-stopped + volumes: + - ./postgres_data:/var/lib/postgresql/data + environment: + - POSTGRES_USER=${POSTGRES_USER:-sub2api} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} + - POSTGRES_DB=${POSTGRES_DB:-sub2api} + - PGDATA=/var/lib/postgresql/data + - TZ=${TZ:-Asia/Shanghai} + networks: + - sub2api-network + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 10s + + redis: + image: redis:8-alpine + container_name: sub2api-redis-dev + restart: unless-stopped + volumes: + - ./redis_data:/data + command: > + sh -c ' + redis-server + --save 60 1 + --appendonly yes + --appendfsync everysec + ${REDIS_PASSWORD:+--requirepass "$REDIS_PASSWORD"}' + environment: + - TZ=${TZ:-Asia/Shanghai} + - REDISCLI_AUTH=${REDIS_PASSWORD:-} + networks: + - sub2api-network + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 5s + +networks: + sub2api-network: + driver: bridge diff --git a/frontend/src/api/admin/backup.ts b/frontend/src/api/admin/backup.ts new file mode 100644 index 00000000..eff70492 --- /dev/null +++ b/frontend/src/api/admin/backup.ts @@ -0,0 +1,114 @@ +import { apiClient } from '../client' + +export interface BackupS3Config { + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + prefix: string + force_path_style: boolean +} + +export interface BackupScheduleConfig { + enabled: boolean + cron_expr: string + retain_days: number + retain_count: number +} + +export interface BackupRecord { + id: string + status: 'pending' | 'running' | 'completed' | 'failed' + backup_type: string + file_name: string + s3_key: string + size_bytes: number + triggered_by: string + error_message?: string + started_at: string + finished_at?: string + expires_at?: string +} + +export interface CreateBackupRequest { + expire_days?: number +} + +export interface TestS3Response { + ok: boolean + message: string +} + +// S3 Config +export async function getS3Config(): Promise { + const { data } = await apiClient.get('/admin/backups/s3-config') + return data +} + +export async function updateS3Config(config: BackupS3Config): Promise { + const { data } = await apiClient.put('/admin/backups/s3-config', config) + return data +} + +export async function testS3Connection(config: BackupS3Config): Promise { + const { data } = await apiClient.post('/admin/backups/s3-config/test', config) + return data +} + +// Schedule +export async function getSchedule(): Promise { + const { data } = await apiClient.get('/admin/backups/schedule') + return data +} + +export async function updateSchedule(config: BackupScheduleConfig): Promise { + const { data } = await apiClient.put('/admin/backups/schedule', config) + return data +} + +// Backup operations +export async function createBackup(req?: CreateBackupRequest): Promise { + const { data } = await apiClient.post('/admin/backups', req || {}, { timeout: 600000 }) + return data +} + +export async function listBackups(): Promise<{ items: BackupRecord[] }> { + const { data } = await apiClient.get<{ items: BackupRecord[] }>('/admin/backups') + return data +} + +export async function getBackup(id: string): Promise { + const { data } = await apiClient.get(`/admin/backups/${id}`) + return data +} + +export async function deleteBackup(id: string): Promise { + await apiClient.delete(`/admin/backups/${id}`) +} + +export async function getDownloadURL(id: string): Promise<{ url: string }> { + const { data } = await apiClient.get<{ url: string }>(`/admin/backups/${id}/download-url`) + return data +} + +// Restore +export async function restoreBackup(id: string): Promise { + await apiClient.post(`/admin/backups/${id}/restore`, {}, { timeout: 600000 }) +} + +export const backupAPI = { + getS3Config, + updateS3Config, + testS3Connection, + getSchedule, + updateSchedule, + createBackup, + listBackups, + getBackup, + deleteBackup, + getDownloadURL, + restoreBackup, +} + +export default backupAPI diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index 135ca50b..a6ebfc2c 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -23,6 +23,7 @@ import errorPassthroughAPI from './errorPassthrough' import dataManagementAPI from './dataManagement' import apiKeysAPI from './apiKeys' import scheduledTestsAPI from './scheduledTests' +import backupAPI from './backup' /** * Unified admin API object for convenient access @@ -47,7 +48,8 @@ export const adminAPI = { errorPassthrough: errorPassthroughAPI, dataManagement: dataManagementAPI, apiKeys: apiKeysAPI, - scheduledTests: scheduledTestsAPI + scheduledTests: scheduledTestsAPI, + backup: backupAPI } export { @@ -70,7 +72,8 @@ export { errorPassthroughAPI, dataManagementAPI, apiKeysAPI, - scheduledTestsAPI + scheduledTestsAPI, + backupAPI } export default adminAPI diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index 3a23e6e0..f3bb5a8a 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -387,6 +387,21 @@ const DatabaseIcon = { ) } +const CloudArrowUpIcon = { + render: () => + h( + 'svg', + { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, + [ + h('path', { + 'stroke-linecap': 'round', + 'stroke-linejoin': 'round', + d: 'M12 16.5V9.75m0 0l3 3m-3-3l-3 3M6.75 19.5a4.5 4.5 0 01-1.41-8.775 5.25 5.25 0 0110.233-2.33 3 3 0 013.758 3.848A3.752 3.752 0 0118 19.5H6.75z' + }) + ] + ) +} + const BellIcon = { render: () => h( @@ -611,6 +626,7 @@ const adminNavItems = computed((): NavItem[] => { if (authStore.isSimpleMode) { const filtered = baseItems.filter(item => !item.hideInSimpleMode) filtered.push({ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon }) + filtered.push({ path: '/admin/backup', label: t('nav.backup'), icon: CloudArrowUpIcon }) filtered.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon }) filtered.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon }) // Add admin custom menu items after settings @@ -620,6 +636,7 @@ const adminNavItems = computed((): NavItem[] => { return filtered } + baseItems.push({ path: '/admin/backup', label: t('nav.backup'), icon: CloudArrowUpIcon }) baseItems.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon }) baseItems.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon }) // Add admin custom menu items after settings diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 9f847eb6..40d1a8eb 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -340,6 +340,7 @@ export default { redeemCodes: 'Redeem Codes', ops: 'Ops', promoCodes: 'Promo Codes', + backup: 'DB Backup', dataManagement: 'Data Management', settings: 'Settings', myAccount: 'My Account', @@ -966,6 +967,110 @@ export default { failedToLoad: 'Failed to load dashboard statistics' }, + backup: { + title: 'Database Backup', + description: 'Full database backup to S3-compatible storage with scheduled backup and restore', + s3: { + title: 'S3 Storage Configuration', + description: 'Configure S3-compatible storage (supports Cloudflare R2)', + descriptionPrefix: 'Configure S3-compatible storage (supports', + descriptionSuffix: ')', + enabled: 'Enable S3 Storage', + endpoint: 'Endpoint', + region: 'Region', + bucket: 'Bucket', + prefix: 'Key Prefix', + accessKeyId: 'Access Key ID', + secretAccessKey: 'Secret Access Key', + secretConfigured: 'Already configured, leave empty to keep', + forcePathStyle: 'Force Path Style', + testConnection: 'Test Connection', + testSuccess: 'S3 connection test successful', + testFailed: 'S3 connection test failed', + saved: 'S3 configuration saved' + }, + schedule: { + title: 'Scheduled Backup', + description: 'Configure automatic scheduled backups', + enabled: 'Enable Scheduled Backup', + cronExpr: 'Cron Expression', + cronHint: 'e.g. "0 2 * * *" means every day at 2:00 AM', + retainDays: 'Backup Expire Days', + retainDaysHint: 'Backup files auto-delete after this many days, 0 = never expire', + retainCount: 'Max Retain Count', + retainCountHint: 'Maximum number of backups to keep, 0 = unlimited', + saved: 'Schedule configuration saved' + }, + operations: { + title: 'Backup Records', + description: 'Create manual backups and manage existing backup records', + createBackup: 'Create Backup', + backing: 'Backing up...', + backupCreated: 'Backup created successfully', + expireDays: 'Expire Days' + }, + columns: { + status: 'Status', + fileName: 'File Name', + size: 'Size', + expiresAt: 'Expires At', + triggeredBy: 'Triggered By', + startedAt: 'Started At', + actions: 'Actions' + }, + status: { + pending: 'Pending', + running: 'Running', + completed: 'Completed', + failed: 'Failed' + }, + trigger: { + manual: 'Manual', + scheduled: 'Scheduled' + }, + neverExpire: 'Never', + empty: 'No backup records', + actions: { + download: 'Download', + restore: 'Restore', + restoreConfirm: 'Are you sure you want to restore from this backup? This will overwrite the current database!', + restoreSuccess: 'Database restored successfully', + deleteConfirm: 'Are you sure you want to delete this backup?', + deleted: 'Backup deleted' + }, + r2Guide: { + title: 'Cloudflare R2 Setup Guide', + intro: 'Cloudflare R2 provides S3-compatible object storage with a free tier of 10GB storage + 1M Class A requests/month, ideal for database backups.', + step1: { + title: 'Create an R2 Bucket', + line1: 'Log in to the Cloudflare Dashboard (dash.cloudflare.com), select "R2 Object Storage" from the sidebar', + line2: 'Click "Create bucket", enter a name (e.g. sub2api-backups), choose a region', + line3: 'Click create to finish' + }, + step2: { + title: 'Create an API Token', + line1: 'On the R2 page, click "Manage R2 API Tokens" in the top right', + line2: 'Click "Create API token", set permission to "Object Read & Write"', + line3: 'Recommended: restrict to specific bucket for better security', + line4: 'After creation, you will see the Access Key ID and Secret Access Key', + warning: 'The Secret Access Key is only shown once — copy and save it immediately!' + }, + step3: { + title: 'Get the S3 Endpoint', + desc: 'Find your Account ID on the R2 overview page (in the URL or the right panel). The endpoint format is:', + accountId: 'your_account_id' + }, + step4: { + title: 'Fill in the Configuration', + checkEnabled: 'Checked', + bucketValue: 'Your bucket name', + fromStep2: 'Value from Step 2', + unchecked: 'Unchecked' + }, + freeTier: 'R2 Free Tier: 10GB storage + 1M Class A requests + 10M Class B requests per month — more than enough for database backups.' + } + }, + dataManagement: { title: 'Data Management', description: 'Manage data management agent status, object storage settings, and backup jobs in one place', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index ddaced42..b276f059 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -340,6 +340,7 @@ export default { redeemCodes: '兑换码', ops: '运维监控', promoCodes: '优惠码', + backup: '数据库备份', dataManagement: '数据管理', settings: '系统设置', myAccount: '我的账户', @@ -988,6 +989,110 @@ export default { failedToLoad: '加载仪表盘数据失败' }, + backup: { + title: '数据库备份', + description: '全量数据库备份到 S3 兼容存储,支持定时备份与恢复', + s3: { + title: 'S3 存储配置', + description: '配置 S3 兼容存储(支持 Cloudflare R2)', + descriptionPrefix: '配置 S3 兼容存储(支持', + descriptionSuffix: ')', + enabled: '启用 S3 存储', + endpoint: '端点地址', + region: '区域', + bucket: '存储桶', + prefix: 'Key 前缀', + accessKeyId: 'Access Key ID', + secretAccessKey: 'Secret Access Key', + secretConfigured: '已配置,留空保持不变', + forcePathStyle: '强制路径风格', + testConnection: '测试连接', + testSuccess: 'S3 连接测试成功', + testFailed: 'S3 连接测试失败', + saved: 'S3 配置已保存' + }, + schedule: { + title: '定时备份', + description: '配置自动定时备份', + enabled: '启用定时备份', + cronExpr: 'Cron 表达式', + cronHint: '例如 "0 2 * * *" 表示每天凌晨 2 点', + retainDays: '备份过期天数', + retainDaysHint: '备份文件超过此天数后自动删除,0 = 永不过期', + retainCount: '最大保留份数', + retainCountHint: '最多保留的备份数量,0 = 不限制', + saved: '定时备份配置已保存' + }, + operations: { + title: '备份记录', + description: '创建手动备份和管理已有备份记录', + createBackup: '创建备份', + backing: '备份中...', + backupCreated: '备份创建成功', + expireDays: '过期天数' + }, + columns: { + status: '状态', + fileName: '文件名', + size: '大小', + expiresAt: '过期时间', + triggeredBy: '触发方式', + startedAt: '开始时间', + actions: '操作' + }, + status: { + pending: '等待中', + running: '执行中', + completed: '已完成', + failed: '失败' + }, + trigger: { + manual: '手动', + scheduled: '定时' + }, + neverExpire: '永不过期', + empty: '暂无备份记录', + actions: { + download: '下载', + restore: '恢复', + restoreConfirm: '确定要从此备份恢复吗?这将覆盖当前数据库!', + restoreSuccess: '数据库恢复成功', + deleteConfirm: '确定要删除此备份吗?', + deleted: '备份已删除' + }, + r2Guide: { + title: 'Cloudflare R2 配置教程', + intro: 'Cloudflare R2 提供 S3 兼容的对象存储,免费额度为 10GB 存储 + 每月 100 万次 A 类请求,非常适合数据库备份。', + step1: { + title: '创建 R2 存储桶', + line1: '登录 Cloudflare Dashboard (dash.cloudflare.com),左侧菜单选择「R2 对象存储」', + line2: '点击「创建存储桶」,输入名称(如 sub2api-backups),选择区域', + line3: '点击创建完成' + }, + step2: { + title: '创建 API 令牌', + line1: '在 R2 页面,点击右上角「管理 R2 API 令牌」', + line2: '点击「创建 API 令牌」,权限选择「对象读和写」', + line3: '建议指定存储桶范围(仅允许访问备份桶,更安全)', + line4: '创建后会显示 Access Key ID 和 Secret Access Key', + warning: 'Secret Access Key 只会显示一次,请立即复制保存!' + }, + step3: { + title: '获取 S3 端点地址', + desc: '在 R2 概览页面找到你的账户 ID(在 URL 或右侧面板中),端点格式为:', + accountId: '你的账户 ID' + }, + step4: { + title: '填写以下配置', + checkEnabled: '勾选', + bucketValue: '你创建的存储桶名称', + fromStep2: '第 2 步获取的值', + unchecked: '不勾选' + }, + freeTier: 'R2 免费额度:10GB 存储 + 每月 100 万次 A 类请求 + 1000 万次 B 类请求,对数据库备份完全够用。' + } + }, + dataManagement: { title: '数据管理', description: '统一管理数据管理代理状态、对象存储配置和备份任务', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index d235de51..2a7f7a77 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -350,6 +350,18 @@ const routes: RouteRecordRaw[] = [ descriptionKey: 'admin.promo.description' } }, + { + path: '/admin/backup', + name: 'AdminBackup', + component: () => import('@/views/admin/BackupView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Database Backup', + titleKey: 'admin.backup.title', + descriptionKey: 'admin.backup.description' + } + }, { path: '/admin/data-management', name: 'AdminDataManagement', diff --git a/frontend/src/views/admin/BackupView.vue b/frontend/src/views/admin/BackupView.vue new file mode 100644 index 00000000..2c54f365 --- /dev/null +++ b/frontend/src/views/admin/BackupView.vue @@ -0,0 +1,506 @@ + + + + + From e73531ce9b72f88dbece814d309db26bc8134a74 Mon Sep 17 00:00:00 2001 From: haruka <1628615876@qq.com> Date: Fri, 13 Mar 2026 10:39:35 +0800 Subject: [PATCH 221/286] =?UTF-8?q?fix:=20=E7=AE=A1=E7=90=86=E5=91=98?= =?UTF-8?q?=E9=87=8D=E7=BD=AE=E9=85=8D=E9=A2=9D=E8=A1=A5=E5=85=A8=20monthl?= =?UTF-8?q?y=20=E5=AD=97=E6=AE=B5=E5=B9=B6=E4=BF=AE=E5=A4=8D=20ristretto?= =?UTF-8?q?=20=E7=BC=93=E5=AD=98=E5=BC=82=E6=AD=A5=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 后端 handler:ResetSubscriptionQuotaRequest 新增 Monthly 字段, 验证逻辑扩展为 daily/weekly/monthly 至少一项为 true - 后端 service:AdminResetQuota 新增 resetMonthly 参数, 调用 ResetMonthlyUsage;重置后追加 subCacheL1.Wait(), 保证 ristretto Del() 的异步删除立即生效,消除重置后 /v1/usage 返回旧用量数据的竞态窗口 - 后端测试:更新存量测试用例匹配新签名,补充 TestAdminResetQuota_ResetMonthlyOnly / TestAdminResetQuota_ResetMonthlyUsageError 两个新用例 - 前端 API:resetQuota options 类型新增 monthly: boolean - 前端视图:confirmResetQuota 改为同时重置 daily/weekly/monthly - i18n:中英文确认提示文案更新,提及每月配额 Co-Authored-By: Claude Sonnet 4.6 --- .../handler/admin/subscription_handler.go | 13 ++-- .../service/subscription_reset_quota_test.go | 67 +++++++++++++++---- .../internal/service/subscription_service.go | 20 ++++-- frontend/src/api/admin/subscriptions.ts | 4 +- frontend/src/i18n/locales/en.ts | 2 +- frontend/src/i18n/locales/zh.ts | 2 +- .../src/views/admin/SubscriptionsView.vue | 2 +- 7 files changed, 81 insertions(+), 29 deletions(-) diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index d6073551..342964b6 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -218,11 +218,12 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { // ResetSubscriptionQuotaRequest represents the reset quota request type ResetSubscriptionQuotaRequest struct { - Daily bool `json:"daily"` - Weekly bool `json:"weekly"` + Daily bool `json:"daily"` + Weekly bool `json:"weekly"` + Monthly bool `json:"monthly"` } -// ResetQuota resets daily and/or weekly usage for a subscription. +// ResetQuota resets daily, weekly, and/or monthly usage for a subscription. // POST /api/v1/admin/subscriptions/:id/reset-quota func (h *SubscriptionHandler) ResetQuota(c *gin.Context) { subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64) @@ -235,11 +236,11 @@ func (h *SubscriptionHandler) ResetQuota(c *gin.Context) { response.BadRequest(c, "Invalid request: "+err.Error()) return } - if !req.Daily && !req.Weekly { - response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true") + if !req.Daily && !req.Weekly && !req.Monthly { + response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true") return } - sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly) + sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/service/subscription_reset_quota_test.go b/backend/internal/service/subscription_reset_quota_test.go index 36aa177f..3bbc2170 100644 --- a/backend/internal/service/subscription_reset_quota_test.go +++ b/backend/internal/service/subscription_reset_quota_test.go @@ -11,17 +11,19 @@ import ( "github.com/stretchr/testify/require" ) -// resetQuotaUserSubRepoStub 支持 GetByID、ResetDailyUsage、ResetWeeklyUsage, +// resetQuotaUserSubRepoStub 支持 GetByID、ResetDailyUsage、ResetWeeklyUsage、ResetMonthlyUsage, // 其余方法继承 userSubRepoNoop(panic)。 type resetQuotaUserSubRepoStub struct { userSubRepoNoop sub *UserSubscription - resetDailyCalled bool - resetWeeklyCalled bool - resetDailyErr error - resetWeeklyErr error + resetDailyCalled bool + resetWeeklyCalled bool + resetMonthlyCalled bool + resetDailyErr error + resetWeeklyErr error + resetMonthlyErr error } func (r *resetQuotaUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) { @@ -46,6 +48,11 @@ func (r *resetQuotaUserSubRepoStub) ResetWeeklyUsage(_ context.Context, _ int64, return r.resetWeeklyErr } +func (r *resetQuotaUserSubRepoStub) ResetMonthlyUsage(_ context.Context, _ int64, _ time.Time) error { + r.resetMonthlyCalled = true + return r.resetMonthlyErr +} + func newResetQuotaSvc(stub *resetQuotaUserSubRepoStub) *SubscriptionService { return NewSubscriptionService(groupRepoNoop{}, stub, nil, nil, nil) } @@ -56,12 +63,13 @@ func TestAdminResetQuota_ResetBoth(t *testing.T) { } svc := newResetQuotaSvc(stub) - result, err := svc.AdminResetQuota(context.Background(), 1, true, true) + result, err := svc.AdminResetQuota(context.Background(), 1, true, true, false) require.NoError(t, err) require.NotNil(t, result) require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage") require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage") + require.False(t, stub.resetMonthlyCalled, "不应调用 ResetMonthlyUsage") } func TestAdminResetQuota_ResetDailyOnly(t *testing.T) { @@ -70,12 +78,13 @@ func TestAdminResetQuota_ResetDailyOnly(t *testing.T) { } svc := newResetQuotaSvc(stub) - result, err := svc.AdminResetQuota(context.Background(), 2, true, false) + result, err := svc.AdminResetQuota(context.Background(), 2, true, false, false) require.NoError(t, err) require.NotNil(t, result) require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage") require.False(t, stub.resetWeeklyCalled, "不应调用 ResetWeeklyUsage") + require.False(t, stub.resetMonthlyCalled, "不应调用 ResetMonthlyUsage") } func TestAdminResetQuota_ResetWeeklyOnly(t *testing.T) { @@ -84,12 +93,13 @@ func TestAdminResetQuota_ResetWeeklyOnly(t *testing.T) { } svc := newResetQuotaSvc(stub) - result, err := svc.AdminResetQuota(context.Background(), 3, false, true) + result, err := svc.AdminResetQuota(context.Background(), 3, false, true, false) require.NoError(t, err) require.NotNil(t, result) require.False(t, stub.resetDailyCalled, "不应调用 ResetDailyUsage") require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage") + require.False(t, stub.resetMonthlyCalled, "不应调用 ResetMonthlyUsage") } func TestAdminResetQuota_BothFalseReturnsError(t *testing.T) { @@ -98,22 +108,24 @@ func TestAdminResetQuota_BothFalseReturnsError(t *testing.T) { } svc := newResetQuotaSvc(stub) - _, err := svc.AdminResetQuota(context.Background(), 7, false, false) + _, err := svc.AdminResetQuota(context.Background(), 7, false, false, false) require.ErrorIs(t, err, ErrInvalidInput) require.False(t, stub.resetDailyCalled) require.False(t, stub.resetWeeklyCalled) + require.False(t, stub.resetMonthlyCalled) } func TestAdminResetQuota_SubscriptionNotFound(t *testing.T) { stub := &resetQuotaUserSubRepoStub{sub: nil} svc := newResetQuotaSvc(stub) - _, err := svc.AdminResetQuota(context.Background(), 999, true, true) + _, err := svc.AdminResetQuota(context.Background(), 999, true, true, true) require.ErrorIs(t, err, ErrSubscriptionNotFound) require.False(t, stub.resetDailyCalled) require.False(t, stub.resetWeeklyCalled) + require.False(t, stub.resetMonthlyCalled) } func TestAdminResetQuota_ResetDailyUsageError(t *testing.T) { @@ -124,7 +136,7 @@ func TestAdminResetQuota_ResetDailyUsageError(t *testing.T) { } svc := newResetQuotaSvc(stub) - _, err := svc.AdminResetQuota(context.Background(), 4, true, true) + _, err := svc.AdminResetQuota(context.Background(), 4, true, true, false) require.ErrorIs(t, err, dbErr) require.True(t, stub.resetDailyCalled) @@ -139,12 +151,41 @@ func TestAdminResetQuota_ResetWeeklyUsageError(t *testing.T) { } svc := newResetQuotaSvc(stub) - _, err := svc.AdminResetQuota(context.Background(), 5, false, true) + _, err := svc.AdminResetQuota(context.Background(), 5, false, true, false) require.ErrorIs(t, err, dbErr) require.True(t, stub.resetWeeklyCalled) } +func TestAdminResetQuota_ResetMonthlyOnly(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 8, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + result, err := svc.AdminResetQuota(context.Background(), 8, false, false, true) + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, stub.resetDailyCalled, "不应调用 ResetDailyUsage") + require.False(t, stub.resetWeeklyCalled, "不应调用 ResetWeeklyUsage") + require.True(t, stub.resetMonthlyCalled, "应调用 ResetMonthlyUsage") +} + +func TestAdminResetQuota_ResetMonthlyUsageError(t *testing.T) { + dbErr := errors.New("db error") + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 9, UserID: 10, GroupID: 20}, + resetMonthlyErr: dbErr, + } + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 9, false, false, true) + + require.ErrorIs(t, err, dbErr) + require.True(t, stub.resetMonthlyCalled) +} + func TestAdminResetQuota_ReturnsRefreshedSub(t *testing.T) { stub := &resetQuotaUserSubRepoStub{ sub: &UserSubscription{ @@ -156,7 +197,7 @@ func TestAdminResetQuota_ReturnsRefreshedSub(t *testing.T) { } svc := newResetQuotaSvc(stub) - result, err := svc.AdminResetQuota(context.Background(), 6, true, false) + result, err := svc.AdminResetQuota(context.Background(), 6, true, false, false) require.NoError(t, err) // ResetDailyUsage stub 会将 sub.DailyUsageUSD 归零, diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 55f029fa..af548509 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -31,7 +31,7 @@ var ( ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group") ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics") ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type") - ErrInvalidInput = infraerrors.BadRequest("INVALID_INPUT", "at least one of resetDaily or resetWeekly must be true") + ErrInvalidInput = infraerrors.BadRequest("INVALID_INPUT", "at least one of resetDaily, resetWeekly, or resetMonthly must be true") ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded") ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded") ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded") @@ -696,10 +696,10 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *U return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart) } -// AdminResetQuota manually resets the daily and/or weekly usage windows. +// AdminResetQuota manually resets the daily, weekly, and/or monthly usage windows. // Uses startOfDay(now) as the new window start, matching automatic resets. -func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionID int64, resetDaily, resetWeekly bool) (*UserSubscription, error) { - if !resetDaily && !resetWeekly { +func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionID int64, resetDaily, resetWeekly, resetMonthly bool) (*UserSubscription, error) { + if !resetDaily && !resetWeekly && !resetMonthly { return nil, ErrInvalidInput } sub, err := s.userSubRepo.GetByID(ctx, subscriptionID) @@ -717,8 +717,18 @@ func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionI return nil, err } } - // Invalidate caches, same as CheckAndResetWindows + if resetMonthly { + if err := s.userSubRepo.ResetMonthlyUsage(ctx, sub.ID, windowStart); err != nil { + return nil, err + } + } + // Invalidate L1 ristretto cache. Ristretto's Del() is asynchronous by design, + // so call Wait() immediately after to flush pending operations and guarantee + // the deleted key is not returned on the very next Get() call. s.InvalidateSubCache(sub.UserID, sub.GroupID) + if s.subCacheL1 != nil { + s.subCacheL1.Wait() + } if s.billingCacheService != nil { _ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID) } diff --git a/frontend/src/api/admin/subscriptions.ts b/frontend/src/api/admin/subscriptions.ts index d06e0774..7557e3ad 100644 --- a/frontend/src/api/admin/subscriptions.ts +++ b/frontend/src/api/admin/subscriptions.ts @@ -121,14 +121,14 @@ export async function revoke(id: number): Promise<{ message: string }> { } /** - * Reset daily and/or weekly usage quota for a subscription + * Reset daily, weekly, and/or monthly usage quota for a subscription * @param id - Subscription ID * @param options - Which windows to reset * @returns Updated subscription */ export async function resetQuota( id: number, - options: { daily: boolean; weekly: boolean } + options: { daily: boolean; weekly: boolean; monthly: boolean } ): Promise { const { data } = await apiClient.post( `/admin/subscriptions/${id}/reset-quota`, diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 9f847eb6..045964bf 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1574,7 +1574,7 @@ export default { revoke: 'Revoke', resetQuota: 'Reset Quota', resetQuotaTitle: 'Reset Usage Quota', - resetQuotaConfirm: "Reset the daily and weekly usage quota for '{user}'? Usage will be zeroed and windows restarted from today.", + resetQuotaConfirm: "Reset the daily, weekly, and monthly usage quota for '{user}'? Usage will be zeroed and windows restarted from today.", quotaResetSuccess: 'Quota reset successfully', failedToResetQuota: 'Failed to reset quota', noSubscriptionsYet: 'No subscriptions yet', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index ddaced42..4307c314 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1662,7 +1662,7 @@ export default { revoke: '撤销', resetQuota: '重置配额', resetQuotaTitle: '重置用量配额', - resetQuotaConfirm: "确定要重置 '{user}' 的每日和每周用量配额吗?用量将归零并从今天开始重新计算。", + resetQuotaConfirm: "确定要重置 '{user}' 的每日、每周和每月用量配额吗?用量将归零并从今天开始重新计算。", quotaResetSuccess: '配额重置成功', failedToResetQuota: '重置配额失败', noSubscriptionsYet: '暂无订阅', diff --git a/frontend/src/views/admin/SubscriptionsView.vue b/frontend/src/views/admin/SubscriptionsView.vue index bb711b01..97282594 100644 --- a/frontend/src/views/admin/SubscriptionsView.vue +++ b/frontend/src/views/admin/SubscriptionsView.vue @@ -1154,7 +1154,7 @@ const confirmResetQuota = async () => { if (resettingQuota.value) return resettingQuota.value = true try { - await adminAPI.subscriptions.resetQuota(resettingSubscription.value.id, { daily: true, weekly: true }) + await adminAPI.subscriptions.resetQuota(resettingSubscription.value.id, { daily: true, weekly: true, monthly: true }) appStore.showSuccess(t('admin.subscriptions.quotaResetSuccess')) showResetQuotaConfirm.value = false resettingSubscription.value = null From 5b85005945b41e05210ddfdbfd50cfc9890f4f48 Mon Sep 17 00:00:00 2001 From: wucm667 Date: Fri, 13 Mar 2026 11:12:37 +0800 Subject: [PATCH 222/286] =?UTF-8?q?feat:=20=E8=B4=A6=E5=8F=B7=E9=85=8D?= =?UTF-8?q?=E9=A2=9D=E6=94=AF=E6=8C=81=E5=9B=BA=E5=AE=9A=E6=97=B6=E9=97=B4?= =?UTF-8?q?=E9=87=8D=E7=BD=AE=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 后端新增 rolling/fixed 两种配额重置模式,支持日配额和周配额 - fixed 模式下可配置重置时刻(小时)、重置星期几(周配额)及时区(IANA) - 在 account_repo.go 中使用 SQL 表达式适配两种模式的过期判断与重置时间推进 - 新增 ComputeQuotaResetAt / ValidateQuotaResetConfig 等辅助函数 - DTO 层新增相关字段并在 mappers 中完整映射 - 前端 QuotaLimitCard 新增 rolling/fixed 切换 UI、时区选择器 - CreateAccountModal / EditAccountModal 透传新配置字段 - i18n(zh/en)同步新增相关翻译词条 --- backend/internal/handler/dto/mappers.go | 25 ++ backend/internal/handler/dto/types.go | 10 + backend/internal/repository/account_repo.go | 62 ++++- backend/internal/service/account.go | 251 +++++++++++++++++- backend/internal/service/admin_service.go | 12 + .../components/account/AccountUsageCell.vue | 22 +- .../components/account/CreateAccountModal.vue | 37 +++ .../components/account/EditAccountModal.vue | 53 ++++ .../src/components/account/QuotaLimitCard.vue | 162 ++++++++++- frontend/src/i18n/locales/en.ts | 17 ++ frontend/src/i18n/locales/zh.ts | 17 ++ frontend/src/types/index.ts | 10 + 12 files changed, 660 insertions(+), 18 deletions(-) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 205ccd65..cef9913e 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -281,6 +281,31 @@ func AccountFromServiceShallow(a *service.Account) *Account { used := a.GetQuotaWeeklyUsed() out.QuotaWeeklyUsed = &used } + // 固定时间重置配置 + if mode := a.GetQuotaDailyResetMode(); mode == "fixed" { + out.QuotaDailyResetMode = &mode + hour := a.GetQuotaDailyResetHour() + out.QuotaDailyResetHour = &hour + } + if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" { + out.QuotaWeeklyResetMode = &mode + day := a.GetQuotaWeeklyResetDay() + out.QuotaWeeklyResetDay = &day + hour := a.GetQuotaWeeklyResetHour() + out.QuotaWeeklyResetHour = &hour + } + if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" { + tz := a.GetQuotaResetTimezone() + out.QuotaResetTimezone = &tz + } + if a.Extra != nil { + if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" { + out.QuotaDailyResetAt = &v + } + if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" { + out.QuotaWeeklyResetAt = &v + } + } } return out diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index d9ccda2d..3708eed5 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -203,6 +203,16 @@ type Account struct { QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"` QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"` + // 配额固定时间重置配置 + QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"` + QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"` + QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"` + QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"` + QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"` + QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"` + QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"` + QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"` + Proxy *Proxy `json:"proxy,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index a9cb2cba..884cc120 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -1727,8 +1727,47 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va // nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string. const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')` +// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired. +// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes. +const dailyExpiredExpr = `( + CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed' + THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz) + ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) + + '24 hours'::interval <= NOW() + END +)` + +// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired. +const weeklyExpiredExpr = `( + CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed' + THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz) + ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) + + '168 hours'::interval <= NOW() + END +)` + +// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs. +// For fixed mode: advances current reset_at by 1 day. For rolling mode: not used (NULL). +const nextDailyResetAtExpr = `( + CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed' + THEN to_char( + COALESCE((extra->>'quota_daily_reset_at')::timestamptz, NOW()) + '1 day'::interval + AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"') + ELSE NULL END +)` + +// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs. +const nextWeeklyResetAtExpr = `( + CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed' + THEN to_char( + COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, NOW()) + '7 days'::interval + AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"') + ELSE NULL END +)` + // IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度) // 日/周额度在周期过期时自动重置为 0 再递增。 +// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。 func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { rows, err := r.sql.QueryContext(ctx, `UPDATE accounts SET extra = ( @@ -1739,31 +1778,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am || CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN jsonb_build_object( 'quota_daily_used', - CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) - + '24 hours'::interval <= NOW() + CASE WHEN `+dailyExpiredExpr+` THEN $1 ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END, 'quota_daily_start', - CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) - + '24 hours'::interval <= NOW() + CASE WHEN `+dailyExpiredExpr+` THEN `+nowUTC+` ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END ) + -- 固定模式重置时更新下次重置时间 + || CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL + THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`) + ELSE '{}'::jsonb END ELSE '{}'::jsonb END -- 周额度:仅在 quota_weekly_limit > 0 时处理 || CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN jsonb_build_object( 'quota_weekly_used', - CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) - + '168 hours'::interval <= NOW() + CASE WHEN `+weeklyExpiredExpr+` THEN $1 ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END, 'quota_weekly_start', - CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) - + '168 hours'::interval <= NOW() + CASE WHEN `+weeklyExpiredExpr+` THEN `+nowUTC+` ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END ) + -- 固定模式重置时更新下次重置时间 + || CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL + THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`) + ELSE '{}'::jsonb END ELSE '{}'::jsonb END ), updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL @@ -1796,12 +1839,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am } // ResetQuotaUsed 重置账号所有维度的配额用量为 0 +// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间 func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error { _, err := r.sql.ExecContext(ctx, `UPDATE accounts SET extra = ( COALESCE(extra, '{}'::jsonb) || '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb - ) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW() + ) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL`, id) if err != nil { diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 9d4f73d4..70643de1 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -3,6 +3,7 @@ package service import ( "encoding/json" + "errors" "hash/fnv" "reflect" "sort" @@ -1260,6 +1261,240 @@ func (a *Account) getExtraTime(key string) time.Time { return time.Time{} } +// getExtraString 从 Extra 中读取指定 key 的字符串值 +func (a *Account) getExtraString(key string) string { + if a.Extra == nil { + return "" + } + if v, ok := a.Extra[key]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +// getExtraInt 从 Extra 中读取指定 key 的 int 值 +func (a *Account) getExtraInt(key string) int { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra[key]; ok { + return int(parseExtraFloat64(v)) + } + return 0 +} + +// GetQuotaDailyResetMode 获取日额度重置模式:"rolling"(默认)或 "fixed" +func (a *Account) GetQuotaDailyResetMode() string { + if m := a.getExtraString("quota_daily_reset_mode"); m == "fixed" { + return "fixed" + } + return "rolling" +} + +// GetQuotaDailyResetHour 获取固定重置的小时(0-23),默认 0 +func (a *Account) GetQuotaDailyResetHour() int { + return a.getExtraInt("quota_daily_reset_hour") +} + +// GetQuotaWeeklyResetMode 获取周额度重置模式:"rolling"(默认)或 "fixed" +func (a *Account) GetQuotaWeeklyResetMode() string { + if m := a.getExtraString("quota_weekly_reset_mode"); m == "fixed" { + return "fixed" + } + return "rolling" +} + +// GetQuotaWeeklyResetDay 获取固定重置的星期几(0=周日, 1=周一, ..., 6=周六),默认 1(周一) +func (a *Account) GetQuotaWeeklyResetDay() int { + if a.Extra == nil { + return 1 + } + if _, ok := a.Extra["quota_weekly_reset_day"]; !ok { + return 1 + } + return a.getExtraInt("quota_weekly_reset_day") +} + +// GetQuotaWeeklyResetHour 获取周配额固定重置的小时(0-23),默认 0 +func (a *Account) GetQuotaWeeklyResetHour() int { + return a.getExtraInt("quota_weekly_reset_hour") +} + +// GetQuotaResetTimezone 获取固定重置的时区名(IANA),默认 "UTC" +func (a *Account) GetQuotaResetTimezone() string { + if tz := a.getExtraString("quota_reset_timezone"); tz != "" { + return tz + } + return "UTC" +} + +// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点 +func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time { + t := after.In(tz) + today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz) + if !after.Before(today) { + return today.AddDate(0, 0, 1) + } + return today +} + +// lastFixedDailyReset 计算 now 之前最近一次的每日固定重置时间点 +func lastFixedDailyReset(hour int, tz *time.Location, now time.Time) time.Time { + t := now.In(tz) + today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz) + if now.Before(today) { + return today.AddDate(0, 0, -1) + } + return today +} + +// nextFixedWeeklyReset 计算在 after 之后的下一个每周固定重置时间点 +// day: 0=Sunday, 1=Monday, ..., 6=Saturday +func nextFixedWeeklyReset(day, hour int, tz *time.Location, after time.Time) time.Time { + t := after.In(tz) + todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz) + currentDay := int(todayReset.Weekday()) + + daysForward := (day - currentDay + 7) % 7 + if daysForward == 0 && !after.Before(todayReset) { + daysForward = 7 + } + return todayReset.AddDate(0, 0, daysForward) +} + +// lastFixedWeeklyReset 计算 now 之前最近一次的每周固定重置时间点 +func lastFixedWeeklyReset(day, hour int, tz *time.Location, now time.Time) time.Time { + t := now.In(tz) + todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz) + currentDay := int(todayReset.Weekday()) + + daysBack := (currentDay - day + 7) % 7 + if daysBack == 0 && now.Before(todayReset) { + daysBack = 7 + } + return todayReset.AddDate(0, 0, -daysBack) +} + +// isFixedDailyPeriodExpired 检查日配额是否在固定时间模式下已过期 +func (a *Account) isFixedDailyPeriodExpired(periodStart time.Time) bool { + if periodStart.IsZero() { + return true + } + tz, err := time.LoadLocation(a.GetQuotaResetTimezone()) + if err != nil { + tz = time.UTC + } + lastReset := lastFixedDailyReset(a.GetQuotaDailyResetHour(), tz, time.Now()) + return periodStart.Before(lastReset) +} + +// isFixedWeeklyPeriodExpired 检查周配额是否在固定时间模式下已过期 +func (a *Account) isFixedWeeklyPeriodExpired(periodStart time.Time) bool { + if periodStart.IsZero() { + return true + } + tz, err := time.LoadLocation(a.GetQuotaResetTimezone()) + if err != nil { + tz = time.UTC + } + lastReset := lastFixedWeeklyReset(a.GetQuotaWeeklyResetDay(), a.GetQuotaWeeklyResetHour(), tz, time.Now()) + return periodStart.Before(lastReset) +} + +// ComputeQuotaResetAt 根据当前配置计算并填充 extra 中的 quota_daily_reset_at / quota_weekly_reset_at +// 在保存账号配置时调用 +func ComputeQuotaResetAt(extra map[string]interface{}) { + now := time.Now() + tzName, _ := extra["quota_reset_timezone"].(string) + if tzName == "" { + tzName = "UTC" + } + tz, err := time.LoadLocation(tzName) + if err != nil { + tz = time.UTC + } + + // 日配额固定重置时间 + if mode, _ := extra["quota_daily_reset_mode"].(string); mode == "fixed" { + hour := int(parseExtraFloat64(extra["quota_daily_reset_hour"])) + if hour < 0 || hour > 23 { + hour = 0 + } + resetAt := nextFixedDailyReset(hour, tz, now) + extra["quota_daily_reset_at"] = resetAt.UTC().Format(time.RFC3339) + } else { + delete(extra, "quota_daily_reset_at") + } + + // 周配额固定重置时间 + if mode, _ := extra["quota_weekly_reset_mode"].(string); mode == "fixed" { + day := 1 // 默认周一 + if d, ok := extra["quota_weekly_reset_day"]; ok { + day = int(parseExtraFloat64(d)) + } + if day < 0 || day > 6 { + day = 1 + } + hour := int(parseExtraFloat64(extra["quota_weekly_reset_hour"])) + if hour < 0 || hour > 23 { + hour = 0 + } + resetAt := nextFixedWeeklyReset(day, hour, tz, now) + extra["quota_weekly_reset_at"] = resetAt.UTC().Format(time.RFC3339) + } else { + delete(extra, "quota_weekly_reset_at") + } +} + +// ValidateQuotaResetConfig 校验配额固定重置时间配置的合法性 +func ValidateQuotaResetConfig(extra map[string]interface{}) error { + if extra == nil { + return nil + } + // 校验时区 + if tz, ok := extra["quota_reset_timezone"].(string); ok && tz != "" { + if _, err := time.LoadLocation(tz); err != nil { + return errors.New("invalid quota_reset_timezone: must be a valid IANA timezone name") + } + } + // 日配额重置模式 + if mode, ok := extra["quota_daily_reset_mode"].(string); ok { + if mode != "rolling" && mode != "fixed" { + return errors.New("quota_daily_reset_mode must be 'rolling' or 'fixed'") + } + } + // 日配额重置小时 + if v, ok := extra["quota_daily_reset_hour"]; ok { + hour := int(parseExtraFloat64(v)) + if hour < 0 || hour > 23 { + return errors.New("quota_daily_reset_hour must be between 0 and 23") + } + } + // 周配额重置模式 + if mode, ok := extra["quota_weekly_reset_mode"].(string); ok { + if mode != "rolling" && mode != "fixed" { + return errors.New("quota_weekly_reset_mode must be 'rolling' or 'fixed'") + } + } + // 周配额重置星期几 + if v, ok := extra["quota_weekly_reset_day"]; ok { + day := int(parseExtraFloat64(v)) + if day < 0 || day > 6 { + return errors.New("quota_weekly_reset_day must be between 0 (Sunday) and 6 (Saturday)") + } + } + // 周配额重置小时 + if v, ok := extra["quota_weekly_reset_hour"]; ok { + hour := int(parseExtraFloat64(v)) + if hour < 0 || hour > 23 { + return errors.New("quota_weekly_reset_hour must be between 0 and 23") + } + } + return nil +} + // HasAnyQuotaLimit 检查是否配置了任一维度的配额限制 func (a *Account) HasAnyQuotaLimit() bool { return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0 @@ -1282,14 +1517,26 @@ func (a *Account) IsQuotaExceeded() bool { // 日额度(周期过期视为未超限,下次 increment 会重置) if limit := a.GetQuotaDailyLimit(); limit > 0 { start := a.getExtraTime("quota_daily_start") - if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit { + var expired bool + if a.GetQuotaDailyResetMode() == "fixed" { + expired = a.isFixedDailyPeriodExpired(start) + } else { + expired = isPeriodExpired(start, 24*time.Hour) + } + if !expired && a.GetQuotaDailyUsed() >= limit { return true } } // 周额度 if limit := a.GetQuotaWeeklyLimit(); limit > 0 { start := a.getExtraTime("quota_weekly_start") - if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit { + var expired bool + if a.GetQuotaWeeklyResetMode() == "fixed" { + expired = a.isFixedWeeklyPeriodExpired(start) + } else { + expired = isPeriodExpired(start, 7*24*time.Hour) + } + if !expired && a.GetQuotaWeeklyUsed() >= limit { return true } } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 5aadda47..a62c6278 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -1438,6 +1438,13 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou Status: StatusActive, Schedulable: true, } + // 预计算固定时间重置的下次重置时间 + if account.Extra != nil { + if err := ValidateQuotaResetConfig(account.Extra); err != nil { + return nil, err + } + ComputeQuotaResetAt(account.Extra) + } if input.ExpiresAt != nil && *input.ExpiresAt > 0 { expiresAt := time.Unix(*input.ExpiresAt, 0) account.ExpiresAt = &expiresAt @@ -1511,6 +1518,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U } } account.Extra = input.Extra + // 校验并预计算固定时间重置的下次重置时间 + if err := ValidateQuotaResetConfig(account.Extra); err != nil { + return nil, err + } + ComputeQuotaResetAt(account.Extra) } if input.ProxyID != nil { // 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图) diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index e83eaead..8154a66d 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -848,11 +848,23 @@ const makeQuotaBar = ( let resetsAt: string | null = null if (startKey) { const extra = props.account.extra as Record | undefined - const startStr = extra?.[startKey] as string | undefined - if (startStr) { - const startDate = new Date(startStr) - const periodMs = startKey.includes('daily') ? 24 * 60 * 60 * 1000 : 7 * 24 * 60 * 60 * 1000 - resetsAt = new Date(startDate.getTime() + periodMs).toISOString() + const isDaily = startKey.includes('daily') + const mode = isDaily + ? (extra?.quota_daily_reset_mode as string) || 'rolling' + : (extra?.quota_weekly_reset_mode as string) || 'rolling' + + if (mode === 'fixed') { + // Use pre-computed next reset time for fixed mode + const resetAtKey = isDaily ? 'quota_daily_reset_at' : 'quota_weekly_reset_at' + resetsAt = (extra?.[resetAtKey] as string) || null + } else { + // Rolling mode: compute from start + period + const startStr = extra?.[startKey] as string | undefined + if (startStr) { + const startDate = new Date(startStr) + const periodMs = isDaily ? 24 * 60 * 60 * 1000 : 7 * 24 * 60 * 60 * 1000 + resetsAt = new Date(startDate.getTime() + periodMs).toISOString() + } } } return { utilization, resetsAt } diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 8423c1b9..5b6a27da 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1291,9 +1291,21 @@ :totalLimit="editQuotaLimit" :dailyLimit="editQuotaDailyLimit" :weeklyLimit="editQuotaWeeklyLimit" + :dailyResetMode="editDailyResetMode" + :dailyResetHour="editDailyResetHour" + :weeklyResetMode="editWeeklyResetMode" + :weeklyResetDay="editWeeklyResetDay" + :weeklyResetHour="editWeeklyResetHour" + :resetTimezone="editResetTimezone" @update:totalLimit="editQuotaLimit = $event" @update:dailyLimit="editQuotaDailyLimit = $event" @update:weeklyLimit="editQuotaWeeklyLimit = $event" + @update:dailyResetMode="editDailyResetMode = $event" + @update:dailyResetHour="editDailyResetHour = $event" + @update:weeklyResetMode="editWeeklyResetMode = $event" + @update:weeklyResetDay="editWeeklyResetDay = $event" + @update:weeklyResetHour="editWeeklyResetHour = $event" + @update:resetTimezone="editResetTimezone = $event" /> @@ -2678,6 +2690,12 @@ const apiKeyValue = ref('') const editQuotaLimit = ref(null) const editQuotaDailyLimit = ref(null) const editQuotaWeeklyLimit = ref(null) +const editDailyResetMode = ref<'rolling' | 'fixed' | null>(null) +const editDailyResetHour = ref(null) +const editWeeklyResetMode = ref<'rolling' | 'fixed' | null>(null) +const editWeeklyResetDay = ref(null) +const editWeeklyResetHour = ref(null) +const editResetTimezone = ref(null) const modelMappings = ref([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) @@ -3347,6 +3365,12 @@ const resetForm = () => { editQuotaLimit.value = null editQuotaDailyLimit.value = null editQuotaWeeklyLimit.value = null + editDailyResetMode.value = null + editDailyResetHour.value = null + editWeeklyResetMode.value = null + editWeeklyResetDay.value = null + editWeeklyResetHour.value = null + editResetTimezone.value = null modelMappings.value = [] modelRestrictionMode.value = 'whitelist' allowedModels.value = [...claudeModels] // Default fill related models @@ -3796,6 +3820,19 @@ const createAccountAndFinish = async ( if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) { quotaExtra.quota_weekly_limit = editQuotaWeeklyLimit.value } + // Quota reset mode config + if (editDailyResetMode.value === 'fixed') { + quotaExtra.quota_daily_reset_mode = 'fixed' + quotaExtra.quota_daily_reset_hour = editDailyResetHour.value ?? 0 + } + if (editWeeklyResetMode.value === 'fixed') { + quotaExtra.quota_weekly_reset_mode = 'fixed' + quotaExtra.quota_weekly_reset_day = editWeeklyResetDay.value ?? 1 + quotaExtra.quota_weekly_reset_hour = editWeeklyResetHour.value ?? 0 + } + if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') { + quotaExtra.quota_reset_timezone = editResetTimezone.value || 'UTC' + } if (Object.keys(quotaExtra).length > 0) { finalExtra = quotaExtra } diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 1f2e988c..9debe283 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -967,9 +967,21 @@ :totalLimit="editQuotaLimit" :dailyLimit="editQuotaDailyLimit" :weeklyLimit="editQuotaWeeklyLimit" + :dailyResetMode="editDailyResetMode" + :dailyResetHour="editDailyResetHour" + :weeklyResetMode="editWeeklyResetMode" + :weeklyResetDay="editWeeklyResetDay" + :weeklyResetHour="editWeeklyResetHour" + :resetTimezone="editResetTimezone" @update:totalLimit="editQuotaLimit = $event" @update:dailyLimit="editQuotaDailyLimit = $event" @update:weeklyLimit="editQuotaWeeklyLimit = $event" + @update:dailyResetMode="editDailyResetMode = $event" + @update:dailyResetHour="editDailyResetHour = $event" + @update:weeklyResetMode="editWeeklyResetMode = $event" + @update:weeklyResetDay="editWeeklyResetDay = $event" + @update:weeklyResetHour="editWeeklyResetHour = $event" + @update:resetTimezone="editResetTimezone = $event" /> @@ -1608,6 +1620,12 @@ const anthropicPassthroughEnabled = ref(false) const editQuotaLimit = ref(null) const editQuotaDailyLimit = ref(null) const editQuotaWeeklyLimit = ref(null) +const editDailyResetMode = ref<'rolling' | 'fixed' | null>(null) +const editDailyResetHour = ref(null) +const editWeeklyResetMode = ref<'rolling' | 'fixed' | null>(null) +const editWeeklyResetDay = ref(null) +const editWeeklyResetHour = ref(null) +const editResetTimezone = ref(null) const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复 @@ -1795,10 +1813,23 @@ watch( editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null const weeklyVal = extra?.quota_weekly_limit as number | undefined editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null + // Load quota reset mode config + editDailyResetMode.value = (extra?.quota_daily_reset_mode as 'rolling' | 'fixed') || null + editDailyResetHour.value = (extra?.quota_daily_reset_hour as number) ?? null + editWeeklyResetMode.value = (extra?.quota_weekly_reset_mode as 'rolling' | 'fixed') || null + editWeeklyResetDay.value = (extra?.quota_weekly_reset_day as number) ?? null + editWeeklyResetHour.value = (extra?.quota_weekly_reset_hour as number) ?? null + editResetTimezone.value = (extra?.quota_reset_timezone as string) || null } else { editQuotaLimit.value = null editQuotaDailyLimit.value = null editQuotaWeeklyLimit.value = null + editDailyResetMode.value = null + editDailyResetHour.value = null + editWeeklyResetMode.value = null + editWeeklyResetDay.value = null + editWeeklyResetHour.value = null + editResetTimezone.value = null } // Load antigravity model mapping (Antigravity 只支持映射模式) @@ -2645,6 +2676,28 @@ const handleSubmit = async () => { } else { delete newExtra.quota_weekly_limit } + // Quota reset mode config + if (editDailyResetMode.value === 'fixed') { + newExtra.quota_daily_reset_mode = 'fixed' + newExtra.quota_daily_reset_hour = editDailyResetHour.value ?? 0 + } else { + delete newExtra.quota_daily_reset_mode + delete newExtra.quota_daily_reset_hour + } + if (editWeeklyResetMode.value === 'fixed') { + newExtra.quota_weekly_reset_mode = 'fixed' + newExtra.quota_weekly_reset_day = editWeeklyResetDay.value ?? 1 + newExtra.quota_weekly_reset_hour = editWeeklyResetHour.value ?? 0 + } else { + delete newExtra.quota_weekly_reset_mode + delete newExtra.quota_weekly_reset_day + delete newExtra.quota_weekly_reset_hour + } + if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') { + newExtra.quota_reset_timezone = editResetTimezone.value || 'UTC' + } else { + delete newExtra.quota_reset_timezone + } updatePayload.extra = newExtra } diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index 505118ba..fdc19ad9 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -8,12 +8,24 @@ const props = defineProps<{ totalLimit: number | null dailyLimit: number | null weeklyLimit: number | null + dailyResetMode: 'rolling' | 'fixed' | null + dailyResetHour: number | null + weeklyResetMode: 'rolling' | 'fixed' | null + weeklyResetDay: number | null + weeklyResetHour: number | null + resetTimezone: string | null }>() const emit = defineEmits<{ 'update:totalLimit': [value: number | null] 'update:dailyLimit': [value: number | null] 'update:weeklyLimit': [value: number | null] + 'update:dailyResetMode': [value: 'rolling' | 'fixed' | null] + 'update:dailyResetHour': [value: number | null] + 'update:weeklyResetMode': [value: 'rolling' | 'fixed' | null] + 'update:weeklyResetDay': [value: number | null] + 'update:weeklyResetHour': [value: number | null] + 'update:resetTimezone': [value: string | null] }>() const enabled = computed(() => @@ -35,9 +47,56 @@ watch(localEnabled, (val) => { emit('update:totalLimit', null) emit('update:dailyLimit', null) emit('update:weeklyLimit', null) + emit('update:dailyResetMode', null) + emit('update:dailyResetHour', null) + emit('update:weeklyResetMode', null) + emit('update:weeklyResetDay', null) + emit('update:weeklyResetHour', null) + emit('update:resetTimezone', null) } }) +// Whether any fixed mode is active (to show timezone selector) +const hasFixedMode = computed(() => + props.dailyResetMode === 'fixed' || props.weeklyResetMode === 'fixed' +) + +// Common timezone options +const timezoneOptions = [ + 'UTC', + 'Asia/Shanghai', + 'Asia/Tokyo', + 'Asia/Seoul', + 'Asia/Singapore', + 'Asia/Kolkata', + 'Asia/Dubai', + 'Europe/London', + 'Europe/Paris', + 'Europe/Berlin', + 'Europe/Moscow', + 'America/New_York', + 'America/Chicago', + 'America/Denver', + 'America/Los_Angeles', + 'America/Sao_Paulo', + 'Australia/Sydney', + 'Pacific/Auckland', +] + +// Hours for dropdown (0-23) +const hourOptions = Array.from({ length: 24 }, (_, i) => i) + +// Day of week options +const dayOptions = [ + { value: 1, key: 'monday' }, + { value: 2, key: 'tuesday' }, + { value: 3, key: 'wednesday' }, + { value: 4, key: 'thursday' }, + { value: 5, key: 'friday' }, + { value: 6, key: 'saturday' }, + { value: 0, key: 'sunday' }, +] + const onTotalInput = (e: Event) => { const raw = (e.target as HTMLInputElement).valueAsNumber emit('update:totalLimit', Number.isNaN(raw) ? null : raw) @@ -50,6 +109,25 @@ const onWeeklyInput = (e: Event) => { const raw = (e.target as HTMLInputElement).valueAsNumber emit('update:weeklyLimit', Number.isNaN(raw) ? null : raw) } + +const onDailyModeChange = (e: Event) => { + const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed' + emit('update:dailyResetMode', val) + if (val === 'fixed') { + if (props.dailyResetHour == null) emit('update:dailyResetHour', 0) + if (!props.resetTimezone) emit('update:resetTimezone', 'UTC') + } +} + +const onWeeklyModeChange = (e: Event) => { + const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed' + emit('update:weeklyResetMode', val) + if (val === 'fixed') { + if (props.weeklyResetDay == null) emit('update:weeklyResetDay', 1) + if (props.weeklyResetHour == null) emit('update:weeklyResetHour', 0) + if (!props.resetTimezone) emit('update:resetTimezone', 'UTC') + } +} @@ -489,6 +501,7 @@ import { Line } from 'vue-chartjs' import BaseDialog from '@/components/common/BaseDialog.vue' import LoadingSpinner from '@/components/common/LoadingSpinner.vue' import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue' +import EndpointDistributionChart from '@/components/charts/EndpointDistributionChart.vue' import Icon from '@/components/icons/Icon.vue' import { adminAPI } from '@/api/admin' import type { Account, AccountUsageStatsResponse } from '@/types' diff --git a/frontend/src/components/admin/account/AccountStatsModal.vue b/frontend/src/components/admin/account/AccountStatsModal.vue index 72a71d36..4dc84d5e 100644 --- a/frontend/src/components/admin/account/AccountStatsModal.vue +++ b/frontend/src/components/admin/account/AccountStatsModal.vue @@ -410,6 +410,18 @@ + + + + @@ -453,6 +465,7 @@ import { Line } from 'vue-chartjs' import BaseDialog from '@/components/common/BaseDialog.vue' import LoadingSpinner from '@/components/common/LoadingSpinner.vue' import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue' +import EndpointDistributionChart from '@/components/charts/EndpointDistributionChart.vue' import Icon from '@/components/icons/Icon.vue' import { adminAPI } from '@/api/admin' import type { Account, AccountUsageStatsResponse } from '@/types' diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index 72f7c010..aa6c2bbd 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -35,6 +35,19 @@ + +