refactor: 移除 Ops 监控模块
移除未完成的运维监控功能,简化系统架构: - 删除 ops_handler, ops_service, ops_repo 等后端代码 - 删除 ops 相关数据库迁移文件 - 删除前端 OpsDashboard 页面和 API
This commit is contained in:
@@ -1,402 +0,0 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpsHandler handles ops dashboard endpoints.
|
||||
type OpsHandler struct {
|
||||
opsService *service.OpsService
|
||||
}
|
||||
|
||||
// NewOpsHandler creates a new OpsHandler.
|
||||
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
|
||||
return &OpsHandler{opsService: opsService}
|
||||
}
|
||||
|
||||
// GetMetrics returns the latest ops metrics snapshot.
|
||||
// GET /api/v1/admin/ops/metrics
|
||||
func (h *OpsHandler) GetMetrics(c *gin.Context) {
|
||||
metrics, err := h.opsService.GetLatestMetrics(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get ops metrics")
|
||||
return
|
||||
}
|
||||
response.Success(c, metrics)
|
||||
}
|
||||
|
||||
// ListMetricsHistory returns a time-range slice of metrics for charts.
|
||||
// GET /api/v1/admin/ops/metrics/history
|
||||
//
|
||||
// Query params:
|
||||
// - window_minutes: int (default 1)
|
||||
// - minutes: int (lookback; optional)
|
||||
// - start_time/end_time: RFC3339 timestamps (optional; overrides minutes when provided)
|
||||
// - limit: int (optional; max 100, default 300 for backward compatibility)
|
||||
func (h *OpsHandler) ListMetricsHistory(c *gin.Context) {
|
||||
windowMinutes := 1
|
||||
if v := c.Query("window_minutes"); v != "" {
|
||||
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
|
||||
windowMinutes = parsed
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid window_minutes")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
limit := 300
|
||||
limitProvided := false
|
||||
if v := c.Query("limit"); v != "" {
|
||||
parsed, err := strconv.Atoi(v)
|
||||
if err != nil || parsed <= 0 || parsed > 5000 {
|
||||
response.BadRequest(c, "Invalid limit (must be 1-5000)")
|
||||
return
|
||||
}
|
||||
limit = parsed
|
||||
limitProvided = true
|
||||
}
|
||||
|
||||
endTime := time.Now()
|
||||
startTime := time.Time{}
|
||||
|
||||
if startTimeStr := c.Query("start_time"); startTimeStr != "" {
|
||||
parsed, err := time.Parse(time.RFC3339, startTimeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_time format (RFC3339)")
|
||||
return
|
||||
}
|
||||
startTime = parsed
|
||||
}
|
||||
if endTimeStr := c.Query("end_time"); endTimeStr != "" {
|
||||
parsed, err := time.Parse(time.RFC3339, endTimeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_time format (RFC3339)")
|
||||
return
|
||||
}
|
||||
endTime = parsed
|
||||
}
|
||||
|
||||
// If explicit range not provided, use lookback minutes.
|
||||
if startTime.IsZero() {
|
||||
if v := c.Query("minutes"); v != "" {
|
||||
minutes, err := strconv.Atoi(v)
|
||||
if err != nil || minutes <= 0 {
|
||||
response.BadRequest(c, "Invalid minutes")
|
||||
return
|
||||
}
|
||||
if minutes > 60*24*7 {
|
||||
minutes = 60 * 24 * 7
|
||||
}
|
||||
startTime = endTime.Add(-time.Duration(minutes) * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
// Default time range: last 24 hours.
|
||||
if startTime.IsZero() {
|
||||
startTime = endTime.Add(-24 * time.Hour)
|
||||
if !limitProvided {
|
||||
// Metrics are collected at 1-minute cadence; 24h requires ~1440 points.
|
||||
limit = 24 * 60
|
||||
}
|
||||
}
|
||||
|
||||
if startTime.After(endTime) {
|
||||
response.BadRequest(c, "Invalid time range: start_time must be <= end_time")
|
||||
return
|
||||
}
|
||||
|
||||
items, err := h.opsService.ListMetricsHistory(c.Request.Context(), windowMinutes, startTime, endTime, limit)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to list ops metrics history")
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"items": items})
|
||||
}
|
||||
|
||||
// ListErrorLogs lists recent error logs with optional filters.
|
||||
// GET /api/v1/admin/ops/error-logs
|
||||
//
|
||||
// Query params:
|
||||
// - start_time/end_time: RFC3339 timestamps (optional)
|
||||
// - platform: string (optional)
|
||||
// - phase: string (optional)
|
||||
// - severity: string (optional)
|
||||
// - q: string (optional; fuzzy match)
|
||||
// - limit: int (optional; default 100; max 500)
|
||||
func (h *OpsHandler) ListErrorLogs(c *gin.Context) {
|
||||
var filters service.OpsErrorLogFilters
|
||||
|
||||
if startTimeStr := c.Query("start_time"); startTimeStr != "" {
|
||||
startTime, err := time.Parse(time.RFC3339, startTimeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_time format (RFC3339)")
|
||||
return
|
||||
}
|
||||
filters.StartTime = &startTime
|
||||
}
|
||||
if endTimeStr := c.Query("end_time"); endTimeStr != "" {
|
||||
endTime, err := time.Parse(time.RFC3339, endTimeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_time format (RFC3339)")
|
||||
return
|
||||
}
|
||||
filters.EndTime = &endTime
|
||||
}
|
||||
|
||||
if filters.StartTime != nil && filters.EndTime != nil && filters.StartTime.After(*filters.EndTime) {
|
||||
response.BadRequest(c, "Invalid time range: start_time must be <= end_time")
|
||||
return
|
||||
}
|
||||
|
||||
filters.Platform = c.Query("platform")
|
||||
filters.Phase = c.Query("phase")
|
||||
filters.Severity = c.Query("severity")
|
||||
filters.Query = c.Query("q")
|
||||
|
||||
filters.Limit = 100
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil || limit <= 0 || limit > 500 {
|
||||
response.BadRequest(c, "Invalid limit (must be 1-500)")
|
||||
return
|
||||
}
|
||||
filters.Limit = limit
|
||||
}
|
||||
|
||||
items, total, err := h.opsService.ListErrorLogs(c.Request.Context(), filters)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to list error logs")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
})
|
||||
}
|
||||
|
||||
// GetDashboardOverview returns realtime ops dashboard overview.
|
||||
// GET /api/v1/admin/ops/dashboard/overview
|
||||
//
|
||||
// Query params:
|
||||
// - time_range: string (optional; default "1h") one of: 5m, 30m, 1h, 6h, 24h
|
||||
func (h *OpsHandler) GetDashboardOverview(c *gin.Context) {
|
||||
timeRange := c.Query("time_range")
|
||||
if timeRange == "" {
|
||||
timeRange = "1h"
|
||||
}
|
||||
|
||||
switch timeRange {
|
||||
case "5m", "30m", "1h", "6h", "24h":
|
||||
default:
|
||||
response.BadRequest(c, "Invalid time_range (supported: 5m, 30m, 1h, 6h, 24h)")
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.opsService.GetDashboardOverview(c.Request.Context(), timeRange)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get dashboard overview")
|
||||
return
|
||||
}
|
||||
response.Success(c, data)
|
||||
}
|
||||
|
||||
// GetProviderHealth returns upstream provider health comparison data.
|
||||
// GET /api/v1/admin/ops/dashboard/providers
|
||||
//
|
||||
// Query params:
|
||||
// - time_range: string (optional; default "1h") one of: 5m, 30m, 1h, 6h, 24h
|
||||
func (h *OpsHandler) GetProviderHealth(c *gin.Context) {
|
||||
timeRange := c.Query("time_range")
|
||||
if timeRange == "" {
|
||||
timeRange = "1h"
|
||||
}
|
||||
|
||||
switch timeRange {
|
||||
case "5m", "30m", "1h", "6h", "24h":
|
||||
default:
|
||||
response.BadRequest(c, "Invalid time_range (supported: 5m, 30m, 1h, 6h, 24h)")
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.opsService.GetProviderHealth(c.Request.Context(), timeRange)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get provider health")
|
||||
return
|
||||
}
|
||||
|
||||
var totalRequests int64
|
||||
var weightedSuccess float64
|
||||
var bestProvider string
|
||||
var worstProvider string
|
||||
var bestRate float64
|
||||
var worstRate float64
|
||||
hasRate := false
|
||||
|
||||
for _, p := range providers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
totalRequests += p.RequestCount
|
||||
weightedSuccess += (p.SuccessRate / 100) * float64(p.RequestCount)
|
||||
|
||||
if p.RequestCount <= 0 {
|
||||
continue
|
||||
}
|
||||
if !hasRate {
|
||||
bestProvider = p.Name
|
||||
worstProvider = p.Name
|
||||
bestRate = p.SuccessRate
|
||||
worstRate = p.SuccessRate
|
||||
hasRate = true
|
||||
continue
|
||||
}
|
||||
|
||||
if p.SuccessRate > bestRate {
|
||||
bestProvider = p.Name
|
||||
bestRate = p.SuccessRate
|
||||
}
|
||||
if p.SuccessRate < worstRate {
|
||||
worstProvider = p.Name
|
||||
worstRate = p.SuccessRate
|
||||
}
|
||||
}
|
||||
|
||||
avgSuccessRate := 0.0
|
||||
if totalRequests > 0 {
|
||||
avgSuccessRate = (weightedSuccess / float64(totalRequests)) * 100
|
||||
avgSuccessRate = math.Round(avgSuccessRate*100) / 100
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"providers": providers,
|
||||
"summary": gin.H{
|
||||
"total_requests": totalRequests,
|
||||
"avg_success_rate": avgSuccessRate,
|
||||
"best_provider": bestProvider,
|
||||
"worst_provider": worstProvider,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// GetErrorLogs returns a paginated error log list with multi-dimensional filters.
|
||||
// GET /api/v1/admin/ops/errors
|
||||
func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
filter := &service.ErrorLogFilter{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
if startTimeStr := c.Query("start_time"); startTimeStr != "" {
|
||||
startTime, err := time.Parse(time.RFC3339, startTimeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_time format (RFC3339)")
|
||||
return
|
||||
}
|
||||
filter.StartTime = &startTime
|
||||
}
|
||||
if endTimeStr := c.Query("end_time"); endTimeStr != "" {
|
||||
endTime, err := time.Parse(time.RFC3339, endTimeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_time format (RFC3339)")
|
||||
return
|
||||
}
|
||||
filter.EndTime = &endTime
|
||||
}
|
||||
|
||||
if filter.StartTime != nil && filter.EndTime != nil && filter.StartTime.After(*filter.EndTime) {
|
||||
response.BadRequest(c, "Invalid time range: start_time must be <= end_time")
|
||||
return
|
||||
}
|
||||
|
||||
if errorCodeStr := c.Query("error_code"); errorCodeStr != "" {
|
||||
code, err := strconv.Atoi(errorCodeStr)
|
||||
if err != nil || code < 0 {
|
||||
response.BadRequest(c, "Invalid error_code")
|
||||
return
|
||||
}
|
||||
filter.ErrorCode = &code
|
||||
}
|
||||
|
||||
// Keep both parameter names for compatibility: provider (docs) and platform (legacy).
|
||||
filter.Provider = c.Query("provider")
|
||||
if filter.Provider == "" {
|
||||
filter.Provider = c.Query("platform")
|
||||
}
|
||||
|
||||
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||
accountID, err := strconv.ParseInt(accountIDStr, 10, 64)
|
||||
if err != nil || accountID <= 0 {
|
||||
response.BadRequest(c, "Invalid account_id")
|
||||
return
|
||||
}
|
||||
filter.AccountID = &accountID
|
||||
}
|
||||
|
||||
out, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get error logs")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"errors": out.Errors,
|
||||
"total": out.Total,
|
||||
"page": out.Page,
|
||||
"page_size": out.PageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// GetLatencyHistogram returns the latency distribution histogram.
|
||||
// GET /api/v1/admin/ops/dashboard/latency-histogram
|
||||
func (h *OpsHandler) GetLatencyHistogram(c *gin.Context) {
|
||||
timeRange := c.Query("time_range")
|
||||
if timeRange == "" {
|
||||
timeRange = "1h"
|
||||
}
|
||||
|
||||
buckets, err := h.opsService.GetLatencyHistogram(c.Request.Context(), timeRange)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get latency histogram")
|
||||
return
|
||||
}
|
||||
|
||||
totalRequests := int64(0)
|
||||
for _, b := range buckets {
|
||||
totalRequests += b.Count
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"buckets": buckets,
|
||||
"total_requests": totalRequests,
|
||||
"slow_request_threshold": 1000,
|
||||
})
|
||||
}
|
||||
|
||||
// GetErrorDistribution returns the error distribution.
|
||||
// GET /api/v1/admin/ops/dashboard/errors/distribution
|
||||
func (h *OpsHandler) GetErrorDistribution(c *gin.Context) {
|
||||
timeRange := c.Query("time_range")
|
||||
if timeRange == "" {
|
||||
timeRange = "1h"
|
||||
}
|
||||
|
||||
items, err := h.opsService.GetErrorDistribution(c.Request.Context(), timeRange)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get error distribution")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"items": items,
|
||||
})
|
||||
}
|
||||
@@ -1,286 +0,0 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type OpsWSProxyConfig struct {
|
||||
TrustProxy bool
|
||||
TrustedProxies []netip.Prefix
|
||||
OriginPolicy string
|
||||
}
|
||||
|
||||
const (
|
||||
envOpsWSTrustProxy = "OPS_WS_TRUST_PROXY"
|
||||
envOpsWSTrustedProxies = "OPS_WS_TRUSTED_PROXIES"
|
||||
envOpsWSOriginPolicy = "OPS_WS_ORIGIN_POLICY"
|
||||
)
|
||||
|
||||
const (
|
||||
OriginPolicyStrict = "strict"
|
||||
OriginPolicyPermissive = "permissive"
|
||||
)
|
||||
|
||||
var opsWSProxyConfig = loadOpsWSProxyConfigFromEnv()
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return isAllowedOpsWSOrigin(r)
|
||||
},
|
||||
}
|
||||
|
||||
// QPSWSHandler handles realtime QPS push via WebSocket.
|
||||
// GET /api/v1/admin/ops/ws/qps
|
||||
func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("[OpsWS] upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
// Set pong handler
|
||||
if err := conn.SetReadDeadline(time.Now().Add(60 * time.Second)); err != nil {
|
||||
log.Printf("[OpsWS] set read deadline failed: %v", err)
|
||||
return
|
||||
}
|
||||
conn.SetPongHandler(func(string) error {
|
||||
return conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
})
|
||||
|
||||
// Push QPS data every 2 seconds
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Heartbeat ping every 30 seconds
|
||||
pingTicker := time.NewTicker(30 * time.Second)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Fetch 1m window stats for current QPS
|
||||
data, err := h.opsService.GetDashboardOverview(ctx, "5m")
|
||||
if err != nil {
|
||||
log.Printf("[OpsWS] get overview failed: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
payload := gin.H{
|
||||
"type": "qps_update",
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
"data": gin.H{
|
||||
"qps": data.QPS.Current,
|
||||
"tps": data.TPS.Current,
|
||||
"request_count": data.Errors.TotalCount + int64(data.QPS.Avg1h*60), // Rough estimate
|
||||
},
|
||||
}
|
||||
|
||||
msg, _ := json.Marshal(payload)
|
||||
if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil {
|
||||
log.Printf("[OpsWS] write failed: %v", err)
|
||||
return
|
||||
}
|
||||
case <-pingTicker.C:
|
||||
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
log.Printf("[OpsWS] ping failed: %v", err)
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isAllowedOpsWSOrigin(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
||||
if origin == "" {
|
||||
switch strings.ToLower(strings.TrimSpace(opsWSProxyConfig.OriginPolicy)) {
|
||||
case OriginPolicyStrict:
|
||||
return false
|
||||
case OriginPolicyPermissive, "":
|
||||
return true
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
parsed, err := url.Parse(origin)
|
||||
if err != nil || parsed.Hostname() == "" {
|
||||
return false
|
||||
}
|
||||
originHost := strings.ToLower(parsed.Hostname())
|
||||
|
||||
trustProxyHeaders := shouldTrustOpsWSProxyHeaders(r)
|
||||
reqHost := hostWithoutPort(r.Host)
|
||||
if trustProxyHeaders {
|
||||
xfHost := strings.TrimSpace(r.Header.Get("X-Forwarded-Host"))
|
||||
if xfHost != "" {
|
||||
xfHost = strings.TrimSpace(strings.Split(xfHost, ",")[0])
|
||||
if xfHost != "" {
|
||||
reqHost = hostWithoutPort(xfHost)
|
||||
}
|
||||
}
|
||||
}
|
||||
reqHost = strings.ToLower(reqHost)
|
||||
if reqHost == "" {
|
||||
return false
|
||||
}
|
||||
return originHost == reqHost
|
||||
}
|
||||
|
||||
func shouldTrustOpsWSProxyHeaders(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
if !opsWSProxyConfig.TrustProxy {
|
||||
return false
|
||||
}
|
||||
peerIP, ok := requestPeerIP(r)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return isAddrInTrustedProxies(peerIP, opsWSProxyConfig.TrustedProxies)
|
||||
}
|
||||
|
||||
func requestPeerIP(r *http.Request) (netip.Addr, bool) {
|
||||
if r == nil {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
|
||||
if err != nil {
|
||||
host = strings.TrimSpace(r.RemoteAddr)
|
||||
}
|
||||
host = strings.TrimPrefix(host, "[")
|
||||
host = strings.TrimSuffix(host, "]")
|
||||
if host == "" {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
addr, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
return addr.Unmap(), true
|
||||
}
|
||||
|
||||
func isAddrInTrustedProxies(addr netip.Addr, trusted []netip.Prefix) bool {
|
||||
if !addr.IsValid() {
|
||||
return false
|
||||
}
|
||||
for _, p := range trusted {
|
||||
if p.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
|
||||
cfg := OpsWSProxyConfig{
|
||||
TrustProxy: true,
|
||||
TrustedProxies: defaultTrustedProxies(),
|
||||
OriginPolicy: OriginPolicyPermissive,
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(os.Getenv(envOpsWSTrustProxy)); v != "" {
|
||||
if parsed, err := strconv.ParseBool(v); err == nil {
|
||||
cfg.TrustProxy = parsed
|
||||
} else {
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
|
||||
}
|
||||
}
|
||||
|
||||
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
|
||||
prefixes, invalid := parseTrustedProxyList(raw)
|
||||
if len(invalid) > 0 {
|
||||
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
|
||||
}
|
||||
cfg.TrustedProxies = prefixes
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(os.Getenv(envOpsWSOriginPolicy)); v != "" {
|
||||
normalized := strings.ToLower(v)
|
||||
switch normalized {
|
||||
case OriginPolicyStrict, OriginPolicyPermissive:
|
||||
cfg.OriginPolicy = normalized
|
||||
default:
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func defaultTrustedProxies() []netip.Prefix {
|
||||
prefixes, _ := parseTrustedProxyList("127.0.0.0/8,::1/128")
|
||||
return prefixes
|
||||
}
|
||||
|
||||
func parseTrustedProxyList(raw string) (prefixes []netip.Prefix, invalid []string) {
|
||||
for _, token := range strings.Split(raw, ",") {
|
||||
item := strings.TrimSpace(token)
|
||||
if item == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var (
|
||||
p netip.Prefix
|
||||
err error
|
||||
)
|
||||
if strings.Contains(item, "/") {
|
||||
p, err = netip.ParsePrefix(item)
|
||||
} else {
|
||||
var addr netip.Addr
|
||||
addr, err = netip.ParseAddr(item)
|
||||
if err == nil {
|
||||
addr = addr.Unmap()
|
||||
bits := 128
|
||||
if addr.Is4() {
|
||||
bits = 32
|
||||
}
|
||||
p = netip.PrefixFrom(addr, bits)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil || !p.IsValid() {
|
||||
invalid = append(invalid, item)
|
||||
continue
|
||||
}
|
||||
|
||||
prefixes = append(prefixes, p.Masked())
|
||||
}
|
||||
return prefixes, invalid
|
||||
}
|
||||
|
||||
func hostWithoutPort(hostport string) string {
|
||||
hostport = strings.TrimSpace(hostport)
|
||||
if hostport == "" {
|
||||
return ""
|
||||
}
|
||||
if host, _, err := net.SplitHostPort(hostport); err == nil {
|
||||
return host
|
||||
}
|
||||
if strings.HasPrefix(hostport, "[") && strings.HasSuffix(hostport, "]") {
|
||||
return strings.Trim(hostport, "[]")
|
||||
}
|
||||
parts := strings.Split(hostport, ":")
|
||||
return parts[0]
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsAllowedOpsWSOrigin_AllowsEmptyOrigin(t *testing.T) {
|
||||
original := opsWSProxyConfig
|
||||
t.Cleanup(func() { opsWSProxyConfig = original })
|
||||
opsWSProxyConfig = OpsWSProxyConfig{OriginPolicy: OriginPolicyPermissive}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://example.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
|
||||
if !isAllowedOpsWSOrigin(req) {
|
||||
t.Fatalf("expected empty Origin to be allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAllowedOpsWSOrigin_RejectsEmptyOrigin_WhenStrict(t *testing.T) {
|
||||
original := opsWSProxyConfig
|
||||
t.Cleanup(func() { opsWSProxyConfig = original })
|
||||
opsWSProxyConfig = OpsWSProxyConfig{OriginPolicy: OriginPolicyStrict}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://example.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
|
||||
if isAllowedOpsWSOrigin(req) {
|
||||
t.Fatalf("expected empty Origin to be rejected under strict policy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAllowedOpsWSOrigin_UsesXForwardedHostOnlyFromTrustedProxy(t *testing.T) {
|
||||
original := opsWSProxyConfig
|
||||
t.Cleanup(func() { opsWSProxyConfig = original })
|
||||
|
||||
opsWSProxyConfig = OpsWSProxyConfig{
|
||||
TrustProxy: true,
|
||||
TrustedProxies: []netip.Prefix{
|
||||
netip.MustParsePrefix("127.0.0.0/8"),
|
||||
},
|
||||
}
|
||||
|
||||
// Untrusted peer: ignore X-Forwarded-Host and compare against r.Host.
|
||||
{
|
||||
req, err := http.NewRequest(http.MethodGet, "http://internal.service.local", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
req.RemoteAddr = "192.0.2.1:12345"
|
||||
req.Host = "internal.service.local"
|
||||
req.Header.Set("Origin", "https://public.example.com")
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com")
|
||||
|
||||
if isAllowedOpsWSOrigin(req) {
|
||||
t.Fatalf("expected Origin to be rejected when peer is not a trusted proxy")
|
||||
}
|
||||
}
|
||||
|
||||
// Trusted peer: allow X-Forwarded-Host to participate in Origin validation.
|
||||
{
|
||||
req, err := http.NewRequest(http.MethodGet, "http://internal.service.local", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
req.RemoteAddr = "127.0.0.1:23456"
|
||||
req.Host = "internal.service.local"
|
||||
req.Header.Set("Origin", "https://public.example.com")
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com")
|
||||
|
||||
if !isAllowedOpsWSOrigin(req) {
|
||||
t.Fatalf("expected Origin to be accepted when peer is a trusted proxy")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadOpsWSProxyConfigFromEnv_OriginPolicy(t *testing.T) {
|
||||
t.Setenv(envOpsWSOriginPolicy, "STRICT")
|
||||
cfg := loadOpsWSProxyConfigFromEnv()
|
||||
if cfg.OriginPolicy != OriginPolicyStrict {
|
||||
t.Fatalf("OriginPolicy=%q, want %q", cfg.OriginPolicy, OriginPolicyStrict)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadOpsWSProxyConfigFromEnv_OriginPolicyInvalidUsesDefault(t *testing.T) {
|
||||
t.Setenv(envOpsWSOriginPolicy, "nope")
|
||||
cfg := loadOpsWSProxyConfigFromEnv()
|
||||
if cfg.OriginPolicy != OriginPolicyPermissive {
|
||||
t.Fatalf("OriginPolicy=%q, want %q", cfg.OriginPolicy, OriginPolicyPermissive)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTrustedProxyList(t *testing.T) {
|
||||
prefixes, invalid := parseTrustedProxyList("10.0.0.1, 10.0.0.0/8, bad, ::1/128")
|
||||
if len(prefixes) != 3 {
|
||||
t.Fatalf("prefixes=%d, want 3", len(prefixes))
|
||||
}
|
||||
if len(invalid) != 1 || invalid[0] != "bad" {
|
||||
t.Fatalf("invalid=%v, want [bad]", invalid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestPeerIP_ParsesIPv6(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "http://example.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
req.RemoteAddr = "[::1]:1234"
|
||||
|
||||
addr, ok := requestPeerIP(req)
|
||||
if !ok {
|
||||
t.Fatalf("expected IPv6 peer IP to parse")
|
||||
}
|
||||
if addr != netip.MustParseAddr("::1") {
|
||||
t.Fatalf("addr=%s, want ::1", addr)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user