From d9b15879824ca9c40e35bd2cfdaaa82da6e743af Mon Sep 17 00:00:00 2001 From: shaw Date: Sun, 4 Jan 2026 10:38:13 +0800 Subject: [PATCH 1/6] =?UTF-8?q?feat(gateway):=20=E5=AE=9E=E7=8E=B0=20Claud?= =?UTF-8?q?e=20Code=20=E7=B3=BB=E7=BB=9F=E6=8F=90=E7=A4=BA=E8=AF=8D?= =?UTF-8?q?=E6=99=BA=E8=83=BD=E6=B3=A8=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/service/gateway_service.go | 93 ++++++++++++++++++--- 1 file changed, 81 insertions(+), 12 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index bd6f59f7..0af2c328 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -30,13 +30,15 @@ const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" stickySessionTTL = time.Hour // 粘性会话TTL + claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." ) // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( - sseDataRe = regexp.MustCompile(`^data:\s*`) - sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) + sseDataRe = regexp.MustCompile(`^data:\s*`) + sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) + claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) ) // allowedHeaders 白名单headers(参考CRS项目) @@ -951,6 +953,76 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool { } } +// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端 +// 简化判断:User-Agent 匹配 + metadata.user_id 存在 +func isClaudeCodeClient(userAgent string, metadataUserID string) bool { + if metadataUserID == "" { + return false + } + return claudeCliUserAgentRe.MatchString(userAgent) +} + +// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词 +// 支持 string 和 []any 两种格式 +func systemIncludesClaudeCodePrompt(system any) bool { + switch v := system.(type) { + case string: + return v == claudeCodeSystemPrompt + case []any: + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt { + return true + } + } + } + } + return false +} + +// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 +// 处理 null、字符串、数组三种格式 +func injectClaudeCodePrompt(body []byte, system any) []byte { + claudeCodeBlock := map[string]any{ + "type": "text", + "text": claudeCodeSystemPrompt, + "cache_control": map[string]string{"type": "ephemeral"}, + } + + var newSystem []any + + switch v := system.(type) { + case nil: + newSystem = []any{claudeCodeBlock} + case string: + if v == "" || v == claudeCodeSystemPrompt { + newSystem = []any{claudeCodeBlock} + } else { + newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": v}} + } + case []any: + newSystem = make([]any, 0, len(v)+1) + newSystem = append(newSystem, claudeCodeBlock) + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt { + continue + } + } + newSystem = append(newSystem, item) + } + default: + newSystem = []any{claudeCodeBlock} + } + + result, err := sjson.SetBytes(body, "system", newSystem) + if err != nil { + log.Printf("Warning: failed to inject Claude Code prompt: %v", err) + return body + } + return result +} + // Forward 转发请求到Claude API func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { startTime := time.Now() @@ -962,16 +1034,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A reqModel := parsed.Model reqStream := parsed.Stream - if !parsed.HasSystem { - body, _ = sjson.SetBytes(body, "system", []any{ - map[string]any{ - "type": "text", - "text": "You are Claude Code, Anthropic's official CLI for Claude.", - "cache_control": map[string]string{ - "type": "ephemeral", - }, - }, - }) + // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) + // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 + if account.IsOAuth() && + !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) && + !strings.Contains(strings.ToLower(reqModel), "haiku") && + !systemIncludesClaudeCodePrompt(parsed.System) { + body = injectClaudeCodePrompt(body, parsed.System) } // 应用模型映射(仅对apikey类型账号) From a11c71cea9168193b9e30670ad1d39f653a7a131 Mon Sep 17 00:00:00 2001 From: shaw Date: Sun, 4 Jan 2026 10:45:18 +0800 Subject: [PATCH 2/6] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=88=9B=E5=BB=BA?= =?UTF-8?q?=E8=B4=A6=E5=8F=B7schedulable=E5=80=BC=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E4=B8=BAfalse=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/service/admin_service.go | 1 + .../internal/service/gateway_prompt_test.go | 233 ++++++++++++++++++ 2 files changed, 234 insertions(+) create mode 100644 backend/internal/service/gateway_prompt_test.go diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 962b3684..707e728b 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -630,6 +630,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou Concurrency: input.Concurrency, Priority: input.Priority, Status: StatusActive, + Schedulable: true, } if err := s.accountRepo.Create(ctx, account); err != nil { return nil, err diff --git a/backend/internal/service/gateway_prompt_test.go b/backend/internal/service/gateway_prompt_test.go new file mode 100644 index 00000000..b056f8fa --- /dev/null +++ b/backend/internal/service/gateway_prompt_test.go @@ -0,0 +1,233 @@ +package service + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsClaudeCodeClient(t *testing.T) { + tests := []struct { + name string + userAgent string + metadataUserID string + want bool + }{ + { + name: "Claude Code client", + userAgent: "claude-cli/1.0.62 (darwin; arm64)", + metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + want: true, + }, + { + name: "Claude Code without version suffix", + userAgent: "claude-cli/2.0.0", + metadataUserID: "session_abc", + want: true, + }, + { + name: "Missing metadata user_id", + userAgent: "claude-cli/1.0.0", + metadataUserID: "", + want: false, + }, + { + name: "Different user agent", + userAgent: "curl/7.68.0", + metadataUserID: "user123", + want: false, + }, + { + name: "Empty user agent", + userAgent: "", + metadataUserID: "user123", + want: false, + }, + { + name: "Similar but not Claude CLI", + userAgent: "claude-api/1.0.0", + metadataUserID: "user123", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isClaudeCodeClient(tt.userAgent, tt.metadataUserID) + require.Equal(t, tt.want, got) + }) + } +} + +func TestSystemIncludesClaudeCodePrompt(t *testing.T) { + tests := []struct { + name string + system any + want bool + }{ + { + name: "nil system", + system: nil, + want: false, + }, + { + name: "empty string", + system: "", + want: false, + }, + { + name: "string with Claude Code prompt", + system: claudeCodeSystemPrompt, + want: true, + }, + { + name: "string with different content", + system: "You are a helpful assistant.", + want: false, + }, + { + name: "empty array", + system: []any{}, + want: false, + }, + { + name: "array with Claude Code prompt", + system: []any{ + map[string]any{ + "type": "text", + "text": claudeCodeSystemPrompt, + }, + }, + want: true, + }, + { + name: "array with Claude Code prompt in second position", + system: []any{ + map[string]any{"type": "text", "text": "First prompt"}, + map[string]any{"type": "text", "text": claudeCodeSystemPrompt}, + }, + want: true, + }, + { + name: "array without Claude Code prompt", + system: []any{ + map[string]any{"type": "text", "text": "Custom prompt"}, + }, + want: false, + }, + { + name: "array with partial match (should not match)", + system: []any{ + map[string]any{"type": "text", "text": "You are Claude"}, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := systemIncludesClaudeCodePrompt(tt.system) + require.Equal(t, tt.want, got) + }) + } +} + +func TestInjectClaudeCodePrompt(t *testing.T) { + tests := []struct { + name string + body string + system any + wantSystemLen int + wantFirstText string + wantSecondText string + }{ + { + name: "nil system", + body: `{"model":"claude-3"}`, + system: nil, + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, + { + name: "empty string system", + body: `{"model":"claude-3"}`, + system: "", + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, + { + name: "string system", + body: `{"model":"claude-3"}`, + system: "Custom prompt", + wantSystemLen: 2, + wantFirstText: claudeCodeSystemPrompt, + wantSecondText: "Custom prompt", + }, + { + name: "string system equals Claude Code prompt", + body: `{"model":"claude-3"}`, + system: claudeCodeSystemPrompt, + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, + { + name: "array system", + body: `{"model":"claude-3"}`, + system: []any{map[string]any{"type": "text", "text": "Custom"}}, + // Claude Code + Custom = 2 + wantSystemLen: 2, + wantFirstText: claudeCodeSystemPrompt, + wantSecondText: "Custom", + }, + { + name: "array system with existing Claude Code prompt (should dedupe)", + body: `{"model":"claude-3"}`, + system: []any{ + map[string]any{"type": "text", "text": claudeCodeSystemPrompt}, + map[string]any{"type": "text", "text": "Other"}, + }, + // Claude Code at start + Other = 2 (deduped) + wantSystemLen: 2, + wantFirstText: claudeCodeSystemPrompt, + wantSecondText: "Other", + }, + { + name: "empty array", + body: `{"model":"claude-3"}`, + system: []any{}, + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := injectClaudeCodePrompt([]byte(tt.body), tt.system) + + var parsed map[string]any + err := json.Unmarshal(result, &parsed) + require.NoError(t, err) + + system, ok := parsed["system"].([]any) + require.True(t, ok, "system should be an array") + require.Len(t, system, tt.wantSystemLen) + + first, ok := system[0].(map[string]any) + require.True(t, ok) + require.Equal(t, tt.wantFirstText, first["text"]) + require.Equal(t, "text", first["type"]) + + // Check cache_control + cc, ok := first["cache_control"].(map[string]any) + require.True(t, ok) + require.Equal(t, "ephemeral", cc["type"]) + + if tt.wantSecondText != "" && len(system) > 1 { + second, ok := system[1].(map[string]any) + require.True(t, ok) + require.Equal(t, tt.wantSecondText, second["text"]) + } + }) + } +} From 600f9ce2548439fdf87cd8877db1386797fb32c9 Mon Sep 17 00:00:00 2001 From: Yuhao Jiang Date: Sat, 3 Jan 2026 21:04:34 -0600 Subject: [PATCH 3/6] =?UTF-8?q?fix(frontend):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E8=B7=A8=E6=97=B6=E5=8C=BA=E6=97=A5=E6=9C=9F=E8=8C=83=E5=9B=B4?= =?UTF-8?q?=E7=AD=9B=E9=80=89=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 当管理员在比服务器时区更早的时区(如芝加哥 UTC-6)访问使用记录页面时, 由于服务器时区(如中国 UTC+8)已经是"明天",导致最新的记录无法显示。 修复方案: - DateRangePicker: 将日期选择器的 max 限制从"今天"改为"明天" - UsageView: 默认和重置时的 endDate 使用"明天"而非"今天" 这样可以确保跨时区场景下用户能看到所有最新记录。 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- frontend/src/components/common/DateRangePicker.vue | 12 ++++++++++-- frontend/src/views/admin/UsageView.vue | 12 +++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/frontend/src/components/common/DateRangePicker.vue b/frontend/src/components/common/DateRangePicker.vue index be641f9b..4fce029f 100644 --- a/frontend/src/components/common/DateRangePicker.vue +++ b/frontend/src/components/common/DateRangePicker.vue @@ -59,7 +59,7 @@ @@ -85,7 +85,7 @@ type="date" v-model="localEndDate" :min="localStartDate" - :max="today" + :max="tomorrow" class="date-picker-input" @change="onDateChange" /> @@ -144,6 +144,14 @@ const today = computed(() => { return `${year}-${month}-${day}` }) +// Tomorrow's date - used for max date to handle timezone differences +// When user is in a timezone behind the server, "today" on server might be "tomorrow" locally +const tomorrow = computed(() => { + const d = new Date() + d.setDate(d.getDate() + 1) + return formatDateToString(d) +}) + // Helper function to format date to YYYY-MM-DD using local timezone const formatDateToString = (date: Date): string => { const year = date.getFullYear() diff --git a/frontend/src/views/admin/UsageView.vue b/frontend/src/views/admin/UsageView.vue index 85a748f4..ac5d1e05 100644 --- a/frontend/src/views/admin/UsageView.vue +++ b/frontend/src/views/admin/UsageView.vue @@ -888,13 +888,17 @@ const formatLocalDate = (date: Date): string => { } // Initialize date range immediately +// Use tomorrow as end date to handle timezone differences between client and server +// e.g., when server is in Asia/Shanghai and client is in America/Chicago const now = new Date() +const tomorrow = new Date(now) +tomorrow.setDate(tomorrow.getDate() + 1) const weekAgo = new Date(now) weekAgo.setDate(weekAgo.getDate() - 6) // Date range state const startDate = ref(formatLocalDate(weekAgo)) -const endDate = ref(formatLocalDate(now)) +const endDate = ref(formatLocalDate(tomorrow)) const filters = ref({ user_id: undefined, @@ -1215,12 +1219,14 @@ const resetFilters = () => { end_date: undefined } granularity.value = 'day' - // Reset date range to default (last 7 days) + // Reset date range to default (last 7 days, with tomorrow as end to handle timezone differences) const now = new Date() + const tomorrowDate = new Date(now) + tomorrowDate.setDate(tomorrowDate.getDate() + 1) const weekAgo = new Date(now) weekAgo.setDate(weekAgo.getDate() - 6) startDate.value = formatLocalDate(weekAgo) - endDate.value = formatLocalDate(now) + endDate.value = formatLocalDate(tomorrowDate) filters.value.start_date = startDate.value filters.value.end_date = endDate.value pagination.value.page = 1 From 70e9329e64f3104755a645373d510e87a406fb04 Mon Sep 17 00:00:00 2001 From: shaw Date: Sun, 4 Jan 2026 11:43:58 +0800 Subject: [PATCH 4/6] =?UTF-8?q?feat(proxy):=20=E7=BB=9F=E4=B8=80=E4=BB=A3?= =?UTF-8?q?=E7=90=86=E9=85=8D=E7=BD=AE=E5=B9=B6=E6=94=AF=E6=8C=81=20SOCKS5?= =?UTF-8?q?H=20=E5=8D=8F=E8=AE=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 proxyutil 包,统一 HTTP/HTTPS/SOCKS5/SOCKS5H 代理配置逻辑 - SOCKS5H 支持服务端 DNS 解析,避免本地 DNS 泄露 - 移除 ProxyStrict 宽松模式,代理失败直接返回错误不回退直连 - 前端代理管理页面支持 SOCKS5H 协议的添加/编辑/批量导入 - 补充 IPv6 地址和特殊字符密码的边界测试 --- .../internal/handler/admin/proxy_handler.go | 6 +- backend/internal/pkg/httpclient/pool.go | 37 +--- backend/internal/pkg/proxyutil/dialer.go | 62 ++++++ backend/internal/pkg/proxyutil/dialer_test.go | 204 ++++++++++++++++++ backend/internal/repository/http_upstream.go | 17 +- .../http_upstream_benchmark_test.go | 6 +- .../repository/proxy_probe_service.go | 1 - frontend/src/types/index.ts | 2 +- frontend/src/views/admin/ProxiesView.vue | 12 +- 9 files changed, 303 insertions(+), 44 deletions(-) create mode 100644 backend/internal/pkg/proxyutil/dialer.go create mode 100644 backend/internal/pkg/proxyutil/dialer_test.go diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index 0480b312..99557f9a 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -26,7 +26,7 @@ func NewProxyHandler(adminService service.AdminService) *ProxyHandler { // CreateProxyRequest represents create proxy request type CreateProxyRequest struct { Name string `json:"name" binding:"required"` - Protocol string `json:"protocol" binding:"required,oneof=http https socks5"` + Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"` Host string `json:"host" binding:"required"` Port int `json:"port" binding:"required,min=1,max=65535"` Username string `json:"username"` @@ -36,7 +36,7 @@ type CreateProxyRequest struct { // UpdateProxyRequest represents update proxy request type UpdateProxyRequest struct { Name string `json:"name"` - Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5"` + Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5 socks5h"` Host string `json:"host"` Port int `json:"port" binding:"omitempty,min=1,max=65535"` Username string `json:"username"` @@ -255,7 +255,7 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) { // BatchCreateProxyItem represents a single proxy in batch create request type BatchCreateProxyItem struct { - Protocol string `json:"protocol" binding:"required,oneof=http https socks5"` + Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"` Host string `json:"host" binding:"required"` Port int `json:"port" binding:"required,min=1,max=65535"` Username string `json:"username"` diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index 1028fb84..8a81c09a 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -11,22 +11,20 @@ // 新实现使用统一的客户端池: // 1. 相同配置复用同一 http.Client 实例 // 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销 -// 3. 支持 HTTP/HTTPS/SOCKS5 代理 -// 4. 支持严格代理模式(代理失败则返回错误) +// 3. 支持 HTTP/HTTPS/SOCKS5/SOCKS5H 代理 +// 4. 代理配置失败时直接返回错误,不会回退到直连(避免 IP 关联风险) package httpclient import ( - "context" "crypto/tls" "fmt" - "net" "net/http" "net/url" "strings" "sync" "time" - "golang.org/x/net/proxy" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" ) // Transport 连接池默认配置 @@ -38,11 +36,10 @@ const ( // Options 定义共享 HTTP 客户端的构建参数 type Options struct { - ProxyURL string // 代理 URL(支持 http/https/socks5) + ProxyURL string // 代理 URL(支持 http/https/socks5/socks5h) Timeout time.Duration // 请求总超时时间 ResponseHeaderTimeout time.Duration // 等待响应头超时时间 InsecureSkipVerify bool // 是否跳过 TLS 证书验证 - ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退 // 可选的连接池参数(不设置则使用默认值) MaxIdleConns int // 最大空闲连接总数(默认 100) @@ -55,6 +52,7 @@ var sharedClients sync.Map // GetClient 返回共享的 HTTP 客户端实例 // 性能优化:相同配置复用同一客户端,避免重复创建 Transport +// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险 func GetClient(opts Options) (*http.Client, error) { key := buildClientKey(opts) if cached, ok := sharedClients.Load(key); ok { @@ -65,12 +63,7 @@ func GetClient(opts Options) (*http.Client, error) { client, err := buildClient(opts) if err != nil { - if opts.ProxyStrict { - return nil, err - } - fallback := opts - fallback.ProxyURL = "" - client, _ = buildClient(fallback) + return nil, err } actual, _ := sharedClients.LoadOrStore(key, client) @@ -125,31 +118,19 @@ func buildTransport(opts Options) (*http.Transport, error) { return nil, err } - switch strings.ToLower(parsed.Scheme) { - case "http", "https": - transport.Proxy = http.ProxyURL(parsed) - case "socks5", "socks5h": - dialer, err := proxy.FromURL(parsed, proxy.Direct) - if err != nil { - return nil, err - } - transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - } - default: - return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme) + if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { + return nil, err } return transport, nil } func buildClientKey(opts Options) string { - return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d", + return fmt.Sprintf("%s|%s|%s|%t|%d|%d|%d", strings.TrimSpace(opts.ProxyURL), opts.Timeout.String(), opts.ResponseHeaderTimeout.String(), opts.InsecureSkipVerify, - opts.ProxyStrict, opts.MaxIdleConns, opts.MaxIdleConnsPerHost, opts.MaxConnsPerHost, diff --git a/backend/internal/pkg/proxyutil/dialer.go b/backend/internal/pkg/proxyutil/dialer.go new file mode 100644 index 00000000..91b224a2 --- /dev/null +++ b/backend/internal/pkg/proxyutil/dialer.go @@ -0,0 +1,62 @@ +// Package proxyutil 提供统一的代理配置功能 +// +// 支持的代理协议: +// - HTTP/HTTPS: 通过 Transport.Proxy 设置 +// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS) +package proxyutil + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/proxy" +) + +// ConfigureTransportProxy 根据代理 URL 配置 Transport +// +// 支持的协议: +// - http/https: 设置 transport.Proxy +// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS) +// +// 参数: +// - transport: 需要配置的 http.Transport +// - proxyURL: 代理地址,nil 表示直连 +// +// 返回: +// - error: 代理配置错误(协议不支持或 dialer 创建失败) +func ConfigureTransportProxy(transport *http.Transport, proxyURL *url.URL) error { + if proxyURL == nil { + return nil + } + + scheme := strings.ToLower(proxyURL.Scheme) + switch scheme { + case "http", "https": + transport.Proxy = http.ProxyURL(proxyURL) + return nil + + case "socks5", "socks5h": + dialer, err := proxy.FromURL(proxyURL, proxy.Direct) + if err != nil { + return fmt.Errorf("create socks5 dialer: %w", err) + } + // 优先使用支持 context 的 DialContext,以支持请求取消和超时 + if contextDialer, ok := dialer.(proxy.ContextDialer); ok { + transport.DialContext = contextDialer.DialContext + } else { + // 回退路径:如果 dialer 不支持 ContextDialer,则包装为简单的 DialContext + // 注意:此回退不支持请求取消和超时控制 + transport.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + } + } + return nil + + default: + return fmt.Errorf("unsupported proxy scheme: %s", scheme) + } +} diff --git a/backend/internal/pkg/proxyutil/dialer_test.go b/backend/internal/pkg/proxyutil/dialer_test.go new file mode 100644 index 00000000..f153cc9f --- /dev/null +++ b/backend/internal/pkg/proxyutil/dialer_test.go @@ -0,0 +1,204 @@ +package proxyutil + +import ( + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfigureTransportProxy_Nil(t *testing.T) { + transport := &http.Transport{} + err := ConfigureTransportProxy(transport, nil) + + require.NoError(t, err) + assert.Nil(t, transport.Proxy, "nil proxy should not set Proxy") + assert.Nil(t, transport.DialContext, "nil proxy should not set DialContext") +} + +func TestConfigureTransportProxy_HTTP(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("http://proxy.example.com:8080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.NotNil(t, transport.Proxy, "HTTP proxy should set Proxy") + assert.Nil(t, transport.DialContext, "HTTP proxy should not set DialContext") +} + +func TestConfigureTransportProxy_HTTPS(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("https://secure-proxy.example.com:8443") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.NotNil(t, transport.Proxy, "HTTPS proxy should set Proxy") + assert.Nil(t, transport.DialContext, "HTTPS proxy should not set DialContext") +} + +func TestConfigureTransportProxy_SOCKS5(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("socks5://socks.example.com:1080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.Nil(t, transport.Proxy, "SOCKS5 proxy should not set Proxy") + assert.NotNil(t, transport.DialContext, "SOCKS5 proxy should set DialContext") +} + +func TestConfigureTransportProxy_SOCKS5H(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("socks5h://socks.example.com:1080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.Nil(t, transport.Proxy, "SOCKS5H proxy should not set Proxy") + assert.NotNil(t, transport.DialContext, "SOCKS5H proxy should set DialContext") +} + +func TestConfigureTransportProxy_CaseInsensitive(t *testing.T) { + testCases := []struct { + scheme string + useProxy bool // true = uses Transport.Proxy, false = uses DialContext + }{ + {"HTTP://proxy.example.com:8080", true}, + {"Http://proxy.example.com:8080", true}, + {"HTTPS://proxy.example.com:8443", true}, + {"Https://proxy.example.com:8443", true}, + {"SOCKS5://socks.example.com:1080", false}, + {"Socks5://socks.example.com:1080", false}, + {"SOCKS5H://socks.example.com:1080", false}, + {"Socks5h://socks.example.com:1080", false}, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse(tc.scheme) + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + if tc.useProxy { + assert.NotNil(t, transport.Proxy) + assert.Nil(t, transport.DialContext) + } else { + assert.Nil(t, transport.Proxy) + assert.NotNil(t, transport.DialContext) + } + }) + } +} + +func TestConfigureTransportProxy_Unsupported(t *testing.T) { + testCases := []string{ + "ftp://ftp.example.com", + "file:///path/to/file", + "unknown://example.com", + } + + for _, tc := range testCases { + t.Run(tc, func(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse(tc) + + err := ConfigureTransportProxy(transport, proxyURL) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported proxy scheme") + }) + } +} + +func TestConfigureTransportProxy_WithAuth(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("socks5://user:password@socks.example.com:1080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.NotNil(t, transport.DialContext, "SOCKS5 with auth should set DialContext") +} + +func TestConfigureTransportProxy_EmptyScheme(t *testing.T) { + transport := &http.Transport{} + // 空 scheme 的 URL + proxyURL := &url.URL{Host: "proxy.example.com:8080"} + + err := ConfigureTransportProxy(transport, proxyURL) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported proxy scheme") +} + +func TestConfigureTransportProxy_PreservesExistingConfig(t *testing.T) { + // 验证代理配置不会覆盖 Transport 的其他配置 + transport := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + } + proxyURL, _ := url.Parse("socks5://socks.example.com:1080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.Equal(t, 100, transport.MaxIdleConns, "MaxIdleConns should be preserved") + assert.Equal(t, 10, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost should be preserved") + assert.NotNil(t, transport.DialContext, "DialContext should be set") +} + +func TestConfigureTransportProxy_IPv6(t *testing.T) { + testCases := []struct { + name string + proxyURL string + }{ + {"SOCKS5H with IPv6 loopback", "socks5h://[::1]:1080"}, + {"SOCKS5 with full IPv6", "socks5://[2001:db8::1]:1080"}, + {"HTTP with IPv6", "http://[::1]:8080"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + transport := &http.Transport{} + proxyURL, err := url.Parse(tc.proxyURL) + require.NoError(t, err, "URL should be parseable") + + err = ConfigureTransportProxy(transport, proxyURL) + require.NoError(t, err) + }) + } +} + +func TestConfigureTransportProxy_SpecialCharsInPassword(t *testing.T) { + testCases := []struct { + name string + proxyURL string + }{ + // 密码包含 @ 符号(URL 编码为 %40) + {"password with @", "socks5://user:p%40ssword@proxy.example.com:1080"}, + // 密码包含 : 符号(URL 编码为 %3A) + {"password with :", "socks5://user:pass%3Aword@proxy.example.com:1080"}, + // 密码包含 / 符号(URL 编码为 %2F) + {"password with /", "socks5://user:pass%2Fword@proxy.example.com:1080"}, + // 复杂密码 + {"complex password", "socks5h://admin:P%40ss%3Aw0rd%2F123@proxy.example.com:1080"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + transport := &http.Transport{} + proxyURL, err := url.Parse(tc.proxyURL) + require.NoError(t, err, "URL should be parseable") + + err = ConfigureTransportProxy(transport, proxyURL) + require.NoError(t, err) + assert.NotNil(t, transport.DialContext, "SOCKS5 should set DialContext") + }) + } +} diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index 180844b5..f0669979 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -13,6 +13,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -225,7 +226,12 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a // 缓存未命中或需要重建,创建新客户端 settings := s.resolvePoolSettings(isolation, accountConcurrency) - client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)} + transport, err := buildUpstreamTransport(settings, parsedProxy) + if err != nil { + s.mu.Unlock() + return nil, fmt.Errorf("build transport: %w", err) + } + client := &http.Client{Transport: transport} entry := &upstreamClientEntry{ client: client, proxyKey: proxyKey, @@ -548,6 +554,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings { // // 返回: // - *http.Transport: 配置好的 Transport 实例 +// - error: 代理配置错误 // // Transport 参数说明: // - MaxIdleConns: 所有主机的最大空闲连接总数 @@ -555,7 +562,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings { // - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待) // - IdleConnTimeout: 空闲连接超时(超时后关闭) // - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输) -func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Transport { +func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) { transport := &http.Transport{ MaxIdleConns: settings.maxIdleConns, MaxIdleConnsPerHost: settings.maxIdleConnsPerHost, @@ -563,10 +570,10 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Tran IdleConnTimeout: settings.idleConnTimeout, ResponseHeaderTimeout: settings.responseHeaderTimeout, } - if proxyURL != nil { - transport.Proxy = http.ProxyURL(proxyURL) + if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil { + return nil, err } - return transport + return transport, nil } // trackedBody 带跟踪功能的响应体包装器 diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go index 3219c6da..1e7430a3 100644 --- a/backend/internal/repository/http_upstream_benchmark_test.go +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -45,8 +45,12 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { settings := defaultPoolSettings(cfg) for i := 0; i < b.N; i++ { // 每次迭代都创建新客户端,包含 Transport 分配 + transport, err := buildUpstreamTransport(settings, parsedProxy) + if err != nil { + b.Fatalf("创建 Transport 失败: %v", err) + } httpClientSink = &http.Client{ - Transport: buildUpstreamTransport(settings, parsedProxy), + Transport: transport, } } }) diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 8b288c3c..f5f625f9 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -27,7 +27,6 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s ProxyURL: proxyURL, Timeout: 15 * time.Second, InsecureSkipVerify: true, - ProxyStrict: true, }) if err != nil { return nil, 0, fmt.Errorf("failed to create proxy client: %w", err) diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index a3e4c25e..2f6aa2c8 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -290,7 +290,7 @@ export interface UpdateGroupRequest { export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' export type AccountType = 'oauth' | 'setup-token' | 'apikey' export type OAuthAddMethod = 'oauth' | 'setup-token' -export type ProxyProtocol = 'http' | 'https' | 'socks5' +export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h' // Claude Model type (returned by /v1/models and account models API) export interface ClaudeModel { diff --git a/frontend/src/views/admin/ProxiesView.vue b/frontend/src/views/admin/ProxiesView.vue index a5df9bd0..613b503c 100644 --- a/frontend/src/views/admin/ProxiesView.vue +++ b/frontend/src/views/admin/ProxiesView.vue @@ -90,7 +90,7 @@