package service import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "regexp" "strings" "time" "github.com/tidwall/gjson" ) // monitorHTTPClient 共享一个 http.Client,避免每次检测重建 transport。 // 自定义 Transport 在 dial 时强制再次校验 IP,防止 DNS rebinding 绕过 validateEndpoint。 var monitorHTTPClient = newSSRFSafeHTTPClient(monitorRequestTimeout) // monitorPingHTTPClient 用于 endpoint origin 的 HEAD ping,超时更短。 var monitorPingHTTPClient = newSSRFSafeHTTPClient(monitorPingTimeout) // newSSRFSafeHTTPClient 返回一个使用 safeDialContext 的 http.Client。 // 仅供监控模块对外发起请求使用——所有目标都应是公网 endpoint。 func newSSRFSafeHTTPClient(timeout time.Duration) *http.Client { tr := &http.Transport{ DialContext: safeDialContext, ForceAttemptHTTP2: true, MaxIdleConns: 16, IdleConnTimeout: monitorIdleConnTimeout, TLSHandshakeTimeout: monitorTLSHandshakeTimeout, ResponseHeaderTimeout: monitorResponseHeaderTimeout, } return &http.Client{Timeout: timeout, Transport: tr} } // runCheckForModel 对单个 (provider, model) 做一次完整检测。 // 不返回 error:所有失败都包装进 CheckResult.Status=error/failed。 func runCheckForModel(ctx context.Context, provider, endpoint, apiKey, model string) *CheckResult { res := &CheckResult{ Model: model, Status: MonitorStatusError, CheckedAt: time.Now(), } challenge := generateChallenge() start := time.Now() respText, statusCode, err := callProvider(ctx, provider, endpoint, apiKey, model, challenge.Prompt) latency := time.Since(start) latencyMs := int(latency / time.Millisecond) res.LatencyMs = &latencyMs if err != nil { res.Status = MonitorStatusError res.Message = truncateMessage(sanitizeErrorMessage(err.Error())) return res } if statusCode < 200 || statusCode >= 300 { res.Status = MonitorStatusError res.Message = truncateMessage(sanitizeErrorMessage(fmt.Sprintf("upstream HTTP %d: %s", statusCode, respText))) return res } if !validateChallenge(respText, challenge.Expected) { res.Status = MonitorStatusFailed res.Message = truncateMessage(sanitizeErrorMessage(fmt.Sprintf("challenge mismatch (expected %s, got %q)", challenge.Expected, respText))) return res } if latency >= monitorDegradedThreshold { res.Status = MonitorStatusDegraded res.Message = truncateMessage(fmt.Sprintf("slow response: %dms", latencyMs)) return res } res.Status = MonitorStatusOperational return res } // pingEndpointOrigin 对 endpoint 的 origin (scheme://host) 发起 HEAD 请求,返回耗时。 // 失败时返回 nil(不影响主状态判定)。 func pingEndpointOrigin(ctx context.Context, endpoint string) *int { origin, err := extractOrigin(endpoint) if err != nil || origin == "" { return nil } req, err := http.NewRequestWithContext(ctx, http.MethodHead, origin, nil) if err != nil { return nil } start := time.Now() resp, err := monitorPingHTTPClient.Do(req) if err != nil { return nil } defer func() { _ = resp.Body.Close() }() _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, monitorPingDiscardMaxBytes)) ms := int(time.Since(start) / time.Millisecond) return &ms } // providerAdapter 描述某个 provider 在 challenge 检测中需要的 4 件事: // - 拼出请求路径(含 model 占位) // - 序列化请求体 // - 构造鉴权头 // - 从响应 JSON 中按 path 提取文本(gjson path) // // 加新 provider 只需要在 providerAdapters 里增加一个条目,无需触碰 callProvider / validateProvider。 type providerAdapter struct { buildPath func(model string) string buildBody func(model, prompt string) ([]byte, error) buildHeaders func(apiKey string) map[string]string textPath string // gjson 提取响应文本的 path } // providerAdapters 全部已支持的 provider。键值即 MonitorProvider* 字符串。 // //nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。 var providerAdapters = map[string]providerAdapter{ MonitorProviderOpenAI: { buildPath: func(string) string { return providerOpenAIPath }, buildBody: func(model, prompt string) ([]byte, error) { return json.Marshal(map[string]any{ "model": model, "messages": []map[string]string{{"role": "user", "content": prompt}}, "max_tokens": monitorChallengeMaxTokens, "stream": false, }) }, buildHeaders: func(apiKey string) map[string]string { return map[string]string{"Authorization": "Bearer " + apiKey} }, textPath: "choices.0.message.content", }, MonitorProviderAnthropic: { buildPath: func(string) string { return providerAnthropicPath }, buildBody: func(model, prompt string) ([]byte, error) { return json.Marshal(map[string]any{ "model": model, "messages": []map[string]string{{"role": "user", "content": prompt}}, "max_tokens": monitorChallengeMaxTokens, }) }, buildHeaders: func(apiKey string) map[string]string { return map[string]string{ "x-api-key": apiKey, "anthropic-version": monitorAnthropicAPIVersion, } }, textPath: "content.0.text", }, MonitorProviderGemini: { // Gemini 把 model 名写在 URL path 上:/v1beta/models/{model}:generateContent buildPath: func(model string) string { return fmt.Sprintf(providerGeminiPathTemplate, model) }, buildBody: func(_, prompt string) ([]byte, error) { return json.Marshal(map[string]any{ "contents": []map[string]any{ {"parts": []map[string]any{{"text": prompt}}}, }, "generationConfig": map[string]any{"maxOutputTokens": monitorChallengeMaxTokens}, }) }, // 使用 x-goog-api-key header 而不是 ?key= query,避免 *url.Error 把 key 回填到错误日志。 buildHeaders: func(apiKey string) map[string]string { return map[string]string{"x-goog-api-key": apiKey} }, textPath: "candidates.0.content.parts.0.text", }, } // isSupportedProvider 校验 provider 字符串是否在 adapter 表中。 // 供 validate.go 的 validateProvider 复用,避免两份 switch 漂移。 func isSupportedProvider(p string) bool { _, ok := providerAdapters[p] return ok } // callProvider 通过 providerAdapters 分发到具体实现。 // 返回值:响应中提取的文本、HTTP status、网络/序列化错误。 func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt string) (string, int, error) { adapter, ok := providerAdapters[provider] if !ok { return "", 0, fmt.Errorf("unsupported provider %q", provider) } body, err := adapter.buildBody(model, prompt) if err != nil { return "", 0, fmt.Errorf("marshal body: %w", err) } full := joinURL(endpoint, adapter.buildPath(model)) respBody, status, err := postRawJSON(ctx, full, body, adapter.buildHeaders(apiKey)) if err != nil { return "", status, err } return gjson.GetBytes(respBody, adapter.textPath).String(), status, nil } // postRawJSON 发送 POST + 已序列化好的 JSON 字节,限制响应体大小,返回响应字节、HTTP status、错误。 // adapter 自行 marshal 是为了精确控制字段顺序与类型,所以这里直接收 []byte 而不是 any。 func postRawJSON(ctx context.Context, fullURL string, payload []byte, headers map[string]string) ([]byte, int, error) { req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload)) if err != nil { return nil, 0, fmt.Errorf("build request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") for k, v := range headers { req.Header.Set(k, v) } resp, err := monitorHTTPClient.Do(req) if err != nil { return nil, 0, fmt.Errorf("do request: %w", err) } defer func() { _ = resp.Body.Close() }() respBody, err := io.ReadAll(io.LimitReader(resp.Body, monitorResponseMaxBytes)) if err != nil { return nil, resp.StatusCode, fmt.Errorf("read body: %w", err) } return respBody, resp.StatusCode, nil } // joinURL 把 base origin 与 path 拼成完整 URL。 // 容忍 base 末尾有/无斜杠,path 必带前导斜杠。 func joinURL(base, path string) string { base = strings.TrimRight(base, "/") if !strings.HasPrefix(path, "/") { path = "/" + path } return base + path } // extractOrigin 从一个 endpoint URL 中提取 scheme://host[:port] 部分。 func extractOrigin(endpoint string) (string, error) { u, err := url.Parse(endpoint) if err != nil { return "", err } if u.Scheme == "" || u.Host == "" { return "", errors.New("endpoint missing scheme or host") } return u.Scheme + "://" + u.Host, nil } // monitorSensitiveQueryParamRegex 匹配 URL query 中可能泄露凭证的参数: // key / api_key / api-key / access_token / token / authorization / x-api-key。 // 大小写不敏感,匹配 `?name=value` 或 `&name=value` 形式(value 截到 & 或字符串末尾)。 var monitorSensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|api[_-]?key|access[_-]?token|token|authorization|x-api-key)=)[^&\s"']+`) // monitorAPIKeyPatterns 匹配常见 provider 的 API key 字面量。 // 顺序敏感:sk-ant- 必须放在 sk- 之前,否则会被通用 sk- 模式先消费。 var monitorAPIKeyPatterns = []struct { pattern *regexp.Regexp replace string }{ // Anthropic(带前缀,必须先匹配):sk-ant-xxxxxxx {regexp.MustCompile(`sk-ant-[A-Za-z0-9_-]{20,}`), "sk-ant-***REDACTED***"}, // OpenAI / Anthropic 通用 sk-: sk-xxxxxxx {regexp.MustCompile(`sk-[A-Za-z0-9-]{20,}`), "sk-***REDACTED***"}, // Gemini / Google API Key:固定前缀 + 35 位 {regexp.MustCompile(`AIza[A-Za-z0-9_-]{35}`), "AIza***REDACTED***"}, // JWT 三段式(Bearer 后常出现):eyJxxx.eyJxxx.signature {regexp.MustCompile(`eyJ[A-Za-z0-9_-]{8,}\.eyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}`), "eyJ***REDACTED.JWT***"}, } // sanitizeErrorMessage 擦除错误/响应文本中可能泄露的 API key。 // 处理两类来源: // 1. URL query 中的 ?key= / ?api_key= 等(Go *url.Error 会回填完整 URL) // 2. 上游 HTTP body 文本里直接出现的 sk-* / AIza* / JWT 等密钥碎片 // // 注意:与 gemini_messages_compat_service.go 的 sanitizeUpstreamErrorMessage 关注点类似但参数集更广, // 监控模块独立维护,避免互相耦合。 func sanitizeErrorMessage(msg string) string { if msg == "" { return msg } msg = monitorSensitiveQueryParamRegex.ReplaceAllString(msg, `${1}REDACTED`) for _, p := range monitorAPIKeyPatterns { msg = p.pattern.ReplaceAllString(msg, p.replace) } return msg } // truncateMessage 把消息按 monitorMessageMaxBytes 截断,避免 DB 列溢出与日志过长。 func truncateMessage(msg string) string { if len(msg) <= monitorMessageMaxBytes { return msg } const ellipsis = "...(truncated)" cutoff := monitorMessageMaxBytes - len(ellipsis) if cutoff < 0 { cutoff = 0 } return msg[:cutoff] + ellipsis }