Merge pull request #325 from slovx2/main

fix(antigravity): 修复Antigravity 频繁429的问题,以及一系列优化,配置增强
This commit is contained in:
Wesley Liddick
2026-01-19 09:17:15 +08:00
committed by GitHub
28 changed files with 1117 additions and 585 deletions

View File

@@ -257,6 +257,14 @@ type GatewayConfig struct {
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义 // 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"` FailoverOn400 bool `mapstructure:"failover_on_400"`
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
MaxAccountSwitches int `mapstructure:"max_account_switches"`
// Gemini 账户切换最大次数Gemini 平台单独配置,因 API 限制更严格)
MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"`
// Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用
AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"`
// Scheduling: 账号调度相关配置 // Scheduling: 账号调度相关配置
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
@@ -298,6 +306,9 @@ type GatewaySchedulingConfig struct {
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"` FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"` FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
// 兜底层账户选择策略: "last_used"(按最后使用时间排序,默认) 或 "random"(随机)
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
// 负载计算 // 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
@@ -786,6 +797,9 @@ func setDefaults() {
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false) viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_account_switches", 10)
viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化) // HTTP 上游连接池配置(针对 5000+ 并发用户优化)
@@ -798,11 +812,12 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180) viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10) viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 10*1024*1024) viper.SetDefault("gateway.max_line_size", 40*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
viper.SetDefault("gateway.scheduling.load_batch_enabled", true) viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true) viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)

View File

@@ -541,6 +541,36 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
newCredentials[k] = v newCredentials[k] = v
} }
} }
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
if tokenInfo.ProjectIDMissing {
// 先更新凭证
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if updateErr != nil {
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
return
}
// 标记账户为 error
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id可能无法使用Antigravity"); setErr != nil {
response.InternalError(c, "Failed to set account error: "+setErr.Error())
return
}
response.Success(c, gin.H{
"message": "Token refreshed but project_id is missing, account marked as error",
"warning": "missing_project_id",
})
return
}
// 成功获取到 project_id如果之前是 missing_project_id 错误则清除
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
return
}
}
} else { } else {
// Use Anthropic/Claude OAuth service to refresh token // Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account) tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)

View File

@@ -31,6 +31,8 @@ type GatewayHandler struct {
userService *service.UserService userService *service.UserService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
} }
// NewGatewayHandler creates a new GatewayHandler // NewGatewayHandler creates a new GatewayHandler
@@ -44,8 +46,16 @@ func NewGatewayHandler(
cfg *config.Config, cfg *config.Config,
) *GatewayHandler { ) *GatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
maxAccountSwitches := 10
maxAccountSwitchesGemini := 3
if cfg != nil { if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
if cfg.Gateway.MaxAccountSwitchesGemini > 0 {
maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini
}
} }
return &GatewayHandler{ return &GatewayHandler{
gatewayService: gatewayService, gatewayService: gatewayService,
@@ -54,6 +64,8 @@ func NewGatewayHandler(
userService: userService, userService: userService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
} }
} }
@@ -179,7 +191,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
if platform == service.PlatformGemini { if platform == service.PlatformGemini {
const maxAccountSwitches = 3 maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0
@@ -313,7 +325,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
} }
const maxAccountSwitches = 10 maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0

View File

@@ -220,7 +220,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if sessionHash != "" { if sessionHash != "" {
sessionKey = "gemini:" + sessionHash sessionKey = "gemini:" + sessionHash
} }
const maxAccountSwitches = 3 maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0

View File

@@ -25,6 +25,7 @@ type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
} }
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -35,13 +36,18 @@ func NewOpenAIGatewayHandler(
cfg *config.Config, cfg *config.Config,
) *OpenAIGatewayHandler { ) *OpenAIGatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
maxAccountSwitches := 3
if cfg != nil { if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
} }
return &OpenAIGatewayHandler{ return &OpenAIGatewayHandler{
gatewayService: gatewayService, gatewayService: gatewayService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
} }
} }
@@ -189,7 +195,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Generate session hash (header first; fallback to prompt_cache_key) // Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody) sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
const maxAccountSwitches = 3 maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0

View File

@@ -16,15 +16,6 @@ import (
"time" "time"
) )
// resolveHost 从 URL 解析 host
func resolveHost(urlStr string) string {
parsed, err := url.Parse(urlStr)
if err != nil {
return ""
}
return parsed.Host
}
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求v1internal 端点) // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求v1internal 端点)
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) { func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
// 构建 URL流式请求添加 ?alt=sse 参数 // 构建 URL流式请求添加 ?alt=sse 参数
@@ -39,23 +30,11 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri
return nil, err return nil, err
} }
// 基础 Headers // 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", UserAgent) req.Header.Set("User-Agent", UserAgent)
// Accept Header 根据请求类型设置
if isStream {
req.Header.Set("Accept", "text/event-stream")
} else {
req.Header.Set("Accept", "application/json")
}
// 显式设置 Host Header
if host := resolveHost(apiURL); host != "" {
req.Host = host
}
return req, nil return req, nil
} }
@@ -195,12 +174,15 @@ func isConnectionError(err error) bool {
} }
// shouldFallbackToNextURL 判断是否应切换到下一个 URL // shouldFallbackToNextURL 判断是否应切换到下一个 URL
// 仅连接错误和 HTTP 429 触发 URL 降级 // 与 Antigravity-Manager 保持一致连接错误、429、408、404、5xx 触发 URL 降级
func shouldFallbackToNextURL(err error, statusCode int) bool { func shouldFallbackToNextURL(err error, statusCode int) bool {
if isConnectionError(err) { if isConnectionError(err) {
return true return true
} }
return statusCode == http.StatusTooManyRequests return statusCode == http.StatusTooManyRequests ||
statusCode == http.StatusRequestTimeout ||
statusCode == http.StatusNotFound ||
statusCode >= 500
} }
// ExchangeCode 用 authorization code 交换 token // ExchangeCode 用 authorization code 交换 token
@@ -321,11 +303,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
return nil, nil, fmt.Errorf("序列化请求失败: %w", err) return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
} }
// 获取可用的 URL 列表 // 固定顺序prod -> daily
availableURLs := DefaultURLAvailability.GetAvailableURLs() availableURLs := BaseURLs
if len(availableURLs) == 0 {
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
}
var lastErr error var lastErr error
for urlIdx, baseURL := range availableURLs { for urlIdx, baseURL := range availableURLs {
@@ -343,7 +322,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
if err != nil { if err != nil {
lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err) lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue continue
} }
@@ -358,7 +336,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
// 检查是否需要 URL 降级 // 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue continue
} }
@@ -376,6 +353,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
var rawResp map[string]any var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp) _ = json.Unmarshal(respBodyBytes, &rawResp)
// 标记成功的 URL下次优先使用
DefaultURLAvailability.MarkSuccess(baseURL)
return &loadResp, rawResp, nil return &loadResp, rawResp, nil
} }
@@ -412,11 +391,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return nil, nil, fmt.Errorf("序列化请求失败: %w", err) return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
} }
// 获取可用的 URL 列表 // 固定顺序prod -> daily
availableURLs := DefaultURLAvailability.GetAvailableURLs() availableURLs := BaseURLs
if len(availableURLs) == 0 {
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
}
var lastErr error var lastErr error
for urlIdx, baseURL := range availableURLs { for urlIdx, baseURL := range availableURLs {
@@ -434,7 +410,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
if err != nil { if err != nil {
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err) lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue continue
} }
@@ -449,7 +424,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
// 检查是否需要 URL 降级 // 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue continue
} }
@@ -467,6 +441,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
var rawResp map[string]any var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp) _ = json.Unmarshal(respBodyBytes, &rawResp)
// 标记成功的 URL下次优先使用
DefaultURLAvailability.MarkSuccess(baseURL)
return &modelsResp, rawResp, nil return &modelsResp, rawResp, nil
} }

View File

@@ -143,9 +143,10 @@ type GeminiResponse struct {
// GeminiCandidate Gemini 候选响应 // GeminiCandidate Gemini 候选响应
type GeminiCandidate struct { type GeminiCandidate struct {
Content *GeminiContent `json:"content,omitempty"` Content *GeminiContent `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"` FinishReason string `json:"finishReason,omitempty"`
Index int `json:"index,omitempty"` Index int `json:"index,omitempty"`
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
} }
// GeminiUsageMetadata Gemini 用量元数据 // GeminiUsageMetadata Gemini 用量元数据
@@ -156,6 +157,23 @@ type GeminiUsageMetadata struct {
TotalTokenCount int `json:"totalTokenCount,omitempty"` TotalTokenCount int `json:"totalTokenCount,omitempty"`
} }
// GeminiGroundingMetadata Gemini grounding 元数据Web Search
type GeminiGroundingMetadata struct {
WebSearchQueries []string `json:"webSearchQueries,omitempty"`
GroundingChunks []GeminiGroundingChunk `json:"groundingChunks,omitempty"`
}
// GeminiGroundingChunk Gemini grounding chunk
type GeminiGroundingChunk struct {
Web *GeminiGroundingWeb `json:"web,omitempty"`
}
// GeminiGroundingWeb Gemini grounding web 信息
type GeminiGroundingWeb struct {
Title string `json:"title,omitempty"`
URI string `json:"uri,omitempty"`
}
// DefaultSafetySettings 默认安全设置(关闭所有过滤) // DefaultSafetySettings 默认安全设置(关闭所有过滤)
var DefaultSafetySettings = []GeminiSafetySetting{ var DefaultSafetySettings = []GeminiSafetySetting{
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"}, {Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},

View File

@@ -32,8 +32,8 @@ const (
"https://www.googleapis.com/auth/cclog " + "https://www.googleapis.com/auth/cclog " +
"https://www.googleapis.com/auth/experimentsandconfigs" "https://www.googleapis.com/auth/experimentsandconfigs"
// User-Agent模拟官方客户端 // User-Agent与 Antigravity-Manager 保持一致
UserAgent = "antigravity/1.104.0 darwin/arm64" UserAgent = "antigravity/1.11.9 windows/amd64"
// Session 过期时间 // Session 过期时间
SessionTTL = 30 * time.Minute SessionTTL = 30 * time.Minute
@@ -42,22 +42,21 @@ const (
URLAvailabilityTTL = 5 * time.Minute URLAvailabilityTTL = 5 * time.Minute
) )
// BaseURLs 定义 Antigravity API 端点,按优先级排序 // BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
// fallback 顺序: sandbox → daily → prod
var BaseURLs = []string{ var BaseURLs = []string{
"https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox "https://cloudcode-pa.googleapis.com", // prod (优先)
"https://daily-cloudcode-pa.googleapis.com", // daily "https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用)
"https://cloudcode-pa.googleapis.com", // prod
} }
// BaseURL 默认 URL保持向后兼容 // BaseURL 默认 URL保持向后兼容
var BaseURL = BaseURLs[0] var BaseURL = BaseURLs[0]
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复) // URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级
type URLAvailability struct { type URLAvailability struct {
mu sync.RWMutex mu sync.RWMutex
unavailable map[string]time.Time // URL -> 恢复时间 unavailable map[string]time.Time // URL -> 恢复时间
ttl time.Duration ttl time.Duration
lastSuccess string // 最近成功请求的 URL优先使用
} }
// DefaultURLAvailability 全局 URL 可用性管理器 // DefaultURLAvailability 全局 URL 可用性管理器
@@ -78,6 +77,15 @@ func (u *URLAvailability) MarkUnavailable(url string) {
u.unavailable[url] = time.Now().Add(u.ttl) u.unavailable[url] = time.Now().Add(u.ttl)
} }
// MarkSuccess 标记 URL 请求成功,将其设为优先使用
func (u *URLAvailability) MarkSuccess(url string) {
u.mu.Lock()
defer u.mu.Unlock()
u.lastSuccess = url
// 成功后清除该 URL 的不可用标记
delete(u.unavailable, url)
}
// IsAvailable 检查 URL 是否可用 // IsAvailable 检查 URL 是否可用
func (u *URLAvailability) IsAvailable(url string) bool { func (u *URLAvailability) IsAvailable(url string) bool {
u.mu.RLock() u.mu.RLock()
@@ -89,14 +97,29 @@ func (u *URLAvailability) IsAvailable(url string) bool {
return time.Now().After(expiry) return time.Now().After(expiry)
} }
// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序) // GetAvailableURLs 返回可用的 URL 列表
// 最近成功的 URL 优先,其他按默认顺序
func (u *URLAvailability) GetAvailableURLs() []string { func (u *URLAvailability) GetAvailableURLs() []string {
u.mu.RLock() u.mu.RLock()
defer u.mu.RUnlock() defer u.mu.RUnlock()
now := time.Now() now := time.Now()
result := make([]string, 0, len(BaseURLs)) result := make([]string, 0, len(BaseURLs))
// 如果有最近成功的 URL 且可用,放在最前面
if u.lastSuccess != "" {
expiry, exists := u.unavailable[u.lastSuccess]
if !exists || now.After(expiry) {
result = append(result, u.lastSuccess)
}
}
// 添加其他可用的 URL按默认顺序
for _, url := range BaseURLs { for _, url := range BaseURLs {
// 跳过已添加的 lastSuccess
if url == u.lastSuccess {
continue
}
expiry, exists := u.unavailable[url] expiry, exists := u.unavailable[url]
if !exists || now.After(expiry) { if !exists || now.After(expiry) {
result = append(result, url) result = append(result, url)
@@ -240,24 +263,3 @@ func BuildAuthorizationURL(state, codeChallenge string) string {
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()) return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
} }
// GenerateMockProjectID 生成随机 project_id当 API 不返回时使用)
// 格式:{形容词}-{名词}-{5位随机字符}
func GenerateMockProjectID() string {
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
randBytes, _ := GenerateRandomBytes(7)
adj := adjectives[int(randBytes[0])%len(adjectives)]
noun := nouns[int(randBytes[1])%len(nouns)]
// 生成 5 位随机字符a-z0-9
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
suffix := make([]byte, 5)
for i := 0; i < 5; i++ {
suffix[i] = charset[int(randBytes[i+2])%len(charset)]
}
return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix))
}

View File

@@ -54,6 +54,9 @@ func DefaultTransformOptions() TransformOptions {
} }
} }
// webSearchFallbackModel web_search 请求使用的降级模型
const webSearchFallbackModel = "gemini-2.5-flash"
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式 // TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) { func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions()) return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions())
@@ -64,12 +67,23 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
// 用于存储 tool_use id -> name 映射 // 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string) toolIDToName := make(map[string]string)
// 检测是否有 web_search 工具
hasWebSearchTool := hasWebSearchTool(claudeReq.Tools)
requestType := "agent"
targetModel := mappedModel
if hasWebSearchTool {
requestType = "web_search"
if targetModel != webSearchFallbackModel {
targetModel = webSearchFallbackModel
}
}
// 检测是否启用 thinking // 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround // 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") allowDummyThought := strings.HasPrefix(targetModel, "gemini-")
// 1. 构建 contents // 1. 构建 contents
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
@@ -78,7 +92,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
} }
// 2. 构建 systemInstruction // 2. 构建 systemInstruction
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts) systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts, claudeReq.Tools)
// 3. 构建 generationConfig // 3. 构建 generationConfig
reqForConfig := claudeReq reqForConfig := claudeReq
@@ -89,6 +103,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
reqCopy.Thinking = nil reqCopy.Thinking = nil
reqForConfig = &reqCopy reqForConfig = &reqCopy
} }
if targetModel != "" && targetModel != reqForConfig.Model {
reqCopy := *reqForConfig
reqCopy.Model = targetModel
reqForConfig = &reqCopy
}
generationConfig := buildGenerationConfig(reqForConfig) generationConfig := buildGenerationConfig(reqForConfig)
// 4. 构建 tools // 4. 构建 tools
@@ -127,8 +146,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
Project: projectID, Project: projectID,
RequestID: "agent-" + uuid.New().String(), RequestID: "agent-" + uuid.New().String(),
UserAgent: "antigravity", // 固定值,与官方客户端一致 UserAgent: "antigravity", // 固定值,与官方客户端一致
RequestType: "agent", RequestType: requestType,
Model: mappedModel, Model: targetModel,
Request: innerRequest, Request: innerRequest,
} }
@@ -154,8 +173,40 @@ func GetDefaultIdentityPatch() string {
return antigravityIdentity return antigravityIdentity
} }
// buildSystemInstruction 构建 systemInstruction // mcpXMLProtocol MCP XML 工具调用协议(与 Antigravity-Manager 保持一致)
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent { const mcpXMLProtocol = `
==== MCP XML 工具调用协议 (Workaround) ====
当你需要调用名称以 ` + "`mcp__`" + ` 开头的 MCP 工具时:
1) 优先尝试 XML 格式调用:输出 ` + "`<mcp__tool_name>{\"arg\":\"value\"}</mcp__tool_name>`" + `
2) 必须直接输出 XML 块,无需 markdown 包装,内容为 JSON 格式的入参。
3) 这种方式具有更高的连通性和容错性,适用于大型结果返回场景。
===========================================`
// hasMCPTools 检测是否有 mcp__ 前缀的工具
func hasMCPTools(tools []ClaudeTool) bool {
for _, tool := range tools {
if strings.HasPrefix(tool.Name, "mcp__") {
return true
}
}
return false
}
// filterOpenCodePrompt 过滤 OpenCode 默认提示词,只保留用户自定义指令
func filterOpenCodePrompt(text string) string {
if !strings.Contains(text, "You are an interactive CLI tool") {
return text
}
// 提取 "Instructions from:" 及之后的部分
if idx := strings.Index(text, "Instructions from:"); idx >= 0 {
return text[idx:]
}
// 如果没有自定义指令,返回空
return ""
}
// buildSystemInstruction 构建 systemInstruction与 Antigravity-Manager 保持一致)
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
var parts []GeminiPart var parts []GeminiPart
// 先解析用户的 system prompt检测是否已包含 Antigravity identity // 先解析用户的 system prompt检测是否已包含 Antigravity identity
@@ -167,10 +218,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
var sysStr string var sysStr string
if err := json.Unmarshal(system, &sysStr); err == nil { if err := json.Unmarshal(system, &sysStr); err == nil {
if strings.TrimSpace(sysStr) != "" { if strings.TrimSpace(sysStr) != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: sysStr})
if strings.Contains(sysStr, "You are Antigravity") { if strings.Contains(sysStr, "You are Antigravity") {
userHasAntigravityIdentity = true userHasAntigravityIdentity = true
} }
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(sysStr)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}
} }
} else { } else {
// 尝试解析为数组 // 尝试解析为数组
@@ -178,10 +233,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if err := json.Unmarshal(system, &sysBlocks); err == nil { if err := json.Unmarshal(system, &sysBlocks); err == nil {
for _, block := range sysBlocks { for _, block := range sysBlocks {
if block.Type == "text" && strings.TrimSpace(block.Text) != "" { if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: block.Text})
if strings.Contains(block.Text, "You are Antigravity") { if strings.Contains(block.Text, "You are Antigravity") {
userHasAntigravityIdentity = true userHasAntigravityIdentity = true
} }
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(block.Text)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}
} }
} }
} }
@@ -200,6 +259,16 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
// 添加用户的 system prompt // 添加用户的 system prompt
parts = append(parts, userSystemParts...) parts = append(parts, userSystemParts...)
// 检测是否有 MCP 工具,如有则注入 XML 调用协议
if hasMCPTools(tools) {
parts = append(parts, GeminiPart{Text: mcpXMLProtocol})
}
// 如果用户没有提供 Antigravity 身份,添加结束标记
if !userHasAntigravityIdentity {
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
}
if len(parts) == 0 { if len(parts) == 0 {
return nil return nil
} }
@@ -429,6 +498,11 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
StopSequences: DefaultStopSequences, StopSequences: DefaultStopSequences,
} }
// 如果请求中指定了 MaxTokens使用请求值
if req.MaxTokens > 0 {
config.MaxOutputTokens = req.MaxTokens
}
// Thinking 配置 // Thinking 配置
if req.Thinking != nil && req.Thinking.Type == "enabled" { if req.Thinking != nil && req.Thinking.Type == "enabled" {
config.ThinkingConfig = &GeminiThinkingConfig{ config.ThinkingConfig = &GeminiThinkingConfig{
@@ -458,37 +532,43 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
return config return config
} }
func hasWebSearchTool(tools []ClaudeTool) bool {
for _, tool := range tools {
if isWebSearchTool(tool) {
return true
}
}
return false
}
func isWebSearchTool(tool ClaudeTool) bool {
if strings.HasPrefix(tool.Type, "web_search") || tool.Type == "google_search" {
return true
}
name := strings.TrimSpace(tool.Name)
switch name {
case "web_search", "google_search", "web_search_20250305":
return true
default:
return false
}
}
// buildTools 构建 tools // buildTools 构建 tools
func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
if len(tools) == 0 { if len(tools) == 0 {
return nil return nil
} }
// 检查是否有 web_search 工具 hasWebSearch := hasWebSearchTool(tools)
hasWebSearch := false
for _, tool := range tools {
if tool.Name == "web_search" {
hasWebSearch = true
break
}
}
if hasWebSearch {
// Web Search 工具映射
return []GeminiToolDeclaration{{
GoogleSearch: &GeminiGoogleSearch{
EnhancedContent: &GeminiEnhancedContent{
ImageSearch: &GeminiImageSearch{
MaxResultCount: 5,
},
},
},
}}
}
// 普通工具 // 普通工具
var funcDecls []GeminiFunctionDecl var funcDecls []GeminiFunctionDecl
for _, tool := range tools { for _, tool := range tools {
if isWebSearchTool(tool) {
continue
}
// 跳过无效工具名称 // 跳过无效工具名称
if strings.TrimSpace(tool.Name) == "" { if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name") log.Printf("Warning: skipping tool with empty name")
@@ -531,7 +611,20 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
} }
if len(funcDecls) == 0 { if len(funcDecls) == 0 {
return nil if !hasWebSearch {
return nil
}
// Web Search 工具映射
return []GeminiToolDeclaration{{
GoogleSearch: &GeminiGoogleSearch{
EnhancedContent: &GeminiEnhancedContent{
ImageSearch: &GeminiImageSearch{
MaxResultCount: 5,
},
},
},
}}
} }
return []GeminiToolDeclaration{{ return []GeminiToolDeclaration{{

View File

@@ -3,6 +3,7 @@ package antigravity
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
) )
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式) // TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
@@ -63,6 +64,12 @@ func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID,
p.processPart(&part) p.processPart(&part)
} }
if len(geminiResp.Candidates) > 0 {
if grounding := geminiResp.Candidates[0].GroundingMetadata; grounding != nil {
p.processGrounding(grounding)
}
}
// 刷新剩余内容 // 刷新剩余内容
p.flushThinking() p.flushThinking()
p.flushText() p.flushText()
@@ -190,6 +197,18 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
} }
} }
func (p *NonStreamingProcessor) processGrounding(grounding *GeminiGroundingMetadata) {
groundingText := buildGroundingText(grounding)
if groundingText == "" {
return
}
p.flushThinking()
p.flushText()
p.textBuilder += groundingText
p.flushText()
}
// flushText 刷新 text builder // flushText 刷新 text builder
func (p *NonStreamingProcessor) flushText() { func (p *NonStreamingProcessor) flushText() {
if p.textBuilder == "" { if p.textBuilder == "" {
@@ -262,6 +281,44 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
} }
} }
func buildGroundingText(grounding *GeminiGroundingMetadata) string {
if grounding == nil {
return ""
}
var builder strings.Builder
if len(grounding.WebSearchQueries) > 0 {
_, _ = builder.WriteString("\n\n---\nWeb search queries: ")
_, _ = builder.WriteString(strings.Join(grounding.WebSearchQueries, ", "))
}
if len(grounding.GroundingChunks) > 0 {
var links []string
for i, chunk := range grounding.GroundingChunks {
if chunk.Web == nil {
continue
}
title := strings.TrimSpace(chunk.Web.Title)
if title == "" {
title = "Source"
}
uri := strings.TrimSpace(chunk.Web.URI)
if uri == "" {
uri = "#"
}
links = append(links, fmt.Sprintf("[%d] [%s](%s)", i+1, title, uri))
}
if len(links) > 0 {
_, _ = builder.WriteString("\n\nSources:\n")
_, _ = builder.WriteString(strings.Join(links, "\n"))
}
}
return builder.String()
}
// generateRandomID 生成随机 ID // generateRandomID 生成随机 ID
func generateRandomID() string { func generateRandomID() string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"

View File

@@ -27,6 +27,8 @@ type StreamingProcessor struct {
pendingSignature string pendingSignature string
trailingSignature string trailingSignature string
originalModel string originalModel string
webSearchQueries []string
groundingChunks []GeminiGroundingChunk
// 累计 usage // 累计 usage
inputTokens int inputTokens int
@@ -93,6 +95,10 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
} }
} }
if len(geminiResp.Candidates) > 0 {
p.captureGrounding(geminiResp.Candidates[0].GroundingMetadata)
}
// 检查是否结束 // 检查是否结束
if len(geminiResp.Candidates) > 0 { if len(geminiResp.Candidates) > 0 {
finishReason := geminiResp.Candidates[0].FinishReason finishReason := geminiResp.Candidates[0].FinishReason
@@ -200,6 +206,20 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
return result.Bytes() return result.Bytes()
} }
func (p *StreamingProcessor) captureGrounding(grounding *GeminiGroundingMetadata) {
if grounding == nil {
return
}
if len(grounding.WebSearchQueries) > 0 && len(p.webSearchQueries) == 0 {
p.webSearchQueries = append([]string(nil), grounding.WebSearchQueries...)
}
if len(grounding.GroundingChunks) > 0 && len(p.groundingChunks) == 0 {
p.groundingChunks = append([]GeminiGroundingChunk(nil), grounding.GroundingChunks...)
}
}
// processThinking 处理 thinking // processThinking 处理 thinking
func (p *StreamingProcessor) processThinking(text, signature string) []byte { func (p *StreamingProcessor) processThinking(text, signature string) []byte {
var result bytes.Buffer var result bytes.Buffer
@@ -417,6 +437,23 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
p.trailingSignature = "" p.trailingSignature = ""
} }
if len(p.webSearchQueries) > 0 || len(p.groundingChunks) > 0 {
groundingText := buildGroundingText(&GeminiGroundingMetadata{
WebSearchQueries: p.webSearchQueries,
GroundingChunks: p.groundingChunks,
})
if groundingText != "" {
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
"type": "text",
"text": "",
}))
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
"text": groundingText,
}))
_, _ = result.Write(p.endBlock())
}
}
// 确定 stop_reason // 确定 stop_reason
stopReason := "end_turn" stopReason := "end_turn"
if p.usedTool { if p.usedTool {

View File

@@ -543,6 +543,15 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
return nil return nil
} }
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetStatus(service.StatusActive).
SetErrorMessage("").
Save(ctx)
return err
}
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
_, err := r.client.AccountGroup.Create(). _, err := r.client.AccountGroup.Create().
SetAccountID(accountID). SetAccountID(accountID).

