From f3ed95d4dea643e54417d0b4e6b8ccd318e0631d Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Fri, 9 Jan 2026 20:54:26 +0800 Subject: [PATCH] =?UTF-8?q?feat(handler):=20=E5=AE=9E=E7=8E=B0=E8=BF=90?= =?UTF-8?q?=E7=BB=B4=E7=9B=91=E6=8E=A7=20API=20=E5=A4=84=E7=90=86=E5=99=A8?= =?UTF-8?q?=E5=92=8C=E4=B8=AD=E9=97=B4=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 ops 错误日志记录器(ops_error_logger.go) - 新增 ops 主处理器(ops_handler.go) - 新增告警管理处理器(ops_alerts_handler.go) - 新增仪表板处理器(ops_dashboard_handler.go) - 新增实时监控处理器(ops_realtime_handler.go) - 新增配置管理处理器(ops_settings_handler.go) - 新增 WebSocket 处理器(ops_ws_handler.go) - 扩展设置 DTO 支持 ops 配置 - 新增客户端请求 ID 中间件(client_request_id.go) - 新增 WebSocket 查询令牌认证中间件(ws_query_token_auth.go) - 更新管理员认证中间件支持 ops 路由 - 注册 handler 依赖注入 --- .../handler/admin/ops_alerts_handler.go | 433 ++++++++++ .../handler/admin/ops_dashboard_handler.go | 243 ++++++ backend/internal/handler/admin/ops_handler.go | 364 +++++++++ .../handler/admin/ops_realtime_handler.go | 120 +++ .../handler/admin/ops_settings_handler.go | 103 +++ .../internal/handler/admin/ops_ws_handler.go | 765 ++++++++++++++++++ backend/internal/handler/dto/settings.go | 5 + backend/internal/handler/ops_error_logger.go | 681 ++++++++++++++++ backend/internal/handler/wire.go | 3 + .../internal/server/middleware/admin_auth.go | 52 ++ .../server/middleware/client_request_id.go | 31 + .../server/middleware/ws_query_token_auth.go | 54 ++ 12 files changed, 2854 insertions(+) create mode 100644 backend/internal/handler/admin/ops_alerts_handler.go create mode 100644 backend/internal/handler/admin/ops_dashboard_handler.go create mode 100644 backend/internal/handler/admin/ops_handler.go create mode 100644 backend/internal/handler/admin/ops_realtime_handler.go create mode 100644 backend/internal/handler/admin/ops_settings_handler.go create mode 100644 backend/internal/handler/admin/ops_ws_handler.go create mode 100644 backend/internal/handler/ops_error_logger.go create mode 100644 backend/internal/server/middleware/client_request_id.go create mode 100644 backend/internal/server/middleware/ws_query_token_auth.go diff --git a/backend/internal/handler/admin/ops_alerts_handler.go b/backend/internal/handler/admin/ops_alerts_handler.go new file mode 100644 index 00000000..19d9d870 --- /dev/null +++ b/backend/internal/handler/admin/ops_alerts_handler.go @@ -0,0 +1,433 @@ +package admin + +import ( + "encoding/json" + "fmt" + "math" + "net/http" + "strconv" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" +) + +var validOpsAlertMetricTypes = []string{ + "success_rate", + "error_rate", + "upstream_error_rate", + "p95_latency_ms", + "p99_latency_ms", + "cpu_usage_percent", + "memory_usage_percent", + "concurrency_queue_depth", +} + +var validOpsAlertMetricTypeSet = func() map[string]struct{} { + set := make(map[string]struct{}, len(validOpsAlertMetricTypes)) + for _, v := range validOpsAlertMetricTypes { + set[v] = struct{}{} + } + return set +}() + +var validOpsAlertOperators = []string{">", "<", ">=", "<=", "==", "!="} + +var validOpsAlertOperatorSet = func() map[string]struct{} { + set := make(map[string]struct{}, len(validOpsAlertOperators)) + for _, v := range validOpsAlertOperators { + set[v] = struct{}{} + } + return set +}() + +var validOpsAlertSeverities = []string{"P0", "P1", "P2", "P3"} + +var validOpsAlertSeveritySet = func() map[string]struct{} { + set := make(map[string]struct{}, len(validOpsAlertSeverities)) + for _, v := range validOpsAlertSeverities { + set[v] = struct{}{} + } + return set +}() + +type opsAlertRuleValidatedInput struct { + Name string + MetricType string + Operator string + Threshold float64 + + Severity string + + WindowMinutes int + SustainedMinutes int + CooldownMinutes int + + Enabled bool + NotifyEmail bool + + WindowProvided bool + SustainedProvided bool + CooldownProvided bool + SeverityProvided bool + EnabledProvided bool + NotifyProvided bool +} + +func isPercentOrRateMetric(metricType string) bool { + switch metricType { + case "success_rate", + "error_rate", + "upstream_error_rate", + "cpu_usage_percent", + "memory_usage_percent": + return true + default: + return false + } +} + +func validateOpsAlertRulePayload(raw map[string]json.RawMessage) (*opsAlertRuleValidatedInput, error) { + if raw == nil { + return nil, fmt.Errorf("invalid request body") + } + + requiredFields := []string{"name", "metric_type", "operator", "threshold"} + for _, field := range requiredFields { + if _, ok := raw[field]; !ok { + return nil, fmt.Errorf("%s is required", field) + } + } + + var name string + if err := json.Unmarshal(raw["name"], &name); err != nil || strings.TrimSpace(name) == "" { + return nil, fmt.Errorf("name is required") + } + name = strings.TrimSpace(name) + + var metricType string + if err := json.Unmarshal(raw["metric_type"], &metricType); err != nil || strings.TrimSpace(metricType) == "" { + return nil, fmt.Errorf("metric_type is required") + } + metricType = strings.TrimSpace(metricType) + if _, ok := validOpsAlertMetricTypeSet[metricType]; !ok { + return nil, fmt.Errorf("metric_type must be one of: %s", strings.Join(validOpsAlertMetricTypes, ", ")) + } + + var operator string + if err := json.Unmarshal(raw["operator"], &operator); err != nil || strings.TrimSpace(operator) == "" { + return nil, fmt.Errorf("operator is required") + } + operator = strings.TrimSpace(operator) + if _, ok := validOpsAlertOperatorSet[operator]; !ok { + return nil, fmt.Errorf("operator must be one of: %s", strings.Join(validOpsAlertOperators, ", ")) + } + + var threshold float64 + if err := json.Unmarshal(raw["threshold"], &threshold); err != nil { + return nil, fmt.Errorf("threshold must be a number") + } + if math.IsNaN(threshold) || math.IsInf(threshold, 0) { + return nil, fmt.Errorf("threshold must be a finite number") + } + if isPercentOrRateMetric(metricType) { + if threshold < 0 || threshold > 100 { + return nil, fmt.Errorf("threshold must be between 0 and 100 for metric_type %s", metricType) + } + } else if threshold < 0 { + return nil, fmt.Errorf("threshold must be >= 0") + } + + validated := &opsAlertRuleValidatedInput{ + Name: name, + MetricType: metricType, + Operator: operator, + Threshold: threshold, + } + + if v, ok := raw["severity"]; ok { + validated.SeverityProvided = true + var sev string + if err := json.Unmarshal(v, &sev); err != nil { + return nil, fmt.Errorf("severity must be a string") + } + sev = strings.ToUpper(strings.TrimSpace(sev)) + if sev != "" { + if _, ok := validOpsAlertSeveritySet[sev]; !ok { + return nil, fmt.Errorf("severity must be one of: %s", strings.Join(validOpsAlertSeverities, ", ")) + } + validated.Severity = sev + } + } + if validated.Severity == "" { + validated.Severity = "P2" + } + + if v, ok := raw["enabled"]; ok { + validated.EnabledProvided = true + if err := json.Unmarshal(v, &validated.Enabled); err != nil { + return nil, fmt.Errorf("enabled must be a boolean") + } + } else { + validated.Enabled = true + } + + if v, ok := raw["notify_email"]; ok { + validated.NotifyProvided = true + if err := json.Unmarshal(v, &validated.NotifyEmail); err != nil { + return nil, fmt.Errorf("notify_email must be a boolean") + } + } else { + validated.NotifyEmail = true + } + + if v, ok := raw["window_minutes"]; ok { + validated.WindowProvided = true + if err := json.Unmarshal(v, &validated.WindowMinutes); err != nil { + return nil, fmt.Errorf("window_minutes must be an integer") + } + switch validated.WindowMinutes { + case 1, 5, 60: + default: + return nil, fmt.Errorf("window_minutes must be one of: 1, 5, 60") + } + } else { + validated.WindowMinutes = 1 + } + + if v, ok := raw["sustained_minutes"]; ok { + validated.SustainedProvided = true + if err := json.Unmarshal(v, &validated.SustainedMinutes); err != nil { + return nil, fmt.Errorf("sustained_minutes must be an integer") + } + if validated.SustainedMinutes < 1 || validated.SustainedMinutes > 1440 { + return nil, fmt.Errorf("sustained_minutes must be between 1 and 1440") + } + } else { + validated.SustainedMinutes = 1 + } + + if v, ok := raw["cooldown_minutes"]; ok { + validated.CooldownProvided = true + if err := json.Unmarshal(v, &validated.CooldownMinutes); err != nil { + return nil, fmt.Errorf("cooldown_minutes must be an integer") + } + if validated.CooldownMinutes < 0 || validated.CooldownMinutes > 1440 { + return nil, fmt.Errorf("cooldown_minutes must be between 0 and 1440") + } + } else { + validated.CooldownMinutes = 0 + } + + return validated, nil +} + +// ListAlertRules returns all ops alert rules. +// GET /api/v1/admin/ops/alert-rules +func (h *OpsHandler) ListAlertRules(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + rules, err := h.opsService.ListAlertRules(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, rules) +} + +// CreateAlertRule creates an ops alert rule. +// POST /api/v1/admin/ops/alert-rules +func (h *OpsHandler) CreateAlertRule(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + var raw map[string]json.RawMessage + if err := c.ShouldBindBodyWith(&raw, binding.JSON); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + validated, err := validateOpsAlertRulePayload(raw) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + var rule service.OpsAlertRule + if err := c.ShouldBindBodyWith(&rule, binding.JSON); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + + rule.Name = validated.Name + rule.MetricType = validated.MetricType + rule.Operator = validated.Operator + rule.Threshold = validated.Threshold + rule.WindowMinutes = validated.WindowMinutes + rule.SustainedMinutes = validated.SustainedMinutes + rule.CooldownMinutes = validated.CooldownMinutes + rule.Severity = validated.Severity + rule.Enabled = validated.Enabled + rule.NotifyEmail = validated.NotifyEmail + + created, err := h.opsService.CreateAlertRule(c.Request.Context(), &rule) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, created) +} + +// UpdateAlertRule updates an existing ops alert rule. +// PUT /api/v1/admin/ops/alert-rules/:id +func (h *OpsHandler) UpdateAlertRule(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid rule ID") + return + } + + var raw map[string]json.RawMessage + if err := c.ShouldBindBodyWith(&raw, binding.JSON); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + validated, err := validateOpsAlertRulePayload(raw) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + var rule service.OpsAlertRule + if err := c.ShouldBindBodyWith(&rule, binding.JSON); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + + rule.ID = id + rule.Name = validated.Name + rule.MetricType = validated.MetricType + rule.Operator = validated.Operator + rule.Threshold = validated.Threshold + rule.WindowMinutes = validated.WindowMinutes + rule.SustainedMinutes = validated.SustainedMinutes + rule.CooldownMinutes = validated.CooldownMinutes + rule.Severity = validated.Severity + rule.Enabled = validated.Enabled + rule.NotifyEmail = validated.NotifyEmail + + updated, err := h.opsService.UpdateAlertRule(c.Request.Context(), &rule) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, updated) +} + +// DeleteAlertRule deletes an ops alert rule. +// DELETE /api/v1/admin/ops/alert-rules/:id +func (h *OpsHandler) DeleteAlertRule(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid rule ID") + return + } + + if err := h.opsService.DeleteAlertRule(c.Request.Context(), id); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +// ListAlertEvents lists recent ops alert events. +// GET /api/v1/admin/ops/alert-events +func (h *OpsHandler) ListAlertEvents(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + limit := 100 + if raw := strings.TrimSpace(c.Query("limit")); raw != "" { + n, err := strconv.Atoi(raw) + if err != nil || n <= 0 { + response.BadRequest(c, "Invalid limit") + return + } + limit = n + } + + filter := &service.OpsAlertEventFilter{ + Limit: limit, + Status: strings.TrimSpace(c.Query("status")), + Severity: strings.TrimSpace(c.Query("severity")), + } + + // Optional global filter support (platform/group/time range). + if platform := strings.TrimSpace(c.Query("platform")); platform != "" { + filter.Platform = platform + } + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &id + } + if startTime, endTime, err := parseOpsTimeRange(c, "24h"); err == nil { + // Only apply when explicitly provided to avoid surprising default narrowing. + if strings.TrimSpace(c.Query("start_time")) != "" || strings.TrimSpace(c.Query("end_time")) != "" || strings.TrimSpace(c.Query("time_range")) != "" { + filter.StartTime = &startTime + filter.EndTime = &endTime + } + } else { + response.BadRequest(c, err.Error()) + return + } + + events, err := h.opsService.ListAlertEvents(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, events) +} + diff --git a/backend/internal/handler/admin/ops_dashboard_handler.go b/backend/internal/handler/admin/ops_dashboard_handler.go new file mode 100644 index 00000000..2c87f734 --- /dev/null +++ b/backend/internal/handler/admin/ops_dashboard_handler.go @@ -0,0 +1,243 @@ +package admin + +import ( + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// GetDashboardOverview returns vNext ops dashboard overview (raw path). +// GET /api/v1/admin/ops/dashboard/overview +func (h *OpsHandler) GetDashboardOverview(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + startTime, endTime, err := parseOpsTimeRange(c, "1h") + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + filter := &service.OpsDashboardFilter{ + StartTime: startTime, + EndTime: endTime, + Platform: strings.TrimSpace(c.Query("platform")), + QueryMode: parseOpsQueryMode(c), + } + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &id + } + + data, err := h.opsService.GetDashboardOverview(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, data) +} + +// GetDashboardThroughputTrend returns throughput time series (raw path). +// GET /api/v1/admin/ops/dashboard/throughput-trend +func (h *OpsHandler) GetDashboardThroughputTrend(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + startTime, endTime, err := parseOpsTimeRange(c, "1h") + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + filter := &service.OpsDashboardFilter{ + StartTime: startTime, + EndTime: endTime, + Platform: strings.TrimSpace(c.Query("platform")), + QueryMode: parseOpsQueryMode(c), + } + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &id + } + + bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime)) + data, err := h.opsService.GetThroughputTrend(c.Request.Context(), filter, bucketSeconds) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, data) +} + +// GetDashboardLatencyHistogram returns the latency distribution histogram (success requests). +// GET /api/v1/admin/ops/dashboard/latency-histogram +func (h *OpsHandler) GetDashboardLatencyHistogram(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + startTime, endTime, err := parseOpsTimeRange(c, "1h") + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + filter := &service.OpsDashboardFilter{ + StartTime: startTime, + EndTime: endTime, + Platform: strings.TrimSpace(c.Query("platform")), + QueryMode: parseOpsQueryMode(c), + } + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &id + } + + data, err := h.opsService.GetLatencyHistogram(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, data) +} + +// GetDashboardErrorTrend returns error counts time series (raw path). +// GET /api/v1/admin/ops/dashboard/error-trend +func (h *OpsHandler) GetDashboardErrorTrend(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + startTime, endTime, err := parseOpsTimeRange(c, "1h") + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + filter := &service.OpsDashboardFilter{ + StartTime: startTime, + EndTime: endTime, + Platform: strings.TrimSpace(c.Query("platform")), + QueryMode: parseOpsQueryMode(c), + } + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &id + } + + bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime)) + data, err := h.opsService.GetErrorTrend(c.Request.Context(), filter, bucketSeconds) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, data) +} + +// GetDashboardErrorDistribution returns error distribution by status code (raw path). +// GET /api/v1/admin/ops/dashboard/error-distribution +func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + startTime, endTime, err := parseOpsTimeRange(c, "1h") + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + filter := &service.OpsDashboardFilter{ + StartTime: startTime, + EndTime: endTime, + Platform: strings.TrimSpace(c.Query("platform")), + QueryMode: parseOpsQueryMode(c), + } + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &id + } + + data, err := h.opsService.GetErrorDistribution(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, data) +} + +func pickThroughputBucketSeconds(window time.Duration) int { + // Keep buckets predictable and avoid huge responses. + switch { + case window <= 2*time.Hour: + return 60 + case window <= 24*time.Hour: + return 300 + default: + return 3600 + } +} + +func parseOpsQueryMode(c *gin.Context) service.OpsQueryMode { + if c == nil { + return "" + } + raw := strings.TrimSpace(c.Query("mode")) + if raw == "" { + // Empty means "use server default" (DB setting ops_query_mode_default). + return "" + } + return service.ParseOpsQueryMode(raw) +} diff --git a/backend/internal/handler/admin/ops_handler.go b/backend/internal/handler/admin/ops_handler.go new file mode 100644 index 00000000..bff7426a --- /dev/null +++ b/backend/internal/handler/admin/ops_handler.go @@ -0,0 +1,364 @@ +package admin + +import ( + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type OpsHandler struct { + opsService *service.OpsService +} + +func NewOpsHandler(opsService *service.OpsService) *OpsHandler { + return &OpsHandler{opsService: opsService} +} + +// GetErrorLogs lists ops error logs. +// GET /api/v1/admin/ops/errors +func (h *OpsHandler) GetErrorLogs(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + page, pageSize := response.ParsePagination(c) + // Ops list can be larger than standard admin tables. + if pageSize > 500 { + pageSize = 500 + } + + startTime, endTime, err := parseOpsTimeRange(c, "1h") + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + filter := &service.OpsErrorLogFilter{ + Page: page, + PageSize: pageSize, + } + if !startTime.IsZero() { + filter.StartTime = &startTime + } + if !endTime.IsZero() { + filter.EndTime = &endTime + } + + if platform := strings.TrimSpace(c.Query("platform")); platform != "" { + filter.Platform = platform + } + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &id + } + if v := strings.TrimSpace(c.Query("account_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid account_id") + return + } + filter.AccountID = &id + } + if phase := strings.TrimSpace(c.Query("phase")); phase != "" { + filter.Phase = phase + } + if q := strings.TrimSpace(c.Query("q")); q != "" { + filter.Query = q + } + if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" { + parts := strings.Split(statusCodesStr, ",") + out := make([]int, 0, len(parts)) + for _, part := range parts { + p := strings.TrimSpace(part) + if p == "" { + continue + } + n, err := strconv.Atoi(p) + if err != nil || n < 0 { + response.BadRequest(c, "Invalid status_codes") + return + } + out = append(out, n) + } + filter.StatusCodes = out + } + + result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize) +} + +// GetErrorLogByID returns a single error log detail. +// GET /api/v1/admin/ops/errors/:id +func (h *OpsHandler) GetErrorLogByID(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + idStr := strings.TrimSpace(c.Param("id")) + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid error id") + return + } + + detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, detail) +} + +// ListRequestDetails returns a request-level list (success + error) for drill-down. +// GET /api/v1/admin/ops/requests +func (h *OpsHandler) ListRequestDetails(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + page, pageSize := response.ParsePagination(c) + if pageSize > 100 { + pageSize = 100 + } + + startTime, endTime, err := parseOpsTimeRange(c, "1h") + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + filter := &service.OpsRequestDetailFilter{ + Page: page, + PageSize: pageSize, + StartTime: &startTime, + EndTime: &endTime, + } + + filter.Kind = strings.TrimSpace(c.Query("kind")) + filter.Platform = strings.TrimSpace(c.Query("platform")) + filter.Model = strings.TrimSpace(c.Query("model")) + filter.RequestID = strings.TrimSpace(c.Query("request_id")) + filter.Query = strings.TrimSpace(c.Query("q")) + filter.Sort = strings.TrimSpace(c.Query("sort")) + + if v := strings.TrimSpace(c.Query("user_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid user_id") + return + } + filter.UserID = &id + } + if v := strings.TrimSpace(c.Query("api_key_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid api_key_id") + return + } + filter.APIKeyID = &id + } + if v := strings.TrimSpace(c.Query("account_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid account_id") + return + } + filter.AccountID = &id + } + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &id + } + + if v := strings.TrimSpace(c.Query("min_duration_ms")); v != "" { + parsed, err := strconv.Atoi(v) + if err != nil || parsed < 0 { + response.BadRequest(c, "Invalid min_duration_ms") + return + } + filter.MinDurationMs = &parsed + } + if v := strings.TrimSpace(c.Query("max_duration_ms")); v != "" { + parsed, err := strconv.Atoi(v) + if err != nil || parsed < 0 { + response.BadRequest(c, "Invalid max_duration_ms") + return + } + filter.MaxDurationMs = &parsed + } + + out, err := h.opsService.ListRequestDetails(c.Request.Context(), filter) + if err != nil { + // Invalid sort/kind/platform etc should be a bad request; keep it simple. + if strings.Contains(strings.ToLower(err.Error()), "invalid") { + response.BadRequest(c, err.Error()) + return + } + response.Error(c, http.StatusInternalServerError, "Failed to list request details") + return + } + + response.Paginated(c, out.Items, out.Total, out.Page, out.PageSize) +} + +type opsRetryRequest struct { + Mode string `json:"mode"` + PinnedAccountID *int64 `json:"pinned_account_id"` +} + +// RetryErrorRequest retries a failed request using stored request_body. +// POST /api/v1/admin/ops/errors/:id/retry +func (h *OpsHandler) RetryErrorRequest(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Error(c, http.StatusUnauthorized, "Unauthorized") + return + } + + idStr := strings.TrimSpace(c.Param("id")) + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid error id") + return + } + + req := opsRetryRequest{Mode: service.OpsRetryModeClient} + if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if strings.TrimSpace(req.Mode) == "" { + req.Mode = service.OpsRetryModeClient + } + + result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, result) +} + +func parseOpsTimeRange(c *gin.Context, defaultRange string) (time.Time, time.Time, error) { + startStr := strings.TrimSpace(c.Query("start_time")) + endStr := strings.TrimSpace(c.Query("end_time")) + + parseTS := func(s string) (time.Time, error) { + if s == "" { + return time.Time{}, nil + } + if t, err := time.Parse(time.RFC3339Nano, s); err == nil { + return t, nil + } + return time.Parse(time.RFC3339, s) + } + + start, err := parseTS(startStr) + if err != nil { + return time.Time{}, time.Time{}, err + } + end, err := parseTS(endStr) + if err != nil { + return time.Time{}, time.Time{}, err + } + + // start/end explicitly provided (even partially) + if startStr != "" || endStr != "" { + if end.IsZero() { + end = time.Now() + } + if start.IsZero() { + dur, _ := parseOpsDuration(defaultRange) + start = end.Add(-dur) + } + if start.After(end) { + return time.Time{}, time.Time{}, fmt.Errorf("invalid time range: start_time must be <= end_time") + } + if end.Sub(start) > 30*24*time.Hour { + return time.Time{}, time.Time{}, fmt.Errorf("invalid time range: max window is 30 days") + } + return start, end, nil + } + + // time_range fallback + tr := strings.TrimSpace(c.Query("time_range")) + if tr == "" { + tr = defaultRange + } + dur, ok := parseOpsDuration(tr) + if !ok { + dur, _ = parseOpsDuration(defaultRange) + } + + end = time.Now() + start = end.Add(-dur) + if end.Sub(start) > 30*24*time.Hour { + return time.Time{}, time.Time{}, fmt.Errorf("invalid time range: max window is 30 days") + } + return start, end, nil +} + +func parseOpsDuration(v string) (time.Duration, bool) { + switch strings.TrimSpace(v) { + case "5m": + return 5 * time.Minute, true + case "30m": + return 30 * time.Minute, true + case "1h": + return time.Hour, true + case "6h": + return 6 * time.Hour, true + case "24h": + return 24 * time.Hour, true + default: + return 0, false + } +} diff --git a/backend/internal/handler/admin/ops_realtime_handler.go b/backend/internal/handler/admin/ops_realtime_handler.go new file mode 100644 index 00000000..0c23c13b --- /dev/null +++ b/backend/internal/handler/admin/ops_realtime_handler.go @@ -0,0 +1,120 @@ +package admin + +import ( + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// GetConcurrencyStats returns real-time concurrency usage aggregated by platform/group/account. +// GET /api/v1/admin/ops/concurrency +func (h *OpsHandler) GetConcurrencyStats(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) { + response.Success(c, gin.H{ + "enabled": false, + "platform": map[string]*service.PlatformConcurrencyInfo{}, + "group": map[int64]*service.GroupConcurrencyInfo{}, + "account": map[int64]*service.AccountConcurrencyInfo{}, + "timestamp": time.Now().UTC(), + }) + return + } + + platformFilter := strings.TrimSpace(c.Query("platform")) + var groupID *int64 + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + groupID = &id + } + + platform, group, account, collectedAt, err := h.opsService.GetConcurrencyStats(c.Request.Context(), platformFilter, groupID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + payload := gin.H{ + "enabled": true, + "platform": platform, + "group": group, + "account": account, + } + if collectedAt != nil { + payload["timestamp"] = collectedAt.UTC() + } + response.Success(c, payload) +} + +// GetAccountAvailability returns account availability statistics. +// GET /api/v1/admin/ops/account-availability +// +// Query params: +// - platform: optional +// - group_id: optional +func (h *OpsHandler) GetAccountAvailability(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) { + response.Success(c, gin.H{ + "enabled": false, + "platform": map[string]*service.PlatformAvailability{}, + "group": map[int64]*service.GroupAvailability{}, + "account": map[int64]*service.AccountAvailability{}, + "timestamp": time.Now().UTC(), + }) + return + } + + platform := strings.TrimSpace(c.Query("platform")) + var groupID *int64 + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + groupID = &id + } + + platformStats, groupStats, accountStats, collectedAt, err := h.opsService.GetAccountAvailabilityStats(c.Request.Context(), platform, groupID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + payload := gin.H{ + "enabled": true, + "platform": platformStats, + "group": groupStats, + "account": accountStats, + } + if collectedAt != nil { + payload["timestamp"] = collectedAt.UTC() + } + response.Success(c, payload) +} diff --git a/backend/internal/handler/admin/ops_settings_handler.go b/backend/internal/handler/admin/ops_settings_handler.go new file mode 100644 index 00000000..e76c1b20 --- /dev/null +++ b/backend/internal/handler/admin/ops_settings_handler.go @@ -0,0 +1,103 @@ +package admin + +import ( + "net/http" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// GetEmailNotificationConfig returns Ops email notification config (DB-backed). +// GET /api/v1/admin/ops/email-notification/config +func (h *OpsHandler) GetEmailNotificationConfig(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + cfg, err := h.opsService.GetEmailNotificationConfig(c.Request.Context()) + if err != nil { + response.Error(c, http.StatusInternalServerError, "Failed to get email notification config") + return + } + response.Success(c, cfg) +} + +// UpdateEmailNotificationConfig updates Ops email notification config (DB-backed). +// PUT /api/v1/admin/ops/email-notification/config +func (h *OpsHandler) UpdateEmailNotificationConfig(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + var req service.OpsEmailNotificationConfigUpdateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + + updated, err := h.opsService.UpdateEmailNotificationConfig(c.Request.Context(), &req) + if err != nil { + // Most failures here are validation errors from request payload; treat as 400. + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + response.Success(c, updated) +} + +// GetAlertRuntimeSettings returns Ops alert evaluator runtime settings (DB-backed). +// GET /api/v1/admin/ops/runtime/alert +func (h *OpsHandler) GetAlertRuntimeSettings(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + cfg, err := h.opsService.GetOpsAlertRuntimeSettings(c.Request.Context()) + if err != nil { + response.Error(c, http.StatusInternalServerError, "Failed to get alert runtime settings") + return + } + response.Success(c, cfg) +} + +// UpdateAlertRuntimeSettings updates Ops alert evaluator runtime settings (DB-backed). +// PUT /api/v1/admin/ops/runtime/alert +func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + var req service.OpsAlertRuntimeSettings + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + + updated, err := h.opsService.UpdateOpsAlertRuntimeSettings(c.Request.Context(), &req) + if err != nil { + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + response.Success(c, updated) +} + diff --git a/backend/internal/handler/admin/ops_ws_handler.go b/backend/internal/handler/admin/ops_ws_handler.go new file mode 100644 index 00000000..4bbd9055 --- /dev/null +++ b/backend/internal/handler/admin/ops_ws_handler.go @@ -0,0 +1,765 @@ +package admin + +import ( + "context" + "encoding/json" + "log" + "math" + "net" + "net/http" + "net/netip" + "net/url" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" +) + +type OpsWSProxyConfig struct { + TrustProxy bool + TrustedProxies []netip.Prefix + OriginPolicy string +} + +const ( + envOpsWSTrustProxy = "OPS_WS_TRUST_PROXY" + envOpsWSTrustedProxies = "OPS_WS_TRUSTED_PROXIES" + envOpsWSOriginPolicy = "OPS_WS_ORIGIN_POLICY" + envOpsWSMaxConns = "OPS_WS_MAX_CONNS" + envOpsWSMaxConnsPerIP = "OPS_WS_MAX_CONNS_PER_IP" +) + +const ( + OriginPolicyStrict = "strict" + OriginPolicyPermissive = "permissive" +) + +var opsWSProxyConfig = loadOpsWSProxyConfigFromEnv() + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return isAllowedOpsWSOrigin(r) + }, + // Subprotocol negotiation: + // - The frontend passes ["sub2api-admin", "jwt."]. + // - We always select "sub2api-admin" so the token is never echoed back in the handshake response. + Subprotocols: []string{"sub2api-admin"}, +} + +const ( + qpsWSPushInterval = 2 * time.Second + qpsWSRefreshInterval = 5 * time.Second + qpsWSRequestCountWindow = 1 * time.Minute + + defaultMaxWSConns = 100 + defaultMaxWSConnsPerIP = 20 +) + +var wsConnCount atomic.Int32 +var wsConnCountByIP sync.Map // map[string]*atomic.Int32 + +const qpsWSIdleStopDelay = 30 * time.Second + +const ( + opsWSCloseRealtimeDisabled = 4001 +) + +var qpsWSIdleStopMu sync.Mutex +var qpsWSIdleStopTimer *time.Timer + +func cancelQPSWSIdleStop() { + qpsWSIdleStopMu.Lock() + if qpsWSIdleStopTimer != nil { + qpsWSIdleStopTimer.Stop() + qpsWSIdleStopTimer = nil + } + qpsWSIdleStopMu.Unlock() +} + +func scheduleQPSWSIdleStop() { + qpsWSIdleStopMu.Lock() + if qpsWSIdleStopTimer != nil { + qpsWSIdleStopMu.Unlock() + return + } + qpsWSIdleStopTimer = time.AfterFunc(qpsWSIdleStopDelay, func() { + // Only stop if truly idle at fire time. + if wsConnCount.Load() == 0 { + qpsWSCache.Stop() + } + qpsWSIdleStopMu.Lock() + qpsWSIdleStopTimer = nil + qpsWSIdleStopMu.Unlock() + }) + qpsWSIdleStopMu.Unlock() +} + +type opsWSRuntimeLimits struct { + MaxConns int32 + MaxConnsPerIP int32 +} + +var opsWSLimits = loadOpsWSRuntimeLimitsFromEnv() + +const ( + qpsWSWriteTimeout = 10 * time.Second + qpsWSPongWait = 60 * time.Second + qpsWSPingInterval = 30 * time.Second + + // We don't expect clients to send application messages; we only read to process control frames (Pong/Close). + qpsWSMaxReadBytes = 1024 +) + +type opsWSQPSCache struct { + refreshInterval time.Duration + requestCountWindow time.Duration + + lastUpdatedUnixNano atomic.Int64 + payload atomic.Value // []byte + + opsService *service.OpsService + cancel context.CancelFunc + done chan struct{} + + mu sync.Mutex + running bool +} + +var qpsWSCache = &opsWSQPSCache{ + refreshInterval: qpsWSRefreshInterval, + requestCountWindow: qpsWSRequestCountWindow, +} + +func (c *opsWSQPSCache) start(opsService *service.OpsService) { + if c == nil || opsService == nil { + return + } + + for { + c.mu.Lock() + if c.running { + c.mu.Unlock() + return + } + + // If a previous refresh loop is currently stopping, wait for it to fully exit. + done := c.done + if done != nil { + c.mu.Unlock() + <-done + + c.mu.Lock() + if c.done == done && !c.running { + c.done = nil + } + c.mu.Unlock() + continue + } + + c.opsService = opsService + ctx, cancel := context.WithCancel(context.Background()) + c.cancel = cancel + c.done = make(chan struct{}) + done = c.done + c.running = true + c.mu.Unlock() + + go func() { + defer close(done) + c.refreshLoop(ctx) + }() + return + } +} + +// Stop stops the background refresh loop. +// It is safe to call multiple times. +func (c *opsWSQPSCache) Stop() { + if c == nil { + return + } + + c.mu.Lock() + if !c.running { + done := c.done + c.mu.Unlock() + if done != nil { + <-done + } + return + } + cancel := c.cancel + c.cancel = nil + c.running = false + c.opsService = nil + done := c.done + c.mu.Unlock() + + if cancel != nil { + cancel() + } + if done != nil { + <-done + } + + c.mu.Lock() + if c.done == done && !c.running { + c.done = nil + } + c.mu.Unlock() +} + +func (c *opsWSQPSCache) refreshLoop(ctx context.Context) { + ticker := time.NewTicker(c.refreshInterval) + defer ticker.Stop() + + c.refresh(ctx) + for { + select { + case <-ticker.C: + c.refresh(ctx) + case <-ctx.Done(): + return + } + } +} + +func (c *opsWSQPSCache) refresh(parentCtx context.Context) { + if c == nil { + return + } + + c.mu.Lock() + opsService := c.opsService + c.mu.Unlock() + if opsService == nil { + return + } + + if parentCtx == nil { + parentCtx = context.Background() + } + ctx, cancel := context.WithTimeout(parentCtx, 10*time.Second) + defer cancel() + + now := time.Now().UTC() + stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now) + if err != nil || stats == nil { + if err != nil { + log.Printf("[OpsWS] refresh: get window stats failed: %v", err) + } + return + } + + requestCount := stats.SuccessCount + stats.ErrorCountTotal + qps := 0.0 + tps := 0.0 + if c.requestCountWindow > 0 { + seconds := c.requestCountWindow.Seconds() + qps = roundTo1DP(float64(requestCount) / seconds) + tps = roundTo1DP(float64(stats.TokenConsumed) / seconds) + } + + payload := gin.H{ + "type": "qps_update", + "timestamp": now.Format(time.RFC3339), + "data": gin.H{ + "qps": qps, + "tps": tps, + "request_count": requestCount, + }, + } + + msg, err := json.Marshal(payload) + if err != nil { + log.Printf("[OpsWS] refresh: marshal payload failed: %v", err) + return + } + + c.payload.Store(msg) + c.lastUpdatedUnixNano.Store(now.UnixNano()) +} + +func roundTo1DP(v float64) float64 { + return math.Round(v*10) / 10 +} + +func (c *opsWSQPSCache) getPayload() []byte { + if c == nil { + return nil + } + if cached, ok := c.payload.Load().([]byte); ok && cached != nil { + return cached + } + return nil +} + +func closeWS(conn *websocket.Conn, code int, reason string) { + if conn == nil { + return + } + msg := websocket.FormatCloseMessage(code, reason) + _ = conn.WriteControl(websocket.CloseMessage, msg, time.Now().Add(qpsWSWriteTimeout)) + _ = conn.Close() +} + +// QPSWSHandler handles realtime QPS push via WebSocket. +// GET /api/v1/admin/ops/ws/qps +func (h *OpsHandler) QPSWSHandler(c *gin.Context) { + clientIP := requestClientIP(c.Request) + + if h == nil || h.opsService == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "ops service not initialized"}) + return + } + + // If realtime monitoring is disabled, prefer a successful WS upgrade followed by a clean close + // with a deterministic close code. This prevents clients from spinning on 404/1006 reconnect loops. + if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) { + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "ops realtime monitoring is disabled"}) + return + } + closeWS(conn, opsWSCloseRealtimeDisabled, "realtime_disabled") + return + } + + cancelQPSWSIdleStop() + // Lazily start the background refresh loop so unit tests that never hit the + // websocket route don't spawn goroutines that depend on DB/Redis stubs. + qpsWSCache.start(h.opsService) + + // Reserve a global slot before upgrading the connection to keep the limit strict. + if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) { + log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns) + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) + return + } + defer func() { + if wsConnCount.Add(-1) == 0 { + scheduleQPSWSIdleStop() + } + }() + + if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" { + if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) { + log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP) + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) + return + } + defer releaseOpsWSIPSlot(clientIP) + } + + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + log.Printf("[OpsWS] upgrade failed: %v", err) + return + } + + defer func() { + _ = conn.Close() + }() + + handleQPSWebSocket(c.Request.Context(), conn) +} + +func tryAcquireOpsWSTotalSlot(limit int32) bool { + if limit <= 0 { + return true + } + for { + current := wsConnCount.Load() + if current >= limit { + return false + } + if wsConnCount.CompareAndSwap(current, current+1) { + return true + } + } +} + +func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool { + if strings.TrimSpace(clientIP) == "" || limit <= 0 { + return true + } + + v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{}) + counter := v.(*atomic.Int32) + + for { + current := counter.Load() + if current >= limit { + return false + } + if counter.CompareAndSwap(current, current+1) { + return true + } + } +} + +func releaseOpsWSIPSlot(clientIP string) { + if strings.TrimSpace(clientIP) == "" { + return + } + + v, ok := wsConnCountByIP.Load(clientIP) + if !ok { + return + } + counter := v.(*atomic.Int32) + next := counter.Add(-1) + if next <= 0 { + // Best-effort cleanup; safe even if a new slot was acquired concurrently. + wsConnCountByIP.Delete(clientIP) + } +} + +func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { + if conn == nil { + return + } + + ctx, cancel := context.WithCancel(parentCtx) + defer cancel() + + var closeOnce sync.Once + closeConn := func() { + closeOnce.Do(func() { + _ = conn.Close() + }) + } + + closeFrameCh := make(chan []byte, 1) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + conn.SetReadLimit(qpsWSMaxReadBytes) + if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil { + log.Printf("[OpsWS] set read deadline failed: %v", err) + return + } + conn.SetPongHandler(func(string) error { + return conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)) + }) + conn.SetCloseHandler(func(code int, text string) error { + select { + case closeFrameCh <- websocket.FormatCloseMessage(code, text): + default: + } + cancel() + return nil + }) + + for { + _, _, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { + log.Printf("[OpsWS] read failed: %v", err) + } + return + } + } + }() + + // Push QPS data every 2 seconds (values are globally cached and refreshed at most once per qpsWSRefreshInterval). + pushTicker := time.NewTicker(qpsWSPushInterval) + defer pushTicker.Stop() + + // Heartbeat ping every 30 seconds. + pingTicker := time.NewTicker(qpsWSPingInterval) + defer pingTicker.Stop() + + writeWithTimeout := func(messageType int, data []byte) error { + if err := conn.SetWriteDeadline(time.Now().Add(qpsWSWriteTimeout)); err != nil { + return err + } + return conn.WriteMessage(messageType, data) + } + + sendClose := func(closeFrame []byte) { + if closeFrame == nil { + closeFrame = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + } + _ = writeWithTimeout(websocket.CloseMessage, closeFrame) + } + + for { + select { + case <-pushTicker.C: + msg := qpsWSCache.getPayload() + if msg == nil { + continue + } + if err := writeWithTimeout(websocket.TextMessage, msg); err != nil { + log.Printf("[OpsWS] write failed: %v", err) + cancel() + closeConn() + wg.Wait() + return + } + + case <-pingTicker.C: + if err := writeWithTimeout(websocket.PingMessage, nil); err != nil { + log.Printf("[OpsWS] ping failed: %v", err) + cancel() + closeConn() + wg.Wait() + return + } + + case closeFrame := <-closeFrameCh: + sendClose(closeFrame) + closeConn() + wg.Wait() + return + + case <-ctx.Done(): + var closeFrame []byte + select { + case closeFrame = <-closeFrameCh: + default: + } + sendClose(closeFrame) + + closeConn() + wg.Wait() + return + } + } +} + +func isAllowedOpsWSOrigin(r *http.Request) bool { + if r == nil { + return false + } + origin := strings.TrimSpace(r.Header.Get("Origin")) + if origin == "" { + switch strings.ToLower(strings.TrimSpace(opsWSProxyConfig.OriginPolicy)) { + case OriginPolicyStrict: + return false + case OriginPolicyPermissive, "": + return true + default: + return true + } + } + parsed, err := url.Parse(origin) + if err != nil || parsed.Hostname() == "" { + return false + } + originHost := strings.ToLower(parsed.Hostname()) + + trustProxyHeaders := shouldTrustOpsWSProxyHeaders(r) + reqHost := hostWithoutPort(r.Host) + if trustProxyHeaders { + xfHost := strings.TrimSpace(r.Header.Get("X-Forwarded-Host")) + if xfHost != "" { + xfHost = strings.TrimSpace(strings.Split(xfHost, ",")[0]) + if xfHost != "" { + reqHost = hostWithoutPort(xfHost) + } + } + } + reqHost = strings.ToLower(reqHost) + if reqHost == "" { + return false + } + return originHost == reqHost +} + +func shouldTrustOpsWSProxyHeaders(r *http.Request) bool { + if r == nil { + return false + } + if !opsWSProxyConfig.TrustProxy { + return false + } + peerIP, ok := requestPeerIP(r) + if !ok { + return false + } + return isAddrInTrustedProxies(peerIP, opsWSProxyConfig.TrustedProxies) +} + +func requestPeerIP(r *http.Request) (netip.Addr, bool) { + if r == nil { + return netip.Addr{}, false + } + host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)) + if err != nil { + host = strings.TrimSpace(r.RemoteAddr) + } + host = strings.TrimPrefix(host, "[") + host = strings.TrimSuffix(host, "]") + if host == "" { + return netip.Addr{}, false + } + addr, err := netip.ParseAddr(host) + if err != nil { + return netip.Addr{}, false + } + return addr.Unmap(), true +} + +func requestClientIP(r *http.Request) string { + if r == nil { + return "" + } + + trustProxyHeaders := shouldTrustOpsWSProxyHeaders(r) + if trustProxyHeaders { + xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")) + if xff != "" { + // Use the left-most entry (original client). If multiple proxies add values, they are comma-separated. + xff = strings.TrimSpace(strings.Split(xff, ",")[0]) + xff = strings.TrimPrefix(xff, "[") + xff = strings.TrimSuffix(xff, "]") + if addr, err := netip.ParseAddr(xff); err == nil && addr.IsValid() { + return addr.Unmap().String() + } + } + } + + if peer, ok := requestPeerIP(r); ok && peer.IsValid() { + return peer.String() + } + return "" +} + +func isAddrInTrustedProxies(addr netip.Addr, trusted []netip.Prefix) bool { + if !addr.IsValid() { + return false + } + for _, p := range trusted { + if p.Contains(addr) { + return true + } + } + return false +} + +func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig { + cfg := OpsWSProxyConfig{ + TrustProxy: true, + TrustedProxies: defaultTrustedProxies(), + OriginPolicy: OriginPolicyPermissive, + } + + if v := strings.TrimSpace(os.Getenv(envOpsWSTrustProxy)); v != "" { + if parsed, err := strconv.ParseBool(v); err == nil { + cfg.TrustProxy = parsed + } else { + log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy) + } + } + + if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" { + prefixes, invalid := parseTrustedProxyList(raw) + if len(invalid) > 0 { + log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", ")) + } + cfg.TrustedProxies = prefixes + } + + if v := strings.TrimSpace(os.Getenv(envOpsWSOriginPolicy)); v != "" { + normalized := strings.ToLower(v) + switch normalized { + case OriginPolicyStrict, OriginPolicyPermissive: + cfg.OriginPolicy = normalized + default: + log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy) + } + } + + return cfg +} + +func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits { + cfg := opsWSRuntimeLimits{ + MaxConns: defaultMaxWSConns, + MaxConnsPerIP: defaultMaxWSConnsPerIP, + } + + if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConns)); v != "" { + if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 { + cfg.MaxConns = int32(parsed) + } else { + log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns) + } + } + if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" { + if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 { + cfg.MaxConnsPerIP = int32(parsed) + } else { + log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP) + } + } + return cfg +} + +func defaultTrustedProxies() []netip.Prefix { + prefixes, _ := parseTrustedProxyList("127.0.0.0/8,::1/128") + return prefixes +} + +func parseTrustedProxyList(raw string) (prefixes []netip.Prefix, invalid []string) { + for _, token := range strings.Split(raw, ",") { + item := strings.TrimSpace(token) + if item == "" { + continue + } + + var ( + p netip.Prefix + err error + ) + if strings.Contains(item, "/") { + p, err = netip.ParsePrefix(item) + } else { + var addr netip.Addr + addr, err = netip.ParseAddr(item) + if err == nil { + addr = addr.Unmap() + bits := 128 + if addr.Is4() { + bits = 32 + } + p = netip.PrefixFrom(addr, bits) + } + } + + if err != nil || !p.IsValid() { + invalid = append(invalid, item) + continue + } + + prefixes = append(prefixes, p.Masked()) + } + return prefixes, invalid +} + +func hostWithoutPort(hostport string) string { + hostport = strings.TrimSpace(hostport) + if hostport == "" { + return "" + } + if host, _, err := net.SplitHostPort(hostport); err == nil { + return host + } + if strings.HasPrefix(hostport, "[") && strings.HasSuffix(hostport, "]") { + return strings.Trim(hostport, "[]") + } + parts := strings.Split(hostport, ":") + return parts[0] +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 4c50cedf..6fd53b26 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -37,6 +37,11 @@ type SystemSettings struct { // Identity patch configuration (Claude -> Gemini) EnableIdentityPatch bool `json:"enable_identity_patch"` IdentityPatchPrompt string `json:"identity_patch_prompt"` + + // Ops monitoring (vNext) + OpsMonitoringEnabled bool `json:"ops_monitoring_enabled"` + OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"` + OpsQueryModeDefault string `json:"ops_query_mode_default"` } type PublicSettings struct { diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go new file mode 100644 index 00000000..b3a90c2f --- /dev/null +++ b/backend/internal/handler/ops_error_logger.go @@ -0,0 +1,681 @@ +package handler + +import ( + "bytes" + "context" + "encoding/json" + "log" + "runtime" + "runtime/debug" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + "unicode/utf8" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +const ( + opsModelKey = "ops_model" + opsStreamKey = "ops_stream" + opsRequestBodyKey = "ops_request_body" + opsAccountIDKey = "ops_account_id" +) + +const ( + opsErrorLogTimeout = 5 * time.Second + opsErrorLogDrainTimeout = 10 * time.Second + + opsErrorLogMinWorkerCount = 4 + opsErrorLogMaxWorkerCount = 32 + + opsErrorLogQueueSizePerWorker = 128 + opsErrorLogMinQueueSize = 256 + opsErrorLogMaxQueueSize = 8192 +) + +type opsErrorLogJob struct { + ops *service.OpsService + entry *service.OpsInsertErrorLogInput + requestBody []byte +} + +var ( + opsErrorLogOnce sync.Once + opsErrorLogQueue chan opsErrorLogJob + + opsErrorLogStopOnce sync.Once + opsErrorLogWorkersWg sync.WaitGroup + opsErrorLogMu sync.RWMutex + opsErrorLogStopping bool + opsErrorLogQueueLen atomic.Int64 + opsErrorLogEnqueued atomic.Int64 + opsErrorLogDropped atomic.Int64 + opsErrorLogProcessed atomic.Int64 + + opsErrorLogLastDropLogAt atomic.Int64 + + opsErrorLogShutdownCh = make(chan struct{}) + opsErrorLogShutdownOnce sync.Once + opsErrorLogDrained atomic.Bool +) + +func startOpsErrorLogWorkers() { + opsErrorLogMu.Lock() + defer opsErrorLogMu.Unlock() + + if opsErrorLogStopping { + return + } + + workerCount, queueSize := opsErrorLogConfig() + opsErrorLogQueue = make(chan opsErrorLogJob, queueSize) + opsErrorLogQueueLen.Store(0) + + opsErrorLogWorkersWg.Add(workerCount) + for i := 0; i < workerCount; i++ { + go func() { + defer opsErrorLogWorkersWg.Done() + for job := range opsErrorLogQueue { + opsErrorLogQueueLen.Add(-1) + if job.ops == nil || job.entry == nil { + continue + } + func() { + defer func() { + if r := recover(); r != nil { + log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack()) + } + }() + ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout) + _ = job.ops.RecordError(ctx, job.entry, job.requestBody) + cancel() + opsErrorLogProcessed.Add(1) + }() + } + }() + } +} + +func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) { + if ops == nil || entry == nil { + return + } + select { + case <-opsErrorLogShutdownCh: + return + default: + } + + opsErrorLogMu.RLock() + stopping := opsErrorLogStopping + opsErrorLogMu.RUnlock() + if stopping { + return + } + + opsErrorLogOnce.Do(startOpsErrorLogWorkers) + + opsErrorLogMu.RLock() + defer opsErrorLogMu.RUnlock() + if opsErrorLogStopping || opsErrorLogQueue == nil { + return + } + + select { + case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}: + opsErrorLogQueueLen.Add(1) + opsErrorLogEnqueued.Add(1) + default: + // Queue is full; drop to avoid blocking request handling. + opsErrorLogDropped.Add(1) + maybeLogOpsErrorLogDrop() + } +} + +func StopOpsErrorLogWorkers() bool { + opsErrorLogStopOnce.Do(func() { + opsErrorLogShutdownOnce.Do(func() { + close(opsErrorLogShutdownCh) + }) + opsErrorLogDrained.Store(stopOpsErrorLogWorkers()) + }) + return opsErrorLogDrained.Load() +} + +func stopOpsErrorLogWorkers() bool { + opsErrorLogMu.Lock() + opsErrorLogStopping = true + ch := opsErrorLogQueue + if ch != nil { + close(ch) + } + opsErrorLogQueue = nil + opsErrorLogMu.Unlock() + + if ch == nil { + opsErrorLogQueueLen.Store(0) + return true + } + + done := make(chan struct{}) + go func() { + opsErrorLogWorkersWg.Wait() + close(done) + }() + + select { + case <-done: + opsErrorLogQueueLen.Store(0) + return true + case <-time.After(opsErrorLogDrainTimeout): + return false + } +} + +func OpsErrorLogQueueLength() int64 { + return opsErrorLogQueueLen.Load() +} + +func OpsErrorLogQueueCapacity() int { + opsErrorLogMu.RLock() + ch := opsErrorLogQueue + opsErrorLogMu.RUnlock() + if ch == nil { + return 0 + } + return cap(ch) +} + +func OpsErrorLogDroppedTotal() int64 { + return opsErrorLogDropped.Load() +} + +func OpsErrorLogEnqueuedTotal() int64 { + return opsErrorLogEnqueued.Load() +} + +func OpsErrorLogProcessedTotal() int64 { + return opsErrorLogProcessed.Load() +} + +func maybeLogOpsErrorLogDrop() { + now := time.Now().Unix() + + for { + last := opsErrorLogLastDropLogAt.Load() + if last != 0 && now-last < 60 { + return + } + if opsErrorLogLastDropLogAt.CompareAndSwap(last, now) { + break + } + } + + queued := opsErrorLogQueueLen.Load() + queueCap := OpsErrorLogQueueCapacity() + + log.Printf( + "[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)", + queued, + queueCap, + opsErrorLogEnqueued.Load(), + opsErrorLogDropped.Load(), + opsErrorLogProcessed.Load(), + ) +} + +func opsErrorLogConfig() (workerCount int, queueSize int) { + workerCount = runtime.GOMAXPROCS(0) * 2 + if workerCount < opsErrorLogMinWorkerCount { + workerCount = opsErrorLogMinWorkerCount + } + if workerCount > opsErrorLogMaxWorkerCount { + workerCount = opsErrorLogMaxWorkerCount + } + + queueSize = workerCount * opsErrorLogQueueSizePerWorker + if queueSize < opsErrorLogMinQueueSize { + queueSize = opsErrorLogMinQueueSize + } + if queueSize > opsErrorLogMaxQueueSize { + queueSize = opsErrorLogMaxQueueSize + } + + return workerCount, queueSize +} + +func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody []byte) { + if c == nil { + return + } + c.Set(opsModelKey, model) + c.Set(opsStreamKey, stream) + if len(requestBody) > 0 { + c.Set(opsRequestBodyKey, requestBody) + } +} + +func setOpsSelectedAccount(c *gin.Context, accountID int64) { + if c == nil || accountID <= 0 { + return + } + c.Set(opsAccountIDKey, accountID) +} + +type opsCaptureWriter struct { + gin.ResponseWriter + limit int + buf bytes.Buffer +} + +func (w *opsCaptureWriter) Write(b []byte) (int, error) { + if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit { + remaining := w.limit - w.buf.Len() + if len(b) > remaining { + _, _ = w.buf.Write(b[:remaining]) + } else { + _, _ = w.buf.Write(b) + } + } + return w.ResponseWriter.Write(b) +} + +func (w *opsCaptureWriter) WriteString(s string) (int, error) { + if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit { + remaining := w.limit - w.buf.Len() + if len(s) > remaining { + _, _ = w.buf.WriteString(s[:remaining]) + } else { + _, _ = w.buf.WriteString(s) + } + } + return w.ResponseWriter.WriteString(s) +} + +// OpsErrorLoggerMiddleware records error responses (status >= 400) into ops_error_logs. +// +// Notes: +// - It buffers response bodies only when status >= 400 to avoid overhead for successful traffic. +// - Streaming errors after the response has started (SSE) may still need explicit logging. +func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { + return func(c *gin.Context) { + w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024} + c.Writer = w + c.Next() + + status := c.Writer.Status() + if status < 400 { + return + } + if ops == nil { + return + } + if !ops.IsMonitoringEnabled(c.Request.Context()) { + return + } + + body := w.buf.Bytes() + parsed := parseOpsErrorResponse(body) + + apiKey, _ := middleware2.GetAPIKeyFromContext(c) + + clientRequestID, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string) + + model, _ := c.Get(opsModelKey) + streamV, _ := c.Get(opsStreamKey) + accountIDV, _ := c.Get(opsAccountIDKey) + + var modelName string + if s, ok := model.(string); ok { + modelName = s + } + stream := false + if b, ok := streamV.(bool); ok { + stream = b + } + var accountID *int64 + if v, ok := accountIDV.(int64); ok && v > 0 { + accountID = &v + } + + fallbackPlatform := guessPlatformFromPath(c.Request.URL.Path) + platform := resolveOpsPlatform(apiKey, fallbackPlatform) + + requestID := c.Writer.Header().Get("X-Request-Id") + if requestID == "" { + requestID = c.Writer.Header().Get("x-request-id") + } + + phase := classifyOpsPhase(parsed.ErrorType, parsed.Message, parsed.Code) + isBusinessLimited := classifyOpsIsBusinessLimited(parsed.ErrorType, phase, parsed.Code, status, parsed.Message) + + errorOwner := classifyOpsErrorOwner(phase, parsed.Message) + errorSource := classifyOpsErrorSource(phase, parsed.Message) + + entry := &service.OpsInsertErrorLogInput{ + RequestID: requestID, + ClientRequestID: clientRequestID, + + AccountID: accountID, + Platform: platform, + Model: modelName, + RequestPath: func() string { + if c.Request != nil && c.Request.URL != nil { + return c.Request.URL.Path + } + return "" + }(), + Stream: stream, + UserAgent: c.GetHeader("User-Agent"), + + ErrorPhase: phase, + ErrorType: normalizeOpsErrorType(parsed.ErrorType, parsed.Code), + Severity: classifyOpsSeverity(parsed.ErrorType, status), + StatusCode: status, + IsBusinessLimited: isBusinessLimited, + + ErrorMessage: parsed.Message, + // Keep the full captured error body (capture is already capped at 64KB) so the + // service layer can sanitize JSON before truncating for storage. + ErrorBody: string(body), + ErrorSource: errorSource, + ErrorOwner: errorOwner, + + IsRetryable: classifyOpsIsRetryable(parsed.ErrorType, status), + RetryCount: 0, + CreatedAt: time.Now(), + } + + if apiKey != nil { + entry.APIKeyID = &apiKey.ID + if apiKey.User != nil { + entry.UserID = &apiKey.User.ID + } + if apiKey.GroupID != nil { + entry.GroupID = apiKey.GroupID + } + // Prefer group platform if present (more stable than inferring from path). + if apiKey.Group != nil && apiKey.Group.Platform != "" { + entry.Platform = apiKey.Group.Platform + } + } + + var clientIP string + if ip := strings.TrimSpace(c.ClientIP()); ip != "" { + clientIP = ip + entry.ClientIP = &clientIP + } + + var requestBody []byte + if v, ok := c.Get(opsRequestBodyKey); ok { + if b, ok := v.([]byte); ok && len(b) > 0 { + requestBody = b + } + } + // Persist only a minimal, whitelisted set of request headers to improve retry fidelity. + // Do NOT store Authorization/Cookie/etc. + entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c) + + enqueueOpsErrorLog(ops, entry, requestBody) + } +} + +var opsRetryRequestHeaderAllowlist = []string{ + "anthropic-beta", + "anthropic-version", +} + +func extractOpsRetryRequestHeaders(c *gin.Context) *string { + if c == nil || c.Request == nil { + return nil + } + + headers := make(map[string]string, 4) + for _, key := range opsRetryRequestHeaderAllowlist { + v := strings.TrimSpace(c.GetHeader(key)) + if v == "" { + continue + } + // Keep headers small even if a client sends something unexpected. + headers[key] = truncateString(v, 512) + } + if len(headers) == 0 { + return nil + } + + raw, err := json.Marshal(headers) + if err != nil { + return nil + } + s := string(raw) + return &s +} + +type parsedOpsError struct { + ErrorType string + Message string + Code string +} + +func parseOpsErrorResponse(body []byte) parsedOpsError { + if len(body) == 0 { + return parsedOpsError{} + } + + // Fast path: attempt to decode into a generic map. + var m map[string]any + if err := json.Unmarshal(body, &m); err != nil { + return parsedOpsError{Message: truncateString(string(body), 1024)} + } + + // Claude/OpenAI-style gateway error: { type:"error", error:{ type, message } } + if errObj, ok := m["error"].(map[string]any); ok { + t, _ := errObj["type"].(string) + msg, _ := errObj["message"].(string) + // Gemini googleError also uses "error": { code, message, status } + if msg == "" { + if v, ok := errObj["message"]; ok { + msg, _ = v.(string) + } + } + if t == "" { + // Gemini error does not have "type" field. + t = "api_error" + } + // For gemini error, capture numeric code as string for business-limited mapping if needed. + var code string + if v, ok := errObj["code"]; ok { + switch n := v.(type) { + case float64: + code = strconvItoa(int(n)) + case int: + code = strconvItoa(n) + } + } + return parsedOpsError{ErrorType: t, Message: msg, Code: code} + } + + // APIKeyAuth-style: { code:"INSUFFICIENT_BALANCE", message:"..." } + code, _ := m["code"].(string) + msg, _ := m["message"].(string) + if code != "" || msg != "" { + return parsedOpsError{ErrorType: "api_error", Message: msg, Code: code} + } + + return parsedOpsError{Message: truncateString(string(body), 1024)} +} + +func resolveOpsPlatform(apiKey *service.APIKey, fallback string) string { + if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform != "" { + return apiKey.Group.Platform + } + return fallback +} + +func guessPlatformFromPath(path string) string { + p := strings.ToLower(path) + switch { + case strings.HasPrefix(p, "/antigravity/"): + return service.PlatformAntigravity + case strings.HasPrefix(p, "/v1beta/"): + return service.PlatformGemini + case strings.Contains(p, "/responses"): + return service.PlatformOpenAI + default: + return "" + } +} + +func normalizeOpsErrorType(errType string, code string) string { + if errType != "" { + return errType + } + switch strings.TrimSpace(code) { + case "INSUFFICIENT_BALANCE": + return "billing_error" + case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID": + return "subscription_error" + default: + return "api_error" + } +} + +func classifyOpsPhase(errType, message, code string) string { + msg := strings.ToLower(message) + switch strings.TrimSpace(code) { + case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID": + return "billing" + } + + switch errType { + case "authentication_error": + return "auth" + case "billing_error", "subscription_error": + return "billing" + case "rate_limit_error": + if strings.Contains(msg, "concurrency") || strings.Contains(msg, "pending") || strings.Contains(msg, "queue") { + return "concurrency" + } + return "upstream" + case "invalid_request_error": + return "response" + case "upstream_error", "overloaded_error": + return "upstream" + case "api_error": + if strings.Contains(msg, "no available accounts") { + return "scheduling" + } + return "internal" + default: + return "internal" + } +} + +func classifyOpsSeverity(errType string, status int) string { + switch errType { + case "invalid_request_error", "authentication_error", "billing_error", "subscription_error": + return "P3" + } + if status >= 500 { + return "P1" + } + if status == 429 { + return "P1" + } + if status >= 400 { + return "P2" + } + return "P3" +} + +func classifyOpsIsRetryable(errType string, statusCode int) bool { + switch errType { + case "authentication_error", "invalid_request_error": + return false + case "timeout_error": + return true + case "rate_limit_error": + // May be transient (upstream or queue); retry can help. + return true + case "billing_error", "subscription_error": + return false + case "upstream_error", "overloaded_error": + return statusCode >= 500 || statusCode == 429 || statusCode == 529 + default: + return statusCode >= 500 + } +} + +func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool { + switch strings.TrimSpace(code) { + case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID": + return true + } + if phase == "billing" || phase == "concurrency" { + // SLA/错误率排除“用户级业务限制” + return true + } + // Avoid treating upstream rate limits as business-limited. + if errType == "rate_limit_error" && strings.Contains(strings.ToLower(message), "upstream") { + return false + } + _ = status + return false +} + +func classifyOpsErrorOwner(phase string, message string) string { + switch phase { + case "upstream", "network": + return "provider" + case "billing", "concurrency", "auth", "response": + return "client" + default: + if strings.Contains(strings.ToLower(message), "upstream") { + return "provider" + } + return "sub2api" + } +} + +func classifyOpsErrorSource(phase string, message string) string { + switch phase { + case "upstream": + return "upstream_http" + case "network": + return "upstream_network" + case "billing": + return "billing" + case "concurrency": + return "concurrency" + default: + if strings.Contains(strings.ToLower(message), "upstream") { + return "upstream_http" + } + return "internal" + } +} + +func truncateString(s string, max int) string { + if max <= 0 { + return "" + } + if len(s) <= max { + return s + } + cut := s[:max] + // Ensure truncation does not split multi-byte characters. + for len(cut) > 0 && !utf8.ValidString(cut) { + cut = cut[:len(cut)-1] + } + return cut +} + +func strconvItoa(v int) string { + return strconv.Itoa(v) +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 1695f8a9..e5d8d077 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -20,6 +20,7 @@ func ProvideAdminHandlers( proxyHandler *admin.ProxyHandler, redeemHandler *admin.RedeemHandler, settingHandler *admin.SettingHandler, + opsHandler *admin.OpsHandler, systemHandler *admin.SystemHandler, subscriptionHandler *admin.SubscriptionHandler, usageHandler *admin.UsageHandler, @@ -37,6 +38,7 @@ func ProvideAdminHandlers( Proxy: proxyHandler, Redeem: redeemHandler, Setting: settingHandler, + Ops: opsHandler, System: systemHandler, Subscription: subscriptionHandler, Usage: usageHandler, @@ -106,6 +108,7 @@ var ProviderSet = wire.NewSet( admin.NewProxyHandler, admin.NewRedeemHandler, admin.NewSettingHandler, + admin.NewOpsHandler, ProvideSystemHandler, admin.NewSubscriptionHandler, admin.NewUsageHandler, diff --git a/backend/internal/server/middleware/admin_auth.go b/backend/internal/server/middleware/admin_auth.go index e02a7b0a..8f30107c 100644 --- a/backend/internal/server/middleware/admin_auth.go +++ b/backend/internal/server/middleware/admin_auth.go @@ -30,6 +30,20 @@ func adminAuth( settingService *service.SettingService, ) gin.HandlerFunc { return func(c *gin.Context) { + // WebSocket upgrade requests cannot set Authorization headers in browsers. + // For admin WebSocket endpoints (e.g. Ops realtime), allow passing the JWT via + // Sec-WebSocket-Protocol (subprotocol list) using a prefixed token item: + // Sec-WebSocket-Protocol: sub2api-admin, jwt. + if isWebSocketUpgradeRequest(c) { + if token := extractJWTFromWebSocketSubprotocol(c); token != "" { + if !validateJWTForAdmin(c, token, authService, userService) { + return + } + c.Next() + return + } + } + // 检查 x-api-key header(Admin API Key 认证) apiKey := c.GetHeader("x-api-key") if apiKey != "" { @@ -58,6 +72,44 @@ func adminAuth( } } +func isWebSocketUpgradeRequest(c *gin.Context) bool { + if c == nil || c.Request == nil { + return false + } + // RFC6455 handshake uses: + // Connection: Upgrade + // Upgrade: websocket + upgrade := strings.ToLower(strings.TrimSpace(c.GetHeader("Upgrade"))) + if upgrade != "websocket" { + return false + } + connection := strings.ToLower(c.GetHeader("Connection")) + return strings.Contains(connection, "upgrade") +} + +func extractJWTFromWebSocketSubprotocol(c *gin.Context) string { + if c == nil { + return "" + } + raw := strings.TrimSpace(c.GetHeader("Sec-WebSocket-Protocol")) + if raw == "" { + return "" + } + + // The header is a comma-separated list of tokens. We reserve the prefix "jwt." + // for carrying the admin JWT. + for _, part := range strings.Split(raw, ",") { + p := strings.TrimSpace(part) + if strings.HasPrefix(p, "jwt.") { + token := strings.TrimSpace(strings.TrimPrefix(p, "jwt.")) + if token != "" { + return token + } + } + } + return "" +} + // validateAdminAPIKey 验证管理员 API Key func validateAdminAPIKey( c *gin.Context, diff --git a/backend/internal/server/middleware/client_request_id.go b/backend/internal/server/middleware/client_request_id.go new file mode 100644 index 00000000..60d444ce --- /dev/null +++ b/backend/internal/server/middleware/client_request_id.go @@ -0,0 +1,31 @@ +package middleware + +import ( + "context" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +// ClientRequestID ensures every request has a unique client_request_id in request.Context(). +// +// This is used by the Ops monitoring module for end-to-end request correlation. +func ClientRequestID() gin.HandlerFunc { + return func(c *gin.Context) { + if c.Request == nil { + c.Next() + return + } + + if v := c.Request.Context().Value(ctxkey.ClientRequestID); v != nil { + c.Next() + return + } + + id := uuid.New().String() + c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)) + c.Next() + } +} + diff --git a/backend/internal/server/middleware/ws_query_token_auth.go b/backend/internal/server/middleware/ws_query_token_auth.go new file mode 100644 index 00000000..3b8d086a --- /dev/null +++ b/backend/internal/server/middleware/ws_query_token_auth.go @@ -0,0 +1,54 @@ +package middleware + +import ( + "net/http" + "strings" + + "github.com/gin-gonic/gin" +) + +// InjectBearerTokenFromQueryForWebSocket copies `?token=` into the Authorization header +// for WebSocket handshake requests on a small allow-list of endpoints. +// +// Why: browsers can't set custom headers on WebSocket handshake, but our admin routes +// are protected by header-based auth. This keeps the token support scoped to WS only. +func InjectBearerTokenFromQueryForWebSocket() gin.HandlerFunc { + return func(c *gin.Context) { + if c == nil || c.Request == nil { + if c != nil { + c.Next() + } + return + } + + // Only GET websocket upgrades. + if c.Request.Method != http.MethodGet { + c.Next() + return + } + if !strings.EqualFold(strings.TrimSpace(c.GetHeader("Upgrade")), "websocket") { + c.Next() + return + } + + // If caller already supplied auth headers, don't override. + if strings.TrimSpace(c.GetHeader("Authorization")) != "" || strings.TrimSpace(c.GetHeader("x-api-key")) != "" { + c.Next() + return + } + + // Allow-list ops websocket endpoints. + path := strings.TrimSpace(c.Request.URL.Path) + if !strings.HasPrefix(path, "/api/v1/admin/ops/ws/") { + c.Next() + return + } + + token := strings.TrimSpace(c.Query("token")) + if token != "" { + c.Request.Header.Set("Authorization", "Bearer "+token) + } + + c.Next() + } +}