Monitor:
- callProvider now returns both textPath-extracted text and raw body;
runCheckForModel uses rawBody on non-2xx so history.message stops being
"upstream HTTP 503: " with empty body (gjson textPath produces "" for
error responses like {"error":{"message":"No available accounts..."}})
- truncateForErrorBody collapses whitespace then caps at 300 bytes
(monitorErrorBodySnippetMaxBytes); final truncateMessage still enforces
the 500-byte DB column cap
Frontend:
- MonitorFormDialog: primary_model input text color and ModelTagInput tags
now both track form.provider (via new getPlatformTextClass + existing
getPlatformTagClass with platform prop).
(cherry-picked from 1d3b0418; dropped gateway_handler logging改动,不在本 PR 范围)
324 lines
12 KiB
Go
324 lines
12 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"net/url"
|
||
"regexp"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
// monitorHTTPClient 共享一个 http.Client,避免每次检测重建 transport。
|
||
// 自定义 Transport 在 dial 时强制再次校验 IP,防止 DNS rebinding 绕过 validateEndpoint。
|
||
var monitorHTTPClient = newSSRFSafeHTTPClient(monitorRequestTimeout)
|
||
|
||
// monitorPingHTTPClient 用于 endpoint origin 的 HEAD ping,超时更短。
|
||
var monitorPingHTTPClient = newSSRFSafeHTTPClient(monitorPingTimeout)
|
||
|
||
// newSSRFSafeHTTPClient 返回一个使用 safeDialContext 的 http.Client。
|
||
// 仅供监控模块对外发起请求使用——所有目标都应是公网 endpoint。
|
||
func newSSRFSafeHTTPClient(timeout time.Duration) *http.Client {
|
||
tr := &http.Transport{
|
||
DialContext: safeDialContext,
|
||
ForceAttemptHTTP2: true,
|
||
MaxIdleConns: 16,
|
||
IdleConnTimeout: monitorIdleConnTimeout,
|
||
TLSHandshakeTimeout: monitorTLSHandshakeTimeout,
|
||
ResponseHeaderTimeout: monitorResponseHeaderTimeout,
|
||
}
|
||
return &http.Client{Timeout: timeout, Transport: tr}
|
||
}
|
||
|
||
// runCheckForModel 对单个 (provider, model) 做一次完整检测。
|
||
// 不返回 error:所有失败都包装进 CheckResult.Status=error/failed。
|
||
func runCheckForModel(ctx context.Context, provider, endpoint, apiKey, model string) *CheckResult {
|
||
res := &CheckResult{
|
||
Model: model,
|
||
Status: MonitorStatusError,
|
||
CheckedAt: time.Now(),
|
||
}
|
||
|
||
challenge := generateChallenge()
|
||
|
||
start := time.Now()
|
||
respText, rawBody, statusCode, err := callProvider(ctx, provider, endpoint, apiKey, model, challenge.Prompt)
|
||
latency := time.Since(start)
|
||
latencyMs := int(latency / time.Millisecond)
|
||
res.LatencyMs = &latencyMs
|
||
|
||
if err != nil {
|
||
res.Status = MonitorStatusError
|
||
res.Message = truncateMessage(sanitizeErrorMessage(err.Error()))
|
||
return res
|
||
}
|
||
if statusCode < 200 || statusCode >= 300 {
|
||
// 错误路径:用 rawBody 而非 respText(gjson textPath 抽取在错误响应里通常为空,
|
||
// 会丢掉真正的上游错误信息,例如 `{"error":{"message":"No available accounts ..."}}`)。
|
||
res.Status = MonitorStatusError
|
||
bodySnippet := truncateForErrorBody(rawBody)
|
||
res.Message = truncateMessage(sanitizeErrorMessage(fmt.Sprintf("upstream HTTP %d: %s", statusCode, bodySnippet)))
|
||
return res
|
||
}
|
||
|
||
if !validateChallenge(respText, challenge.Expected) {
|
||
res.Status = MonitorStatusFailed
|
||
res.Message = truncateMessage(sanitizeErrorMessage(fmt.Sprintf("challenge mismatch (expected %s, got %q)", challenge.Expected, respText)))
|
||
return res
|
||
}
|
||
|
||
if latency >= monitorDegradedThreshold {
|
||
res.Status = MonitorStatusDegraded
|
||
res.Message = truncateMessage(fmt.Sprintf("slow response: %dms", latencyMs))
|
||
return res
|
||
}
|
||
|
||
res.Status = MonitorStatusOperational
|
||
return res
|
||
}
|
||
|
||
// pingEndpointOrigin 对 endpoint 的 origin (scheme://host) 发起 HEAD 请求,返回耗时。
|
||
// 失败时返回 nil(不影响主状态判定)。
|
||
func pingEndpointOrigin(ctx context.Context, endpoint string) *int {
|
||
origin, err := extractOrigin(endpoint)
|
||
if err != nil || origin == "" {
|
||
return nil
|
||
}
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, origin, nil)
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
start := time.Now()
|
||
resp, err := monitorPingHTTPClient.Do(req)
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, monitorPingDiscardMaxBytes))
|
||
ms := int(time.Since(start) / time.Millisecond)
|
||
return &ms
|
||
}
|
||
|
||
// providerAdapter 描述某个 provider 在 challenge 检测中需要的 4 件事:
|
||
// - 拼出请求路径(含 model 占位)
|
||
// - 序列化请求体
|
||
// - 构造鉴权头
|
||
// - 从响应 JSON 中按 path 提取文本(gjson path)
|
||
//
|
||
// 加新 provider 只需要在 providerAdapters 里增加一个条目,无需触碰 callProvider / validateProvider。
|
||
type providerAdapter struct {
|
||
buildPath func(model string) string
|
||
buildBody func(model, prompt string) ([]byte, error)
|
||
buildHeaders func(apiKey string) map[string]string
|
||
textPath string // gjson 提取响应文本的 path
|
||
}
|
||
|
||
// providerAdapters 全部已支持的 provider。键值即 MonitorProvider* 字符串。
|
||
//
|
||
//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。
|
||
var providerAdapters = map[string]providerAdapter{
|
||
MonitorProviderOpenAI: {
|
||
buildPath: func(string) string { return providerOpenAIPath },
|
||
buildBody: func(model, prompt string) ([]byte, error) {
|
||
return json.Marshal(map[string]any{
|
||
"model": model,
|
||
"messages": []map[string]string{{"role": "user", "content": prompt}},
|
||
"max_tokens": monitorChallengeMaxTokens,
|
||
"stream": false,
|
||
})
|
||
},
|
||
buildHeaders: func(apiKey string) map[string]string {
|
||
return map[string]string{"Authorization": "Bearer " + apiKey}
|
||
},
|
||
textPath: "choices.0.message.content",
|
||
},
|
||
MonitorProviderAnthropic: {
|
||
buildPath: func(string) string { return providerAnthropicPath },
|
||
buildBody: func(model, prompt string) ([]byte, error) {
|
||
return json.Marshal(map[string]any{
|
||
"model": model,
|
||
"messages": []map[string]string{{"role": "user", "content": prompt}},
|
||
"max_tokens": monitorChallengeMaxTokens,
|
||
})
|
||
},
|
||
buildHeaders: func(apiKey string) map[string]string {
|
||
return map[string]string{
|
||
"x-api-key": apiKey,
|
||
"anthropic-version": monitorAnthropicAPIVersion,
|
||
}
|
||
},
|
||
textPath: "content.0.text",
|
||
},
|
||
MonitorProviderGemini: {
|
||
// Gemini 把 model 名写在 URL path 上:/v1beta/models/{model}:generateContent
|
||
buildPath: func(model string) string { return fmt.Sprintf(providerGeminiPathTemplate, model) },
|
||
buildBody: func(_, prompt string) ([]byte, error) {
|
||
return json.Marshal(map[string]any{
|
||
"contents": []map[string]any{
|
||
{"parts": []map[string]any{{"text": prompt}}},
|
||
},
|
||
"generationConfig": map[string]any{"maxOutputTokens": monitorChallengeMaxTokens},
|
||
})
|
||
},
|
||
// 使用 x-goog-api-key header 而不是 ?key= query,避免 *url.Error 把 key 回填到错误日志。
|
||
buildHeaders: func(apiKey string) map[string]string {
|
||
return map[string]string{"x-goog-api-key": apiKey}
|
||
},
|
||
textPath: "candidates.0.content.parts.0.text",
|
||
},
|
||
}
|
||
|
||
// isSupportedProvider 校验 provider 字符串是否在 adapter 表中。
|
||
// 供 validate.go 的 validateProvider 复用,避免两份 switch 漂移。
|
||
func isSupportedProvider(p string) bool {
|
||
_, ok := providerAdapters[p]
|
||
return ok
|
||
}
|
||
|
||
// callProvider 通过 providerAdapters 分发到具体实现。
|
||
//
|
||
// 返回值:
|
||
// - extractedText: 按 textPath 抽出的成功文本,仅在 status 2xx 时有意义;非 2xx 时通常为空串
|
||
// - rawBody: 完整响应体的字符串形式(已被 monitorResponseMaxBytes 截断),用于错误路径保留上游真实回包
|
||
// - status: HTTP 状态码
|
||
// - err: 网络 / 序列化错误
|
||
func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt string) (extractedText, rawBody string, status int, err error) {
|
||
adapter, ok := providerAdapters[provider]
|
||
if !ok {
|
||
return "", "", 0, fmt.Errorf("unsupported provider %q", provider)
|
||
}
|
||
body, err := adapter.buildBody(model, prompt)
|
||
if err != nil {
|
||
return "", "", 0, fmt.Errorf("marshal body: %w", err)
|
||
}
|
||
full := joinURL(endpoint, adapter.buildPath(model))
|
||
respBytes, status, err := postRawJSON(ctx, full, body, adapter.buildHeaders(apiKey))
|
||
if err != nil {
|
||
return "", "", status, err
|
||
}
|
||
return gjson.GetBytes(respBytes, adapter.textPath).String(), string(respBytes), status, nil
|
||
}
|
||
|
||
// postRawJSON 发送 POST + 已序列化好的 JSON 字节,限制响应体大小,返回响应字节、HTTP status、错误。
|
||
// adapter 自行 marshal 是为了精确控制字段顺序与类型,所以这里直接收 []byte 而不是 any。
|
||
func postRawJSON(ctx context.Context, fullURL string, payload []byte, headers map[string]string) ([]byte, int, error) {
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("build request: %w", err)
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Accept", "application/json")
|
||
for k, v := range headers {
|
||
req.Header.Set(k, v)
|
||
}
|
||
|
||
resp, err := monitorHTTPClient.Do(req)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("do request: %w", err)
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, monitorResponseMaxBytes))
|
||
if err != nil {
|
||
return nil, resp.StatusCode, fmt.Errorf("read body: %w", err)
|
||
}
|
||
return respBody, resp.StatusCode, nil
|
||
}
|
||
|
||
// joinURL 把 base origin 与 path 拼成完整 URL。
|
||
// 容忍 base 末尾有/无斜杠,path 必带前导斜杠。
|
||
func joinURL(base, path string) string {
|
||
base = strings.TrimRight(base, "/")
|
||
if !strings.HasPrefix(path, "/") {
|
||
path = "/" + path
|
||
}
|
||
return base + path
|
||
}
|
||
|
||
// extractOrigin 从一个 endpoint URL 中提取 scheme://host[:port] 部分。
|
||
func extractOrigin(endpoint string) (string, error) {
|
||
u, err := url.Parse(endpoint)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
if u.Scheme == "" || u.Host == "" {
|
||
return "", errors.New("endpoint missing scheme or host")
|
||
}
|
||
return u.Scheme + "://" + u.Host, nil
|
||
}
|
||
|
||
// monitorSensitiveQueryParamRegex 匹配 URL query 中可能泄露凭证的参数:
|
||
// key / api_key / api-key / access_token / token / authorization / x-api-key。
|
||
// 大小写不敏感,匹配 `?name=value` 或 `&name=value` 形式(value 截到 & 或字符串末尾)。
|
||
var monitorSensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|api[_-]?key|access[_-]?token|token|authorization|x-api-key)=)[^&\s"']+`)
|
||
|
||
// monitorAPIKeyPatterns 匹配常见 provider 的 API key 字面量。
|
||
// 顺序敏感:sk-ant- 必须放在 sk- 之前,否则会被通用 sk- 模式先消费。
|
||
var monitorAPIKeyPatterns = []struct {
|
||
pattern *regexp.Regexp
|
||
replace string
|
||
}{
|
||
// Anthropic(带前缀,必须先匹配):sk-ant-xxxxxxx
|
||
{regexp.MustCompile(`sk-ant-[A-Za-z0-9_-]{20,}`), "sk-ant-***REDACTED***"},
|
||
// OpenAI / Anthropic 通用 sk-: sk-xxxxxxx
|
||
{regexp.MustCompile(`sk-[A-Za-z0-9-]{20,}`), "sk-***REDACTED***"},
|
||
// Gemini / Google API Key:固定前缀 + 35 位
|
||
{regexp.MustCompile(`AIza[A-Za-z0-9_-]{35}`), "AIza***REDACTED***"},
|
||
// JWT 三段式(Bearer 后常出现):eyJxxx.eyJxxx.signature
|
||
{regexp.MustCompile(`eyJ[A-Za-z0-9_-]{8,}\.eyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}`), "eyJ***REDACTED.JWT***"},
|
||
}
|
||
|
||
// sanitizeErrorMessage 擦除错误/响应文本中可能泄露的 API key。
|
||
// 处理两类来源:
|
||
// 1. URL query 中的 ?key= / ?api_key= 等(Go *url.Error 会回填完整 URL)
|
||
// 2. 上游 HTTP body 文本里直接出现的 sk-* / AIza* / JWT 等密钥碎片
|
||
//
|
||
// 注意:与 gemini_messages_compat_service.go 的 sanitizeUpstreamErrorMessage 关注点类似但参数集更广,
|
||
// 监控模块独立维护,避免互相耦合。
|
||
func sanitizeErrorMessage(msg string) string {
|
||
if msg == "" {
|
||
return msg
|
||
}
|
||
msg = monitorSensitiveQueryParamRegex.ReplaceAllString(msg, `${1}REDACTED`)
|
||
for _, p := range monitorAPIKeyPatterns {
|
||
msg = p.pattern.ReplaceAllString(msg, p.replace)
|
||
}
|
||
return msg
|
||
}
|
||
|
||
// truncateMessage 把消息按 monitorMessageMaxBytes 截断,避免 DB 列溢出与日志过长。
|
||
func truncateMessage(msg string) string {
|
||
if len(msg) <= monitorMessageMaxBytes {
|
||
return msg
|
||
}
|
||
const ellipsis = "...(truncated)"
|
||
cutoff := monitorMessageMaxBytes - len(ellipsis)
|
||
if cutoff < 0 {
|
||
cutoff = 0
|
||
}
|
||
return msg[:cutoff] + ellipsis
|
||
}
|
||
|
||
// truncateForErrorBody 把上游错误响应 body 压到 monitorErrorBodySnippetMaxBytes 以内,
|
||
// 并顺手把连续空白折成一个空格:上游 HTML 错误页常含大量缩进/换行,保留会浪费预算。
|
||
// 被 truncateMessage 做最终总截断兜底,所以这里只负责 body 自身的精简。
|
||
func truncateForErrorBody(body string) string {
|
||
body = strings.Join(strings.Fields(body), " ")
|
||
if len(body) <= monitorErrorBodySnippetMaxBytes {
|
||
return body
|
||
}
|
||
const ellipsis = "...(body truncated)"
|
||
cutoff := monitorErrorBodySnippetMaxBytes - len(ellipsis)
|
||
if cutoff < 0 {
|
||
cutoff = 0
|
||
}
|
||
return body[:cutoff] + ellipsis
|
||
}
|