View File

@@ -744,6 +744,10 @@ func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg strin
return errors.New("not implemented") return errors.New("not implemented")
} }
func (s *stubAccountRepo) ClearError(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return errors.New("not implemented") return errors.New("not implemented")
} }

View File

@@ -37,6 +37,7 @@ type AccountRepository interface {
UpdateLastUsed(ctx context.Context, id int64) error UpdateLastUsed(ctx context.Context, id int64) error
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
SetError(ctx context.Context, id int64, errorMsg string) error SetError(ctx context.Context, id int64, errorMsg string) error
ClearError(ctx context.Context, id int64) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error SetSchedulable(ctx context.Context, id int64, schedulable bool) error
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error

View File

@@ -99,6 +99,10 @@ func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg strin
panic("unexpected SetError call") panic("unexpected SetError call")
} }
func (s *accountRepoStub) ClearError(ctx context.Context, id int64) error {
panic("unexpected ClearError call")
}
func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
panic("unexpected SetSchedulable call") panic("unexpected SetSchedulable call")
} }

View File

@@ -42,6 +42,7 @@ type AdminService interface {
DeleteAccount(ctx context.Context, id int64) error DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error) ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountError(ctx context.Context, id int64, errorMsg string) error
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
@@ -1101,6 +1102,10 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
return account, nil return account, nil
} }
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
return s.accountRepo.SetError(ctx, id, errorMsg)
}
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) { func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err return nil, err

File diff suppressed because it is too large Load Diff

View File

@@ -82,13 +82,14 @@ type AntigravityExchangeCodeInput struct {
// AntigravityTokenInfo token 信息 // AntigravityTokenInfo token 信息
type AntigravityTokenInfo struct { type AntigravityTokenInfo struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"` ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
ProjectID string `json:"project_id,omitempty"` ProjectID string `json:"project_id,omitempty"`
ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id
} }
// ExchangeCode 用 authorization code 交换 token // ExchangeCode 用 authorization code 交换 token
@@ -149,12 +150,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.ProjectID = loadResp.CloudAICompanionProject result.ProjectID = loadResp.CloudAICompanionProject
} }
// 兜底:随机生成 project_id
if result.ProjectID == "" {
result.ProjectID = antigravity.GenerateMockProjectID()
fmt.Printf("[AntigravityOAuth] 使用随机生成的 project_id: %s\n", result.ProjectID)
}
return result, nil return result, nil
} }
@@ -236,16 +231,24 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
return nil, err return nil, err
} }
// 保留原有的 project_id 和 email // 保留原有的 email
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
if existingProjectID != "" {
tokenInfo.ProjectID = existingProjectID
}
existingEmail := strings.TrimSpace(account.GetCredential("email")) existingEmail := strings.TrimSpace(account.GetCredential("email"))
if existingEmail != "" { if existingEmail != "" {
tokenInfo.Email = existingEmail tokenInfo.Email = existingEmail
} }
// 每次刷新都调用 LoadCodeAssist 获取 project_id
client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
// LoadCodeAssist 失败或返回空,保留原有 project_id标记缺失
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
tokenInfo.ProjectID = existingProjectID
tokenInfo.ProjectIDMissing = true
} else {
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
}
return tokenInfo, nil return tokenInfo, nil
} }

