merge: 合并官方 upstream/main 的 6 个功能更新
合并内容: 1. feat(gateway): Claude Code 系统提示词智能注入 2. fix: 修复创建账号 schedulable 默认值为 false 的 bug 3. fix(frontend): 修复跨时区日期范围筛选问题 4. feat(proxy): SOCKS5H 代理支持(统一代理配置) 5. fix(oauth): 修复 Claude Cookie 添加账号时会话混淆 6. fix(test): 修复 OAuth 账号测试刷新 token 的 bug 新增文件: - backend/internal/pkg/proxyutil/* (SOCKS5H 支持) - backend/internal/service/gateway_prompt_test.go (测试) 来自 upstream: Wei-Shaw/sub2api commits d9b1587..a527559
This commit is contained in:
@@ -100,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
|
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
|
||||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream)
|
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream)
|
||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func NewProxyHandler(adminService service.AdminService) *ProxyHandler {
|
|||||||
// CreateProxyRequest represents create proxy request
|
// CreateProxyRequest represents create proxy request
|
||||||
type CreateProxyRequest struct {
|
type CreateProxyRequest struct {
|
||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
|
||||||
Host string `json:"host" binding:"required"`
|
Host string `json:"host" binding:"required"`
|
||||||
Port int `json:"port" binding:"required,min=1,max=65535"`
|
Port int `json:"port" binding:"required,min=1,max=65535"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
@@ -36,7 +36,7 @@ type CreateProxyRequest struct {
|
|||||||
// UpdateProxyRequest represents update proxy request
|
// UpdateProxyRequest represents update proxy request
|
||||||
type UpdateProxyRequest struct {
|
type UpdateProxyRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5"`
|
Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5 socks5h"`
|
||||||
Host string `json:"host"`
|
Host string `json:"host"`
|
||||||
Port int `json:"port" binding:"omitempty,min=1,max=65535"`
|
Port int `json:"port" binding:"omitempty,min=1,max=65535"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
@@ -255,7 +255,7 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
|||||||
|
|
||||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||||
type BatchCreateProxyItem struct {
|
type BatchCreateProxyItem struct {
|
||||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
|
||||||
Host string `json:"host" binding:"required"`
|
Host string `json:"host" binding:"required"`
|
||||||
Port int `json:"port" binding:"required,min=1,max=65535"`
|
Port int `json:"port" binding:"required,min=1,max=65535"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
|||||||
@@ -11,22 +11,20 @@
|
|||||||
// 新实现使用统一的客户端池:
|
// 新实现使用统一的客户端池:
|
||||||
// 1. 相同配置复用同一 http.Client 实例
|
// 1. 相同配置复用同一 http.Client 实例
|
||||||
// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销
|
// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销
|
||||||
// 3. 支持 HTTP/HTTPS/SOCKS5 代理
|
// 3. 支持 HTTP/HTTPS/SOCKS5/SOCKS5H 代理
|
||||||
// 4. 支持严格代理模式(代理失败则返回错误)
|
// 4. 代理配置失败时直接返回错误,不会回退到直连(避免 IP 关联风险)
|
||||||
package httpclient
|
package httpclient
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/proxy"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Transport 连接池默认配置
|
// Transport 连接池默认配置
|
||||||
@@ -38,11 +36,10 @@ const (
|
|||||||
|
|
||||||
// Options 定义共享 HTTP 客户端的构建参数
|
// Options 定义共享 HTTP 客户端的构建参数
|
||||||
type Options struct {
|
type Options struct {
|
||||||
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
ProxyURL string // 代理 URL(支持 http/https/socks5/socks5h)
|
||||||
Timeout time.Duration // 请求总超时时间
|
Timeout time.Duration // 请求总超时时间
|
||||||
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
|
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
|
||||||
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
|
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
|
||||||
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
|
|
||||||
|
|
||||||
// 可选的连接池参数(不设置则使用默认值)
|
// 可选的连接池参数(不设置则使用默认值)
|
||||||
MaxIdleConns int // 最大空闲连接总数(默认 100)
|
MaxIdleConns int // 最大空闲连接总数(默认 100)
|
||||||
@@ -55,6 +52,7 @@ var sharedClients sync.Map
|
|||||||
|
|
||||||
// GetClient 返回共享的 HTTP 客户端实例
|
// GetClient 返回共享的 HTTP 客户端实例
|
||||||
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
|
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
|
||||||
|
// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
|
||||||
func GetClient(opts Options) (*http.Client, error) {
|
func GetClient(opts Options) (*http.Client, error) {
|
||||||
key := buildClientKey(opts)
|
key := buildClientKey(opts)
|
||||||
if cached, ok := sharedClients.Load(key); ok {
|
if cached, ok := sharedClients.Load(key); ok {
|
||||||
@@ -65,12 +63,7 @@ func GetClient(opts Options) (*http.Client, error) {
|
|||||||
|
|
||||||
client, err := buildClient(opts)
|
client, err := buildClient(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if opts.ProxyStrict {
|
return nil, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
fallback := opts
|
|
||||||
fallback.ProxyURL = ""
|
|
||||||
client, _ = buildClient(fallback)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
actual, _ := sharedClients.LoadOrStore(key, client)
|
actual, _ := sharedClients.LoadOrStore(key, client)
|
||||||
@@ -125,31 +118,19 @@ func buildTransport(opts Options) (*http.Transport, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch strings.ToLower(parsed.Scheme) {
|
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
|
||||||
case "http", "https":
|
return nil, err
|
||||||
transport.Proxy = http.ProxyURL(parsed)
|
|
||||||
case "socks5", "socks5h":
|
|
||||||
dialer, err := proxy.FromURL(parsed, proxy.Direct)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return transport, nil
|
return transport, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildClientKey(opts Options) string {
|
func buildClientKey(opts Options) string {
|
||||||
return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d",
|
return fmt.Sprintf("%s|%s|%s|%t|%d|%d|%d",
|
||||||
strings.TrimSpace(opts.ProxyURL),
|
strings.TrimSpace(opts.ProxyURL),
|
||||||
opts.Timeout.String(),
|
opts.Timeout.String(),
|
||||||
opts.ResponseHeaderTimeout.String(),
|
opts.ResponseHeaderTimeout.String(),
|
||||||
opts.InsecureSkipVerify,
|
opts.InsecureSkipVerify,
|
||||||
opts.ProxyStrict,
|
|
||||||
opts.MaxIdleConns,
|
opts.MaxIdleConns,
|
||||||
opts.MaxIdleConnsPerHost,
|
opts.MaxIdleConnsPerHost,
|
||||||
opts.MaxConnsPerHost,
|
opts.MaxConnsPerHost,
|
||||||
|
|||||||
62
backend/internal/pkg/proxyutil/dialer.go
Normal file
62
backend/internal/pkg/proxyutil/dialer.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
// Package proxyutil 提供统一的代理配置功能
|
||||||
|
//
|
||||||
|
// 支持的代理协议:
|
||||||
|
// - HTTP/HTTPS: 通过 Transport.Proxy 设置
|
||||||
|
// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS)
|
||||||
|
package proxyutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigureTransportProxy 根据代理 URL 配置 Transport
|
||||||
|
//
|
||||||
|
// 支持的协议:
|
||||||
|
// - http/https: 设置 transport.Proxy
|
||||||
|
// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS)
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - transport: 需要配置的 http.Transport
|
||||||
|
// - proxyURL: 代理地址,nil 表示直连
|
||||||
|
//
|
||||||
|
// 返回:
|
||||||
|
// - error: 代理配置错误(协议不支持或 dialer 创建失败)
|
||||||
|
func ConfigureTransportProxy(transport *http.Transport, proxyURL *url.URL) error {
|
||||||
|
if proxyURL == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
scheme := strings.ToLower(proxyURL.Scheme)
|
||||||
|
switch scheme {
|
||||||
|
case "http", "https":
|
||||||
|
transport.Proxy = http.ProxyURL(proxyURL)
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case "socks5", "socks5h":
|
||||||
|
dialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create socks5 dialer: %w", err)
|
||||||
|
}
|
||||||
|
// 优先使用支持 context 的 DialContext,以支持请求取消和超时
|
||||||
|
if contextDialer, ok := dialer.(proxy.ContextDialer); ok {
|
||||||
|
transport.DialContext = contextDialer.DialContext
|
||||||
|
} else {
|
||||||
|
// 回退路径:如果 dialer 不支持 ContextDialer,则包装为简单的 DialContext
|
||||||
|
// 注意:此回退不支持请求取消和超时控制
|
||||||
|
transport.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return dialer.Dial(network, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported proxy scheme: %s", scheme)
|
||||||
|
}
|
||||||
|
}
|
||||||
204
backend/internal/pkg/proxyutil/dialer_test.go
Normal file
204
backend/internal/pkg/proxyutil/dialer_test.go
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
package proxyutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfigureTransportProxy_Nil(t *testing.T) {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
err := ConfigureTransportProxy(transport, nil)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, transport.Proxy, "nil proxy should not set Proxy")
|
||||||
|
assert.Nil(t, transport.DialContext, "nil proxy should not set DialContext")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigureTransportProxy_HTTP(t *testing.T) {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
proxyURL, _ := url.Parse("http://proxy.example.com:8080")
|
||||||
|
|
||||||
|
err := ConfigureTransportProxy(transport, proxyURL)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, transport.Proxy, "HTTP proxy should set Proxy")
|
||||||
|
assert.Nil(t, transport.DialContext, "HTTP proxy should not set DialContext")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigureTransportProxy_HTTPS(t *testing.T) {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
proxyURL, _ := url.Parse("https://secure-proxy.example.com:8443")
|
||||||
|
|
||||||
|
err := ConfigureTransportProxy(transport, proxyURL)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, transport.Proxy, "HTTPS proxy should set Proxy")
|
||||||
|
assert.Nil(t, transport.DialContext, "HTTPS proxy should not set DialContext")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigureTransportProxy_SOCKS5(t *testing.T) {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
proxyURL, _ := url.Parse("socks5://socks.example.com:1080")
|
||||||
|
|
||||||
|
err := ConfigureTransportProxy(transport, proxyURL)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, transport.Proxy, "SOCKS5 proxy should not set Proxy")
|
||||||
|
assert.NotNil(t, transport.DialContext, "SOCKS5 proxy should set DialContext")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigureTransportProxy_SOCKS5H(t *testing.T) {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
proxyURL, _ := url.Parse("socks5h://socks.example.com:1080")
|
||||||
|
|
||||||
|
err := ConfigureTransportProxy(transport, proxyURL)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, transport.Proxy, "SOCKS5H proxy should not set Proxy")
|
||||||
|
assert.NotNil(t, transport.DialContext, "SOCKS5H proxy should set DialContext")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigureTransportProxy_CaseInsensitive(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
scheme string
|
||||||
|
useProxy bool // true = uses Transport.Proxy, false = uses DialContext
|
||||||
|
}{
|
||||||
|
{"HTTP://proxy.example.com:8080", true},
|
||||||
|
{"Http://proxy.example.com:8080", true},
|
||||||
|
{"HTTPS://proxy.example.com:8443", true},
|
||||||
|
{"Https://proxy.example.com:8443", true},
|
||||||
|
{"SOCKS5://socks.example.com:1080", false},
|
||||||
|
{"Socks5://socks.example.com:1080", false},
|
||||||
|
{"SOCKS5H://socks.example.com:1080", false},
|
||||||
|
{"Socks5h://socks.example.com:1080", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.scheme, func(t *testing.T) {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
proxyURL, _ := url.Parse(tc.scheme)
|
||||||
|
|
||||||
|
err := ConfigureTransportProxy(transport, proxyURL)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
if tc.useProxy {
|
||||||
|
assert.NotNil(t, transport.Proxy)
|
||||||
|
assert.Nil(t, transport.DialContext)
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, transport.Proxy)
|
||||||
|
assert.NotNil(t, transport.DialContext)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigureTransportProxy_Unsupported(t *testing.T) {
|
||||||
|
testCases := []string{
|
||||||
|
"ftp://ftp.example.com",
|
||||||
|
"file:///path/to/file",
|
||||||
|
"unknown://example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc, func(t *testing.T) {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
proxyURL, _ := url.Parse(tc)
|
||||||
|
|
||||||
|
err := ConfigureTransportProxy(transport, proxyURL)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unsupported proxy scheme")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigureTransportProxy_WithAuth(t *testing.T) {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
proxyURL, _ := url.Parse("socks5://user:password@socks.example.com:1080")
|
||||||
|
|
||||||
|
err := ConfigureTransportProxy(transport, proxyURL)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, transport.DialContext, "SOCKS5 with auth should set DialContext")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigureTransportProxy_EmptyScheme(t *testing.T) {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
// 空 scheme 的 URL
|
||||||
|
proxyURL := &url.URL{Host: "proxy.example.com:8080"}
|
||||||
|
|
||||||
|
err := ConfigureTransportProxy(transport, proxyURL)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unsupported proxy scheme")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigureTransportProxy_PreservesExistingConfig(t *testing.T) {
|
||||||
|
// 验证代理配置不会覆盖 Transport 的其他配置
|
||||||
|
transport := &http.Transport{
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
}
|
||||||
|
proxyURL, _ := url.Parse("socks5://socks.example.com:1080")
|
||||||
|
|
||||||
|
err := ConfigureTransportProxy(transport, proxyURL)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 100, transport.MaxIdleConns, "MaxIdleConns should be preserved")
|
||||||
|
assert.Equal(t, 10, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost should be preserved")
|
||||||
|
assert.NotNil(t, transport.DialContext, "DialContext should be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigureTransportProxy_IPv6(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
proxyURL string
|
||||||
|
}{
|
||||||
|
{"SOCKS5H with IPv6 loopback", "socks5h://[::1]:1080"},
|
||||||
|
{"SOCKS5 with full IPv6", "socks5://[2001:db8::1]:1080"},
|
||||||
|
{"HTTP with IPv6", "http://[::1]:8080"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
proxyURL, err := url.Parse(tc.proxyURL)
|
||||||
|
require.NoError(t, err, "URL should be parseable")
|
||||||
|
|
||||||
|
err = ConfigureTransportProxy(transport, proxyURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigureTransportProxy_SpecialCharsInPassword(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
proxyURL string
|
||||||
|
}{
|
||||||
|
// 密码包含 @ 符号(URL 编码为 %40)
|
||||||
|
{"password with @", "socks5://user:p%40ssword@proxy.example.com:1080"},
|
||||||
|
// 密码包含 : 符号(URL 编码为 %3A)
|
||||||
|
{"password with :", "socks5://user:pass%3Aword@proxy.example.com:1080"},
|
||||||
|
// 密码包含 / 符号(URL 编码为 %2F)
|
||||||
|
{"password with /", "socks5://user:pass%2Fword@proxy.example.com:1080"},
|
||||||
|
// 复杂密码
|
||||||
|
{"complex password", "socks5h://admin:P%40ss%3Aw0rd%2F123@proxy.example.com:1080"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
proxyURL, err := url.Parse(tc.proxyURL)
|
||||||
|
require.NoError(t, err, "URL should be parseable")
|
||||||
|
|
||||||
|
err = ConfigureTransportProxy(transport, proxyURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, transport.DialContext, "SOCKS5 should set DialContext")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -233,11 +233,17 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createReqClient(proxyURL string) *req.Client {
|
func createReqClient(proxyURL string) *req.Client {
|
||||||
return getSharedReqClient(reqClientOptions{
|
// 禁用 CookieJar,确保每次授权都是干净的会话
|
||||||
ProxyURL: proxyURL,
|
client := req.C().
|
||||||
Timeout: 60 * time.Second,
|
SetTimeout(60 * time.Second).
|
||||||
Impersonate: true,
|
ImpersonateChrome().
|
||||||
})
|
SetCookieJar(nil) // 禁用 CookieJar
|
||||||
|
|
||||||
|
if strings.TrimSpace(proxyURL) != "" {
|
||||||
|
client.SetProxyURL(strings.TrimSpace(proxyURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
func prefix(s string, n int) string {
|
func prefix(s string, n int) string {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -225,7 +226,12 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
|
|||||||
|
|
||||||
// 缓存未命中或需要重建,创建新客户端
|
// 缓存未命中或需要重建,创建新客户端
|
||||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||||
client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)}
|
transport, err := buildUpstreamTransport(settings, parsedProxy)
|
||||||
|
if err != nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil, fmt.Errorf("build transport: %w", err)
|
||||||
|
}
|
||||||
|
client := &http.Client{Transport: transport}
|
||||||
entry := &upstreamClientEntry{
|
entry := &upstreamClientEntry{
|
||||||
client: client,
|
client: client,
|
||||||
proxyKey: proxyKey,
|
proxyKey: proxyKey,
|
||||||
@@ -548,6 +554,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings {
|
|||||||
//
|
//
|
||||||
// 返回:
|
// 返回:
|
||||||
// - *http.Transport: 配置好的 Transport 实例
|
// - *http.Transport: 配置好的 Transport 实例
|
||||||
|
// - error: 代理配置错误
|
||||||
//
|
//
|
||||||
// Transport 参数说明:
|
// Transport 参数说明:
|
||||||
// - MaxIdleConns: 所有主机的最大空闲连接总数
|
// - MaxIdleConns: 所有主机的最大空闲连接总数
|
||||||
@@ -555,7 +562,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings {
|
|||||||
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
|
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
|
||||||
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
|
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
|
||||||
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
|
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
|
||||||
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Transport {
|
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) {
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
MaxIdleConns: settings.maxIdleConns,
|
MaxIdleConns: settings.maxIdleConns,
|
||||||
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
|
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
|
||||||
@@ -563,10 +570,10 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Tran
|
|||||||
IdleConnTimeout: settings.idleConnTimeout,
|
IdleConnTimeout: settings.idleConnTimeout,
|
||||||
ResponseHeaderTimeout: settings.responseHeaderTimeout,
|
ResponseHeaderTimeout: settings.responseHeaderTimeout,
|
||||||
}
|
}
|
||||||
if proxyURL != nil {
|
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
|
||||||
transport.Proxy = http.ProxyURL(proxyURL)
|
return nil, err
|
||||||
}
|
}
|
||||||
return transport
|
return transport, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// trackedBody 带跟踪功能的响应体包装器
|
// trackedBody 带跟踪功能的响应体包装器
|
||||||
|
|||||||
@@ -45,8 +45,12 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
|
|||||||
settings := defaultPoolSettings(cfg)
|
settings := defaultPoolSettings(cfg)
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
// 每次迭代都创建新客户端,包含 Transport 分配
|
// 每次迭代都创建新客户端,包含 Transport 分配
|
||||||
|
transport, err := buildUpstreamTransport(settings, parsedProxy)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("创建 Transport 失败: %v", err)
|
||||||
|
}
|
||||||
httpClientSink = &http.Client{
|
httpClientSink = &http.Client{
|
||||||
Transport: buildUpstreamTransport(settings, parsedProxy),
|
Transport: transport,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
|||||||
ProxyURL: proxyURL,
|
ProxyURL: proxyURL,
|
||||||
Timeout: 15 * time.Second,
|
Timeout: 15 * time.Second,
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
ProxyStrict: true,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
|
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||||
@@ -28,7 +27,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||||
testOpenAIAPIURL = "https://api.openai.com/v1/responses"
|
|
||||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -44,8 +42,6 @@ type TestEvent struct {
|
|||||||
// AccountTestService handles account testing operations
|
// AccountTestService handles account testing operations
|
||||||
type AccountTestService struct {
|
type AccountTestService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
oauthService *OAuthService
|
|
||||||
openaiOAuthService *OpenAIOAuthService
|
|
||||||
geminiTokenProvider *GeminiTokenProvider
|
geminiTokenProvider *GeminiTokenProvider
|
||||||
antigravityGatewayService *AntigravityGatewayService
|
antigravityGatewayService *AntigravityGatewayService
|
||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
@@ -54,16 +50,12 @@ type AccountTestService struct {
|
|||||||
// NewAccountTestService creates a new AccountTestService
|
// NewAccountTestService creates a new AccountTestService
|
||||||
func NewAccountTestService(
|
func NewAccountTestService(
|
||||||
accountRepo AccountRepository,
|
accountRepo AccountRepository,
|
||||||
oauthService *OAuthService,
|
|
||||||
openaiOAuthService *OpenAIOAuthService,
|
|
||||||
geminiTokenProvider *GeminiTokenProvider,
|
geminiTokenProvider *GeminiTokenProvider,
|
||||||
antigravityGatewayService *AntigravityGatewayService,
|
antigravityGatewayService *AntigravityGatewayService,
|
||||||
httpUpstream HTTPUpstream,
|
httpUpstream HTTPUpstream,
|
||||||
) *AccountTestService {
|
) *AccountTestService {
|
||||||
return &AccountTestService{
|
return &AccountTestService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
oauthService: oauthService,
|
|
||||||
openaiOAuthService: openaiOAuthService,
|
|
||||||
geminiTokenProvider: geminiTokenProvider,
|
geminiTokenProvider: geminiTokenProvider,
|
||||||
antigravityGatewayService: antigravityGatewayService,
|
antigravityGatewayService: antigravityGatewayService,
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
@@ -183,22 +175,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
|||||||
if authToken == "" {
|
if authToken == "" {
|
||||||
return s.sendErrorAndEnd(c, "No access token available")
|
return s.sendErrorAndEnd(c, "No access token available")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if token needs refresh
|
|
||||||
needRefresh := false
|
|
||||||
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
|
|
||||||
if time.Now().Add(5 * time.Minute).After(*expiresAt) {
|
|
||||||
needRefresh = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if needRefresh && s.oauthService != nil {
|
|
||||||
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
|
|
||||||
}
|
|
||||||
authToken = tokenInfo.AccessToken
|
|
||||||
}
|
|
||||||
} else if account.Type == "apikey" {
|
} else if account.Type == "apikey" {
|
||||||
// API Key - use x-api-key header
|
// API Key - use x-api-key header
|
||||||
useBearer = false
|
useBearer = false
|
||||||
@@ -310,15 +286,6 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
|||||||
return s.sendErrorAndEnd(c, "No access token available")
|
return s.sendErrorAndEnd(c, "No access token available")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if token is expired and refresh if needed
|
|
||||||
if account.IsOpenAITokenExpired() && s.openaiOAuthService != nil {
|
|
||||||
tokenInfo, err := s.openaiOAuthService.RefreshAccountToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
|
|
||||||
}
|
|
||||||
authToken = tokenInfo.AccessToken
|
|
||||||
}
|
|
||||||
|
|
||||||
// OAuth uses ChatGPT internal API
|
// OAuth uses ChatGPT internal API
|
||||||
apiURL = chatgptCodexAPIURL
|
apiURL = chatgptCodexAPIURL
|
||||||
chatgptAccountID = account.GetChatGPTAccountID()
|
chatgptAccountID = account.GetChatGPTAccountID()
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ type antigravityUsageCache struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
apiCacheTTL = 10 * time.Minute
|
apiCacheTTL = 3 * time.Minute
|
||||||
windowStatsCacheTTL = 1 * time.Minute
|
windowStatsCacheTTL = 1 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -630,6 +630,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
|||||||
Concurrency: input.Concurrency,
|
Concurrency: input.Concurrency,
|
||||||
Priority: input.Priority,
|
Priority: input.Priority,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
}
|
}
|
||||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
233
backend/internal/service/gateway_prompt_test.go
Normal file
233
backend/internal/service/gateway_prompt_test.go
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsClaudeCodeClient(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userAgent string
|
||||||
|
metadataUserID string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Claude Code client",
|
||||||
|
userAgent: "claude-cli/1.0.62 (darwin; arm64)",
|
||||||
|
metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Claude Code without version suffix",
|
||||||
|
userAgent: "claude-cli/2.0.0",
|
||||||
|
metadataUserID: "session_abc",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing metadata user_id",
|
||||||
|
userAgent: "claude-cli/1.0.0",
|
||||||
|
metadataUserID: "",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Different user agent",
|
||||||
|
userAgent: "curl/7.68.0",
|
||||||
|
metadataUserID: "user123",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty user agent",
|
||||||
|
userAgent: "",
|
||||||
|
metadataUserID: "user123",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Similar but not Claude CLI",
|
||||||
|
userAgent: "claude-api/1.0.0",
|
||||||
|
metadataUserID: "user123",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := isClaudeCodeClient(tt.userAgent, tt.metadataUserID)
|
||||||
|
require.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
system any
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil system",
|
||||||
|
system: nil,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
system: "",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string with Claude Code prompt",
|
||||||
|
system: claudeCodeSystemPrompt,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string with different content",
|
||||||
|
system: "You are a helpful assistant.",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty array",
|
||||||
|
system: []any{},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array with Claude Code prompt",
|
||||||
|
system: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": claudeCodeSystemPrompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array with Claude Code prompt in second position",
|
||||||
|
system: []any{
|
||||||
|
map[string]any{"type": "text", "text": "First prompt"},
|
||||||
|
map[string]any{"type": "text", "text": claudeCodeSystemPrompt},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array without Claude Code prompt",
|
||||||
|
system: []any{
|
||||||
|
map[string]any{"type": "text", "text": "Custom prompt"},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array with partial match (should not match)",
|
||||||
|
system: []any{
|
||||||
|
map[string]any{"type": "text", "text": "You are Claude"},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := systemIncludesClaudeCodePrompt(tt.system)
|
||||||
|
require.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInjectClaudeCodePrompt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
system any
|
||||||
|
wantSystemLen int
|
||||||
|
wantFirstText string
|
||||||
|
wantSecondText string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil system",
|
||||||
|
body: `{"model":"claude-3"}`,
|
||||||
|
system: nil,
|
||||||
|
wantSystemLen: 1,
|
||||||
|
wantFirstText: claudeCodeSystemPrompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string system",
|
||||||
|
body: `{"model":"claude-3"}`,
|
||||||
|
system: "",
|
||||||
|
wantSystemLen: 1,
|
||||||
|
wantFirstText: claudeCodeSystemPrompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string system",
|
||||||
|
body: `{"model":"claude-3"}`,
|
||||||
|
system: "Custom prompt",
|
||||||
|
wantSystemLen: 2,
|
||||||
|
wantFirstText: claudeCodeSystemPrompt,
|
||||||
|
wantSecondText: "Custom prompt",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string system equals Claude Code prompt",
|
||||||
|
body: `{"model":"claude-3"}`,
|
||||||
|
system: claudeCodeSystemPrompt,
|
||||||
|
wantSystemLen: 1,
|
||||||
|
wantFirstText: claudeCodeSystemPrompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array system",
|
||||||
|
body: `{"model":"claude-3"}`,
|
||||||
|
system: []any{map[string]any{"type": "text", "text": "Custom"}},
|
||||||
|
// Claude Code + Custom = 2
|
||||||
|
wantSystemLen: 2,
|
||||||
|
wantFirstText: claudeCodeSystemPrompt,
|
||||||
|
wantSecondText: "Custom",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array system with existing Claude Code prompt (should dedupe)",
|
||||||
|
body: `{"model":"claude-3"}`,
|
||||||
|
system: []any{
|
||||||
|
map[string]any{"type": "text", "text": claudeCodeSystemPrompt},
|
||||||
|
map[string]any{"type": "text", "text": "Other"},
|
||||||
|
},
|
||||||
|
// Claude Code at start + Other = 2 (deduped)
|
||||||
|
wantSystemLen: 2,
|
||||||
|
wantFirstText: claudeCodeSystemPrompt,
|
||||||
|
wantSecondText: "Other",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty array",
|
||||||
|
body: `{"model":"claude-3"}`,
|
||||||
|
system: []any{},
|
||||||
|
wantSystemLen: 1,
|
||||||
|
wantFirstText: claudeCodeSystemPrompt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := injectClaudeCodePrompt([]byte(tt.body), tt.system)
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
err := json.Unmarshal(result, &parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
system, ok := parsed["system"].([]any)
|
||||||
|
require.True(t, ok, "system should be an array")
|
||||||
|
require.Len(t, system, tt.wantSystemLen)
|
||||||
|
|
||||||
|
first, ok := system[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, tt.wantFirstText, first["text"])
|
||||||
|
require.Equal(t, "text", first["type"])
|
||||||
|
|
||||||
|
// Check cache_control
|
||||||
|
cc, ok := first["cache_control"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "ephemeral", cc["type"])
|
||||||
|
|
||||||
|
if tt.wantSecondText != "" && len(system) > 1 {
|
||||||
|
second, ok := system[1].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, tt.wantSecondText, second["text"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -30,13 +30,15 @@ const (
|
|||||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||||
|
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||||
)
|
)
|
||||||
|
|
||||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||||
var (
|
var (
|
||||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||||
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||||
|
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
|
||||||
)
|
)
|
||||||
|
|
||||||
// allowedHeaders 白名单headers(参考CRS项目)
|
// allowedHeaders 白名单headers(参考CRS项目)
|
||||||
@@ -951,6 +953,76 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
|
||||||
|
// 简化判断:User-Agent 匹配 + metadata.user_id 存在
|
||||||
|
func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
|
||||||
|
if metadataUserID == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return claudeCliUserAgentRe.MatchString(userAgent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
|
||||||
|
// 支持 string 和 []any 两种格式
|
||||||
|
func systemIncludesClaudeCodePrompt(system any) bool {
|
||||||
|
switch v := system.(type) {
|
||||||
|
case string:
|
||||||
|
return v == claudeCodeSystemPrompt
|
||||||
|
case []any:
|
||||||
|
for _, item := range v {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
|
||||||
|
// 处理 null、字符串、数组三种格式
|
||||||
|
func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||||
|
claudeCodeBlock := map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": claudeCodeSystemPrompt,
|
||||||
|
"cache_control": map[string]string{"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var newSystem []any
|
||||||
|
|
||||||
|
switch v := system.(type) {
|
||||||
|
case nil:
|
||||||
|
newSystem = []any{claudeCodeBlock}
|
||||||
|
case string:
|
||||||
|
if v == "" || v == claudeCodeSystemPrompt {
|
||||||
|
newSystem = []any{claudeCodeBlock}
|
||||||
|
} else {
|
||||||
|
newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": v}}
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
newSystem = make([]any, 0, len(v)+1)
|
||||||
|
newSystem = append(newSystem, claudeCodeBlock)
|
||||||
|
for _, item := range v {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newSystem = append(newSystem, item)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
newSystem = []any{claudeCodeBlock}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := sjson.SetBytes(body, "system", newSystem)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Warning: failed to inject Claude Code prompt: %v", err)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// Forward 转发请求到Claude API
|
// Forward 转发请求到Claude API
|
||||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
@@ -962,16 +1034,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
reqModel := parsed.Model
|
reqModel := parsed.Model
|
||||||
reqStream := parsed.Stream
|
reqStream := parsed.Stream
|
||||||
|
|
||||||
if !parsed.HasSystem {
|
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
|
||||||
body, _ = sjson.SetBytes(body, "system", []any{
|
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
|
||||||
map[string]any{
|
if account.IsOAuth() &&
|
||||||
"type": "text",
|
!isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
|
||||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
!strings.Contains(strings.ToLower(reqModel), "haiku") &&
|
||||||
"cache_control": map[string]string{
|
!systemIncludesClaudeCodePrompt(parsed.System) {
|
||||||
"type": "ephemeral",
|
body = injectClaudeCodePrompt(body, parsed.System)
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 应用模型映射(仅对apikey类型账号)
|
// 应用模型映射(仅对apikey类型账号)
|
||||||
|
|||||||
@@ -59,7 +59,7 @@
|
|||||||
<input
|
<input
|
||||||
type="date"
|
type="date"
|
||||||
v-model="localStartDate"
|
v-model="localStartDate"
|
||||||
:max="localEndDate || today"
|
:max="localEndDate || tomorrow"
|
||||||
class="date-picker-input"
|
class="date-picker-input"
|
||||||
@change="onDateChange"
|
@change="onDateChange"
|
||||||
/>
|
/>
|
||||||
@@ -85,7 +85,7 @@
|
|||||||
type="date"
|
type="date"
|
||||||
v-model="localEndDate"
|
v-model="localEndDate"
|
||||||
:min="localStartDate"
|
:min="localStartDate"
|
||||||
:max="today"
|
:max="tomorrow"
|
||||||
class="date-picker-input"
|
class="date-picker-input"
|
||||||
@change="onDateChange"
|
@change="onDateChange"
|
||||||
/>
|
/>
|
||||||
@@ -144,6 +144,14 @@ const today = computed(() => {
|
|||||||
return `${year}-${month}-${day}`
|
return `${year}-${month}-${day}`
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Tomorrow's date - used for max date to handle timezone differences
|
||||||
|
// When user is in a timezone behind the server, "today" on server might be "tomorrow" locally
|
||||||
|
const tomorrow = computed(() => {
|
||||||
|
const d = new Date()
|
||||||
|
d.setDate(d.getDate() + 1)
|
||||||
|
return formatDateToString(d)
|
||||||
|
})
|
||||||
|
|
||||||
// Helper function to format date to YYYY-MM-DD using local timezone
|
// Helper function to format date to YYYY-MM-DD using local timezone
|
||||||
const formatDateToString = (date: Date): string => {
|
const formatDateToString = (date: Date): string => {
|
||||||
const year = date.getFullYear()
|
const year = date.getFullYear()
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Core Type Definitions for TianShuAPI Frontend
|
* Core Type Definitions for Sub2API Frontend
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// ==================== User & Auth Types ====================
|
// ==================== User & Auth Types ====================
|
||||||
@@ -290,7 +290,7 @@ export interface UpdateGroupRequest {
|
|||||||
export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity'
|
export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity'
|
||||||
export type AccountType = 'oauth' | 'setup-token' | 'apikey'
|
export type AccountType = 'oauth' | 'setup-token' | 'apikey'
|
||||||
export type OAuthAddMethod = 'oauth' | 'setup-token'
|
export type OAuthAddMethod = 'oauth' | 'setup-token'
|
||||||
export type ProxyProtocol = 'http' | 'https' | 'socks5'
|
export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h'
|
||||||
|
|
||||||
// Claude Model type (returned by /v1/models and account models API)
|
// Claude Model type (returned by /v1/models and account models API)
|
||||||
export interface ClaudeModel {
|
export interface ClaudeModel {
|
||||||
|
|||||||
@@ -90,7 +90,7 @@
|
|||||||
<template #cell-protocol="{ value }">
|
<template #cell-protocol="{ value }">
|
||||||
<span
|
<span
|
||||||
v-if="value"
|
v-if="value"
|
||||||
:class="['badge', value === 'socks5' ? 'badge-primary' : 'badge-gray']"
|
:class="['badge', value.startsWith('socks5') ? 'badge-primary' : 'badge-gray']"
|
||||||
>
|
>
|
||||||
{{ value.toUpperCase() }}
|
{{ value.toUpperCase() }}
|
||||||
</span>
|
</span>
|
||||||
@@ -628,7 +628,8 @@ const protocolOptions = computed(() => [
|
|||||||
{ value: '', label: t('admin.proxies.allProtocols') },
|
{ value: '', label: t('admin.proxies.allProtocols') },
|
||||||
{ value: 'http', label: 'HTTP' },
|
{ value: 'http', label: 'HTTP' },
|
||||||
{ value: 'https', label: 'HTTPS' },
|
{ value: 'https', label: 'HTTPS' },
|
||||||
{ value: 'socks5', label: 'SOCKS5' }
|
{ value: 'socks5', label: 'SOCKS5' },
|
||||||
|
{ value: 'socks5h', label: 'SOCKS5H' }
|
||||||
])
|
])
|
||||||
|
|
||||||
const statusOptions = computed(() => [
|
const statusOptions = computed(() => [
|
||||||
@@ -641,7 +642,8 @@ const statusOptions = computed(() => [
|
|||||||
const protocolSelectOptions = [
|
const protocolSelectOptions = [
|
||||||
{ value: 'http', label: 'HTTP' },
|
{ value: 'http', label: 'HTTP' },
|
||||||
{ value: 'https', label: 'HTTPS' },
|
{ value: 'https', label: 'HTTPS' },
|
||||||
{ value: 'socks5', label: 'SOCKS5' }
|
{ value: 'socks5', label: 'SOCKS5' },
|
||||||
|
{ value: 'socks5h', label: 'SOCKS5H (服务端解析DNS)' }
|
||||||
]
|
]
|
||||||
|
|
||||||
const editStatusOptions = computed(() => [
|
const editStatusOptions = computed(() => [
|
||||||
@@ -798,8 +800,8 @@ const parseProxyUrl = (
|
|||||||
const trimmed = line.trim()
|
const trimmed = line.trim()
|
||||||
if (!trimmed) return null
|
if (!trimmed) return null
|
||||||
|
|
||||||
// Regex to parse proxy URL
|
// Regex to parse proxy URL (supports http, https, socks5, socks5h)
|
||||||
const regex = /^(https?|socks5):\/\/(?:([^:@]+):([^@]+)@)?([^:]+):(\d+)$/i
|
const regex = /^(https?|socks5h?):\/\/(?:([^:@]+):([^@]+)@)?([^:]+):(\d+)$/i
|
||||||
const match = trimmed.match(regex)
|
const match = trimmed.match(regex)
|
||||||
|
|
||||||
if (!match) return null
|
if (!match) return null
|
||||||
|
|||||||
@@ -888,13 +888,17 @@ const formatLocalDate = (date: Date): string => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize date range immediately
|
// Initialize date range immediately
|
||||||
|
// Use tomorrow as end date to handle timezone differences between client and server
|
||||||
|
// e.g., when server is in Asia/Shanghai and client is in America/Chicago
|
||||||
const now = new Date()
|
const now = new Date()
|
||||||
|
const tomorrow = new Date(now)
|
||||||
|
tomorrow.setDate(tomorrow.getDate() + 1)
|
||||||
const weekAgo = new Date(now)
|
const weekAgo = new Date(now)
|
||||||
weekAgo.setDate(weekAgo.getDate() - 6)
|
weekAgo.setDate(weekAgo.getDate() - 6)
|
||||||
|
|
||||||
// Date range state
|
// Date range state
|
||||||
const startDate = ref(formatLocalDate(weekAgo))
|
const startDate = ref(formatLocalDate(weekAgo))
|
||||||
const endDate = ref(formatLocalDate(now))
|
const endDate = ref(formatLocalDate(tomorrow))
|
||||||
|
|
||||||
const filters = ref<AdminUsageQueryParams>({
|
const filters = ref<AdminUsageQueryParams>({
|
||||||
user_id: undefined,
|
user_id: undefined,
|
||||||
@@ -1215,12 +1219,14 @@ const resetFilters = () => {
|
|||||||
end_date: undefined
|
end_date: undefined
|
||||||
}
|
}
|
||||||
granularity.value = 'day'
|
granularity.value = 'day'
|
||||||
// Reset date range to default (last 7 days)
|
// Reset date range to default (last 7 days, with tomorrow as end to handle timezone differences)
|
||||||
const now = new Date()
|
const now = new Date()
|
||||||
|
const tomorrowDate = new Date(now)
|
||||||
|
tomorrowDate.setDate(tomorrowDate.getDate() + 1)
|
||||||
const weekAgo = new Date(now)
|
const weekAgo = new Date(now)
|
||||||
weekAgo.setDate(weekAgo.getDate() - 6)
|
weekAgo.setDate(weekAgo.getDate() - 6)
|
||||||
startDate.value = formatLocalDate(weekAgo)
|
startDate.value = formatLocalDate(weekAgo)
|
||||||
endDate.value = formatLocalDate(now)
|
endDate.value = formatLocalDate(tomorrowDate)
|
||||||
filters.value.start_date = startDate.value
|
filters.value.start_date = startDate.value
|
||||||
filters.value.end_date = endDate.value
|
filters.value.end_date = endDate.value
|
||||||
pagination.value.page = 1
|
pagination.value.page = 1
|
||||||
|
|||||||
Reference in New Issue
Block a user