Merge branch 'main' of https://github.com/mt21625457/aicodex2api
This commit is contained in:
@@ -258,8 +258,43 @@ type GatewayConfig struct {
|
||||
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
|
||||
FailoverOn400 bool `mapstructure:"failover_on_400"`
|
||||
|
||||
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
|
||||
MaxAccountSwitches int `mapstructure:"max_account_switches"`
|
||||
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
|
||||
MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"`
|
||||
|
||||
// Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用
|
||||
AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"`
|
||||
|
||||
// Scheduling: 账号调度相关配置
|
||||
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
|
||||
|
||||
// TLSFingerprint: TLS指纹伪装配置
|
||||
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
|
||||
}
|
||||
|
||||
// TLSFingerprintConfig TLS指纹伪装配置
|
||||
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
|
||||
type TLSFingerprintConfig struct {
|
||||
// Enabled: 是否全局启用TLS指纹功能
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// Profiles: 预定义的TLS指纹配置模板
|
||||
// key 为模板名称,如 "claude_cli_v2", "chrome_120" 等
|
||||
Profiles map[string]TLSProfileConfig `mapstructure:"profiles"`
|
||||
}
|
||||
|
||||
// TLSProfileConfig 单个TLS指纹模板的配置
|
||||
type TLSProfileConfig struct {
|
||||
// Name: 模板显示名称
|
||||
Name string `mapstructure:"name"`
|
||||
// EnableGREASE: 是否启用GREASE扩展(Chrome使用,Node.js不使用)
|
||||
EnableGREASE bool `mapstructure:"enable_grease"`
|
||||
// CipherSuites: TLS加密套件列表(空则使用内置默认值)
|
||||
CipherSuites []uint16 `mapstructure:"cipher_suites"`
|
||||
// Curves: 椭圆曲线列表(空则使用内置默认值)
|
||||
Curves []uint16 `mapstructure:"curves"`
|
||||
// PointFormats: 点格式列表(空则使用内置默认值)
|
||||
PointFormats []uint8 `mapstructure:"point_formats"`
|
||||
}
|
||||
|
||||
// GatewaySchedulingConfig accounts scheduling configuration.
|
||||
@@ -272,6 +307,9 @@ type GatewaySchedulingConfig struct {
|
||||
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
|
||||
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
|
||||
|
||||
// 兜底层账户选择策略: "last_used"(按最后使用时间排序,默认) 或 "random"(随机)
|
||||
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
|
||||
|
||||
// 负载计算
|
||||
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
|
||||
|
||||
@@ -781,6 +819,9 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
|
||||
viper.SetDefault("gateway.inject_beta_for_apikey", false)
|
||||
viper.SetDefault("gateway.failover_on_400", false)
|
||||
viper.SetDefault("gateway.max_account_switches", 10)
|
||||
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
||||
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
||||
@@ -793,11 +834,12 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
|
||||
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
||||
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
|
||||
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
||||
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
|
||||
@@ -809,6 +851,8 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
|
||||
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
|
||||
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
|
||||
// TLS指纹伪装配置(默认关闭,需要账号级别单独启用)
|
||||
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
||||
viper.SetDefault("concurrency.ping_interval", 10)
|
||||
|
||||
// TokenRefresh
|
||||
|
||||
@@ -173,6 +173,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
windowCostAccountIDs := make([]int64, 0)
|
||||
sessionLimitAccountIDs := make([]int64, 0)
|
||||
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||
@@ -181,6 +182,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
if acc.GetMaxSessions() > 0 {
|
||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -189,9 +191,9 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
var windowCosts map[int64]float64
|
||||
var activeSessions map[int64]int
|
||||
|
||||
// 获取活跃会话数(批量查询)
|
||||
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
activeSessions = make(map[int64]int)
|
||||
}
|
||||
@@ -211,12 +213,8 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
accCopy := acc // 闭包捕获
|
||||
g.Go(func() error {
|
||||
var startTime time.Time
|
||||
if accCopy.SessionWindowStart != nil {
|
||||
startTime = *accCopy.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := accCopy.GetCurrentWindowStartTime()
|
||||
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
|
||||
if err == nil && stats != nil {
|
||||
mu.Lock()
|
||||
@@ -545,6 +543,36 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
// 先更新凭证
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||
return
|
||||
}
|
||||
// 标记账户为 error
|
||||
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil {
|
||||
response.InternalError(c, "Failed to set account error: "+setErr.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed but project_id is missing, account marked as error",
|
||||
"warning": "missing_project_id",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
||||
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
|
||||
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
|
||||
@@ -200,6 +200,10 @@ func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*se
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable}
|
||||
return &account, nil
|
||||
|
||||
@@ -161,6 +161,16 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
|
||||
out.SessionIdleTimeoutMin = &idleTimeout
|
||||
}
|
||||
// TLS指纹伪装开关
|
||||
if a.IsTLSFingerprintEnabled() {
|
||||
enabled := true
|
||||
out.EnableTLSFingerprint = &enabled
|
||||
}
|
||||
// 会话ID伪装开关
|
||||
if a.IsSessionIDMaskingEnabled() {
|
||||
enabled := true
|
||||
out.EnableSessionIDMasking = &enabled
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
@@ -112,6 +112,15 @@ type Account struct {
|
||||
MaxSessions *int `json:"max_sessions,omitempty"`
|
||||
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
|
||||
|
||||
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
|
||||
|
||||
// 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 启用后将在15分钟内固定 metadata.user_id 中的 session ID
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
|
||||
@@ -31,6 +31,8 @@ type GatewayHandler struct {
|
||||
userService *service.UserService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
maxAccountSwitchesGemini int
|
||||
}
|
||||
|
||||
// NewGatewayHandler creates a new GatewayHandler
|
||||
@@ -44,8 +46,16 @@ func NewGatewayHandler(
|
||||
cfg *config.Config,
|
||||
) *GatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 10
|
||||
maxAccountSwitchesGemini := 3
|
||||
if cfg != nil {
|
||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||
}
|
||||
if cfg.Gateway.MaxAccountSwitchesGemini > 0 {
|
||||
maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini
|
||||
}
|
||||
}
|
||||
return &GatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
@@ -54,6 +64,8 @@ func NewGatewayHandler(
|
||||
userService: userService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,7 +191,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
if platform == service.PlatformGemini {
|
||||
const maxAccountSwitches = 3
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
@@ -313,7 +325,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
const maxAccountSwitches = 10
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
@@ -220,7 +220,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if sessionHash != "" {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
const maxAccountSwitches = 3
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
@@ -25,6 +25,7 @@ type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
@@ -35,13 +36,18 @@ func NewOpenAIGatewayHandler(
|
||||
cfg *config.Config,
|
||||
) *OpenAIGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 3
|
||||
if cfg != nil {
|
||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||
}
|
||||
}
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,7 +195,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Generate session hash (header first; fallback to prompt_cache_key)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
|
||||
|
||||
const maxAccountSwitches = 3
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
@@ -16,15 +16,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// resolveHost 从 URL 解析 host
|
||||
func resolveHost(urlStr string) string {
|
||||
parsed, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return parsed.Host
|
||||
}
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||
@@ -39,23 +30,11 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 基础 Headers
|
||||
// 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
// Accept Header 根据请求类型设置
|
||||
if isStream {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
// 显式设置 Host Header
|
||||
if host := resolveHost(apiURL); host != "" {
|
||||
req.Host = host
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -195,12 +174,15 @@ func isConnectionError(err error) bool {
|
||||
}
|
||||
|
||||
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
|
||||
// 仅连接错误和 HTTP 429 触发 URL 降级
|
||||
// 与 Antigravity-Manager 保持一致:连接错误、429、408、404、5xx 触发 URL 降级
|
||||
func shouldFallbackToNextURL(err error, statusCode int) bool {
|
||||
if isConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
return statusCode == http.StatusTooManyRequests
|
||||
return statusCode == http.StatusTooManyRequests ||
|
||||
statusCode == http.StatusRequestTimeout ||
|
||||
statusCode == http.StatusNotFound ||
|
||||
statusCode >= 500
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
@@ -321,11 +303,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取可用的 URL 列表
|
||||
availableURLs := DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
// 固定顺序:prod -> daily
|
||||
availableURLs := BaseURLs
|
||||
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
@@ -343,7 +322,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
@@ -358,7 +336,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
@@ -376,6 +353,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
// 标记成功的 URL,下次优先使用
|
||||
DefaultURLAvailability.MarkSuccess(baseURL)
|
||||
return &loadResp, rawResp, nil
|
||||
}
|
||||
|
||||
@@ -412,11 +391,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取可用的 URL 列表
|
||||
availableURLs := DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
// 固定顺序:prod -> daily
|
||||
availableURLs := BaseURLs
|
||||
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
@@ -434,7 +410,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
@@ -449,7 +424,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
@@ -467,6 +441,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
// 标记成功的 URL,下次优先使用
|
||||
DefaultURLAvailability.MarkSuccess(baseURL)
|
||||
return &modelsResp, rawResp, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -143,9 +143,10 @@ type GeminiResponse struct {
|
||||
|
||||
// GeminiCandidate Gemini 候选响应
|
||||
type GeminiCandidate struct {
|
||||
Content *GeminiContent `json:"content,omitempty"`
|
||||
FinishReason string `json:"finishReason,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Content *GeminiContent `json:"content,omitempty"`
|
||||
FinishReason string `json:"finishReason,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiUsageMetadata Gemini 用量元数据
|
||||
@@ -156,6 +157,23 @@ type GeminiUsageMetadata struct {
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
|
||||
type GeminiGroundingMetadata struct {
|
||||
WebSearchQueries []string `json:"webSearchQueries,omitempty"`
|
||||
GroundingChunks []GeminiGroundingChunk `json:"groundingChunks,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiGroundingChunk Gemini grounding chunk
|
||||
type GeminiGroundingChunk struct {
|
||||
Web *GeminiGroundingWeb `json:"web,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiGroundingWeb Gemini grounding web 信息
|
||||
type GeminiGroundingWeb struct {
|
||||
Title string `json:"title,omitempty"`
|
||||
URI string `json:"uri,omitempty"`
|
||||
}
|
||||
|
||||
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
|
||||
var DefaultSafetySettings = []GeminiSafetySetting{
|
||||
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},
|
||||
|
||||
@@ -32,8 +32,8 @@ const (
|
||||
"https://www.googleapis.com/auth/cclog " +
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||
|
||||
// User-Agent(模拟官方客户端)
|
||||
UserAgent = "antigravity/1.104.0 darwin/arm64"
|
||||
// User-Agent(与 Antigravity-Manager 保持一致)
|
||||
UserAgent = "antigravity/1.11.9 windows/amd64"
|
||||
|
||||
// Session 过期时间
|
||||
SessionTTL = 30 * time.Minute
|
||||
@@ -42,22 +42,21 @@ const (
|
||||
URLAvailabilityTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// BaseURLs 定义 Antigravity API 端点,按优先级排序
|
||||
// fallback 顺序: sandbox → daily → prod
|
||||
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
|
||||
var BaseURLs = []string{
|
||||
"https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox
|
||||
"https://daily-cloudcode-pa.googleapis.com", // daily
|
||||
"https://cloudcode-pa.googleapis.com", // prod
|
||||
"https://cloudcode-pa.googleapis.com", // prod (优先)
|
||||
"https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用)
|
||||
}
|
||||
|
||||
// BaseURL 默认 URL(保持向后兼容)
|
||||
var BaseURL = BaseURLs[0]
|
||||
|
||||
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复)
|
||||
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级)
|
||||
type URLAvailability struct {
|
||||
mu sync.RWMutex
|
||||
unavailable map[string]time.Time // URL -> 恢复时间
|
||||
ttl time.Duration
|
||||
lastSuccess string // 最近成功请求的 URL,优先使用
|
||||
}
|
||||
|
||||
// DefaultURLAvailability 全局 URL 可用性管理器
|
||||
@@ -78,6 +77,15 @@ func (u *URLAvailability) MarkUnavailable(url string) {
|
||||
u.unavailable[url] = time.Now().Add(u.ttl)
|
||||
}
|
||||
|
||||
// MarkSuccess 标记 URL 请求成功,将其设为优先使用
|
||||
func (u *URLAvailability) MarkSuccess(url string) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
u.lastSuccess = url
|
||||
// 成功后清除该 URL 的不可用标记
|
||||
delete(u.unavailable, url)
|
||||
}
|
||||
|
||||
// IsAvailable 检查 URL 是否可用
|
||||
func (u *URLAvailability) IsAvailable(url string) bool {
|
||||
u.mu.RLock()
|
||||
@@ -89,14 +97,29 @@ func (u *URLAvailability) IsAvailable(url string) bool {
|
||||
return time.Now().After(expiry)
|
||||
}
|
||||
|
||||
// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序)
|
||||
// GetAvailableURLs 返回可用的 URL 列表
|
||||
// 最近成功的 URL 优先,其他按默认顺序
|
||||
func (u *URLAvailability) GetAvailableURLs() []string {
|
||||
u.mu.RLock()
|
||||
defer u.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
result := make([]string, 0, len(BaseURLs))
|
||||
|
||||
// 如果有最近成功的 URL 且可用,放在最前面
|
||||
if u.lastSuccess != "" {
|
||||
expiry, exists := u.unavailable[u.lastSuccess]
|
||||
if !exists || now.After(expiry) {
|
||||
result = append(result, u.lastSuccess)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加其他可用的 URL(按默认顺序)
|
||||
for _, url := range BaseURLs {
|
||||
// 跳过已添加的 lastSuccess
|
||||
if url == u.lastSuccess {
|
||||
continue
|
||||
}
|
||||
expiry, exists := u.unavailable[url]
|
||||
if !exists || now.After(expiry) {
|
||||
result = append(result, url)
|
||||
@@ -240,24 +263,3 @@ func BuildAuthorizationURL(state, codeChallenge string) string {
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// GenerateMockProjectID 生成随机 project_id(当 API 不返回时使用)
|
||||
// 格式:{形容词}-{名词}-{5位随机字符}
|
||||
func GenerateMockProjectID() string {
|
||||
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
||||
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
||||
|
||||
randBytes, _ := GenerateRandomBytes(7)
|
||||
|
||||
adj := adjectives[int(randBytes[0])%len(adjectives)]
|
||||
noun := nouns[int(randBytes[1])%len(nouns)]
|
||||
|
||||
// 生成 5 位随机字符(a-z0-9)
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
suffix := make([]byte, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
suffix[i] = charset[int(randBytes[i+2])%len(charset)]
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix))
|
||||
}
|
||||
|
||||
@@ -54,6 +54,9 @@ func DefaultTransformOptions() TransformOptions {
|
||||
}
|
||||
}
|
||||
|
||||
// webSearchFallbackModel web_search 请求使用的降级模型
|
||||
const webSearchFallbackModel = "gemini-2.5-flash"
|
||||
|
||||
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
|
||||
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
|
||||
return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions())
|
||||
@@ -64,12 +67,23 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
// 用于存储 tool_use id -> name 映射
|
||||
toolIDToName := make(map[string]string)
|
||||
|
||||
// 检测是否有 web_search 工具
|
||||
hasWebSearchTool := hasWebSearchTool(claudeReq.Tools)
|
||||
requestType := "agent"
|
||||
targetModel := mappedModel
|
||||
if hasWebSearchTool {
|
||||
requestType = "web_search"
|
||||
if targetModel != webSearchFallbackModel {
|
||||
targetModel = webSearchFallbackModel
|
||||
}
|
||||
}
|
||||
|
||||
// 检测是否启用 thinking
|
||||
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
||||
|
||||
// 只有 Gemini 模型支持 dummy thought workaround
|
||||
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
|
||||
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
|
||||
allowDummyThought := strings.HasPrefix(targetModel, "gemini-")
|
||||
|
||||
// 1. 构建 contents
|
||||
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
|
||||
@@ -78,7 +92,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
}
|
||||
|
||||
// 2. 构建 systemInstruction
|
||||
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts)
|
||||
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts, claudeReq.Tools)
|
||||
|
||||
// 3. 构建 generationConfig
|
||||
reqForConfig := claudeReq
|
||||
@@ -89,6 +103,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
reqCopy.Thinking = nil
|
||||
reqForConfig = &reqCopy
|
||||
}
|
||||
if targetModel != "" && targetModel != reqForConfig.Model {
|
||||
reqCopy := *reqForConfig
|
||||
reqCopy.Model = targetModel
|
||||
reqForConfig = &reqCopy
|
||||
}
|
||||
generationConfig := buildGenerationConfig(reqForConfig)
|
||||
|
||||
// 4. 构建 tools
|
||||
@@ -127,8 +146,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
Project: projectID,
|
||||
RequestID: "agent-" + uuid.New().String(),
|
||||
UserAgent: "antigravity", // 固定值,与官方客户端一致
|
||||
RequestType: "agent",
|
||||
Model: mappedModel,
|
||||
RequestType: requestType,
|
||||
Model: targetModel,
|
||||
Request: innerRequest,
|
||||
}
|
||||
|
||||
@@ -154,8 +173,40 @@ func GetDefaultIdentityPatch() string {
|
||||
return antigravityIdentity
|
||||
}
|
||||
|
||||
// buildSystemInstruction 构建 systemInstruction
|
||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent {
|
||||
// mcpXMLProtocol MCP XML 工具调用协议(与 Antigravity-Manager 保持一致)
|
||||
const mcpXMLProtocol = `
|
||||
==== MCP XML 工具调用协议 (Workaround) ====
|
||||
当你需要调用名称以 ` + "`mcp__`" + ` 开头的 MCP 工具时:
|
||||
1) 优先尝试 XML 格式调用:输出 ` + "`<mcp__tool_name>{\"arg\":\"value\"}</mcp__tool_name>`" + `。
|
||||
2) 必须直接输出 XML 块,无需 markdown 包装,内容为 JSON 格式的入参。
|
||||
3) 这种方式具有更高的连通性和容错性,适用于大型结果返回场景。
|
||||
===========================================`
|
||||
|
||||
// hasMCPTools 检测是否有 mcp__ 前缀的工具
|
||||
func hasMCPTools(tools []ClaudeTool) bool {
|
||||
for _, tool := range tools {
|
||||
if strings.HasPrefix(tool.Name, "mcp__") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// filterOpenCodePrompt 过滤 OpenCode 默认提示词,只保留用户自定义指令
|
||||
func filterOpenCodePrompt(text string) string {
|
||||
if !strings.Contains(text, "You are an interactive CLI tool") {
|
||||
return text
|
||||
}
|
||||
// 提取 "Instructions from:" 及之后的部分
|
||||
if idx := strings.Index(text, "Instructions from:"); idx >= 0 {
|
||||
return text[idx:]
|
||||
}
|
||||
// 如果没有自定义指令,返回空
|
||||
return ""
|
||||
}
|
||||
|
||||
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
||||
var parts []GeminiPart
|
||||
|
||||
// 先解析用户的 system prompt,检测是否已包含 Antigravity identity
|
||||
@@ -167,10 +218,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
||||
var sysStr string
|
||||
if err := json.Unmarshal(system, &sysStr); err == nil {
|
||||
if strings.TrimSpace(sysStr) != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: sysStr})
|
||||
if strings.Contains(sysStr, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
// 过滤 OpenCode 默认提示词
|
||||
filtered := filterOpenCodePrompt(sysStr)
|
||||
if filtered != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 尝试解析为数组
|
||||
@@ -178,10 +233,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
||||
if err := json.Unmarshal(system, &sysBlocks); err == nil {
|
||||
for _, block := range sysBlocks {
|
||||
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: block.Text})
|
||||
if strings.Contains(block.Text, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
// 过滤 OpenCode 默认提示词
|
||||
filtered := filterOpenCodePrompt(block.Text)
|
||||
if filtered != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -200,6 +259,16 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
||||
// 添加用户的 system prompt
|
||||
parts = append(parts, userSystemParts...)
|
||||
|
||||
// 检测是否有 MCP 工具,如有则注入 XML 调用协议
|
||||
if hasMCPTools(tools) {
|
||||
parts = append(parts, GeminiPart{Text: mcpXMLProtocol})
|
||||
}
|
||||
|
||||
// 如果用户没有提供 Antigravity 身份,添加结束标记
|
||||
if !userHasAntigravityIdentity {
|
||||
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -429,6 +498,11 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
||||
StopSequences: DefaultStopSequences,
|
||||
}
|
||||
|
||||
// 如果请求中指定了 MaxTokens,使用请求值
|
||||
if req.MaxTokens > 0 {
|
||||
config.MaxOutputTokens = req.MaxTokens
|
||||
}
|
||||
|
||||
// Thinking 配置
|
||||
if req.Thinking != nil && req.Thinking.Type == "enabled" {
|
||||
config.ThinkingConfig = &GeminiThinkingConfig{
|
||||
@@ -458,37 +532,43 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
||||
return config
|
||||
}
|
||||
|
||||
func hasWebSearchTool(tools []ClaudeTool) bool {
|
||||
for _, tool := range tools {
|
||||
if isWebSearchTool(tool) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isWebSearchTool(tool ClaudeTool) bool {
|
||||
if strings.HasPrefix(tool.Type, "web_search") || tool.Type == "google_search" {
|
||||
return true
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(tool.Name)
|
||||
switch name {
|
||||
case "web_search", "google_search", "web_search_20250305":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// buildTools 构建 tools
|
||||
func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否有 web_search 工具
|
||||
hasWebSearch := false
|
||||
for _, tool := range tools {
|
||||
if tool.Name == "web_search" {
|
||||
hasWebSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasWebSearch {
|
||||
// Web Search 工具映射
|
||||
return []GeminiToolDeclaration{{
|
||||
GoogleSearch: &GeminiGoogleSearch{
|
||||
EnhancedContent: &GeminiEnhancedContent{
|
||||
ImageSearch: &GeminiImageSearch{
|
||||
MaxResultCount: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
}
|
||||
hasWebSearch := hasWebSearchTool(tools)
|
||||
|
||||
// 普通工具
|
||||
var funcDecls []GeminiFunctionDecl
|
||||
for _, tool := range tools {
|
||||
if isWebSearchTool(tool) {
|
||||
continue
|
||||
}
|
||||
// 跳过无效工具名称
|
||||
if strings.TrimSpace(tool.Name) == "" {
|
||||
log.Printf("Warning: skipping tool with empty name")
|
||||
@@ -531,7 +611,20 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
}
|
||||
|
||||
if len(funcDecls) == 0 {
|
||||
return nil
|
||||
if !hasWebSearch {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Web Search 工具映射
|
||||
return []GeminiToolDeclaration{{
|
||||
GoogleSearch: &GeminiGoogleSearch{
|
||||
EnhancedContent: &GeminiEnhancedContent{
|
||||
ImageSearch: &GeminiImageSearch{
|
||||
MaxResultCount: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
return []GeminiToolDeclaration{{
|
||||
|
||||
@@ -3,6 +3,7 @@ package antigravity
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
|
||||
@@ -63,6 +64,12 @@ func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID,
|
||||
p.processPart(&part)
|
||||
}
|
||||
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
if grounding := geminiResp.Candidates[0].GroundingMetadata; grounding != nil {
|
||||
p.processGrounding(grounding)
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新剩余内容
|
||||
p.flushThinking()
|
||||
p.flushText()
|
||||
@@ -190,6 +197,18 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *NonStreamingProcessor) processGrounding(grounding *GeminiGroundingMetadata) {
|
||||
groundingText := buildGroundingText(grounding)
|
||||
if groundingText == "" {
|
||||
return
|
||||
}
|
||||
|
||||
p.flushThinking()
|
||||
p.flushText()
|
||||
p.textBuilder += groundingText
|
||||
p.flushText()
|
||||
}
|
||||
|
||||
// flushText 刷新 text builder
|
||||
func (p *NonStreamingProcessor) flushText() {
|
||||
if p.textBuilder == "" {
|
||||
@@ -262,6 +281,44 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
|
||||
}
|
||||
}
|
||||
|
||||
func buildGroundingText(grounding *GeminiGroundingMetadata) string {
|
||||
if grounding == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
|
||||
if len(grounding.WebSearchQueries) > 0 {
|
||||
_, _ = builder.WriteString("\n\n---\nWeb search queries: ")
|
||||
_, _ = builder.WriteString(strings.Join(grounding.WebSearchQueries, ", "))
|
||||
}
|
||||
|
||||
if len(grounding.GroundingChunks) > 0 {
|
||||
var links []string
|
||||
for i, chunk := range grounding.GroundingChunks {
|
||||
if chunk.Web == nil {
|
||||
continue
|
||||
}
|
||||
title := strings.TrimSpace(chunk.Web.Title)
|
||||
if title == "" {
|
||||
title = "Source"
|
||||
}
|
||||
uri := strings.TrimSpace(chunk.Web.URI)
|
||||
if uri == "" {
|
||||
uri = "#"
|
||||
}
|
||||
links = append(links, fmt.Sprintf("[%d] [%s](%s)", i+1, title, uri))
|
||||
}
|
||||
|
||||
if len(links) > 0 {
|
||||
_, _ = builder.WriteString("\n\nSources:\n")
|
||||
_, _ = builder.WriteString(strings.Join(links, "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// generateRandomID 生成随机 ID
|
||||
func generateRandomID() string {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
@@ -27,6 +27,8 @@ type StreamingProcessor struct {
|
||||
pendingSignature string
|
||||
trailingSignature string
|
||||
originalModel string
|
||||
webSearchQueries []string
|
||||
groundingChunks []GeminiGroundingChunk
|
||||
|
||||
// 累计 usage
|
||||
inputTokens int
|
||||
@@ -93,6 +95,10 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
}
|
||||
}
|
||||
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
p.captureGrounding(geminiResp.Candidates[0].GroundingMetadata)
|
||||
}
|
||||
|
||||
// 检查是否结束
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
finishReason := geminiResp.Candidates[0].FinishReason
|
||||
@@ -200,6 +206,20 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
func (p *StreamingProcessor) captureGrounding(grounding *GeminiGroundingMetadata) {
|
||||
if grounding == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(grounding.WebSearchQueries) > 0 && len(p.webSearchQueries) == 0 {
|
||||
p.webSearchQueries = append([]string(nil), grounding.WebSearchQueries...)
|
||||
}
|
||||
|
||||
if len(grounding.GroundingChunks) > 0 && len(p.groundingChunks) == 0 {
|
||||
p.groundingChunks = append([]GeminiGroundingChunk(nil), grounding.GroundingChunks...)
|
||||
}
|
||||
}
|
||||
|
||||
// processThinking 处理 thinking
|
||||
func (p *StreamingProcessor) processThinking(text, signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
@@ -417,6 +437,23 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
if len(p.webSearchQueries) > 0 || len(p.groundingChunks) > 0 {
|
||||
groundingText := buildGroundingText(&GeminiGroundingMetadata{
|
||||
WebSearchQueries: p.webSearchQueries,
|
||||
GroundingChunks: p.groundingChunks,
|
||||
})
|
||||
if groundingText != "" {
|
||||
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
}))
|
||||
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
|
||||
"text": groundingText,
|
||||
}))
|
||||
_, _ = result.Write(p.endBlock())
|
||||
}
|
||||
}
|
||||
|
||||
// 确定 stop_reason
|
||||
stopReason := "end_turn"
|
||||
if p.usedTool {
|
||||
|
||||
@@ -162,11 +162,11 @@ func ParsePagination(c *gin.Context) (page, pageSize int) {
|
||||
|
||||
// 支持 page_size 和 limit 两种参数名
|
||||
if ps := c.Query("page_size"); ps != "" {
|
||||
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
|
||||
if val, err := parseInt(ps); err == nil && val > 0 && val <= 1000 {
|
||||
pageSize = val
|
||||
}
|
||||
} else if l := c.Query("limit"); l != "" {
|
||||
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
|
||||
if val, err := parseInt(l); err == nil && val > 0 && val <= 1000 {
|
||||
pageSize = val
|
||||
}
|
||||
}
|
||||
|
||||
568
backend/internal/pkg/tlsfingerprint/dialer.go
Normal file
568
backend/internal/pkg/tlsfingerprint/dialer.go
Normal file
@@ -0,0 +1,568 @@
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
// It uses the utls library to create TLS connections that mimic Node.js/Claude Code clients.
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
utls "github.com/refraction-networking/utls"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// Profile contains TLS fingerprint configuration.
|
||||
type Profile struct {
|
||||
Name string // Profile name for identification
|
||||
CipherSuites []uint16
|
||||
Curves []uint16
|
||||
PointFormats []uint8
|
||||
EnableGREASE bool
|
||||
}
|
||||
|
||||
// Dialer creates TLS connections with custom fingerprints.
|
||||
type Dialer struct {
|
||||
profile *Profile
|
||||
baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// HTTPProxyDialer creates TLS connections through HTTP/HTTPS proxies with custom fingerprints.
|
||||
// It handles the CONNECT tunnel establishment before performing TLS handshake.
|
||||
type HTTPProxyDialer struct {
|
||||
profile *Profile
|
||||
proxyURL *url.URL
|
||||
}
|
||||
|
||||
// SOCKS5ProxyDialer creates TLS connections through SOCKS5 proxies with custom fingerprints.
|
||||
// It uses golang.org/x/net/proxy to establish the SOCKS5 tunnel.
|
||||
type SOCKS5ProxyDialer struct {
|
||||
profile *Profile
|
||||
proxyURL *url.URL
|
||||
}
|
||||
|
||||
// Default TLS fingerprint values captured from Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)
|
||||
// Captured using: tshark -i lo -f "tcp port 8443" -Y "tls.handshake.type == 1" -V
|
||||
// JA3 Hash: 1a28e69016765d92e3b381168d68922c
|
||||
//
|
||||
// Note: JA3/JA4 may have slight variations due to:
|
||||
// - Session ticket presence/absence
|
||||
// - Extension negotiation state
|
||||
var (
|
||||
// defaultCipherSuites contains all 59 cipher suites from Claude CLI
|
||||
// Order is critical for JA3 fingerprint matching
|
||||
defaultCipherSuites = []uint16{
|
||||
// TLS 1.3 cipher suites (MUST be first)
|
||||
0x1302, // TLS_AES_256_GCM_SHA384
|
||||
0x1303, // TLS_CHACHA20_POLY1305_SHA256
|
||||
0x1301, // TLS_AES_128_GCM_SHA256
|
||||
|
||||
// ECDHE + AES-GCM
|
||||
0xc02f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
|
||||
0xc02b, // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
|
||||
0xc030, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
|
||||
0xc02c, // TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
|
||||
|
||||
// DHE + AES-GCM
|
||||
0x009e, // TLS_DHE_RSA_WITH_AES_128_GCM_SHA256
|
||||
|
||||
// ECDHE/DHE + AES-CBC-SHA256/384
|
||||
0xc027, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256
|
||||
0x0067, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA256
|
||||
0xc028, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384
|
||||
0x006b, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA256
|
||||
|
||||
// DHE-DSS/RSA + AES-GCM
|
||||
0x00a3, // TLS_DHE_DSS_WITH_AES_256_GCM_SHA384
|
||||
0x009f, // TLS_DHE_RSA_WITH_AES_256_GCM_SHA384
|
||||
|
||||
// ChaCha20-Poly1305
|
||||
0xcca9, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
|
||||
0xcca8, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
|
||||
0xccaa, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256
|
||||
|
||||
// AES-CCM (256-bit)
|
||||
0xc0af, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8
|
||||
0xc0ad, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM
|
||||
0xc0a3, // TLS_DHE_RSA_WITH_AES_256_CCM_8
|
||||
0xc09f, // TLS_DHE_RSA_WITH_AES_256_CCM
|
||||
|
||||
// ARIA (256-bit)
|
||||
0xc05d, // TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384
|
||||
0xc061, // TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384
|
||||
0xc057, // TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384
|
||||
0xc053, // TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384
|
||||
|
||||
// DHE-DSS + AES-GCM (128-bit)
|
||||
0x00a2, // TLS_DHE_DSS_WITH_AES_128_GCM_SHA256
|
||||
|
||||
// AES-CCM (128-bit)
|
||||
0xc0ae, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8
|
||||
0xc0ac, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM
|
||||
0xc0a2, // TLS_DHE_RSA_WITH_AES_128_CCM_8
|
||||
0xc09e, // TLS_DHE_RSA_WITH_AES_128_CCM
|
||||
|
||||
// ARIA (128-bit)
|
||||
0xc05c, // TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256
|
||||
0xc060, // TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256
|
||||
0xc056, // TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256
|
||||
0xc052, // TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256
|
||||
|
||||
// ECDHE/DHE + AES-CBC-SHA384/256 (more)
|
||||
0xc024, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384
|
||||
0x006a, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA256
|
||||
0xc023, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256
|
||||
0x0040, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA256
|
||||
|
||||
// ECDHE/DHE + AES-CBC-SHA (legacy)
|
||||
0xc00a, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA
|
||||
0xc014, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA
|
||||
0x0039, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA
|
||||
0x0038, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA
|
||||
0xc009, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA
|
||||
0xc013, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA
|
||||
0x0033, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA
|
||||
0x0032, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA
|
||||
|
||||
// RSA + AES-GCM/CCM/ARIA (non-PFS, 256-bit)
|
||||
0x009d, // TLS_RSA_WITH_AES_256_GCM_SHA384
|
||||
0xc0a1, // TLS_RSA_WITH_AES_256_CCM_8
|
||||
0xc09d, // TLS_RSA_WITH_AES_256_CCM
|
||||
0xc051, // TLS_RSA_WITH_ARIA_256_GCM_SHA384
|
||||
|
||||
// RSA + AES-GCM/CCM/ARIA (non-PFS, 128-bit)
|
||||
0x009c, // TLS_RSA_WITH_AES_128_GCM_SHA256
|
||||
0xc0a0, // TLS_RSA_WITH_AES_128_CCM_8
|
||||
0xc09c, // TLS_RSA_WITH_AES_128_CCM
|
||||
0xc050, // TLS_RSA_WITH_ARIA_128_GCM_SHA256
|
||||
|
||||
// RSA + AES-CBC (non-PFS, legacy)
|
||||
0x003d, // TLS_RSA_WITH_AES_256_CBC_SHA256
|
||||
0x003c, // TLS_RSA_WITH_AES_128_CBC_SHA256
|
||||
0x0035, // TLS_RSA_WITH_AES_256_CBC_SHA
|
||||
0x002f, // TLS_RSA_WITH_AES_128_CBC_SHA
|
||||
|
||||
// Renegotiation indication
|
||||
0x00ff, // TLS_EMPTY_RENEGOTIATION_INFO_SCSV
|
||||
}
|
||||
|
||||
// defaultCurves contains the 10 supported groups from Claude CLI (including FFDHE)
|
||||
defaultCurves = []utls.CurveID{
|
||||
utls.X25519, // 0x001d
|
||||
utls.CurveP256, // 0x0017 (secp256r1)
|
||||
utls.CurveID(0x001e), // x448
|
||||
utls.CurveP521, // 0x0019 (secp521r1)
|
||||
utls.CurveP384, // 0x0018 (secp384r1)
|
||||
utls.CurveID(0x0100), // ffdhe2048
|
||||
utls.CurveID(0x0101), // ffdhe3072
|
||||
utls.CurveID(0x0102), // ffdhe4096
|
||||
utls.CurveID(0x0103), // ffdhe6144
|
||||
utls.CurveID(0x0104), // ffdhe8192
|
||||
}
|
||||
|
||||
// defaultPointFormats contains all 3 point formats from Claude CLI
|
||||
defaultPointFormats = []uint8{
|
||||
0, // uncompressed
|
||||
1, // ansiX962_compressed_prime
|
||||
2, // ansiX962_compressed_char2
|
||||
}
|
||||
|
||||
// defaultSignatureAlgorithms contains the 20 signature algorithms from Claude CLI
|
||||
defaultSignatureAlgorithms = []utls.SignatureScheme{
|
||||
0x0403, // ecdsa_secp256r1_sha256
|
||||
0x0503, // ecdsa_secp384r1_sha384
|
||||
0x0603, // ecdsa_secp521r1_sha512
|
||||
0x0807, // ed25519
|
||||
0x0808, // ed448
|
||||
0x0809, // rsa_pss_pss_sha256
|
||||
0x080a, // rsa_pss_pss_sha384
|
||||
0x080b, // rsa_pss_pss_sha512
|
||||
0x0804, // rsa_pss_rsae_sha256
|
||||
0x0805, // rsa_pss_rsae_sha384
|
||||
0x0806, // rsa_pss_rsae_sha512
|
||||
0x0401, // rsa_pkcs1_sha256
|
||||
0x0501, // rsa_pkcs1_sha384
|
||||
0x0601, // rsa_pkcs1_sha512
|
||||
0x0303, // ecdsa_sha224
|
||||
0x0301, // rsa_pkcs1_sha224
|
||||
0x0302, // dsa_sha224
|
||||
0x0402, // dsa_sha256
|
||||
0x0502, // dsa_sha384
|
||||
0x0602, // dsa_sha512
|
||||
}
|
||||
)
|
||||
|
||||
// NewDialer creates a new TLS fingerprint dialer.
|
||||
// baseDialer is used for TCP connection establishment (supports proxy scenarios).
|
||||
// If baseDialer is nil, direct TCP dial is used.
|
||||
func NewDialer(profile *Profile, baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *Dialer {
|
||||
if baseDialer == nil {
|
||||
baseDialer = (&net.Dialer{}).DialContext
|
||||
}
|
||||
return &Dialer{profile: profile, baseDialer: baseDialer}
|
||||
}
|
||||
|
||||
// NewHTTPProxyDialer creates a new TLS fingerprint dialer that works through HTTP/HTTPS proxies.
|
||||
// It establishes a CONNECT tunnel before performing TLS handshake with custom fingerprint.
|
||||
func NewHTTPProxyDialer(profile *Profile, proxyURL *url.URL) *HTTPProxyDialer {
|
||||
return &HTTPProxyDialer{profile: profile, proxyURL: proxyURL}
|
||||
}
|
||||
|
||||
// NewSOCKS5ProxyDialer creates a new TLS fingerprint dialer that works through SOCKS5 proxies.
|
||||
// It establishes a SOCKS5 tunnel before performing TLS handshake with custom fingerprint.
|
||||
func NewSOCKS5ProxyDialer(profile *Profile, proxyURL *url.URL) *SOCKS5ProxyDialer {
|
||||
return &SOCKS5ProxyDialer{profile: profile, proxyURL: proxyURL}
|
||||
}
|
||||
|
||||
// DialTLSContext establishes a TLS connection through SOCKS5 proxy with the configured fingerprint.
|
||||
// Flow: SOCKS5 CONNECT to target -> TLS handshake with utls on the tunnel
|
||||
func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
slog.Debug("tls_fingerprint_socks5_connecting", "proxy", d.proxyURL.Host, "target", addr)
|
||||
|
||||
// Step 1: Create SOCKS5 dialer
|
||||
var auth *proxy.Auth
|
||||
if d.proxyURL.User != nil {
|
||||
username := d.proxyURL.User.Username()
|
||||
password, _ := d.proxyURL.User.Password()
|
||||
auth = &proxy.Auth{
|
||||
User: username,
|
||||
Password: password,
|
||||
}
|
||||
}
|
||||
|
||||
// Determine proxy address
|
||||
proxyAddr := d.proxyURL.Host
|
||||
if d.proxyURL.Port() == "" {
|
||||
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "1080") // Default SOCKS5 port
|
||||
}
|
||||
|
||||
socksDialer, err := proxy.SOCKS5("tcp", proxyAddr, auth, proxy.Direct)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_dialer_failed", "error", err)
|
||||
return nil, fmt.Errorf("create SOCKS5 dialer: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Establish SOCKS5 tunnel to target
|
||||
slog.Debug("tls_fingerprint_socks5_establishing_tunnel", "target", addr)
|
||||
conn, err := socksDialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_connect_failed", "error", err)
|
||||
return nil, fmt.Errorf("SOCKS5 connect: %w", err)
|
||||
}
|
||||
slog.Debug("tls_fingerprint_socks5_tunnel_established")
|
||||
|
||||
// Step 3: Perform TLS handshake on the tunnel with utls fingerprint
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
slog.Debug("tls_fingerprint_socks5_starting_handshake", "host", host)
|
||||
|
||||
// Build ClientHello specification from profile (Node.js/Claude CLI fingerprint)
|
||||
spec := buildClientHelloSpecFromProfile(d.profile)
|
||||
slog.Debug("tls_fingerprint_socks5_clienthello_spec",
|
||||
"cipher_suites", len(spec.CipherSuites),
|
||||
"extensions", len(spec.Extensions),
|
||||
"compression_methods", spec.CompressionMethods,
|
||||
"tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax),
|
||||
"tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin))
|
||||
|
||||
if d.profile != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
|
||||
}
|
||||
|
||||
// Create uTLS connection on the tunnel
|
||||
tlsConn := utls.UClient(conn, &utls.Config{
|
||||
ServerName: host,
|
||||
}, utls.HelloCustom)
|
||||
|
||||
if err := tlsConn.ApplyPreset(spec); err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_apply_preset_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("apply TLS preset: %w", err)
|
||||
}
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
}
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_socks5_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// DialTLSContext establishes a TLS connection through HTTP proxy with the configured fingerprint.
|
||||
// Flow: TCP connect to proxy -> CONNECT tunnel -> TLS handshake with utls
|
||||
func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
slog.Debug("tls_fingerprint_http_proxy_connecting", "proxy", d.proxyURL.Host, "target", addr)
|
||||
|
||||
// Step 1: TCP connect to proxy server
|
||||
var proxyAddr string
|
||||
if d.proxyURL.Port() != "" {
|
||||
proxyAddr = d.proxyURL.Host
|
||||
} else {
|
||||
// Default ports
|
||||
if d.proxyURL.Scheme == "https" {
|
||||
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "443")
|
||||
} else {
|
||||
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "80")
|
||||
}
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", proxyAddr)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_http_proxy_connect_failed", "error", err)
|
||||
return nil, fmt.Errorf("connect to proxy: %w", err)
|
||||
}
|
||||
slog.Debug("tls_fingerprint_http_proxy_connected", "proxy_addr", proxyAddr)
|
||||
|
||||
// Step 2: Send CONNECT request to establish tunnel
|
||||
req := &http.Request{
|
||||
Method: "CONNECT",
|
||||
URL: &url.URL{Opaque: addr},
|
||||
Host: addr,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
|
||||
// Add proxy authentication if present
|
||||
if d.proxyURL.User != nil {
|
||||
username := d.proxyURL.User.Username()
|
||||
password, _ := d.proxyURL.User.Password()
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
|
||||
req.Header.Set("Proxy-Authorization", "Basic "+auth)
|
||||
}
|
||||
|
||||
slog.Debug("tls_fingerprint_http_proxy_sending_connect", "target", addr)
|
||||
if err := req.Write(conn); err != nil {
|
||||
_ = conn.Close()
|
||||
slog.Debug("tls_fingerprint_http_proxy_write_failed", "error", err)
|
||||
return nil, fmt.Errorf("write CONNECT request: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Read CONNECT response
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(br, req)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
slog.Debug("tls_fingerprint_http_proxy_read_response_failed", "error", err)
|
||||
return nil, fmt.Errorf("read CONNECT response: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_ = conn.Close()
|
||||
slog.Debug("tls_fingerprint_http_proxy_connect_failed_status", "status_code", resp.StatusCode, "status", resp.Status)
|
||||
return nil, fmt.Errorf("proxy CONNECT failed: %s", resp.Status)
|
||||
}
|
||||
slog.Debug("tls_fingerprint_http_proxy_tunnel_established")
|
||||
|
||||
// Step 4: Perform TLS handshake on the tunnel with utls fingerprint
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
slog.Debug("tls_fingerprint_http_proxy_starting_handshake", "host", host)
|
||||
|
||||
// Build ClientHello specification (reuse the shared method)
|
||||
spec := buildClientHelloSpecFromProfile(d.profile)
|
||||
slog.Debug("tls_fingerprint_http_proxy_clienthello_spec",
|
||||
"cipher_suites", len(spec.CipherSuites),
|
||||
"extensions", len(spec.Extensions))
|
||||
|
||||
if d.profile != nil {
|
||||
slog.Debug("tls_fingerprint_http_proxy_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
|
||||
}
|
||||
|
||||
// Create uTLS connection on the tunnel
|
||||
// Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
|
||||
tlsConn := utls.UClient(conn, &utls.Config{
|
||||
ServerName: host,
|
||||
}, utls.HelloCustom)
|
||||
|
||||
if err := tlsConn.ApplyPreset(spec); err != nil {
|
||||
slog.Debug("tls_fingerprint_http_proxy_apply_preset_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("apply TLS preset: %w", err)
|
||||
}
|
||||
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
slog.Debug("tls_fingerprint_http_proxy_handshake_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
}
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_http_proxy_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// DialTLSContext establishes a TLS connection with the configured fingerprint.
|
||||
// This method is designed to be used as http.Transport.DialTLSContext.
|
||||
func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// Establish TCP connection using base dialer (supports proxy)
|
||||
slog.Debug("tls_fingerprint_dialing_tcp", "addr", addr)
|
||||
conn, err := d.baseDialer(ctx, network, addr)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_tcp_dial_failed", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
slog.Debug("tls_fingerprint_tcp_connected", "addr", addr)
|
||||
|
||||
// Extract hostname for SNI
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
slog.Debug("tls_fingerprint_sni_hostname", "host", host)
|
||||
|
||||
// Build ClientHello specification
|
||||
spec := d.buildClientHelloSpec()
|
||||
slog.Debug("tls_fingerprint_clienthello_spec",
|
||||
"cipher_suites", len(spec.CipherSuites),
|
||||
"extensions", len(spec.Extensions))
|
||||
|
||||
// Log profile info
|
||||
if d.profile != nil {
|
||||
slog.Debug("tls_fingerprint_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
|
||||
} else {
|
||||
slog.Debug("tls_fingerprint_using_default_profile")
|
||||
}
|
||||
|
||||
// Create uTLS connection
|
||||
// Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
|
||||
tlsConn := utls.UClient(conn, &utls.Config{
|
||||
ServerName: host,
|
||||
}, utls.HelloCustom)
|
||||
|
||||
// Apply fingerprint
|
||||
if err := tlsConn.ApplyPreset(spec); err != nil {
|
||||
slog.Debug("tls_fingerprint_apply_preset_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
slog.Debug("tls_fingerprint_preset_applied")
|
||||
|
||||
// Perform TLS handshake
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
slog.Debug("tls_fingerprint_handshake_failed",
|
||||
"error", err,
|
||||
"local_addr", conn.LocalAddr(),
|
||||
"remote_addr", conn.RemoteAddr())
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
}
|
||||
|
||||
// Log successful handshake details
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// buildClientHelloSpec constructs the ClientHello specification based on the profile.
|
||||
func (d *Dialer) buildClientHelloSpec() *utls.ClientHelloSpec {
|
||||
return buildClientHelloSpecFromProfile(d.profile)
|
||||
}
|
||||
|
||||
// toUTLSCurves converts uint16 slice to utls.CurveID slice.
|
||||
func toUTLSCurves(curves []uint16) []utls.CurveID {
|
||||
result := make([]utls.CurveID, len(curves))
|
||||
for i, c := range curves {
|
||||
result[i] = utls.CurveID(c)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// buildClientHelloSpecFromProfile constructs ClientHelloSpec from a Profile.
|
||||
// This is a standalone function that can be used by both Dialer and HTTPProxyDialer.
|
||||
func buildClientHelloSpecFromProfile(profile *Profile) *utls.ClientHelloSpec {
|
||||
// Get cipher suites
|
||||
var cipherSuites []uint16
|
||||
if profile != nil && len(profile.CipherSuites) > 0 {
|
||||
cipherSuites = profile.CipherSuites
|
||||
} else {
|
||||
cipherSuites = defaultCipherSuites
|
||||
}
|
||||
|
||||
// Get curves
|
||||
var curves []utls.CurveID
|
||||
if profile != nil && len(profile.Curves) > 0 {
|
||||
curves = toUTLSCurves(profile.Curves)
|
||||
} else {
|
||||
curves = defaultCurves
|
||||
}
|
||||
|
||||
// Get point formats
|
||||
var pointFormats []uint8
|
||||
if profile != nil && len(profile.PointFormats) > 0 {
|
||||
pointFormats = profile.PointFormats
|
||||
} else {
|
||||
pointFormats = defaultPointFormats
|
||||
}
|
||||
|
||||
// Check if GREASE is enabled
|
||||
enableGREASE := profile != nil && profile.EnableGREASE
|
||||
|
||||
extensions := make([]utls.TLSExtension, 0, 16)
|
||||
|
||||
if enableGREASE {
|
||||
extensions = append(extensions, &utls.UtlsGREASEExtension{})
|
||||
}
|
||||
|
||||
// SNI extension - MUST be explicitly added for HelloCustom mode
|
||||
// utls will populate the server name from Config.ServerName
|
||||
extensions = append(extensions, &utls.SNIExtension{})
|
||||
|
||||
// Claude CLI extension order (captured from tshark):
|
||||
// server_name(0), ec_point_formats(11), supported_groups(10), session_ticket(35),
|
||||
// alpn(16), encrypt_then_mac(22), extended_master_secret(23),
|
||||
// signature_algorithms(13), supported_versions(43),
|
||||
// psk_key_exchange_modes(45), key_share(51)
|
||||
extensions = append(extensions,
|
||||
&utls.SupportedPointsExtension{SupportedPoints: pointFormats},
|
||||
&utls.SupportedCurvesExtension{Curves: curves},
|
||||
&utls.SessionTicketExtension{},
|
||||
&utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}},
|
||||
&utls.GenericExtension{Id: 22},
|
||||
&utls.ExtendedMasterSecretExtension{},
|
||||
&utls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: defaultSignatureAlgorithms},
|
||||
&utls.SupportedVersionsExtension{Versions: []uint16{
|
||||
utls.VersionTLS13,
|
||||
utls.VersionTLS12,
|
||||
}},
|
||||
&utls.PSKKeyExchangeModesExtension{Modes: []uint8{utls.PskModeDHE}},
|
||||
&utls.KeyShareExtension{KeyShares: []utls.KeyShare{
|
||||
{Group: utls.X25519},
|
||||
}},
|
||||
)
|
||||
|
||||
if enableGREASE {
|
||||
extensions = append(extensions, &utls.UtlsGREASEExtension{})
|
||||
}
|
||||
|
||||
return &utls.ClientHelloSpec{
|
||||
CipherSuites: cipherSuites,
|
||||
CompressionMethods: []uint8{0}, // null compression only (standard)
|
||||
Extensions: extensions,
|
||||
TLSVersMax: utls.VersionTLS13,
|
||||
TLSVersMin: utls.VersionTLS10,
|
||||
}
|
||||
}
|
||||
307
backend/internal/pkg/tlsfingerprint/dialer_test.go
Normal file
307
backend/internal/pkg/tlsfingerprint/dialer_test.go
Normal file
@@ -0,0 +1,307 @@
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
//
|
||||
// Integration tests for verifying TLS fingerprint correctness.
|
||||
// These tests make actual network requests and should be run manually.
|
||||
//
|
||||
// Run with: go test -v ./internal/pkg/tlsfingerprint/...
|
||||
// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/...
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FingerprintResponse represents the response from tls.peet.ws/api/all.
|
||||
type FingerprintResponse struct {
|
||||
IP string `json:"ip"`
|
||||
TLS TLSInfo `json:"tls"`
|
||||
HTTP2 any `json:"http2"`
|
||||
}
|
||||
|
||||
// TLSInfo contains TLS fingerprint details.
|
||||
type TLSInfo struct {
|
||||
JA3 string `json:"ja3"`
|
||||
JA3Hash string `json:"ja3_hash"`
|
||||
JA4 string `json:"ja4"`
|
||||
PeetPrint string `json:"peetprint"`
|
||||
PeetPrintHash string `json:"peetprint_hash"`
|
||||
ClientRandom string `json:"client_random"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
|
||||
func TestDialerBasicConnection(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping network test in short mode")
|
||||
}
|
||||
|
||||
// Create a dialer with default profile
|
||||
profile := &Profile{
|
||||
Name: "Test Profile",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
dialer := NewDialer(profile, nil)
|
||||
|
||||
// Create HTTP client with custom TLS dialer
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Make a request to a known HTTPS endpoint
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect: %v", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
|
||||
// This test uses tls.peet.ws to verify the fingerprint.
|
||||
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
|
||||
func TestJA3Fingerprint(t *testing.T) {
|
||||
// Skip if network is unavailable or if running in short mode
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
profile := &Profile{
|
||||
Name: "Claude CLI Test",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
dialer := NewDialer(profile, nil)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Use tls.peet.ws fingerprint detection API
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get fingerprint: %v", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
}
|
||||
|
||||
// Log all fingerprint information
|
||||
t.Logf("JA3: %s", fpResp.TLS.JA3)
|
||||
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
|
||||
t.Logf("JA4: %s", fpResp.TLS.JA4)
|
||||
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
|
||||
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
|
||||
|
||||
// Verify JA3 hash matches expected value
|
||||
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
|
||||
if fpResp.TLS.JA3Hash == expectedJA3Hash {
|
||||
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
|
||||
} else {
|
||||
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
|
||||
}
|
||||
|
||||
// Verify JA4 fingerprint
|
||||
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
|
||||
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
|
||||
// The suffix _a33745022dd6_1f22a2ca17c4 should match
|
||||
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
|
||||
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
|
||||
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
|
||||
}
|
||||
|
||||
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
|
||||
// d = domain (SNI present), i = IP (no SNI)
|
||||
// Since we connect to tls.peet.ws (domain), we expect 'd'
|
||||
expectedJA4Prefix := "t13d5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
|
||||
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
|
||||
} else {
|
||||
// Also accept 'i' variant for IP connections
|
||||
altPrefix := "t13i5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
|
||||
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
|
||||
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
|
||||
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
|
||||
} else {
|
||||
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
|
||||
}
|
||||
|
||||
// Verify extension list (should be 11 extensions including SNI)
|
||||
// Expected: 0-11-10-35-16-22-23-13-43-45-51
|
||||
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
|
||||
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
|
||||
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
|
||||
} else {
|
||||
t.Logf("Warning: JA3 extension list may differ")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDialerWithProfile tests that different profiles produce different fingerprints.
|
||||
func TestDialerWithProfile(t *testing.T) {
|
||||
// Create two dialers with different profiles
|
||||
profile1 := &Profile{
|
||||
Name: "Profile 1 - No GREASE",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
profile2 := &Profile{
|
||||
Name: "Profile 2 - With GREASE",
|
||||
EnableGREASE: true,
|
||||
}
|
||||
|
||||
dialer1 := NewDialer(profile1, nil)
|
||||
dialer2 := NewDialer(profile2, nil)
|
||||
|
||||
// Build specs and compare
|
||||
// Note: We can't directly compare JA3 without making network requests
|
||||
// but we can verify the specs are different
|
||||
spec1 := dialer1.buildClientHelloSpec()
|
||||
spec2 := dialer2.buildClientHelloSpec()
|
||||
|
||||
// Profile with GREASE should have more extensions
|
||||
if len(spec2.Extensions) <= len(spec1.Extensions) {
|
||||
t.Error("expected GREASE profile to have more extensions")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHTTPProxyDialerBasic tests HTTP proxy dialer creation.
|
||||
// Note: This is a unit test - actual proxy testing requires a proxy server.
|
||||
func TestHTTPProxyDialerBasic(t *testing.T) {
|
||||
profile := &Profile{
|
||||
Name: "Test Profile",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
|
||||
// Test that dialer is created without panic
|
||||
proxyURL := mustParseURL("http://proxy.example.com:8080")
|
||||
dialer := NewHTTPProxyDialer(profile, proxyURL)
|
||||
|
||||
if dialer == nil {
|
||||
t.Fatal("expected dialer to be created")
|
||||
}
|
||||
if dialer.profile != profile {
|
||||
t.Error("expected profile to be set")
|
||||
}
|
||||
if dialer.proxyURL != proxyURL {
|
||||
t.Error("expected proxyURL to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSOCKS5ProxyDialerBasic tests SOCKS5 proxy dialer creation.
|
||||
// Note: This is a unit test - actual proxy testing requires a proxy server.
|
||||
func TestSOCKS5ProxyDialerBasic(t *testing.T) {
|
||||
profile := &Profile{
|
||||
Name: "Test Profile",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
|
||||
// Test that dialer is created without panic
|
||||
proxyURL := mustParseURL("socks5://proxy.example.com:1080")
|
||||
dialer := NewSOCKS5ProxyDialer(profile, proxyURL)
|
||||
|
||||
if dialer == nil {
|
||||
t.Fatal("expected dialer to be created")
|
||||
}
|
||||
if dialer.profile != profile {
|
||||
t.Error("expected profile to be set")
|
||||
}
|
||||
if dialer.proxyURL != proxyURL {
|
||||
t.Error("expected proxyURL to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildClientHelloSpec tests ClientHello spec construction.
|
||||
func TestBuildClientHelloSpec(t *testing.T) {
|
||||
// Test with nil profile (should use defaults)
|
||||
spec := buildClientHelloSpecFromProfile(nil)
|
||||
|
||||
if len(spec.CipherSuites) == 0 {
|
||||
t.Error("expected cipher suites to be set")
|
||||
}
|
||||
if len(spec.Extensions) == 0 {
|
||||
t.Error("expected extensions to be set")
|
||||
}
|
||||
|
||||
// Verify default cipher suites are used
|
||||
if len(spec.CipherSuites) != len(defaultCipherSuites) {
|
||||
t.Errorf("expected %d cipher suites, got %d", len(defaultCipherSuites), len(spec.CipherSuites))
|
||||
}
|
||||
|
||||
// Test with custom profile
|
||||
customProfile := &Profile{
|
||||
Name: "Custom",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{0x1301, 0x1302},
|
||||
}
|
||||
spec = buildClientHelloSpecFromProfile(customProfile)
|
||||
|
||||
if len(spec.CipherSuites) != 2 {
|
||||
t.Errorf("expected 2 cipher suites, got %d", len(spec.CipherSuites))
|
||||
}
|
||||
}
|
||||
|
||||
// TestToUTLSCurves tests curve ID conversion.
|
||||
func TestToUTLSCurves(t *testing.T) {
|
||||
input := []uint16{0x001d, 0x0017, 0x0018}
|
||||
result := toUTLSCurves(input)
|
||||
|
||||
if len(result) != len(input) {
|
||||
t.Errorf("expected %d curves, got %d", len(input), len(result))
|
||||
}
|
||||
|
||||
for i, curve := range result {
|
||||
if uint16(curve) != input[i] {
|
||||
t.Errorf("curve %d: expected 0x%04x, got 0x%04x", i, input[i], uint16(curve))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to parse URL without error handling.
|
||||
func mustParseURL(rawURL string) *url.URL {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
171
backend/internal/pkg/tlsfingerprint/registry.go
Normal file
171
backend/internal/pkg/tlsfingerprint/registry.go
Normal file
@@ -0,0 +1,171 @@
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// DefaultProfileName is the name of the built-in Claude CLI profile.
|
||||
const DefaultProfileName = "claude_cli_v2"
|
||||
|
||||
// Registry manages TLS fingerprint profiles.
|
||||
// It holds a collection of profiles that can be used for TLS fingerprint simulation.
|
||||
// Profiles are selected based on account ID using modulo operation.
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
profiles map[string]*Profile
|
||||
profileNames []string // Sorted list of profile names for deterministic selection
|
||||
}
|
||||
|
||||
// NewRegistry creates a new TLS fingerprint profile registry.
|
||||
// It initializes with the built-in default profile.
|
||||
func NewRegistry() *Registry {
|
||||
r := &Registry{
|
||||
profiles: make(map[string]*Profile),
|
||||
profileNames: make([]string, 0),
|
||||
}
|
||||
|
||||
// Register the built-in default profile
|
||||
r.registerBuiltinProfile()
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// NewRegistryFromConfig creates a new registry and loads profiles from config.
|
||||
// If the config has custom profiles defined, they will be merged with the built-in default.
|
||||
func NewRegistryFromConfig(cfg *config.TLSFingerprintConfig) *Registry {
|
||||
r := NewRegistry()
|
||||
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
slog.Debug("tls_registry_disabled", "reason", "disabled or no config")
|
||||
return r
|
||||
}
|
||||
|
||||
// Load custom profiles from config
|
||||
for name, profileCfg := range cfg.Profiles {
|
||||
profile := &Profile{
|
||||
Name: profileCfg.Name,
|
||||
EnableGREASE: profileCfg.EnableGREASE,
|
||||
CipherSuites: profileCfg.CipherSuites,
|
||||
Curves: profileCfg.Curves,
|
||||
PointFormats: profileCfg.PointFormats,
|
||||
}
|
||||
|
||||
// If the profile has empty values, they will use defaults in dialer
|
||||
r.RegisterProfile(name, profile)
|
||||
slog.Debug("tls_registry_loaded_profile", "key", name, "name", profileCfg.Name)
|
||||
}
|
||||
|
||||
slog.Debug("tls_registry_initialized", "profile_count", len(r.profileNames), "profiles", r.profileNames)
|
||||
return r
|
||||
}
|
||||
|
||||
// registerBuiltinProfile adds the default Claude CLI profile to the registry.
|
||||
func (r *Registry) registerBuiltinProfile() {
|
||||
defaultProfile := &Profile{
|
||||
Name: "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)",
|
||||
EnableGREASE: false, // Node.js does not use GREASE
|
||||
// Empty slices will cause dialer to use built-in defaults
|
||||
CipherSuites: nil,
|
||||
Curves: nil,
|
||||
PointFormats: nil,
|
||||
}
|
||||
r.RegisterProfile(DefaultProfileName, defaultProfile)
|
||||
}
|
||||
|
||||
// RegisterProfile adds or updates a profile in the registry.
|
||||
func (r *Registry) RegisterProfile(name string, profile *Profile) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Check if this is a new profile
|
||||
_, exists := r.profiles[name]
|
||||
r.profiles[name] = profile
|
||||
|
||||
if !exists {
|
||||
r.profileNames = append(r.profileNames, name)
|
||||
// Keep names sorted for deterministic selection
|
||||
sort.Strings(r.profileNames)
|
||||
}
|
||||
}
|
||||
|
||||
// GetProfile returns a profile by name.
|
||||
// Returns nil if the profile does not exist.
|
||||
func (r *Registry) GetProfile(name string) *Profile {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.profiles[name]
|
||||
}
|
||||
|
||||
// GetDefaultProfile returns the built-in default profile.
|
||||
func (r *Registry) GetDefaultProfile() *Profile {
|
||||
return r.GetProfile(DefaultProfileName)
|
||||
}
|
||||
|
||||
// GetProfileByAccountID returns a profile for the given account ID.
|
||||
// The profile is selected using: profileNames[accountID % len(profiles)]
|
||||
// This ensures deterministic profile assignment for each account.
|
||||
func (r *Registry) GetProfileByAccountID(accountID int64) *Profile {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if len(r.profileNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use modulo to select profile index
|
||||
// Use absolute value to handle negative IDs (though unlikely)
|
||||
idx := accountID
|
||||
if idx < 0 {
|
||||
idx = -idx
|
||||
}
|
||||
selectedIndex := int(idx % int64(len(r.profileNames)))
|
||||
selectedName := r.profileNames[selectedIndex]
|
||||
|
||||
return r.profiles[selectedName]
|
||||
}
|
||||
|
||||
// ProfileCount returns the number of registered profiles.
|
||||
func (r *Registry) ProfileCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.profiles)
|
||||
}
|
||||
|
||||
// ProfileNames returns a sorted list of all registered profile names.
|
||||
func (r *Registry) ProfileNames() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// Return a copy to prevent modification
|
||||
names := make([]string, len(r.profileNames))
|
||||
copy(names, r.profileNames)
|
||||
return names
|
||||
}
|
||||
|
||||
// Global registry instance for convenience
|
||||
var globalRegistry *Registry
|
||||
var globalRegistryOnce sync.Once
|
||||
|
||||
// GlobalRegistry returns the global TLS fingerprint registry.
|
||||
// The registry is lazily initialized with the default profile.
|
||||
func GlobalRegistry() *Registry {
|
||||
globalRegistryOnce.Do(func() {
|
||||
globalRegistry = NewRegistry()
|
||||
})
|
||||
return globalRegistry
|
||||
}
|
||||
|
||||
// InitGlobalRegistry initializes the global registry with configuration.
|
||||
// This should be called during application startup.
|
||||
// It is safe to call multiple times; subsequent calls will update the registry.
|
||||
func InitGlobalRegistry(cfg *config.TLSFingerprintConfig) *Registry {
|
||||
globalRegistryOnce.Do(func() {
|
||||
globalRegistry = NewRegistryFromConfig(cfg)
|
||||
})
|
||||
return globalRegistry
|
||||
}
|
||||
243
backend/internal/pkg/tlsfingerprint/registry_test.go
Normal file
243
backend/internal/pkg/tlsfingerprint/registry_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
func TestNewRegistry(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// Should have exactly one profile (the default)
|
||||
if r.ProfileCount() != 1 {
|
||||
t.Errorf("expected 1 profile, got %d", r.ProfileCount())
|
||||
}
|
||||
|
||||
// Should have the default profile
|
||||
profile := r.GetDefaultProfile()
|
||||
if profile == nil {
|
||||
t.Error("expected default profile to exist")
|
||||
}
|
||||
|
||||
// Default profile name should be in the list
|
||||
names := r.ProfileNames()
|
||||
if len(names) != 1 || names[0] != DefaultProfileName {
|
||||
t.Errorf("expected profile names to be [%s], got %v", DefaultProfileName, names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProfile(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// Register a new profile
|
||||
customProfile := &Profile{
|
||||
Name: "Custom Profile",
|
||||
EnableGREASE: true,
|
||||
}
|
||||
r.RegisterProfile("custom", customProfile)
|
||||
|
||||
// Should now have 2 profiles
|
||||
if r.ProfileCount() != 2 {
|
||||
t.Errorf("expected 2 profiles, got %d", r.ProfileCount())
|
||||
}
|
||||
|
||||
// Should be able to retrieve the custom profile
|
||||
retrieved := r.GetProfile("custom")
|
||||
if retrieved == nil {
|
||||
t.Fatal("expected custom profile to exist")
|
||||
}
|
||||
if retrieved.Name != "Custom Profile" {
|
||||
t.Errorf("expected profile name 'Custom Profile', got '%s'", retrieved.Name)
|
||||
}
|
||||
if !retrieved.EnableGREASE {
|
||||
t.Error("expected EnableGREASE to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProfile(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// Get existing profile
|
||||
profile := r.GetProfile(DefaultProfileName)
|
||||
if profile == nil {
|
||||
t.Error("expected default profile to exist")
|
||||
}
|
||||
|
||||
// Get non-existing profile
|
||||
nonExistent := r.GetProfile("nonexistent")
|
||||
if nonExistent != nil {
|
||||
t.Error("expected nil for non-existent profile")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProfileByAccountID(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// With only default profile, all account IDs should return the same profile
|
||||
for i := int64(0); i < 10; i++ {
|
||||
profile := r.GetProfileByAccountID(i)
|
||||
if profile == nil {
|
||||
t.Errorf("expected profile for account %d, got nil", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Add more profiles
|
||||
r.RegisterProfile("profile_a", &Profile{Name: "Profile A"})
|
||||
r.RegisterProfile("profile_b", &Profile{Name: "Profile B"})
|
||||
|
||||
// Now we have 3 profiles: claude_cli_v2, profile_a, profile_b
|
||||
// Names are sorted, so order is: claude_cli_v2, profile_a, profile_b
|
||||
expectedOrder := []string{DefaultProfileName, "profile_a", "profile_b"}
|
||||
names := r.ProfileNames()
|
||||
for i, name := range expectedOrder {
|
||||
if names[i] != name {
|
||||
t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Test modulo selection
|
||||
// Account ID 0 % 3 = 0 -> claude_cli_v2
|
||||
// Account ID 1 % 3 = 1 -> profile_a
|
||||
// Account ID 2 % 3 = 2 -> profile_b
|
||||
// Account ID 3 % 3 = 0 -> claude_cli_v2
|
||||
testCases := []struct {
|
||||
accountID int64
|
||||
expectedName string
|
||||
}{
|
||||
{0, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"},
|
||||
{1, "Profile A"},
|
||||
{2, "Profile B"},
|
||||
{3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"},
|
||||
{4, "Profile A"},
|
||||
{5, "Profile B"},
|
||||
{100, "Profile A"}, // 100 % 3 = 1
|
||||
{-1, "Profile A"}, // |-1| % 3 = 1
|
||||
{-3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"}, // |-3| % 3 = 0
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
profile := r.GetProfileByAccountID(tc.accountID)
|
||||
if profile == nil {
|
||||
t.Errorf("expected profile for account %d, got nil", tc.accountID)
|
||||
continue
|
||||
}
|
||||
if profile.Name != tc.expectedName {
|
||||
t.Errorf("account %d: expected profile name '%s', got '%s'", tc.accountID, tc.expectedName, profile.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRegistryFromConfig(t *testing.T) {
|
||||
// Test with nil config
|
||||
r := NewRegistryFromConfig(nil)
|
||||
if r.ProfileCount() != 1 {
|
||||
t.Errorf("expected 1 profile with nil config, got %d", r.ProfileCount())
|
||||
}
|
||||
|
||||
// Test with disabled config
|
||||
disabledCfg := &config.TLSFingerprintConfig{
|
||||
Enabled: false,
|
||||
}
|
||||
r = NewRegistryFromConfig(disabledCfg)
|
||||
if r.ProfileCount() != 1 {
|
||||
t.Errorf("expected 1 profile with disabled config, got %d", r.ProfileCount())
|
||||
}
|
||||
|
||||
// Test with enabled config and custom profiles
|
||||
enabledCfg := &config.TLSFingerprintConfig{
|
||||
Enabled: true,
|
||||
Profiles: map[string]config.TLSProfileConfig{
|
||||
"custom1": {
|
||||
Name: "Custom Profile 1",
|
||||
EnableGREASE: true,
|
||||
},
|
||||
"custom2": {
|
||||
Name: "Custom Profile 2",
|
||||
EnableGREASE: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
r = NewRegistryFromConfig(enabledCfg)
|
||||
|
||||
// Should have 3 profiles: default + 2 custom
|
||||
if r.ProfileCount() != 3 {
|
||||
t.Errorf("expected 3 profiles, got %d", r.ProfileCount())
|
||||
}
|
||||
|
||||
// Check custom profiles exist
|
||||
custom1 := r.GetProfile("custom1")
|
||||
if custom1 == nil || custom1.Name != "Custom Profile 1" {
|
||||
t.Error("expected custom1 profile to exist with correct name")
|
||||
}
|
||||
custom2 := r.GetProfile("custom2")
|
||||
if custom2 == nil || custom2.Name != "Custom Profile 2" {
|
||||
t.Error("expected custom2 profile to exist with correct name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProfileNames(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// Add profiles in non-alphabetical order
|
||||
r.RegisterProfile("zebra", &Profile{Name: "Zebra"})
|
||||
r.RegisterProfile("alpha", &Profile{Name: "Alpha"})
|
||||
r.RegisterProfile("beta", &Profile{Name: "Beta"})
|
||||
|
||||
names := r.ProfileNames()
|
||||
|
||||
// Should be sorted alphabetically
|
||||
expected := []string{"alpha", "beta", DefaultProfileName, "zebra"}
|
||||
if len(names) != len(expected) {
|
||||
t.Errorf("expected %d names, got %d", len(expected), len(names))
|
||||
}
|
||||
for i, name := range expected {
|
||||
if names[i] != name {
|
||||
t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Test that returned slice is a copy (modifying it shouldn't affect registry)
|
||||
names[0] = "modified"
|
||||
originalNames := r.ProfileNames()
|
||||
if originalNames[0] == "modified" {
|
||||
t.Error("modifying returned slice should not affect registry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// Run concurrent reads and writes
|
||||
done := make(chan bool)
|
||||
|
||||
// Writers
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
for j := 0; j < 100; j++ {
|
||||
r.RegisterProfile("concurrent"+string(rune('0'+id)), &Profile{Name: "Concurrent"})
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Readers
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
for j := 0; j < 100; j++ {
|
||||
_ = r.ProfileCount()
|
||||
_ = r.ProfileNames()
|
||||
_ = r.GetProfileByAccountID(int64(id * j))
|
||||
_ = r.GetProfile(DefaultProfileName)
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 20; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Test should pass without data races (run with -race flag)
|
||||
}
|
||||
@@ -575,6 +575,15 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
|
||||
}
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Account.Update().
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
SetStatus(service.StatusActive).
|
||||
SetErrorMessage("").
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||
_, err := r.client.AccountGroup.Create().
|
||||
SetAccountID(accountID).
|
||||
@@ -993,7 +1002,16 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
|
||||
builder.SetSessionWindowEnd(*end)
|
||||
}
|
||||
_, err := builder.Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 触发调度器缓存更新(仅当窗口时间有变化时)
|
||||
if start != nil || end != nil {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -12,9 +13,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||
authCacheInvalidateChannel = "auth:cache:invalidate"
|
||||
)
|
||||
|
||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||
@@ -91,3 +93,45 @@ func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *servi
|
||||
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
|
||||
}
|
||||
|
||||
// PublishAuthCacheInvalidation publishes a cache invalidation message to all instances
|
||||
func (c *apiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return c.rdb.Publish(ctx, authCacheInvalidateChannel, cacheKey).Err()
|
||||
}
|
||||
|
||||
// SubscribeAuthCacheInvalidation subscribes to cache invalidation messages
|
||||
func (c *apiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
pubsub := c.rdb.Subscribe(ctx, authCacheInvalidateChannel)
|
||||
|
||||
// Verify subscription is working
|
||||
_, err := pubsub.Receive(ctx)
|
||||
if err != nil {
|
||||
_ = pubsub.Close()
|
||||
return fmt.Errorf("subscribe to auth cache invalidation: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := pubsub.Close(); err != nil {
|
||||
log.Printf("Warning: failed to close auth cache invalidation pubsub: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ch := pubsub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if msg != nil {
|
||||
handler(msg.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
@@ -150,6 +152,172 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
|
||||
// 根据 enableTLSFingerprint 参数决定是否使用 TLS 指纹
|
||||
//
|
||||
// 参数:
|
||||
// - req: HTTP 请求对象
|
||||
// - proxyURL: 代理地址,空字符串表示直连
|
||||
// - accountID: 账户 ID,用于账户级隔离和 TLS 指纹模板选择
|
||||
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
|
||||
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
|
||||
//
|
||||
// TLS 指纹说明:
|
||||
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
|
||||
// - 指纹模板根据 accountID % len(profiles) 自动选择
|
||||
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
|
||||
func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
// 如果未启用 TLS 指纹,直接使用标准请求路径
|
||||
if !enableTLSFingerprint {
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
// TLS 指纹已启用,记录调试日志
|
||||
targetHost := ""
|
||||
if req != nil && req.URL != nil {
|
||||
targetHost = req.URL.Host
|
||||
}
|
||||
proxyInfo := "direct"
|
||||
if proxyURL != "" {
|
||||
proxyInfo = proxyURL
|
||||
}
|
||||
slog.Debug("tls_fingerprint_enabled", "account_id", accountID, "target", targetHost, "proxy", proxyInfo)
|
||||
|
||||
if err := s.validateRequestHost(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取 TLS 指纹 Profile
|
||||
registry := tlsfingerprint.GlobalRegistry()
|
||||
profile := registry.GetProfileByAccountID(accountID)
|
||||
if profile == nil {
|
||||
// 如果获取不到 profile,回退到普通请求
|
||||
slog.Debug("tls_fingerprint_no_profile", "account_id", accountID, "fallback", "standard_request")
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
slog.Debug("tls_fingerprint_using_profile", "account_id", accountID, "profile", profile.Name, "grease", profile.EnableGREASE)
|
||||
|
||||
// 获取或创建带 TLS 指纹的客户端
|
||||
entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_acquire_client_failed", "account_id", accountID, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
resp, err := entry.client.Do(req)
|
||||
if err != nil {
|
||||
// 请求失败,立即减少计数
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
slog.Debug("tls_fingerprint_request_failed", "account_id", accountID, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Debug("tls_fingerprint_request_success", "account_id", accountID, "status", resp.StatusCode)
|
||||
|
||||
// 包装响应体,在关闭时自动减少计数并更新时间戳
|
||||
resp.Body = wrapTrackedBody(resp.Body, func() {
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
})
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// acquireClientWithTLS 获取或创建带 TLS 指纹的客户端
|
||||
func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*upstreamClientEntry, error) {
|
||||
return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, true, true)
|
||||
}
|
||||
|
||||
// getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目
|
||||
// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
|
||||
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
|
||||
isolation := s.getIsolationMode()
|
||||
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
|
||||
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
|
||||
cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
|
||||
poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
|
||||
|
||||
now := time.Now()
|
||||
nowUnix := now.UnixNano()
|
||||
|
||||
// 读锁快速路径
|
||||
s.mu.RLock()
|
||||
if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
|
||||
return entry, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
// 写锁慢路径
|
||||
s.mu.Lock()
|
||||
if entry, ok := s.clients[cacheKey]; ok {
|
||||
if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
|
||||
return entry, nil
|
||||
}
|
||||
slog.Debug("tls_fingerprint_evicting_stale_client",
|
||||
"account_id", accountID,
|
||||
"cache_key", cacheKey,
|
||||
"proxy_changed", entry.proxyKey != proxyKey,
|
||||
"pool_changed", entry.poolKey != poolKey)
|
||||
s.removeClientLocked(cacheKey, entry)
|
||||
}
|
||||
|
||||
// 超出缓存上限时尝试淘汰
|
||||
if enforceLimit && s.maxUpstreamClients() > 0 {
|
||||
s.evictIdleLocked(now)
|
||||
if len(s.clients) >= s.maxUpstreamClients() {
|
||||
if !s.evictOldestIdleLocked() {
|
||||
s.mu.Unlock()
|
||||
return nil, errUpstreamClientLimitReached
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建带 TLS 指纹的 Transport
|
||||
slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", proxyKey)
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("build TLS fingerprint transport: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
if s.shouldValidateResolvedIP() {
|
||||
client.CheckRedirect = s.redirectChecker
|
||||
}
|
||||
|
||||
entry := &upstreamClientEntry{
|
||||
client: client,
|
||||
proxyKey: proxyKey,
|
||||
poolKey: poolKey,
|
||||
}
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.StoreInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.clients[cacheKey] = entry
|
||||
|
||||
s.evictIdleLocked(now)
|
||||
s.evictOverLimitLocked()
|
||||
s.mu.Unlock()
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
|
||||
if s.cfg == nil {
|
||||
return false
|
||||
@@ -618,6 +786,64 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// buildUpstreamTransportWithTLSFingerprint 构建带 TLS 指纹伪装的 Transport
|
||||
// 使用 utls 库模拟 Claude CLI 的 TLS 指纹
|
||||
//
|
||||
// 参数:
|
||||
// - settings: 连接池配置
|
||||
// - proxyURL: 代理 URL(nil 表示直连)
|
||||
// - profile: TLS 指纹配置
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Transport: 配置好的 Transport 实例
|
||||
// - error: 配置错误
|
||||
//
|
||||
// 代理类型处理:
|
||||
// - nil/空: 直连,使用 TLSFingerprintDialer
|
||||
// - http/https: HTTP 代理,使用 HTTPProxyDialer(CONNECT 隧道 + utls 握手)
|
||||
// - socks5: SOCKS5 代理,使用 SOCKS5ProxyDialer(SOCKS5 隧道 + utls 握手)
|
||||
func buildUpstreamTransportWithTLSFingerprint(settings poolSettings, proxyURL *url.URL, profile *tlsfingerprint.Profile) (*http.Transport, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: settings.maxIdleConns,
|
||||
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: settings.maxConnsPerHost,
|
||||
IdleConnTimeout: settings.idleConnTimeout,
|
||||
ResponseHeaderTimeout: settings.responseHeaderTimeout,
|
||||
// 禁用默认的 TLS,我们使用自定义的 DialTLSContext
|
||||
ForceAttemptHTTP2: false,
|
||||
}
|
||||
|
||||
// 根据代理类型选择合适的 TLS 指纹 Dialer
|
||||
if proxyURL == nil {
|
||||
// 直连:使用 TLSFingerprintDialer
|
||||
slog.Debug("tls_fingerprint_transport_direct")
|
||||
dialer := tlsfingerprint.NewDialer(profile, nil)
|
||||
transport.DialTLSContext = dialer.DialTLSContext
|
||||
} else {
|
||||
scheme := strings.ToLower(proxyURL.Scheme)
|
||||
switch scheme {
|
||||
case "socks5", "socks5h":
|
||||
// SOCKS5 代理:使用 SOCKS5ProxyDialer
|
||||
slog.Debug("tls_fingerprint_transport_socks5", "proxy", proxyURL.Host)
|
||||
socks5Dialer := tlsfingerprint.NewSOCKS5ProxyDialer(profile, proxyURL)
|
||||
transport.DialTLSContext = socks5Dialer.DialTLSContext
|
||||
case "http", "https":
|
||||
// HTTP/HTTPS 代理:使用 HTTPProxyDialer(CONNECT 隧道)
|
||||
slog.Debug("tls_fingerprint_transport_http_connect", "proxy", proxyURL.Host)
|
||||
httpDialer := tlsfingerprint.NewHTTPProxyDialer(profile, proxyURL)
|
||||
transport.DialTLSContext = httpDialer.DialTLSContext
|
||||
default:
|
||||
// 未知代理类型,回退到普通代理配置(无 TLS 指纹)
|
||||
slog.Debug("tls_fingerprint_transport_unknown_scheme_fallback", "scheme", scheme)
|
||||
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// trackedBody 带跟踪功能的响应体包装器
|
||||
// 在 Close 时执行回调,用于更新请求计数
|
||||
type trackedBody struct {
|
||||
|
||||
@@ -11,8 +11,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
fingerprintKeyPrefix = "fingerprint:"
|
||||
fingerprintTTL = 24 * time.Hour
|
||||
fingerprintKeyPrefix = "fingerprint:"
|
||||
fingerprintTTL = 24 * time.Hour
|
||||
maskedSessionKeyPrefix = "masked_session:"
|
||||
maskedSessionTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
// fingerprintKey generates the Redis key for account fingerprint cache.
|
||||
@@ -20,6 +22,11 @@ func fingerprintKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
// maskedSessionKey generates the Redis key for masked session ID cache.
|
||||
func maskedSessionKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", maskedSessionKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
type identityCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -49,3 +56,20 @@ func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
|
||||
}
|
||||
|
||||
func (c *identityCache) GetMaskedSessionID(ctx context.Context, accountID int64) (string, error) {
|
||||
key := maskedSessionKey(accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *identityCache) SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error {
|
||||
key := maskedSessionKey(accountID)
|
||||
return c.rdb.Set(ctx, key, sessionID, maskedSessionTTL).Err()
|
||||
}
|
||||
|
||||
@@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return make(map[int64]int), nil
|
||||
}
|
||||
@@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco
|
||||
|
||||
// 使用 pipeline 批量执行
|
||||
pipe := c.rdb.Pipeline()
|
||||
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
|
||||
|
||||
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
key := sessionLimitKey(accountID)
|
||||
// 使用各账号自己的 idleTimeout,如果没有则用默认值
|
||||
idleTimeout := c.defaultIdleTimeout
|
||||
if idleTimeouts != nil {
|
||||
if t, ok := idleTimeouts[accountID]; ok && t > 0 {
|
||||
idleTimeout = t
|
||||
}
|
||||
}
|
||||
idleTimeoutSeconds := int(idleTimeout.Seconds())
|
||||
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
|
||||
}
|
||||
|
||||
|
||||
@@ -618,6 +618,14 @@ func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stubApiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stubApiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubGroupRepo struct{}
|
||||
|
||||
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
|
||||
@@ -736,6 +744,10 @@ func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg strin
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ClearError(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -576,6 +576,44 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
|
||||
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
|
||||
}
|
||||
|
||||
// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||
// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征
|
||||
func (a *Account) IsTLSFingerprintEnabled() bool {
|
||||
// 仅支持 Anthropic OAuth/SetupToken 账号
|
||||
if !a.IsAnthropicOAuthOrSetupToken() {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["enable_tls_fingerprint"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||
// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID,
|
||||
// 使上游认为请求来自同一个会话
|
||||
func (a *Account) IsSessionIDMaskingEnabled() bool {
|
||||
if !a.IsAnthropicOAuthOrSetupToken() {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["session_id_masking_enabled"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetWindowCostLimit() float64 {
|
||||
@@ -652,6 +690,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo
|
||||
return WindowCostNotSchedulable
|
||||
}
|
||||
|
||||
// GetCurrentWindowStartTime 获取当前有效的窗口开始时间
|
||||
// 逻辑:
|
||||
// 1. 如果窗口未过期(SessionWindowEnd 存在且在当前时间之后),使用记录的 SessionWindowStart
|
||||
// 2. 否则(窗口过期或未设置),使用新的预测窗口开始时间(从当前整点开始)
|
||||
func (a *Account) GetCurrentWindowStartTime() time.Time {
|
||||
now := time.Now()
|
||||
|
||||
// 窗口未过期,使用记录的窗口开始时间
|
||||
if a.SessionWindowStart != nil && a.SessionWindowEnd != nil && now.Before(*a.SessionWindowEnd) {
|
||||
return *a.SessionWindowStart
|
||||
}
|
||||
|
||||
// 窗口已过期或未设置,预测新的窗口开始时间(从当前整点开始)
|
||||
// 与 ratelimit_service.go 中 UpdateSessionWindow 的预测逻辑保持一致
|
||||
return time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
|
||||
}
|
||||
|
||||
// parseExtraFloat64 从 extra 字段解析 float64 值
|
||||
func parseExtraFloat64(value any) float64 {
|
||||
switch v := value.(type) {
|
||||
|
||||
@@ -37,6 +37,7 @@ type AccountRepository interface {
|
||||
UpdateLastUsed(ctx context.Context, id int64) error
|
||||
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||
SetError(ctx context.Context, id int64, errorMsg string) error
|
||||
ClearError(ctx context.Context, id int64) error
|
||||
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
||||
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
|
||||
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
||||
|
||||
@@ -99,6 +99,10 @@ func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg strin
|
||||
panic("unexpected SetError call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearError(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearError call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
panic("unexpected SetSchedulable call")
|
||||
}
|
||||
|
||||
@@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
|
||||
@@ -369,12 +369,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
||||
|
||||
// 如果没有缓存,从数据库查询
|
||||
if windowStats == nil {
|
||||
var startTime time.Time
|
||||
if account.SessionWindowStart != nil {
|
||||
startTime = *account.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := account.GetCurrentWindowStartTime()
|
||||
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
|
||||
@@ -42,6 +42,7 @@ type AdminService interface {
|
||||
DeleteAccount(ctx context.Context, id int64) error
|
||||
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
|
||||
ClearAccountError(ctx context.Context, id int64) (*Account, error)
|
||||
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||
|
||||
@@ -1101,6 +1102,10 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return s.accountRepo.SetError(ctx, id, errorMsg)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
|
||||
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
|
||||
return nil, err
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -82,13 +82,14 @@ type AntigravityExchangeCodeInput struct {
|
||||
|
||||
// AntigravityTokenInfo token 信息
|
||||
type AntigravityTokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
@@ -149,12 +150,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
||||
result.ProjectID = loadResp.CloudAICompanionProject
|
||||
}
|
||||
|
||||
// 兜底:随机生成 project_id
|
||||
if result.ProjectID == "" {
|
||||
result.ProjectID = antigravity.GenerateMockProjectID()
|
||||
fmt.Printf("[AntigravityOAuth] 使用随机生成的 project_id: %s\n", result.ProjectID)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -236,16 +231,24 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保留原有的 project_id 和 email
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if existingProjectID != "" {
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
}
|
||||
// 保留原有的 email
|
||||
existingEmail := strings.TrimSpace(account.GetCredential("email"))
|
||||
if existingEmail != "" {
|
||||
tokenInfo.Email = existingEmail
|
||||
}
|
||||
|
||||
// 每次刷新都调用 LoadCodeAssist 获取 project_id
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
|
||||
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
|
||||
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
tokenInfo.ProjectIDMissing = true
|
||||
} else {
|
||||
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -31,11 +31,6 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
accessToken := account.GetCredential("access_token")
|
||||
projectID := account.GetCredential("project_id")
|
||||
|
||||
// 如果没有 project_id,生成一个随机的
|
||||
if projectID == "" {
|
||||
projectID = antigravity.GenerateMockProjectID()
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
|
||||
// 调用 API 获取配额
|
||||
|
||||
190
backend/internal/service/antigravity_rate_limit_test.go
Normal file
190
backend/internal/service/antigravity_rate_limit_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type stubAntigravityUpstream struct {
|
||||
firstBase string
|
||||
secondBase string
|
||||
calls []string
|
||||
}
|
||||
|
||||
func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
url := req.URL.String()
|
||||
s.calls = append(s.calls, url)
|
||||
if strings.HasPrefix(url, s.firstBase) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Resource has been exhausted"}}`)),
|
||||
}, nil
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader("ok")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
type scopeLimitCall struct {
|
||||
accountID int64
|
||||
scope AntigravityQuotaScope
|
||||
resetAt time.Time
|
||||
}
|
||||
|
||||
type rateLimitCall struct {
|
||||
accountID int64
|
||||
resetAt time.Time
|
||||
}
|
||||
|
||||
type stubAntigravityAccountRepo struct {
|
||||
AccountRepository
|
||||
scopeCalls []scopeLimitCall
|
||||
rateCalls []rateLimitCall
|
||||
}
|
||||
|
||||
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvailability
|
||||
}()
|
||||
|
||||
base1 := "https://ag-1.test"
|
||||
base2 := "https://ag-2.test"
|
||||
antigravity.BaseURLs = []string{base1, base2}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
|
||||
upstream := &stubAntigravityUpstream{firstBase: base1, secondBase: base2}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "acc-1",
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
var handleErrorCalled bool
|
||||
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
prefix: "[test]",
|
||||
ctx: context.Background(),
|
||||
account: account,
|
||||
proxyURL: "",
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
quotaScope: AntigravityQuotaScopeClaude,
|
||||
httpUpstream: upstream,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||
handleErrorCalled = true
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.resp)
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.False(t, handleErrorCalled)
|
||||
require.Len(t, upstream.calls, 2)
|
||||
require.True(t, strings.HasPrefix(upstream.calls[0], base1))
|
||||
require.True(t, strings.HasPrefix(upstream.calls[1], base2))
|
||||
|
||||
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
require.NotEmpty(t, available)
|
||||
require.Equal(t, base2, available[0])
|
||||
}
|
||||
|
||||
func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) {
|
||||
t.Setenv(antigravityScopeRateLimitEnv, "true")
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
|
||||
|
||||
body := buildGeminiRateLimitBody("3s")
|
||||
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
|
||||
|
||||
require.Len(t, repo.scopeCalls, 1)
|
||||
require.Empty(t, repo.rateCalls)
|
||||
call := repo.scopeCalls[0]
|
||||
require.Equal(t, account.ID, call.accountID)
|
||||
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
|
||||
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) {
|
||||
t.Setenv(antigravityScopeRateLimitEnv, "false")
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity}
|
||||
|
||||
body := buildGeminiRateLimitBody("2s")
|
||||
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
|
||||
|
||||
require.Len(t, repo.rateCalls, 1)
|
||||
require.Empty(t, repo.scopeCalls)
|
||||
call := repo.rateCalls[0]
|
||||
require.Equal(t, account.ID, call.accountID)
|
||||
require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
|
||||
now := time.Now()
|
||||
future := now.Add(10 * time.Minute)
|
||||
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "acc",
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
account.RateLimitResetAt = &future
|
||||
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
||||
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||
|
||||
account.RateLimitResetAt = nil
|
||||
account.Extra = map[string]any{
|
||||
antigravityQuotaScopesKey: map[string]any{
|
||||
"claude": map[string]any{
|
||||
"rate_limit_reset_at": future.Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
||||
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||
}
|
||||
|
||||
func buildGeminiRateLimitBody(delay string) []byte {
|
||||
return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay))
|
||||
}
|
||||
@@ -61,5 +61,10 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity")
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
|
||||
@@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) {
|
||||
s.authCacheL1 = cache
|
||||
}
|
||||
|
||||
// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation.
|
||||
// This should be called after the service is fully initialized.
|
||||
func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) {
|
||||
if s.cache == nil || s.authCacheL1 == nil {
|
||||
return
|
||||
}
|
||||
if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) {
|
||||
s.authCacheL1.Del(cacheKey)
|
||||
}); err != nil {
|
||||
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
|
||||
println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) authCacheKey(key string) string {
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(sum[:])
|
||||
@@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
|
||||
return
|
||||
}
|
||||
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
|
||||
// Publish invalidation message to other instances
|
||||
_ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
|
||||
|
||||
@@ -65,6 +65,10 @@ type APIKeyCache interface {
|
||||
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
|
||||
DeleteAuthCache(ctx context.Context, key string) error
|
||||
|
||||
// Pub/Sub for L1 cache invalidation across instances
|
||||
PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error
|
||||
SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
|
||||
|
||||
@@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
|
||||
@@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||
// 预期行为:
|
||||
// - GetKeyAndOwnerID 返回所有者 ID 为 1
|
||||
|
||||
@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up
|
||||
func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ClearError(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
mathrand "math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -445,11 +447,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||
// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
|
||||
// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
|
||||
// 调试日志:记录调度入口参数
|
||||
excludedIDsList := make([]int64, 0, len(excludedIDs))
|
||||
for id := range excludedIDs {
|
||||
excludedIDsList = append(excludedIDsList, id)
|
||||
}
|
||||
slog.Debug("account_scheduling_starting",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"model", requestedModel,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"excluded_ids", excludedIDsList)
|
||||
|
||||
cfg := s.schedulingConfig()
|
||||
// 提取会话 UUID(用于会话数量限制)
|
||||
sessionUUID := extractSessionUUID(metadataUserID)
|
||||
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
@@ -475,41 +486,63 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// 复制排除列表,用于会话限制拒绝时的重试
|
||||
localExcluded := make(map[int64]struct{})
|
||||
for k, v := range excludedIDs {
|
||||
localExcluded[k] = v
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
|
||||
for {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符)
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位
|
||||
localExcluded[account.ID] = struct{}{} // 排除此账号
|
||||
continue // 重新选择
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 对于等待计划的情况,也需要先检查会话限制
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
localExcluded[account.ID] = struct{}{}
|
||||
continue
|
||||
}
|
||||
|
||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
|
||||
@@ -625,7 +658,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
@@ -643,15 +676,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: stickyAccountID,
|
||||
MaxConcurrency: stickyAccount.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
// 会话限制已满,继续到负载感知选择
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: stickyAccountID,
|
||||
MaxConcurrency: stickyAccount.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||
}
|
||||
@@ -714,7 +752,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
@@ -732,20 +770,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
|
||||
acc := routingAvailable[0].account
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID)
|
||||
// 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的)
|
||||
// 遍历找到第一个满足会话限制的账号
|
||||
for _, item := range routingAvailable {
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
|
||||
continue // 会话限制已满,尝试下一个
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: item.account.ID,
|
||||
MaxConcurrency: item.account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
|
||||
}
|
||||
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
|
||||
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
|
||||
@@ -773,7 +817,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
// Session count limit check
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
@@ -787,15 +831,22 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
// Session count limit check (wait plan also requires session quota)
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
// Session limit full, continue to Layer 2
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -845,7 +896,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
|
||||
return result, nil
|
||||
}
|
||||
} else {
|
||||
@@ -895,7 +946,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
@@ -913,8 +964,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
// ============ Layer 3: 兜底排队 ============
|
||||
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
|
||||
s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode)
|
||||
for _, acc := range candidates {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
|
||||
continue // 会话限制已满,尝试下一个账号
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
@@ -928,7 +983,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) {
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||
|
||||
@@ -936,7 +991,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, acc, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
@@ -1093,7 +1148,24 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
|
||||
|
||||
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err == nil {
|
||||
slog.Debug("account_scheduling_list_snapshot",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"use_mixed", useMixed,
|
||||
"count", len(accounts))
|
||||
for _, acc := range accounts {
|
||||
slog.Debug("account_scheduling_account_detail",
|
||||
"account_id", acc.ID,
|
||||
"name", acc.Name,
|
||||
"platform", acc.Platform,
|
||||
"type", acc.Type,
|
||||
"status", acc.Status,
|
||||
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
|
||||
}
|
||||
}
|
||||
return accounts, useMixed, err
|
||||
}
|
||||
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
||||
if useMixed {
|
||||
@@ -1106,6 +1178,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
if err != nil {
|
||||
slog.Debug("account_scheduling_list_failed",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"error", err)
|
||||
return nil, useMixed, err
|
||||
}
|
||||
filtered := make([]Account, 0, len(accounts))
|
||||
@@ -1115,6 +1191,20 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
}
|
||||
filtered = append(filtered, acc)
|
||||
}
|
||||
slog.Debug("account_scheduling_list_mixed",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"raw_count", len(accounts),
|
||||
"filtered_count", len(filtered))
|
||||
for _, acc := range filtered {
|
||||
slog.Debug("account_scheduling_account_detail",
|
||||
"account_id", acc.ID,
|
||||
"name", acc.Name,
|
||||
"platform", acc.Platform,
|
||||
"type", acc.Type,
|
||||
"status", acc.Status,
|
||||
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
|
||||
}
|
||||
return filtered, useMixed, nil
|
||||
}
|
||||
|
||||
@@ -1129,8 +1219,25 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
if err != nil {
|
||||
slog.Debug("account_scheduling_list_failed",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"error", err)
|
||||
return nil, useMixed, err
|
||||
}
|
||||
slog.Debug("account_scheduling_list_single",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"count", len(accounts))
|
||||
for _, acc := range accounts {
|
||||
slog.Debug("account_scheduling_account_detail",
|
||||
"account_id", acc.ID,
|
||||
"name", acc.Name,
|
||||
"platform", acc.Platform,
|
||||
"type", acc.Type,
|
||||
"status", acc.Status,
|
||||
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
|
||||
}
|
||||
return accounts, useMixed, nil
|
||||
}
|
||||
|
||||
@@ -1196,12 +1303,8 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
|
||||
|
||||
// 缓存未命中,从数据库查询
|
||||
{
|
||||
var startTime time.Time
|
||||
if account.SessionWindowStart != nil {
|
||||
startTime = *account.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := account.GetCurrentWindowStartTime()
|
||||
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
@@ -1234,15 +1337,16 @@ checkSchedulability:
|
||||
|
||||
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
// sessionID: 会话标识符(使用粘性会话的 hash)
|
||||
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
|
||||
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool {
|
||||
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool {
|
||||
// 只检查 Anthropic OAuth/SetupToken 账号
|
||||
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||
return true
|
||||
}
|
||||
|
||||
maxSessions := account.GetMaxSessions()
|
||||
if maxSessions <= 0 || sessionUUID == "" {
|
||||
if maxSessions <= 0 || sessionID == "" {
|
||||
return true // 未启用会话限制或无会话ID
|
||||
}
|
||||
|
||||
@@ -1252,7 +1356,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
|
||||
|
||||
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
|
||||
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout)
|
||||
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout)
|
||||
if err != nil {
|
||||
// 失败开放:缓存错误时允许通过
|
||||
return true
|
||||
@@ -1260,18 +1364,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
|
||||
return allowed
|
||||
}
|
||||
|
||||
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
|
||||
// 格式: user_{64位hex}_account__session_{uuid}
|
||||
func extractSessionUUID(metadataUserID string) string {
|
||||
if metadataUserID == "" {
|
||||
return ""
|
||||
}
|
||||
if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
|
||||
return match[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
@@ -1301,6 +1393,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
})
|
||||
}
|
||||
|
||||
// sortCandidatesForFallback 根据配置选择排序策略
|
||||
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
|
||||
func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
|
||||
if mode == "random" {
|
||||
// 先按优先级排序,然后在同优先级内随机打乱
|
||||
sortAccountsByPriorityOnly(accounts, preferOAuth)
|
||||
shuffleWithinPriority(accounts)
|
||||
} else {
|
||||
// 默认按最后使用时间排序
|
||||
sortAccountsByPriorityAndLastUsed(accounts, preferOAuth)
|
||||
}
|
||||
}
|
||||
|
||||
// sortAccountsByPriorityOnly 仅按优先级排序
|
||||
func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) {
|
||||
sort.SliceStable(accounts, func(i, j int) bool {
|
||||
a, b := accounts[i], accounts[j]
|
||||
if a.Priority != b.Priority {
|
||||
return a.Priority < b.Priority
|
||||
}
|
||||
if preferOAuth && a.Type != b.Type {
|
||||
return a.Type == AccountTypeOAuth
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
// shuffleWithinPriority 在同优先级内随机打乱顺序
|
||||
func shuffleWithinPriority(accounts []*Account) {
|
||||
if len(accounts) <= 1 {
|
||||
return
|
||||
}
|
||||
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
|
||||
start := 0
|
||||
for start < len(accounts) {
|
||||
priority := accounts[start].Priority
|
||||
end := start + 1
|
||||
for end < len(accounts) && accounts[end].Priority == priority {
|
||||
end++
|
||||
}
|
||||
// 对 [start, end) 范围内的账户随机打乱
|
||||
if end-start > 1 {
|
||||
r.Shuffle(end-start, func(i, j int) {
|
||||
accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i]
|
||||
})
|
||||
}
|
||||
start = end
|
||||
}
|
||||
}
|
||||
|
||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
preferOAuth := platform == PlatformGemini
|
||||
@@ -2158,6 +2300,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 调试日志:记录即将转发的账号信息
|
||||
log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
|
||||
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
|
||||
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
retryStart := time.Now()
|
||||
@@ -2172,7 +2318,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
@@ -2246,7 +2392,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
if retryResp.StatusCode < 400 {
|
||||
log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
|
||||
@@ -2278,7 +2424,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
|
||||
if buildErr2 == nil {
|
||||
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
|
||||
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr2 == nil {
|
||||
resp = retryResp2
|
||||
break
|
||||
@@ -2393,6 +2539,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
// 调试日志:打印重试耗尽后的错误响应
|
||||
log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
|
||||
|
||||
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
@@ -2420,6 +2570,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
// 调试日志:打印上游错误响应
|
||||
log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
@@ -2549,9 +2703,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
fingerprint = fp
|
||||
|
||||
// 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid)
|
||||
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
|
||||
accountUUID := account.GetExtraString("account_uuid")
|
||||
if accountUUID != "" && fp.ClientID != "" {
|
||||
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
||||
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
@@ -2770,6 +2925,10 @@ func extractUpstreamErrorMessage(body []byte) string {
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
// 调试日志:打印上游错误响应
|
||||
log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
|
||||
@@ -3478,7 +3637,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
@@ -3500,7 +3659,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
resp = retryResp
|
||||
respBody, err = io.ReadAll(resp.Body)
|
||||
@@ -3578,12 +3737,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
|
||||
// OAuth 账号:应用统一指纹和重写 userID
|
||||
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if err == nil {
|
||||
accountUUID := account.GetExtraString("account_uuid")
|
||||
if accountUUID != "" && fp.ClientID != "" {
|
||||
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
||||
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +90,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda
|
||||
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearError(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import "net/http"
|
||||
// - 支持可选代理配置
|
||||
// - 支持账户级连接池隔离
|
||||
// - 实现类负责连接池管理和复用
|
||||
// - 支持可选的 TLS 指纹伪装
|
||||
type HTTPUpstream interface {
|
||||
// Do 执行 HTTP 请求
|
||||
//
|
||||
@@ -27,4 +28,28 @@ type HTTPUpstream interface {
|
||||
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
|
||||
// - 响应体可能已被包装以跟踪请求生命周期
|
||||
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
|
||||
|
||||
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
|
||||
//
|
||||
// 参数:
|
||||
// - req: HTTP 请求对象,由调用方构建
|
||||
// - proxyURL: 代理服务器地址,空字符串表示直连
|
||||
// - accountID: 账户 ID,用于连接池隔离和 TLS 指纹模板选择
|
||||
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
|
||||
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Response: HTTP 响应,调用方必须关闭 Body
|
||||
// - error: 请求错误(网络错误、超时等)
|
||||
//
|
||||
// TLS 指纹说明:
|
||||
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
|
||||
// - TLS 指纹模板根据 accountID % len(profiles) 自动选择
|
||||
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
|
||||
// - 如果 enableTLSFingerprint=false,行为与 Do 方法相同
|
||||
//
|
||||
// 注意:
|
||||
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
|
||||
// - TLS 指纹客户端与普通客户端使用不同的缓存键,互不影响
|
||||
DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error)
|
||||
}
|
||||
|
||||
@@ -8,9 +8,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -49,6 +51,13 @@ type Fingerprint struct {
|
||||
type IdentityCache interface {
|
||||
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
|
||||
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
|
||||
// GetMaskedSessionID 获取固定的会话ID(用于会话ID伪装功能)
|
||||
// 返回的 sessionID 是一个 UUID 格式的字符串
|
||||
// 如果不存在或已过期(15分钟无请求),返回空字符串
|
||||
GetMaskedSessionID(ctx context.Context, accountID int64) (string, error)
|
||||
// SetMaskedSessionID 设置固定的会话ID,TTL 为 15 分钟
|
||||
// 每次调用都会刷新 TTL
|
||||
SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error
|
||||
}
|
||||
|
||||
// IdentityService 管理OAuth账号的请求身份指纹
|
||||
@@ -203,6 +212,94 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
||||
return json.Marshal(reqMap)
|
||||
}
|
||||
|
||||
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
|
||||
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
|
||||
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
|
||||
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
|
||||
// 先执行常规的 RewriteUserID 逻辑
|
||||
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
|
||||
if err != nil {
|
||||
return newBody, err
|
||||
}
|
||||
|
||||
// 检查是否启用会话ID伪装
|
||||
if !account.IsSessionIDMaskingEnabled() {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
// 解析重写后的 body,提取 user_id
|
||||
var reqMap map[string]any
|
||||
if err := json.Unmarshal(newBody, &reqMap); err != nil {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
metadata, ok := reqMap["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
userID, ok := metadata["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
// 查找 _session_ 的位置,替换其后的内容
|
||||
const sessionMarker = "_session_"
|
||||
idx := strings.LastIndex(userID, sessionMarker)
|
||||
if idx == -1 {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
// 获取或生成固定的伪装 session ID
|
||||
maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to get masked session ID for account %d: %v", account.ID, err)
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
if maskedSessionID == "" {
|
||||
// 首次或已过期,生成新的伪装 session ID
|
||||
maskedSessionID = generateRandomUUID()
|
||||
log.Printf("Generated new masked session ID for account %d: %s", account.ID, maskedSessionID)
|
||||
}
|
||||
|
||||
// 刷新 TTL(每次请求都刷新,保持 15 分钟有效期)
|
||||
if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil {
|
||||
log.Printf("Warning: failed to set masked session ID for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
|
||||
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID
|
||||
|
||||
slog.Debug("session_id_masking_applied",
|
||||
"account_id", account.ID,
|
||||
"before", userID,
|
||||
"after", newUserID,
|
||||
)
|
||||
|
||||
metadata["user_id"] = newUserID
|
||||
reqMap["metadata"] = metadata
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
}
|
||||
|
||||
// generateRandomUUID 生成随机 UUID v4 格式字符串
|
||||
func generateRandomUUID() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// fallback: 使用时间戳生成
|
||||
h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
|
||||
b = h[:16]
|
||||
}
|
||||
|
||||
// 设置 UUID v4 版本和变体位
|
||||
b[6] = (b[6] & 0x0f) | 0x40
|
||||
b[8] = (b[8] & 0x3f) | 0x80
|
||||
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
||||
}
|
||||
|
||||
// generateClientID 生成64位十六进制客户端ID(32字节随机数)
|
||||
func generateClientID() string {
|
||||
b := make([]byte, 32)
|
||||
|
||||
@@ -73,10 +73,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
return false
|
||||
}
|
||||
|
||||
tempMatched := false
|
||||
// 先尝试临时不可调度规则(401除外)
|
||||
// 如果匹配成功,直接返回,不执行后续禁用逻辑
|
||||
if statusCode != 401 {
|
||||
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||
if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if upstreamMsg != "" {
|
||||
@@ -84,6 +88,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 400:
|
||||
// 只有当错误信息包含 "organization has been disabled" 时才禁用
|
||||
if strings.Contains(strings.ToLower(upstreamMsg), "organization has been disabled") {
|
||||
msg := "Organization disabled (400): " + upstreamMsg
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
}
|
||||
// 其他 400 错误(如参数问题)不处理,不禁用账号
|
||||
case 401:
|
||||
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
|
||||
if account.Type == AccountTypeOAuth {
|
||||
@@ -148,9 +160,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
}
|
||||
|
||||
if tempMatched {
|
||||
return true
|
||||
}
|
||||
return shouldDisable
|
||||
}
|
||||
|
||||
|
||||
@@ -38,8 +38,9 @@ type SessionLimitCache interface {
|
||||
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||
// idleTimeouts: 每个账号的空闲超时时间配置,key 为 accountID;若为 nil 或某账号不在其中,则使用默认超时
|
||||
// 返回 map[accountID]count,查询失败的账号不在 map 中
|
||||
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
|
||||
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error)
|
||||
|
||||
// IsSessionActive 检查特定会话是否活跃(未过期)
|
||||
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
|
||||
|
||||
@@ -166,11 +166,25 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
|
||||
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
||||
newCredentials, err := refresher.Refresh(ctx, account)
|
||||
if err == nil {
|
||||
// 刷新成功,更新账号credentials
|
||||
|
||||
// 如果有新凭证,先更新(即使有错误也要保存 token)
|
||||
if newCredentials != nil {
|
||||
account.Credentials = newCredentials
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", err)
|
||||
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
|
||||
if account.Platform == PlatformAntigravity &&
|
||||
account.Status == StatusError &&
|
||||
strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
|
||||
log.Printf("[TokenRefresh] Failed to clear error status for account %d: %v", account.ID, clearErr)
|
||||
} else {
|
||||
log.Printf("[TokenRefresh] Account %d: cleared missing_project_id error", account.ID)
|
||||
}
|
||||
}
|
||||
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
||||
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
||||
@@ -230,6 +244,7 @@ func isNonRetryableRefreshError(err error) bool {
|
||||
"invalid_client", // 客户端配置错误
|
||||
"unauthorized_client", // 客户端未授权
|
||||
"access_denied", // 访问被拒绝
|
||||
"missing_project_id", // 缺少 project_id
|
||||
}
|
||||
for _, needle := range nonRetryable {
|
||||
if strings.Contains(msg, needle) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
@@ -196,6 +197,8 @@ func ProvideOpsScheduledReportService(
|
||||
|
||||
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
|
||||
func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator {
|
||||
// Start Pub/Sub subscriber for L1 cache invalidation across instances
|
||||
apiKeyService.StartAuthCacheInvalidationSubscriber(context.Background())
|
||||
return apiKeyService
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user