View File

@@ -31,11 +31,6 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id") projectID := account.GetCredential("project_id")
// 如果没有 project_id生成一个随机的
if projectID == "" {
projectID = antigravity.GenerateMockProjectID()
}
client := antigravity.NewClient(proxyURL) client := antigravity.NewClient(proxyURL)
// 调用 API 获取配额 // 调用 API 获取配额

View File

@@ -0,0 +1,186 @@
//go:build unit
package service
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
type stubAntigravityUpstream struct {
firstBase string
secondBase string
calls []string
}
func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
url := req.URL.String()
s.calls = append(s.calls, url)
if strings.HasPrefix(url, s.firstBase) {
return &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Resource has been exhausted"}}`)),
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader("ok")),
}, nil
}
type scopeLimitCall struct {
accountID int64
scope AntigravityQuotaScope
resetAt time.Time
}
type rateLimitCall struct {
accountID int64
resetAt time.Time
}
type stubAntigravityAccountRepo struct {
AccountRepository
scopeCalls []scopeLimitCall
rateCalls []rateLimitCall
}
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
return nil
}
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
return nil
}
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvailability
}()
base1 := "https://ag-1.test"
base2 := "https://ag-2.test"
antigravity.BaseURLs = []string{base1, base2}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
upstream := &stubAntigravityUpstream{firstBase: base1, secondBase: base2}
account := &Account{
ID: 1,
Name: "acc-1",
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
var handleErrorCalled bool
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
prefix: "[test]",
ctx: context.Background(),
account: account,
proxyURL: "",
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
handleErrorCalled = true
},
})
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.False(t, handleErrorCalled)
require.Len(t, upstream.calls, 2)
require.True(t, strings.HasPrefix(upstream.calls[0], base1))
require.True(t, strings.HasPrefix(upstream.calls[1], base2))
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
require.NotEmpty(t, available)
require.Equal(t, base2, available[0])
}
func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) {
t.Setenv(antigravityScopeRateLimitEnv, "true")
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("3s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
require.Len(t, repo.scopeCalls, 1)
require.Empty(t, repo.rateCalls)
call := repo.scopeCalls[0]
require.Equal(t, account.ID, call.accountID)
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
}
func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) {
t.Setenv(antigravityScopeRateLimitEnv, "false")
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("2s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
require.Len(t, repo.rateCalls, 1)
require.Empty(t, repo.scopeCalls)
call := repo.rateCalls[0]
require.Equal(t, account.ID, call.accountID)
require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second)
}
func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
now := time.Now()
future := now.Add(10 * time.Minute)
account := &Account{
ID: 1,
Name: "acc",
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
}
account.RateLimitResetAt = &future
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
account.RateLimitResetAt = nil
account.Extra = map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future.Format(time.RFC3339),
},
},
}
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
}
func buildGeminiRateLimitBody(delay string) []byte {
return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay))
}

