diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 75e84e13..20e33317 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -62,7 +62,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
- usageService := service.NewUsageService(usageLogRepository, userRepository, client)
+ usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
billingCache := repository.NewBillingCache(redisClient)
@@ -88,8 +88,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
- tempUnschedCache := repository.NewTempUnschedCache(redisClient)
- rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache)
+ rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
@@ -100,8 +99,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
- antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
- accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream)
+ antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
+ accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go
index 0480b312..99557f9a 100644
--- a/backend/internal/handler/admin/proxy_handler.go
+++ b/backend/internal/handler/admin/proxy_handler.go
@@ -26,7 +26,7 @@ func NewProxyHandler(adminService service.AdminService) *ProxyHandler {
// CreateProxyRequest represents create proxy request
type CreateProxyRequest struct {
Name string `json:"name" binding:"required"`
- Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
+ Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required,min=1,max=65535"`
Username string `json:"username"`
@@ -36,7 +36,7 @@ type CreateProxyRequest struct {
// UpdateProxyRequest represents update proxy request
type UpdateProxyRequest struct {
Name string `json:"name"`
- Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5"`
+ Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5 socks5h"`
Host string `json:"host"`
Port int `json:"port" binding:"omitempty,min=1,max=65535"`
Username string `json:"username"`
@@ -255,7 +255,7 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
// BatchCreateProxyItem represents a single proxy in batch create request
type BatchCreateProxyItem struct {
- Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
+ Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required,min=1,max=65535"`
Username string `json:"username"`
diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go
index 1028fb84..8a81c09a 100644
--- a/backend/internal/pkg/httpclient/pool.go
+++ b/backend/internal/pkg/httpclient/pool.go
@@ -11,22 +11,20 @@
// 新实现使用统一的客户端池:
// 1. 相同配置复用同一 http.Client 实例
// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销
-// 3. 支持 HTTP/HTTPS/SOCKS5 代理
-// 4. 支持严格代理模式(代理失败则返回错误)
+// 3. 支持 HTTP/HTTPS/SOCKS5/SOCKS5H 代理
+// 4. 代理配置失败时直接返回错误,不会回退到直连(避免 IP 关联风险)
package httpclient
import (
- "context"
"crypto/tls"
"fmt"
- "net"
"net/http"
"net/url"
"strings"
"sync"
"time"
- "golang.org/x/net/proxy"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
)
// Transport 连接池默认配置
@@ -38,11 +36,10 @@ const (
// Options 定义共享 HTTP 客户端的构建参数
type Options struct {
- ProxyURL string // 代理 URL(支持 http/https/socks5)
+ ProxyURL string // 代理 URL(支持 http/https/socks5/socks5h)
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
- ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
// 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100)
@@ -55,6 +52,7 @@ var sharedClients sync.Map
// GetClient 返回共享的 HTTP 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
+// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
func GetClient(opts Options) (*http.Client, error) {
key := buildClientKey(opts)
if cached, ok := sharedClients.Load(key); ok {
@@ -65,12 +63,7 @@ func GetClient(opts Options) (*http.Client, error) {
client, err := buildClient(opts)
if err != nil {
- if opts.ProxyStrict {
- return nil, err
- }
- fallback := opts
- fallback.ProxyURL = ""
- client, _ = buildClient(fallback)
+ return nil, err
}
actual, _ := sharedClients.LoadOrStore(key, client)
@@ -125,31 +118,19 @@ func buildTransport(opts Options) (*http.Transport, error) {
return nil, err
}
- switch strings.ToLower(parsed.Scheme) {
- case "http", "https":
- transport.Proxy = http.ProxyURL(parsed)
- case "socks5", "socks5h":
- dialer, err := proxy.FromURL(parsed, proxy.Direct)
- if err != nil {
- return nil, err
- }
- transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
- return dialer.Dial(network, addr)
- }
- default:
- return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme)
+ if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
+ return nil, err
}
return transport, nil
}
func buildClientKey(opts Options) string {
- return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d",
+ return fmt.Sprintf("%s|%s|%s|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
- opts.ProxyStrict,
opts.MaxIdleConns,
opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost,
diff --git a/backend/internal/pkg/proxyutil/dialer.go b/backend/internal/pkg/proxyutil/dialer.go
new file mode 100644
index 00000000..91b224a2
--- /dev/null
+++ b/backend/internal/pkg/proxyutil/dialer.go
@@ -0,0 +1,62 @@
+// Package proxyutil 提供统一的代理配置功能
+//
+// 支持的代理协议:
+// - HTTP/HTTPS: 通过 Transport.Proxy 设置
+// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS)
+package proxyutil
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "net/url"
+ "strings"
+
+ "golang.org/x/net/proxy"
+)
+
+// ConfigureTransportProxy 根据代理 URL 配置 Transport
+//
+// 支持的协议:
+// - http/https: 设置 transport.Proxy
+// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS)
+//
+// 参数:
+// - transport: 需要配置的 http.Transport
+// - proxyURL: 代理地址,nil 表示直连
+//
+// 返回:
+// - error: 代理配置错误(协议不支持或 dialer 创建失败)
+func ConfigureTransportProxy(transport *http.Transport, proxyURL *url.URL) error {
+ if proxyURL == nil {
+ return nil
+ }
+
+ scheme := strings.ToLower(proxyURL.Scheme)
+ switch scheme {
+ case "http", "https":
+ transport.Proxy = http.ProxyURL(proxyURL)
+ return nil
+
+ case "socks5", "socks5h":
+ dialer, err := proxy.FromURL(proxyURL, proxy.Direct)
+ if err != nil {
+ return fmt.Errorf("create socks5 dialer: %w", err)
+ }
+ // 优先使用支持 context 的 DialContext,以支持请求取消和超时
+ if contextDialer, ok := dialer.(proxy.ContextDialer); ok {
+ transport.DialContext = contextDialer.DialContext
+ } else {
+ // 回退路径:如果 dialer 不支持 ContextDialer,则包装为简单的 DialContext
+ // 注意:此回退不支持请求取消和超时控制
+ transport.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
+ return dialer.Dial(network, addr)
+ }
+ }
+ return nil
+
+ default:
+ return fmt.Errorf("unsupported proxy scheme: %s", scheme)
+ }
+}
diff --git a/backend/internal/pkg/proxyutil/dialer_test.go b/backend/internal/pkg/proxyutil/dialer_test.go
new file mode 100644
index 00000000..f153cc9f
--- /dev/null
+++ b/backend/internal/pkg/proxyutil/dialer_test.go
@@ -0,0 +1,204 @@
+package proxyutil
+
+import (
+ "net/http"
+ "net/url"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestConfigureTransportProxy_Nil(t *testing.T) {
+ transport := &http.Transport{}
+ err := ConfigureTransportProxy(transport, nil)
+
+ require.NoError(t, err)
+ assert.Nil(t, transport.Proxy, "nil proxy should not set Proxy")
+ assert.Nil(t, transport.DialContext, "nil proxy should not set DialContext")
+}
+
+func TestConfigureTransportProxy_HTTP(t *testing.T) {
+ transport := &http.Transport{}
+ proxyURL, _ := url.Parse("http://proxy.example.com:8080")
+
+ err := ConfigureTransportProxy(transport, proxyURL)
+
+ require.NoError(t, err)
+ assert.NotNil(t, transport.Proxy, "HTTP proxy should set Proxy")
+ assert.Nil(t, transport.DialContext, "HTTP proxy should not set DialContext")
+}
+
+func TestConfigureTransportProxy_HTTPS(t *testing.T) {
+ transport := &http.Transport{}
+ proxyURL, _ := url.Parse("https://secure-proxy.example.com:8443")
+
+ err := ConfigureTransportProxy(transport, proxyURL)
+
+ require.NoError(t, err)
+ assert.NotNil(t, transport.Proxy, "HTTPS proxy should set Proxy")
+ assert.Nil(t, transport.DialContext, "HTTPS proxy should not set DialContext")
+}
+
+func TestConfigureTransportProxy_SOCKS5(t *testing.T) {
+ transport := &http.Transport{}
+ proxyURL, _ := url.Parse("socks5://socks.example.com:1080")
+
+ err := ConfigureTransportProxy(transport, proxyURL)
+
+ require.NoError(t, err)
+ assert.Nil(t, transport.Proxy, "SOCKS5 proxy should not set Proxy")
+ assert.NotNil(t, transport.DialContext, "SOCKS5 proxy should set DialContext")
+}
+
+func TestConfigureTransportProxy_SOCKS5H(t *testing.T) {
+ transport := &http.Transport{}
+ proxyURL, _ := url.Parse("socks5h://socks.example.com:1080")
+
+ err := ConfigureTransportProxy(transport, proxyURL)
+
+ require.NoError(t, err)
+ assert.Nil(t, transport.Proxy, "SOCKS5H proxy should not set Proxy")
+ assert.NotNil(t, transport.DialContext, "SOCKS5H proxy should set DialContext")
+}
+
+func TestConfigureTransportProxy_CaseInsensitive(t *testing.T) {
+ testCases := []struct {
+ scheme string
+ useProxy bool // true = uses Transport.Proxy, false = uses DialContext
+ }{
+ {"HTTP://proxy.example.com:8080", true},
+ {"Http://proxy.example.com:8080", true},
+ {"HTTPS://proxy.example.com:8443", true},
+ {"Https://proxy.example.com:8443", true},
+ {"SOCKS5://socks.example.com:1080", false},
+ {"Socks5://socks.example.com:1080", false},
+ {"SOCKS5H://socks.example.com:1080", false},
+ {"Socks5h://socks.example.com:1080", false},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.scheme, func(t *testing.T) {
+ transport := &http.Transport{}
+ proxyURL, _ := url.Parse(tc.scheme)
+
+ err := ConfigureTransportProxy(transport, proxyURL)
+
+ require.NoError(t, err)
+ if tc.useProxy {
+ assert.NotNil(t, transport.Proxy)
+ assert.Nil(t, transport.DialContext)
+ } else {
+ assert.Nil(t, transport.Proxy)
+ assert.NotNil(t, transport.DialContext)
+ }
+ })
+ }
+}
+
+func TestConfigureTransportProxy_Unsupported(t *testing.T) {
+ testCases := []string{
+ "ftp://ftp.example.com",
+ "file:///path/to/file",
+ "unknown://example.com",
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc, func(t *testing.T) {
+ transport := &http.Transport{}
+ proxyURL, _ := url.Parse(tc)
+
+ err := ConfigureTransportProxy(transport, proxyURL)
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "unsupported proxy scheme")
+ })
+ }
+}
+
+func TestConfigureTransportProxy_WithAuth(t *testing.T) {
+ transport := &http.Transport{}
+ proxyURL, _ := url.Parse("socks5://user:password@socks.example.com:1080")
+
+ err := ConfigureTransportProxy(transport, proxyURL)
+
+ require.NoError(t, err)
+ assert.NotNil(t, transport.DialContext, "SOCKS5 with auth should set DialContext")
+}
+
+func TestConfigureTransportProxy_EmptyScheme(t *testing.T) {
+ transport := &http.Transport{}
+ // 空 scheme 的 URL
+ proxyURL := &url.URL{Host: "proxy.example.com:8080"}
+
+ err := ConfigureTransportProxy(transport, proxyURL)
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "unsupported proxy scheme")
+}
+
+func TestConfigureTransportProxy_PreservesExistingConfig(t *testing.T) {
+ // 验证代理配置不会覆盖 Transport 的其他配置
+ transport := &http.Transport{
+ MaxIdleConns: 100,
+ MaxIdleConnsPerHost: 10,
+ }
+ proxyURL, _ := url.Parse("socks5://socks.example.com:1080")
+
+ err := ConfigureTransportProxy(transport, proxyURL)
+
+ require.NoError(t, err)
+ assert.Equal(t, 100, transport.MaxIdleConns, "MaxIdleConns should be preserved")
+ assert.Equal(t, 10, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost should be preserved")
+ assert.NotNil(t, transport.DialContext, "DialContext should be set")
+}
+
+func TestConfigureTransportProxy_IPv6(t *testing.T) {
+ testCases := []struct {
+ name string
+ proxyURL string
+ }{
+ {"SOCKS5H with IPv6 loopback", "socks5h://[::1]:1080"},
+ {"SOCKS5 with full IPv6", "socks5://[2001:db8::1]:1080"},
+ {"HTTP with IPv6", "http://[::1]:8080"},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ transport := &http.Transport{}
+ proxyURL, err := url.Parse(tc.proxyURL)
+ require.NoError(t, err, "URL should be parseable")
+
+ err = ConfigureTransportProxy(transport, proxyURL)
+ require.NoError(t, err)
+ })
+ }
+}
+
+func TestConfigureTransportProxy_SpecialCharsInPassword(t *testing.T) {
+ testCases := []struct {
+ name string
+ proxyURL string
+ }{
+ // 密码包含 @ 符号(URL 编码为 %40)
+ {"password with @", "socks5://user:p%40ssword@proxy.example.com:1080"},
+ // 密码包含 : 符号(URL 编码为 %3A)
+ {"password with :", "socks5://user:pass%3Aword@proxy.example.com:1080"},
+ // 密码包含 / 符号(URL 编码为 %2F)
+ {"password with /", "socks5://user:pass%2Fword@proxy.example.com:1080"},
+ // 复杂密码
+ {"complex password", "socks5h://admin:P%40ss%3Aw0rd%2F123@proxy.example.com:1080"},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ transport := &http.Transport{}
+ proxyURL, err := url.Parse(tc.proxyURL)
+ require.NoError(t, err, "URL should be parseable")
+
+ err = ConfigureTransportProxy(transport, proxyURL)
+ require.NoError(t, err)
+ assert.NotNil(t, transport.DialContext, "SOCKS5 should set DialContext")
+ })
+ }
+}
diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go
index b03b5415..35e7f535 100644
--- a/backend/internal/repository/claude_oauth_service.go
+++ b/backend/internal/repository/claude_oauth_service.go
@@ -233,11 +233,17 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
}
func createReqClient(proxyURL string) *req.Client {
- return getSharedReqClient(reqClientOptions{
- ProxyURL: proxyURL,
- Timeout: 60 * time.Second,
- Impersonate: true,
- })
+ // 禁用 CookieJar,确保每次授权都是干净的会话
+ client := req.C().
+ SetTimeout(60 * time.Second).
+ ImpersonateChrome().
+ SetCookieJar(nil) // 禁用 CookieJar
+
+ if strings.TrimSpace(proxyURL) != "" {
+ client.SetProxyURL(strings.TrimSpace(proxyURL))
+ }
+
+ return client
}
func prefix(s string, n int) string {
diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go
index 180844b5..f0669979 100644
--- a/backend/internal/repository/http_upstream.go
+++ b/backend/internal/repository/http_upstream.go
@@ -13,6 +13,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -225,7 +226,12 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
// 缓存未命中或需要重建,创建新客户端
settings := s.resolvePoolSettings(isolation, accountConcurrency)
- client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)}
+ transport, err := buildUpstreamTransport(settings, parsedProxy)
+ if err != nil {
+ s.mu.Unlock()
+ return nil, fmt.Errorf("build transport: %w", err)
+ }
+ client := &http.Client{Transport: transport}
entry := &upstreamClientEntry{
client: client,
proxyKey: proxyKey,
@@ -548,6 +554,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings {
//
// 返回:
// - *http.Transport: 配置好的 Transport 实例
+// - error: 代理配置错误
//
// Transport 参数说明:
// - MaxIdleConns: 所有主机的最大空闲连接总数
@@ -555,7 +562,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings {
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
-func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Transport {
+func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) {
transport := &http.Transport{
MaxIdleConns: settings.maxIdleConns,
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
@@ -563,10 +570,10 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Tran
IdleConnTimeout: settings.idleConnTimeout,
ResponseHeaderTimeout: settings.responseHeaderTimeout,
}
- if proxyURL != nil {
- transport.Proxy = http.ProxyURL(proxyURL)
+ if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
+ return nil, err
}
- return transport
+ return transport, nil
}
// trackedBody 带跟踪功能的响应体包装器
diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go
index 3219c6da..1e7430a3 100644
--- a/backend/internal/repository/http_upstream_benchmark_test.go
+++ b/backend/internal/repository/http_upstream_benchmark_test.go
@@ -45,8 +45,12 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
settings := defaultPoolSettings(cfg)
for i := 0; i < b.N; i++ {
// 每次迭代都创建新客户端,包含 Transport 分配
+ transport, err := buildUpstreamTransport(settings, parsedProxy)
+ if err != nil {
+ b.Fatalf("创建 Transport 失败: %v", err)
+ }
httpClientSink = &http.Client{
- Transport: buildUpstreamTransport(settings, parsedProxy),
+ Transport: transport,
}
}
})
diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go
index 8b288c3c..f5f625f9 100644
--- a/backend/internal/repository/proxy_probe_service.go
+++ b/backend/internal/repository/proxy_probe_service.go
@@ -27,7 +27,6 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
ProxyURL: proxyURL,
Timeout: 15 * time.Second,
InsecureSkipVerify: true,
- ProxyStrict: true,
})
if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index 47335019..4df87e9e 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -13,7 +13,6 @@ import (
"net/http"
"regexp"
"strings"
- "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
@@ -28,7 +27,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
- testOpenAIAPIURL = "https://api.openai.com/v1/responses"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
)
@@ -44,8 +42,6 @@ type TestEvent struct {
// AccountTestService handles account testing operations
type AccountTestService struct {
accountRepo AccountRepository
- oauthService *OAuthService
- openaiOAuthService *OpenAIOAuthService
geminiTokenProvider *GeminiTokenProvider
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
@@ -54,16 +50,12 @@ type AccountTestService struct {
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(
accountRepo AccountRepository,
- oauthService *OAuthService,
- openaiOAuthService *OpenAIOAuthService,
geminiTokenProvider *GeminiTokenProvider,
antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream,
) *AccountTestService {
return &AccountTestService{
accountRepo: accountRepo,
- oauthService: oauthService,
- openaiOAuthService: openaiOAuthService,
geminiTokenProvider: geminiTokenProvider,
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
@@ -183,22 +175,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
if authToken == "" {
return s.sendErrorAndEnd(c, "No access token available")
}
-
- // Check if token needs refresh
- needRefresh := false
- if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
- if time.Now().Add(5 * time.Minute).After(*expiresAt) {
- needRefresh = true
- }
- }
-
- if needRefresh && s.oauthService != nil {
- tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
- if err != nil {
- return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
- }
- authToken = tokenInfo.AccessToken
- }
} else if account.Type == "apikey" {
// API Key - use x-api-key header
useBearer = false
@@ -296,64 +272,77 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
}
}
- // Set SSE headers early
+ // Determine authentication method and API URL
+ var authToken string
+ var apiURL string
+ var isOAuth bool
+ var chatgptAccountID string
+
+ if account.IsOAuth() {
+ isOAuth = true
+ // OAuth - use Bearer token with ChatGPT internal API
+ authToken = account.GetOpenAIAccessToken()
+ if authToken == "" {
+ return s.sendErrorAndEnd(c, "No access token available")
+ }
+
+ // OAuth uses ChatGPT internal API
+ apiURL = chatgptCodexAPIURL
+ chatgptAccountID = account.GetChatGPTAccountID()
+ } else if account.Type == "apikey" {
+ // API Key - use Platform API
+ authToken = account.GetOpenAIApiKey()
+ if authToken == "" {
+ return s.sendErrorAndEnd(c, "No API key available")
+ }
+
+ baseURL := account.GetOpenAIBaseURL()
+ if baseURL == "" {
+ baseURL = "https://api.openai.com"
+ }
+ apiURL = strings.TrimSuffix(baseURL, "/") + "/responses"
+ } else {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
+ }
+
+ // Set SSE headers
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
+ // Create OpenAI Responses API payload
+ payload := createOpenAITestPayload(testModelID, isOAuth)
+ payloadBytes, _ := json.Marshal(payload)
+
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
+ req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
+ if err != nil {
+ return s.sendErrorAndEnd(c, "Failed to create request")
+ }
+
+ // Set common headers
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+authToken)
+
+ // Set OAuth-specific headers for ChatGPT internal API
+ if isOAuth {
+ req.Host = "chatgpt.com"
+ req.Header.Set("accept", "text/event-stream")
+ if chatgptAccountID != "" {
+ req.Header.Set("chatgpt-account-id", chatgptAccountID)
+ }
+ }
+
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
- if account.IsOAuth() {
- // OAuth - use ChatGPT internal API (Responses API)
- return s.testOpenAIOAuthAccount(c, ctx, account, testModelID, proxyURL)
- }
-
- // API Key - try Chat Completions API first, fallback to Responses API
- return s.testOpenAIApiKeyAccount(c, ctx, account, testModelID, proxyURL)
-}
-
-// testOpenAIOAuthAccount tests OAuth account using ChatGPT internal API
-func (s *AccountTestService) testOpenAIOAuthAccount(c *gin.Context, ctx context.Context, account *Account, testModelID, proxyURL string) error {
- authToken := account.GetOpenAIAccessToken()
- if authToken == "" {
- return s.sendErrorAndEnd(c, "No access token available")
- }
-
- // Check if token is expired and refresh if needed
- if account.IsOpenAITokenExpired() && s.openaiOAuthService != nil {
- tokenInfo, err := s.openaiOAuthService.RefreshAccountToken(ctx, account)
- if err != nil {
- return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
- }
- authToken = tokenInfo.AccessToken
- }
-
- // Create Responses API payload
- payload := createOpenAITestPayload(testModelID, true)
- payloadBytes, _ := json.Marshal(payload)
-
- req, err := http.NewRequestWithContext(ctx, "POST", chatgptCodexAPIURL, bytes.NewReader(payloadBytes))
- if err != nil {
- return s.sendErrorAndEnd(c, "Failed to create request")
- }
-
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Authorization", "Bearer "+authToken)
- req.Host = "chatgpt.com"
- req.Header.Set("accept", "text/event-stream")
- if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
- req.Header.Set("chatgpt-account-id", chatgptAccountID)
- }
-
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
@@ -365,153 +354,10 @@ func (s *AccountTestService) testOpenAIOAuthAccount(c *gin.Context, ctx context.
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
+ // Process SSE stream
return s.processOpenAIStream(c, resp.Body)
}
-// testOpenAIApiKeyAccount tests API Key account, trying Chat Completions first, then Responses API
-func (s *AccountTestService) testOpenAIApiKeyAccount(c *gin.Context, ctx context.Context, account *Account, testModelID, proxyURL string) error {
- authToken := account.GetOpenAIApiKey()
- if authToken == "" {
- return s.sendErrorAndEnd(c, "No API key available")
- }
-
- baseURL := account.GetOpenAIBaseURL()
- if baseURL == "" {
- baseURL = "https://api.openai.com"
- }
- baseURL = strings.TrimSuffix(baseURL, "/")
-
- // Try Chat Completions API first (more compatible with third-party proxies)
- chatCompletionsURL := baseURL + "/v1/chat/completions"
- chatPayload := createOpenAIChatCompletionsPayload(testModelID)
- chatPayloadBytes, _ := json.Marshal(chatPayload)
-
- req, err := http.NewRequestWithContext(ctx, "POST", chatCompletionsURL, bytes.NewReader(chatPayloadBytes))
- if err != nil {
- return s.sendErrorAndEnd(c, "Failed to create request")
- }
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Authorization", "Bearer "+authToken)
-
- resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
- if err != nil {
- // Network error, try Responses API
- s.sendEvent(c, TestEvent{Type: "info", Text: "Chat Completions API failed, trying Responses API..."})
- return s.tryOpenAIResponsesAPI(c, ctx, account, testModelID, baseURL, authToken, proxyURL)
- }
- defer func() { _ = resp.Body.Close() }()
-
- if resp.StatusCode == http.StatusOK {
- // Chat Completions API succeeded
- return s.processOpenAIChatCompletionsStream(c, resp.Body)
- }
-
- // Chat Completions API failed, try Responses API
- _ = resp.Body.Close()
- s.sendEvent(c, TestEvent{Type: "info", Text: "Chat Completions API failed, trying Responses API..."})
- return s.tryOpenAIResponsesAPI(c, ctx, account, testModelID, baseURL, authToken, proxyURL)
-}
-
-// tryOpenAIResponsesAPI tries the OpenAI Responses API as fallback
-func (s *AccountTestService) tryOpenAIResponsesAPI(c *gin.Context, ctx context.Context, account *Account, testModelID, baseURL, authToken, proxyURL string) error {
- responsesURL := baseURL + "/v1/responses"
- payload := createOpenAITestPayload(testModelID, false)
- payloadBytes, _ := json.Marshal(payload)
-
- req, err := http.NewRequestWithContext(ctx, "POST", responsesURL, bytes.NewReader(payloadBytes))
- if err != nil {
- return s.sendErrorAndEnd(c, "Failed to create request")
- }
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Authorization", "Bearer "+authToken)
-
- resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
- if err != nil {
- return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
- }
- defer func() { _ = resp.Body.Close() }()
-
- if resp.StatusCode != http.StatusOK {
- body, _ := io.ReadAll(resp.Body)
- return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
- }
-
- return s.processOpenAIStream(c, resp.Body)
-}
-
-// createOpenAIChatCompletionsPayload creates a test payload for OpenAI Chat Completions API
-func createOpenAIChatCompletionsPayload(modelID string) map[string]any {
- return map[string]any{
- "model": modelID,
- "messages": []map[string]any{
- {
- "role": "user",
- "content": "hi",
- },
- },
- "stream": true,
- "max_tokens": 100,
- }
-}
-
-// processOpenAIChatCompletionsStream processes the SSE stream from OpenAI Chat Completions API
-func (s *AccountTestService) processOpenAIChatCompletionsStream(c *gin.Context, body io.Reader) error {
- reader := bufio.NewReader(body)
-
- for {
- line, err := reader.ReadString('\n')
- if err != nil {
- if err == io.EOF {
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
- }
- return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
- }
-
- line = strings.TrimSpace(line)
- if line == "" || !sseDataPrefix.MatchString(line) {
- continue
- }
-
- jsonStr := sseDataPrefix.ReplaceAllString(line, "")
- if jsonStr == "[DONE]" {
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
- }
-
- var data map[string]any
- if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
- continue
- }
-
- // Handle Chat Completions format: choices[0].delta.content
- if choices, ok := data["choices"].([]any); ok && len(choices) > 0 {
- if choice, ok := choices[0].(map[string]any); ok {
- // Check finish_reason
- if finishReason, ok := choice["finish_reason"].(string); ok && finishReason != "" {
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
- }
- // Extract content from delta
- if delta, ok := choice["delta"].(map[string]any); ok {
- if content, ok := delta["content"].(string); ok && content != "" {
- s.sendEvent(c, TestEvent{Type: "content", Text: content})
- }
- }
- }
- }
-
- // Handle error
- if errData, ok := data["error"].(map[string]any); ok {
- errorMsg := "Unknown error"
- if msg, ok := errData["message"].(string); ok {
- errorMsg = msg
- }
- return s.sendErrorAndEnd(c, errorMsg)
- }
- }
-}
-
// testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context()
@@ -748,11 +594,11 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
}
line = strings.TrimSpace(line)
- if line == "" || !sseDataPrefix.MatchString(line) {
+ if line == "" || !strings.HasPrefix(line, "data: ") {
continue
}
- jsonStr := sseDataPrefix.ReplaceAllString(line, "")
+ jsonStr := strings.TrimPrefix(line, "data: ")
if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
@@ -771,7 +617,13 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
}
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
if candidate, ok := candidates[0].(map[string]any); ok {
- // Extract content first (before checking finishReason)
+ // Check for completion
+ if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+ }
+
+ // Extract content
if content, ok := candidate["content"].(map[string]any); ok {
if parts, ok := content["parts"].([]any); ok {
for _, part := range parts {
@@ -783,12 +635,6 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
}
}
}
-
- // Check for completion after extracting content
- if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
- }
}
}
diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go
index 0a1227dd..2c27354a 100644
--- a/backend/internal/service/account_usage_service.go
+++ b/backend/internal/service/account_usage_service.go
@@ -78,7 +78,7 @@ type antigravityUsageCache struct {
}
const (
- apiCacheTTL = 10 * time.Minute
+ apiCacheTTL = 3 * time.Minute
windowStatsCacheTTL = 1 * time.Minute
)
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 9003f5a1..be07f37f 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -661,6 +661,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
Concurrency: input.Concurrency,
Priority: input.Priority,
Status: StatusActive,
+ Schedulable: true,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
diff --git a/backend/internal/service/gateway_prompt_test.go b/backend/internal/service/gateway_prompt_test.go
new file mode 100644
index 00000000..b056f8fa
--- /dev/null
+++ b/backend/internal/service/gateway_prompt_test.go
@@ -0,0 +1,233 @@
+package service
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestIsClaudeCodeClient(t *testing.T) {
+ tests := []struct {
+ name string
+ userAgent string
+ metadataUserID string
+ want bool
+ }{
+ {
+ name: "Claude Code client",
+ userAgent: "claude-cli/1.0.62 (darwin; arm64)",
+ metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
+ want: true,
+ },
+ {
+ name: "Claude Code without version suffix",
+ userAgent: "claude-cli/2.0.0",
+ metadataUserID: "session_abc",
+ want: true,
+ },
+ {
+ name: "Missing metadata user_id",
+ userAgent: "claude-cli/1.0.0",
+ metadataUserID: "",
+ want: false,
+ },
+ {
+ name: "Different user agent",
+ userAgent: "curl/7.68.0",
+ metadataUserID: "user123",
+ want: false,
+ },
+ {
+ name: "Empty user agent",
+ userAgent: "",
+ metadataUserID: "user123",
+ want: false,
+ },
+ {
+ name: "Similar but not Claude CLI",
+ userAgent: "claude-api/1.0.0",
+ metadataUserID: "user123",
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := isClaudeCodeClient(tt.userAgent, tt.metadataUserID)
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
+ tests := []struct {
+ name string
+ system any
+ want bool
+ }{
+ {
+ name: "nil system",
+ system: nil,
+ want: false,
+ },
+ {
+ name: "empty string",
+ system: "",
+ want: false,
+ },
+ {
+ name: "string with Claude Code prompt",
+ system: claudeCodeSystemPrompt,
+ want: true,
+ },
+ {
+ name: "string with different content",
+ system: "You are a helpful assistant.",
+ want: false,
+ },
+ {
+ name: "empty array",
+ system: []any{},
+ want: false,
+ },
+ {
+ name: "array with Claude Code prompt",
+ system: []any{
+ map[string]any{
+ "type": "text",
+ "text": claudeCodeSystemPrompt,
+ },
+ },
+ want: true,
+ },
+ {
+ name: "array with Claude Code prompt in second position",
+ system: []any{
+ map[string]any{"type": "text", "text": "First prompt"},
+ map[string]any{"type": "text", "text": claudeCodeSystemPrompt},
+ },
+ want: true,
+ },
+ {
+ name: "array without Claude Code prompt",
+ system: []any{
+ map[string]any{"type": "text", "text": "Custom prompt"},
+ },
+ want: false,
+ },
+ {
+ name: "array with partial match (should not match)",
+ system: []any{
+ map[string]any{"type": "text", "text": "You are Claude"},
+ },
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := systemIncludesClaudeCodePrompt(tt.system)
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestInjectClaudeCodePrompt(t *testing.T) {
+ tests := []struct {
+ name string
+ body string
+ system any
+ wantSystemLen int
+ wantFirstText string
+ wantSecondText string
+ }{
+ {
+ name: "nil system",
+ body: `{"model":"claude-3"}`,
+ system: nil,
+ wantSystemLen: 1,
+ wantFirstText: claudeCodeSystemPrompt,
+ },
+ {
+ name: "empty string system",
+ body: `{"model":"claude-3"}`,
+ system: "",
+ wantSystemLen: 1,
+ wantFirstText: claudeCodeSystemPrompt,
+ },
+ {
+ name: "string system",
+ body: `{"model":"claude-3"}`,
+ system: "Custom prompt",
+ wantSystemLen: 2,
+ wantFirstText: claudeCodeSystemPrompt,
+ wantSecondText: "Custom prompt",
+ },
+ {
+ name: "string system equals Claude Code prompt",
+ body: `{"model":"claude-3"}`,
+ system: claudeCodeSystemPrompt,
+ wantSystemLen: 1,
+ wantFirstText: claudeCodeSystemPrompt,
+ },
+ {
+ name: "array system",
+ body: `{"model":"claude-3"}`,
+ system: []any{map[string]any{"type": "text", "text": "Custom"}},
+ // Claude Code + Custom = 2
+ wantSystemLen: 2,
+ wantFirstText: claudeCodeSystemPrompt,
+ wantSecondText: "Custom",
+ },
+ {
+ name: "array system with existing Claude Code prompt (should dedupe)",
+ body: `{"model":"claude-3"}`,
+ system: []any{
+ map[string]any{"type": "text", "text": claudeCodeSystemPrompt},
+ map[string]any{"type": "text", "text": "Other"},
+ },
+ // Claude Code at start + Other = 2 (deduped)
+ wantSystemLen: 2,
+ wantFirstText: claudeCodeSystemPrompt,
+ wantSecondText: "Other",
+ },
+ {
+ name: "empty array",
+ body: `{"model":"claude-3"}`,
+ system: []any{},
+ wantSystemLen: 1,
+ wantFirstText: claudeCodeSystemPrompt,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := injectClaudeCodePrompt([]byte(tt.body), tt.system)
+
+ var parsed map[string]any
+ err := json.Unmarshal(result, &parsed)
+ require.NoError(t, err)
+
+ system, ok := parsed["system"].([]any)
+ require.True(t, ok, "system should be an array")
+ require.Len(t, system, tt.wantSystemLen)
+
+ first, ok := system[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, tt.wantFirstText, first["text"])
+ require.Equal(t, "text", first["type"])
+
+ // Check cache_control
+ cc, ok := first["cache_control"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "ephemeral", cc["type"])
+
+ if tt.wantSecondText != "" && len(system) > 1 {
+ second, ok := system[1].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, tt.wantSecondText, second["text"])
+ }
+ })
+ }
+}
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 4e4b180a..327f19f9 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -30,13 +30,15 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
+ claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
)
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var (
- sseDataRe = regexp.MustCompile(`^data:\s*`)
- sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
+ sseDataRe = regexp.MustCompile(`^data:\s*`)
+ sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
+ claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
)
// allowedHeaders 白名单headers(参考CRS项目)
@@ -951,6 +953,76 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
}
}
+// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
+// 简化判断:User-Agent 匹配 + metadata.user_id 存在
+func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
+ if metadataUserID == "" {
+ return false
+ }
+ return claudeCliUserAgentRe.MatchString(userAgent)
+}
+
+// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
+// 支持 string 和 []any 两种格式
+func systemIncludesClaudeCodePrompt(system any) bool {
+ switch v := system.(type) {
+ case string:
+ return v == claudeCodeSystemPrompt
+ case []any:
+ for _, item := range v {
+ if m, ok := item.(map[string]any); ok {
+ if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt {
+ return true
+ }
+ }
+ }
+ }
+ return false
+}
+
+// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
+// 处理 null、字符串、数组三种格式
+func injectClaudeCodePrompt(body []byte, system any) []byte {
+ claudeCodeBlock := map[string]any{
+ "type": "text",
+ "text": claudeCodeSystemPrompt,
+ "cache_control": map[string]string{"type": "ephemeral"},
+ }
+
+ var newSystem []any
+
+ switch v := system.(type) {
+ case nil:
+ newSystem = []any{claudeCodeBlock}
+ case string:
+ if v == "" || v == claudeCodeSystemPrompt {
+ newSystem = []any{claudeCodeBlock}
+ } else {
+ newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": v}}
+ }
+ case []any:
+ newSystem = make([]any, 0, len(v)+1)
+ newSystem = append(newSystem, claudeCodeBlock)
+ for _, item := range v {
+ if m, ok := item.(map[string]any); ok {
+ if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt {
+ continue
+ }
+ }
+ newSystem = append(newSystem, item)
+ }
+ default:
+ newSystem = []any{claudeCodeBlock}
+ }
+
+ result, err := sjson.SetBytes(body, "system", newSystem)
+ if err != nil {
+ log.Printf("Warning: failed to inject Claude Code prompt: %v", err)
+ return body
+ }
+ return result
+}
+
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
startTime := time.Now()
@@ -962,16 +1034,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
reqModel := parsed.Model
reqStream := parsed.Stream
- if !parsed.HasSystem {
- body, _ = sjson.SetBytes(body, "system", []any{
- map[string]any{
- "type": "text",
- "text": "You are Claude Code, Anthropic's official CLI for Claude.",
- "cache_control": map[string]string{
- "type": "ephemeral",
- },
- },
- })
+ // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
+ // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
+ if account.IsOAuth() &&
+ !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
+ !strings.Contains(strings.ToLower(reqModel), "haiku") &&
+ !systemIncludesClaudeCodePrompt(parsed.System) {
+ body = injectClaudeCodePrompt(body, parsed.System)
}
// 应用模型映射(仅对apikey类型账号)
diff --git a/frontend/src/components/common/DateRangePicker.vue b/frontend/src/components/common/DateRangePicker.vue
index be641f9b..4fce029f 100644
--- a/frontend/src/components/common/DateRangePicker.vue
+++ b/frontend/src/components/common/DateRangePicker.vue
@@ -59,7 +59,7 @@
@@ -85,7 +85,7 @@
type="date"
v-model="localEndDate"
:min="localStartDate"
- :max="today"
+ :max="tomorrow"
class="date-picker-input"
@change="onDateChange"
/>
@@ -144,6 +144,14 @@ const today = computed(() => {
return `${year}-${month}-${day}`
})
+// Tomorrow's date - used for max date to handle timezone differences
+// When user is in a timezone behind the server, "today" on server might be "tomorrow" locally
+const tomorrow = computed(() => {
+ const d = new Date()
+ d.setDate(d.getDate() + 1)
+ return formatDateToString(d)
+})
+
// Helper function to format date to YYYY-MM-DD using local timezone
const formatDateToString = (date: Date): string => {
const year = date.getFullYear()
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index d998e27b..04db3731 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -290,7 +290,7 @@ export interface UpdateGroupRequest {
export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity'
export type AccountType = 'oauth' | 'setup-token' | 'apikey'
export type OAuthAddMethod = 'oauth' | 'setup-token'
-export type ProxyProtocol = 'http' | 'https' | 'socks5'
+export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h'
// Claude Model type (returned by /v1/models and account models API)
export interface ClaudeModel {
diff --git a/frontend/src/views/admin/ProxiesView.vue b/frontend/src/views/admin/ProxiesView.vue
index a5df9bd0..613b503c 100644
--- a/frontend/src/views/admin/ProxiesView.vue
+++ b/frontend/src/views/admin/ProxiesView.vue
@@ -90,7 +90,7 @@
{{ value.toUpperCase() }}
@@ -628,7 +628,8 @@ const protocolOptions = computed(() => [
{ value: '', label: t('admin.proxies.allProtocols') },
{ value: 'http', label: 'HTTP' },
{ value: 'https', label: 'HTTPS' },
- { value: 'socks5', label: 'SOCKS5' }
+ { value: 'socks5', label: 'SOCKS5' },
+ { value: 'socks5h', label: 'SOCKS5H' }
])
const statusOptions = computed(() => [
@@ -641,7 +642,8 @@ const statusOptions = computed(() => [
const protocolSelectOptions = [
{ value: 'http', label: 'HTTP' },
{ value: 'https', label: 'HTTPS' },
- { value: 'socks5', label: 'SOCKS5' }
+ { value: 'socks5', label: 'SOCKS5' },
+ { value: 'socks5h', label: 'SOCKS5H (服务端解析DNS)' }
]
const editStatusOptions = computed(() => [
@@ -798,8 +800,8 @@ const parseProxyUrl = (
const trimmed = line.trim()
if (!trimmed) return null
- // Regex to parse proxy URL
- const regex = /^(https?|socks5):\/\/(?:([^:@]+):([^@]+)@)?([^:]+):(\d+)$/i
+ // Regex to parse proxy URL (supports http, https, socks5, socks5h)
+ const regex = /^(https?|socks5h?):\/\/(?:([^:@]+):([^@]+)@)?([^:]+):(\d+)$/i
const match = trimmed.match(regex)
if (!match) return null
diff --git a/frontend/src/views/admin/UsageView.vue b/frontend/src/views/admin/UsageView.vue
index 85a748f4..ac5d1e05 100644
--- a/frontend/src/views/admin/UsageView.vue
+++ b/frontend/src/views/admin/UsageView.vue
@@ -888,13 +888,17 @@ const formatLocalDate = (date: Date): string => {
}
// Initialize date range immediately
+// Use tomorrow as end date to handle timezone differences between client and server
+// e.g., when server is in Asia/Shanghai and client is in America/Chicago
const now = new Date()
+const tomorrow = new Date(now)
+tomorrow.setDate(tomorrow.getDate() + 1)
const weekAgo = new Date(now)
weekAgo.setDate(weekAgo.getDate() - 6)
// Date range state
const startDate = ref(formatLocalDate(weekAgo))
-const endDate = ref(formatLocalDate(now))
+const endDate = ref(formatLocalDate(tomorrow))
const filters = ref({
user_id: undefined,
@@ -1215,12 +1219,14 @@ const resetFilters = () => {
end_date: undefined
}
granularity.value = 'day'
- // Reset date range to default (last 7 days)
+ // Reset date range to default (last 7 days, with tomorrow as end to handle timezone differences)
const now = new Date()
+ const tomorrowDate = new Date(now)
+ tomorrowDate.setDate(tomorrowDate.getDate() + 1)
const weekAgo = new Date(now)
weekAgo.setDate(weekAgo.getDate() - 6)
startDate.value = formatLocalDate(weekAgo)
- endDate.value = formatLocalDate(now)
+ endDate.value = formatLocalDate(tomorrowDate)
filters.value.start_date = startDate.value
filters.value.end_date = endDate.value
pagination.value.page = 1