View File

@@ -61,5 +61,10 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
} }
} }
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
if tokenInfo.ProjectIDMissing {
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id可能无法使用Antigravity")
}
return newCredentials, nil return newCredentials, nil
} }

View File

@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up
func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error { func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil return nil
} }
func (m *mockAccountRepoForPlatform) ClearError(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil return nil
} }

View File

@@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
mathrand "math/rand"
"net/http" "net/http"
"os" "os"
"regexp" "regexp"
@@ -918,7 +919,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
// ============ Layer 3: 兜底排队 ============ // ============ Layer 3: 兜底排队 ============
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth) s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode)
for _, acc := range candidates { for _, acc := range candidates {
// 会话数量限制检查(等待计划也需要占用会话配额) // 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, acc, sessionHash) { if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
@@ -1318,6 +1319,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
}) })
} }
// sortCandidatesForFallback 根据配置选择排序策略
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
if mode == "random" {
// 先按优先级排序,然后在同优先级内随机打乱
sortAccountsByPriorityOnly(accounts, preferOAuth)
shuffleWithinPriority(accounts)
} else {
// 默认按最后使用时间排序
sortAccountsByPriorityAndLastUsed(accounts, preferOAuth)
}
}
// sortAccountsByPriorityOnly 仅按优先级排序
func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
if preferOAuth && a.Type != b.Type {
return a.Type == AccountTypeOAuth
}
return false
})
}
// shuffleWithinPriority 在同优先级内随机打乱顺序
func shuffleWithinPriority(accounts []*Account) {
if len(accounts) <= 1 {
return
}
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
start := 0
for start < len(accounts) {
priority := accounts[start].Priority
end := start + 1
for end < len(accounts) && accounts[end].Priority == priority {
end++
}
// 对 [start, end) 范围内的账户随机打乱
if end-start > 1 {
r.Shuffle(end-start, func(i, j int) {
accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i]
})
}
start = end
}
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离) // selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini preferOAuth := platform == PlatformGemini

View File

@@ -88,6 +88,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error { func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil return nil
} }
func (m *mockAccountRepoForGemini) ClearError(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil return nil
} }

View File

@@ -166,11 +166,25 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ { for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
newCredentials, err := refresher.Refresh(ctx, account) newCredentials, err := refresher.Refresh(ctx, account)
if err == nil {
// 刷新成功更新账号credentials // 如果有新凭证,先更新(即使有错误也要保存 token
if newCredentials != nil {
account.Credentials = newCredentials account.Credentials = newCredentials
if err := s.accountRepo.Update(ctx, account); err != nil { if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
return fmt.Errorf("failed to save credentials: %w", err) return fmt.Errorf("failed to save credentials: %w", saveErr)
}
}
if err == nil {
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error现在成功获取到了清除错误状态
if account.Platform == PlatformAntigravity &&
account.Status == StatusError &&
strings.Contains(account.ErrorMessage, "missing_project_id:") {
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
log.Printf("[TokenRefresh] Failed to clear error status for account %d: %v", account.ID, clearErr)
} else {
log.Printf("[TokenRefresh] Account %d: cleared missing_project_id error", account.ID)
}
} }
// 对所有 OAuth 账号调用缓存失效InvalidateToken 内部根据平台判断是否需要处理) // 对所有 OAuth 账号调用缓存失效InvalidateToken 内部根据平台判断是否需要处理)
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth { if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
@@ -230,6 +244,7 @@ func isNonRetryableRefreshError(err error) bool {
"invalid_client", // 客户端配置错误 "invalid_client", // 客户端配置错误
"unauthorized_client", // 客户端未授权 "unauthorized_client", // 客户端未授权
"access_denied", // 访问被拒绝 "access_denied", // 访问被拒绝
"missing_project_id", // 缺少 project_id
} }
for _, needle := range nonRetryable { for _, needle := range nonRetryable {
if strings.Contains(msg, needle) { if strings.Contains(msg, needle) {

View File

@@ -21,8 +21,20 @@
</div> </div>
</div> </div>
<!-- Right: Language + Subscriptions + Balance + User Dropdown --> <!-- Right: Docs + Language + Subscriptions + Balance + User Dropdown -->
<div class="flex items-center gap-3"> <div class="flex items-center gap-3">
<!-- Docs Link -->
<a
v-if="docUrl"
:href="docUrl"
target="_blank"
rel="noopener noreferrer"
class="flex items-center gap-1.5 rounded-lg px-2.5 py-1.5 text-sm font-medium text-gray-600 transition-colors hover:bg-gray-100 hover:text-gray-900 dark:text-dark-400 dark:hover:bg-dark-800 dark:hover:text-white"
>
<Icon name="book" size="sm" />
<span class="hidden sm:inline">{{ t('nav.docs') }}</span>
</a>
<!-- Language Switcher --> <!-- Language Switcher -->
<LocaleSwitcher /> <LocaleSwitcher />
@@ -211,6 +223,7 @@ const user = computed(() => authStore.user)
const dropdownOpen = ref(false) const dropdownOpen = ref(false)
const dropdownRef = ref<HTMLElement | null>(null) const dropdownRef = ref<HTMLElement | null>(null)
const contactInfo = computed(() => appStore.contactInfo) const contactInfo = computed(() => appStore.contactInfo)
const docUrl = computed(() => appStore.docUrl)
// 只在标准模式的管理员下显示新手引导按钮 // 只在标准模式的管理员下显示新手引导按钮
const showOnboardingButton = computed(() => { const showOnboardingButton = computed(() => {

View File

@@ -196,7 +196,8 @@ export default {
expand: 'Expand', expand: 'Expand',
logout: 'Logout', logout: 'Logout',
github: 'GitHub', github: 'GitHub',
mySubscriptions: 'My Subscriptions' mySubscriptions: 'My Subscriptions',
docs: 'Docs'
}, },
// Auth // Auth

View File

@@ -193,7 +193,8 @@ export default {
expand: '展开', expand: '展开',
logout: '退出登录', logout: '退出登录',
github: 'GitHub', github: 'GitHub',
mySubscriptions: '我的订阅' mySubscriptions: '我的订阅',
docs: '文档'
}, },
// Auth // Auth