Merge pull request #325 from slovx2/main
fix(antigravity): 修复Antigravity 频繁429的问题,以及一系列优化,配置增强
This commit is contained in:
@@ -257,6 +257,14 @@ type GatewayConfig struct {
|
|||||||
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
|
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
|
||||||
FailoverOn400 bool `mapstructure:"failover_on_400"`
|
FailoverOn400 bool `mapstructure:"failover_on_400"`
|
||||||
|
|
||||||
|
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
|
||||||
|
MaxAccountSwitches int `mapstructure:"max_account_switches"`
|
||||||
|
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
|
||||||
|
MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"`
|
||||||
|
|
||||||
|
// Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用
|
||||||
|
AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"`
|
||||||
|
|
||||||
// Scheduling: 账号调度相关配置
|
// Scheduling: 账号调度相关配置
|
||||||
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
|
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
|
||||||
|
|
||||||
@@ -298,6 +306,9 @@ type GatewaySchedulingConfig struct {
|
|||||||
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
|
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
|
||||||
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
|
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
|
||||||
|
|
||||||
|
// 兜底层账户选择策略: "last_used"(按最后使用时间排序,默认) 或 "random"(随机)
|
||||||
|
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
|
||||||
|
|
||||||
// 负载计算
|
// 负载计算
|
||||||
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
|
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
|
||||||
|
|
||||||
@@ -786,6 +797,9 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
|
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
|
||||||
viper.SetDefault("gateway.inject_beta_for_apikey", false)
|
viper.SetDefault("gateway.inject_beta_for_apikey", false)
|
||||||
viper.SetDefault("gateway.failover_on_400", false)
|
viper.SetDefault("gateway.failover_on_400", false)
|
||||||
|
viper.SetDefault("gateway.max_account_switches", 10)
|
||||||
|
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
||||||
|
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||||
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
||||||
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
||||||
@@ -798,11 +812,12 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||||
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
|
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
||||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||||
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
||||||
|
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
|
||||||
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
||||||
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
||||||
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
|
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
|
||||||
|
|||||||
@@ -541,6 +541,36 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
|||||||
newCredentials[k] = v
|
newCredentials[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
|
||||||
|
if tokenInfo.ProjectIDMissing {
|
||||||
|
// 先更新凭证
|
||||||
|
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||||
|
Credentials: newCredentials,
|
||||||
|
})
|
||||||
|
if updateErr != nil {
|
||||||
|
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 标记账户为 error
|
||||||
|
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil {
|
||||||
|
response.InternalError(c, "Failed to set account error: "+setErr.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"message": "Token refreshed but project_id is missing, account marked as error",
|
||||||
|
"warning": "missing_project_id",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
||||||
|
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||||
|
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
|
||||||
|
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Use Anthropic/Claude OAuth service to refresh token
|
// Use Anthropic/Claude OAuth service to refresh token
|
||||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ type GatewayHandler struct {
|
|||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
billingCacheService *service.BillingCacheService
|
billingCacheService *service.BillingCacheService
|
||||||
concurrencyHelper *ConcurrencyHelper
|
concurrencyHelper *ConcurrencyHelper
|
||||||
|
maxAccountSwitches int
|
||||||
|
maxAccountSwitchesGemini int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGatewayHandler creates a new GatewayHandler
|
// NewGatewayHandler creates a new GatewayHandler
|
||||||
@@ -44,8 +46,16 @@ func NewGatewayHandler(
|
|||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) *GatewayHandler {
|
) *GatewayHandler {
|
||||||
pingInterval := time.Duration(0)
|
pingInterval := time.Duration(0)
|
||||||
|
maxAccountSwitches := 10
|
||||||
|
maxAccountSwitchesGemini := 3
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||||
|
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||||
|
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||||
|
}
|
||||||
|
if cfg.Gateway.MaxAccountSwitchesGemini > 0 {
|
||||||
|
maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return &GatewayHandler{
|
return &GatewayHandler{
|
||||||
gatewayService: gatewayService,
|
gatewayService: gatewayService,
|
||||||
@@ -54,6 +64,8 @@ func NewGatewayHandler(
|
|||||||
userService: userService,
|
userService: userService,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||||
|
maxAccountSwitches: maxAccountSwitches,
|
||||||
|
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,7 +191,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if platform == service.PlatformGemini {
|
if platform == service.PlatformGemini {
|
||||||
const maxAccountSwitches = 3
|
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
lastFailoverStatus := 0
|
lastFailoverStatus := 0
|
||||||
@@ -313,7 +325,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxAccountSwitches = 10
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
lastFailoverStatus := 0
|
lastFailoverStatus := 0
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
sessionKey = "gemini:" + sessionHash
|
sessionKey = "gemini:" + sessionHash
|
||||||
}
|
}
|
||||||
const maxAccountSwitches = 3
|
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
lastFailoverStatus := 0
|
lastFailoverStatus := 0
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ type OpenAIGatewayHandler struct {
|
|||||||
gatewayService *service.OpenAIGatewayService
|
gatewayService *service.OpenAIGatewayService
|
||||||
billingCacheService *service.BillingCacheService
|
billingCacheService *service.BillingCacheService
|
||||||
concurrencyHelper *ConcurrencyHelper
|
concurrencyHelper *ConcurrencyHelper
|
||||||
|
maxAccountSwitches int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||||
@@ -35,13 +36,18 @@ func NewOpenAIGatewayHandler(
|
|||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) *OpenAIGatewayHandler {
|
) *OpenAIGatewayHandler {
|
||||||
pingInterval := time.Duration(0)
|
pingInterval := time.Duration(0)
|
||||||
|
maxAccountSwitches := 3
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||||
|
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||||
|
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return &OpenAIGatewayHandler{
|
return &OpenAIGatewayHandler{
|
||||||
gatewayService: gatewayService,
|
gatewayService: gatewayService,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||||
|
maxAccountSwitches: maxAccountSwitches,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,7 +195,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
// Generate session hash (header first; fallback to prompt_cache_key)
|
// Generate session hash (header first; fallback to prompt_cache_key)
|
||||||
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
|
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
|
||||||
|
|
||||||
const maxAccountSwitches = 3
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
lastFailoverStatus := 0
|
lastFailoverStatus := 0
|
||||||
|
|||||||
@@ -16,15 +16,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// resolveHost 从 URL 解析 host
|
|
||||||
func resolveHost(urlStr string) string {
|
|
||||||
parsed, err := url.Parse(urlStr)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return parsed.Host
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||||
@@ -39,23 +30,11 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 基础 Headers
|
// 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
req.Header.Set("User-Agent", UserAgent)
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
|
||||||
// Accept Header 根据请求类型设置
|
|
||||||
if isStream {
|
|
||||||
req.Header.Set("Accept", "text/event-stream")
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 显式设置 Host Header
|
|
||||||
if host := resolveHost(apiURL); host != "" {
|
|
||||||
req.Host = host
|
|
||||||
}
|
|
||||||
|
|
||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,12 +174,15 @@ func isConnectionError(err error) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
|
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
|
||||||
// 仅连接错误和 HTTP 429 触发 URL 降级
|
// 与 Antigravity-Manager 保持一致:连接错误、429、408、404、5xx 触发 URL 降级
|
||||||
func shouldFallbackToNextURL(err error, statusCode int) bool {
|
func shouldFallbackToNextURL(err error, statusCode int) bool {
|
||||||
if isConnectionError(err) {
|
if isConnectionError(err) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return statusCode == http.StatusTooManyRequests
|
return statusCode == http.StatusTooManyRequests ||
|
||||||
|
statusCode == http.StatusRequestTimeout ||
|
||||||
|
statusCode == http.StatusNotFound ||
|
||||||
|
statusCode >= 500
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExchangeCode 用 authorization code 交换 token
|
// ExchangeCode 用 authorization code 交换 token
|
||||||
@@ -321,11 +303,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
|||||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取可用的 URL 列表
|
// 固定顺序:prod -> daily
|
||||||
availableURLs := DefaultURLAvailability.GetAvailableURLs()
|
availableURLs := BaseURLs
|
||||||
if len(availableURLs) == 0 {
|
|
||||||
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
|
|
||||||
}
|
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for urlIdx, baseURL := range availableURLs {
|
for urlIdx, baseURL := range availableURLs {
|
||||||
@@ -343,7 +322,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err)
|
lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err)
|
||||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
|
||||||
log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -358,7 +336,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
|||||||
|
|
||||||
// 检查是否需要 URL 降级
|
// 检查是否需要 URL 降级
|
||||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
|
||||||
log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -376,6 +353,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
|||||||
var rawResp map[string]any
|
var rawResp map[string]any
|
||||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||||
|
|
||||||
|
// 标记成功的 URL,下次优先使用
|
||||||
|
DefaultURLAvailability.MarkSuccess(baseURL)
|
||||||
return &loadResp, rawResp, nil
|
return &loadResp, rawResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -412,11 +391,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
|||||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取可用的 URL 列表
|
// 固定顺序:prod -> daily
|
||||||
availableURLs := DefaultURLAvailability.GetAvailableURLs()
|
availableURLs := BaseURLs
|
||||||
if len(availableURLs) == 0 {
|
|
||||||
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
|
|
||||||
}
|
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for urlIdx, baseURL := range availableURLs {
|
for urlIdx, baseURL := range availableURLs {
|
||||||
@@ -434,7 +410,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
||||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
|
||||||
log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -449,7 +424,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
|||||||
|
|
||||||
// 检查是否需要 URL 降级
|
// 检查是否需要 URL 降级
|
||||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
|
||||||
log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -467,6 +441,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
|||||||
var rawResp map[string]any
|
var rawResp map[string]any
|
||||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||||
|
|
||||||
|
// 标记成功的 URL,下次优先使用
|
||||||
|
DefaultURLAvailability.MarkSuccess(baseURL)
|
||||||
return &modelsResp, rawResp, nil
|
return &modelsResp, rawResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -143,9 +143,10 @@ type GeminiResponse struct {
|
|||||||
|
|
||||||
// GeminiCandidate Gemini 候选响应
|
// GeminiCandidate Gemini 候选响应
|
||||||
type GeminiCandidate struct {
|
type GeminiCandidate struct {
|
||||||
Content *GeminiContent `json:"content,omitempty"`
|
Content *GeminiContent `json:"content,omitempty"`
|
||||||
FinishReason string `json:"finishReason,omitempty"`
|
FinishReason string `json:"finishReason,omitempty"`
|
||||||
Index int `json:"index,omitempty"`
|
Index int `json:"index,omitempty"`
|
||||||
|
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GeminiUsageMetadata Gemini 用量元数据
|
// GeminiUsageMetadata Gemini 用量元数据
|
||||||
@@ -156,6 +157,23 @@ type GeminiUsageMetadata struct {
|
|||||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
|
||||||
|
type GeminiGroundingMetadata struct {
|
||||||
|
WebSearchQueries []string `json:"webSearchQueries,omitempty"`
|
||||||
|
GroundingChunks []GeminiGroundingChunk `json:"groundingChunks,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeminiGroundingChunk Gemini grounding chunk
|
||||||
|
type GeminiGroundingChunk struct {
|
||||||
|
Web *GeminiGroundingWeb `json:"web,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeminiGroundingWeb Gemini grounding web 信息
|
||||||
|
type GeminiGroundingWeb struct {
|
||||||
|
Title string `json:"title,omitempty"`
|
||||||
|
URI string `json:"uri,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
|
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
|
||||||
var DefaultSafetySettings = []GeminiSafetySetting{
|
var DefaultSafetySettings = []GeminiSafetySetting{
|
||||||
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},
|
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ const (
|
|||||||
"https://www.googleapis.com/auth/cclog " +
|
"https://www.googleapis.com/auth/cclog " +
|
||||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||||
|
|
||||||
// User-Agent(模拟官方客户端)
|
// User-Agent(与 Antigravity-Manager 保持一致)
|
||||||
UserAgent = "antigravity/1.104.0 darwin/arm64"
|
UserAgent = "antigravity/1.11.9 windows/amd64"
|
||||||
|
|
||||||
// Session 过期时间
|
// Session 过期时间
|
||||||
SessionTTL = 30 * time.Minute
|
SessionTTL = 30 * time.Minute
|
||||||
@@ -42,22 +42,21 @@ const (
|
|||||||
URLAvailabilityTTL = 5 * time.Minute
|
URLAvailabilityTTL = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
// BaseURLs 定义 Antigravity API 端点,按优先级排序
|
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
|
||||||
// fallback 顺序: sandbox → daily → prod
|
|
||||||
var BaseURLs = []string{
|
var BaseURLs = []string{
|
||||||
"https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox
|
"https://cloudcode-pa.googleapis.com", // prod (优先)
|
||||||
"https://daily-cloudcode-pa.googleapis.com", // daily
|
"https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用)
|
||||||
"https://cloudcode-pa.googleapis.com", // prod
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BaseURL 默认 URL(保持向后兼容)
|
// BaseURL 默认 URL(保持向后兼容)
|
||||||
var BaseURL = BaseURLs[0]
|
var BaseURL = BaseURLs[0]
|
||||||
|
|
||||||
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复)
|
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级)
|
||||||
type URLAvailability struct {
|
type URLAvailability struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
unavailable map[string]time.Time // URL -> 恢复时间
|
unavailable map[string]time.Time // URL -> 恢复时间
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
|
lastSuccess string // 最近成功请求的 URL,优先使用
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultURLAvailability 全局 URL 可用性管理器
|
// DefaultURLAvailability 全局 URL 可用性管理器
|
||||||
@@ -78,6 +77,15 @@ func (u *URLAvailability) MarkUnavailable(url string) {
|
|||||||
u.unavailable[url] = time.Now().Add(u.ttl)
|
u.unavailable[url] = time.Now().Add(u.ttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarkSuccess 标记 URL 请求成功,将其设为优先使用
|
||||||
|
func (u *URLAvailability) MarkSuccess(url string) {
|
||||||
|
u.mu.Lock()
|
||||||
|
defer u.mu.Unlock()
|
||||||
|
u.lastSuccess = url
|
||||||
|
// 成功后清除该 URL 的不可用标记
|
||||||
|
delete(u.unavailable, url)
|
||||||
|
}
|
||||||
|
|
||||||
// IsAvailable 检查 URL 是否可用
|
// IsAvailable 检查 URL 是否可用
|
||||||
func (u *URLAvailability) IsAvailable(url string) bool {
|
func (u *URLAvailability) IsAvailable(url string) bool {
|
||||||
u.mu.RLock()
|
u.mu.RLock()
|
||||||
@@ -89,14 +97,29 @@ func (u *URLAvailability) IsAvailable(url string) bool {
|
|||||||
return time.Now().After(expiry)
|
return time.Now().After(expiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序)
|
// GetAvailableURLs 返回可用的 URL 列表
|
||||||
|
// 最近成功的 URL 优先,其他按默认顺序
|
||||||
func (u *URLAvailability) GetAvailableURLs() []string {
|
func (u *URLAvailability) GetAvailableURLs() []string {
|
||||||
u.mu.RLock()
|
u.mu.RLock()
|
||||||
defer u.mu.RUnlock()
|
defer u.mu.RUnlock()
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
result := make([]string, 0, len(BaseURLs))
|
result := make([]string, 0, len(BaseURLs))
|
||||||
|
|
||||||
|
// 如果有最近成功的 URL 且可用,放在最前面
|
||||||
|
if u.lastSuccess != "" {
|
||||||
|
expiry, exists := u.unavailable[u.lastSuccess]
|
||||||
|
if !exists || now.After(expiry) {
|
||||||
|
result = append(result, u.lastSuccess)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加其他可用的 URL(按默认顺序)
|
||||||
for _, url := range BaseURLs {
|
for _, url := range BaseURLs {
|
||||||
|
// 跳过已添加的 lastSuccess
|
||||||
|
if url == u.lastSuccess {
|
||||||
|
continue
|
||||||
|
}
|
||||||
expiry, exists := u.unavailable[url]
|
expiry, exists := u.unavailable[url]
|
||||||
if !exists || now.After(expiry) {
|
if !exists || now.After(expiry) {
|
||||||
result = append(result, url)
|
result = append(result, url)
|
||||||
@@ -240,24 +263,3 @@ func BuildAuthorizationURL(state, codeChallenge string) string {
|
|||||||
|
|
||||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateMockProjectID 生成随机 project_id(当 API 不返回时使用)
|
|
||||||
// 格式:{形容词}-{名词}-{5位随机字符}
|
|
||||||
func GenerateMockProjectID() string {
|
|
||||||
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
|
||||||
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
|
||||||
|
|
||||||
randBytes, _ := GenerateRandomBytes(7)
|
|
||||||
|
|
||||||
adj := adjectives[int(randBytes[0])%len(adjectives)]
|
|
||||||
noun := nouns[int(randBytes[1])%len(nouns)]
|
|
||||||
|
|
||||||
// 生成 5 位随机字符(a-z0-9)
|
|
||||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
||||||
suffix := make([]byte, 5)
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
suffix[i] = charset[int(randBytes[i+2])%len(charset)]
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -54,6 +54,9 @@ func DefaultTransformOptions() TransformOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// webSearchFallbackModel web_search 请求使用的降级模型
|
||||||
|
const webSearchFallbackModel = "gemini-2.5-flash"
|
||||||
|
|
||||||
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
|
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
|
||||||
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
|
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
|
||||||
return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions())
|
return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions())
|
||||||
@@ -64,12 +67,23 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
|||||||
// 用于存储 tool_use id -> name 映射
|
// 用于存储 tool_use id -> name 映射
|
||||||
toolIDToName := make(map[string]string)
|
toolIDToName := make(map[string]string)
|
||||||
|
|
||||||
|
// 检测是否有 web_search 工具
|
||||||
|
hasWebSearchTool := hasWebSearchTool(claudeReq.Tools)
|
||||||
|
requestType := "agent"
|
||||||
|
targetModel := mappedModel
|
||||||
|
if hasWebSearchTool {
|
||||||
|
requestType = "web_search"
|
||||||
|
if targetModel != webSearchFallbackModel {
|
||||||
|
targetModel = webSearchFallbackModel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 检测是否启用 thinking
|
// 检测是否启用 thinking
|
||||||
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
||||||
|
|
||||||
// 只有 Gemini 模型支持 dummy thought workaround
|
// 只有 Gemini 模型支持 dummy thought workaround
|
||||||
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
|
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
|
||||||
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
|
allowDummyThought := strings.HasPrefix(targetModel, "gemini-")
|
||||||
|
|
||||||
// 1. 构建 contents
|
// 1. 构建 contents
|
||||||
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
|
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
|
||||||
@@ -78,7 +92,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. 构建 systemInstruction
|
// 2. 构建 systemInstruction
|
||||||
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts)
|
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts, claudeReq.Tools)
|
||||||
|
|
||||||
// 3. 构建 generationConfig
|
// 3. 构建 generationConfig
|
||||||
reqForConfig := claudeReq
|
reqForConfig := claudeReq
|
||||||
@@ -89,6 +103,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
|||||||
reqCopy.Thinking = nil
|
reqCopy.Thinking = nil
|
||||||
reqForConfig = &reqCopy
|
reqForConfig = &reqCopy
|
||||||
}
|
}
|
||||||
|
if targetModel != "" && targetModel != reqForConfig.Model {
|
||||||
|
reqCopy := *reqForConfig
|
||||||
|
reqCopy.Model = targetModel
|
||||||
|
reqForConfig = &reqCopy
|
||||||
|
}
|
||||||
generationConfig := buildGenerationConfig(reqForConfig)
|
generationConfig := buildGenerationConfig(reqForConfig)
|
||||||
|
|
||||||
// 4. 构建 tools
|
// 4. 构建 tools
|
||||||
@@ -127,8 +146,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
|||||||
Project: projectID,
|
Project: projectID,
|
||||||
RequestID: "agent-" + uuid.New().String(),
|
RequestID: "agent-" + uuid.New().String(),
|
||||||
UserAgent: "antigravity", // 固定值,与官方客户端一致
|
UserAgent: "antigravity", // 固定值,与官方客户端一致
|
||||||
RequestType: "agent",
|
RequestType: requestType,
|
||||||
Model: mappedModel,
|
Model: targetModel,
|
||||||
Request: innerRequest,
|
Request: innerRequest,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,8 +173,40 @@ func GetDefaultIdentityPatch() string {
|
|||||||
return antigravityIdentity
|
return antigravityIdentity
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildSystemInstruction 构建 systemInstruction
|
// mcpXMLProtocol MCP XML 工具调用协议(与 Antigravity-Manager 保持一致)
|
||||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent {
|
const mcpXMLProtocol = `
|
||||||
|
==== MCP XML 工具调用协议 (Workaround) ====
|
||||||
|
当你需要调用名称以 ` + "`mcp__`" + ` 开头的 MCP 工具时:
|
||||||
|
1) 优先尝试 XML 格式调用:输出 ` + "`<mcp__tool_name>{\"arg\":\"value\"}</mcp__tool_name>`" + `。
|
||||||
|
2) 必须直接输出 XML 块,无需 markdown 包装,内容为 JSON 格式的入参。
|
||||||
|
3) 这种方式具有更高的连通性和容错性,适用于大型结果返回场景。
|
||||||
|
===========================================`
|
||||||
|
|
||||||
|
// hasMCPTools 检测是否有 mcp__ 前缀的工具
|
||||||
|
func hasMCPTools(tools []ClaudeTool) bool {
|
||||||
|
for _, tool := range tools {
|
||||||
|
if strings.HasPrefix(tool.Name, "mcp__") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterOpenCodePrompt 过滤 OpenCode 默认提示词,只保留用户自定义指令
|
||||||
|
func filterOpenCodePrompt(text string) string {
|
||||||
|
if !strings.Contains(text, "You are an interactive CLI tool") {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
// 提取 "Instructions from:" 及之后的部分
|
||||||
|
if idx := strings.Index(text, "Instructions from:"); idx >= 0 {
|
||||||
|
return text[idx:]
|
||||||
|
}
|
||||||
|
// 如果没有自定义指令,返回空
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
||||||
|
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
||||||
var parts []GeminiPart
|
var parts []GeminiPart
|
||||||
|
|
||||||
// 先解析用户的 system prompt,检测是否已包含 Antigravity identity
|
// 先解析用户的 system prompt,检测是否已包含 Antigravity identity
|
||||||
@@ -167,10 +218,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
var sysStr string
|
var sysStr string
|
||||||
if err := json.Unmarshal(system, &sysStr); err == nil {
|
if err := json.Unmarshal(system, &sysStr); err == nil {
|
||||||
if strings.TrimSpace(sysStr) != "" {
|
if strings.TrimSpace(sysStr) != "" {
|
||||||
userSystemParts = append(userSystemParts, GeminiPart{Text: sysStr})
|
|
||||||
if strings.Contains(sysStr, "You are Antigravity") {
|
if strings.Contains(sysStr, "You are Antigravity") {
|
||||||
userHasAntigravityIdentity = true
|
userHasAntigravityIdentity = true
|
||||||
}
|
}
|
||||||
|
// 过滤 OpenCode 默认提示词
|
||||||
|
filtered := filterOpenCodePrompt(sysStr)
|
||||||
|
if filtered != "" {
|
||||||
|
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 尝试解析为数组
|
// 尝试解析为数组
|
||||||
@@ -178,10 +233,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
if err := json.Unmarshal(system, &sysBlocks); err == nil {
|
if err := json.Unmarshal(system, &sysBlocks); err == nil {
|
||||||
for _, block := range sysBlocks {
|
for _, block := range sysBlocks {
|
||||||
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
|
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
|
||||||
userSystemParts = append(userSystemParts, GeminiPart{Text: block.Text})
|
|
||||||
if strings.Contains(block.Text, "You are Antigravity") {
|
if strings.Contains(block.Text, "You are Antigravity") {
|
||||||
userHasAntigravityIdentity = true
|
userHasAntigravityIdentity = true
|
||||||
}
|
}
|
||||||
|
// 过滤 OpenCode 默认提示词
|
||||||
|
filtered := filterOpenCodePrompt(block.Text)
|
||||||
|
if filtered != "" {
|
||||||
|
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -200,6 +259,16 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
// 添加用户的 system prompt
|
// 添加用户的 system prompt
|
||||||
parts = append(parts, userSystemParts...)
|
parts = append(parts, userSystemParts...)
|
||||||
|
|
||||||
|
// 检测是否有 MCP 工具,如有则注入 XML 调用协议
|
||||||
|
if hasMCPTools(tools) {
|
||||||
|
parts = append(parts, GeminiPart{Text: mcpXMLProtocol})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果用户没有提供 Antigravity 身份,添加结束标记
|
||||||
|
if !userHasAntigravityIdentity {
|
||||||
|
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
|
||||||
|
}
|
||||||
|
|
||||||
if len(parts) == 0 {
|
if len(parts) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -429,6 +498,11 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
|||||||
StopSequences: DefaultStopSequences,
|
StopSequences: DefaultStopSequences,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果请求中指定了 MaxTokens,使用请求值
|
||||||
|
if req.MaxTokens > 0 {
|
||||||
|
config.MaxOutputTokens = req.MaxTokens
|
||||||
|
}
|
||||||
|
|
||||||
// Thinking 配置
|
// Thinking 配置
|
||||||
if req.Thinking != nil && req.Thinking.Type == "enabled" {
|
if req.Thinking != nil && req.Thinking.Type == "enabled" {
|
||||||
config.ThinkingConfig = &GeminiThinkingConfig{
|
config.ThinkingConfig = &GeminiThinkingConfig{
|
||||||
@@ -458,37 +532,43 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
|||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasWebSearchTool(tools []ClaudeTool) bool {
|
||||||
|
for _, tool := range tools {
|
||||||
|
if isWebSearchTool(tool) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWebSearchTool(tool ClaudeTool) bool {
|
||||||
|
if strings.HasPrefix(tool.Type, "web_search") || tool.Type == "google_search" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
name := strings.TrimSpace(tool.Name)
|
||||||
|
switch name {
|
||||||
|
case "web_search", "google_search", "web_search_20250305":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// buildTools 构建 tools
|
// buildTools 构建 tools
|
||||||
func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||||
if len(tools) == 0 {
|
if len(tools) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否有 web_search 工具
|
hasWebSearch := hasWebSearchTool(tools)
|
||||||
hasWebSearch := false
|
|
||||||
for _, tool := range tools {
|
|
||||||
if tool.Name == "web_search" {
|
|
||||||
hasWebSearch = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasWebSearch {
|
|
||||||
// Web Search 工具映射
|
|
||||||
return []GeminiToolDeclaration{{
|
|
||||||
GoogleSearch: &GeminiGoogleSearch{
|
|
||||||
EnhancedContent: &GeminiEnhancedContent{
|
|
||||||
ImageSearch: &GeminiImageSearch{
|
|
||||||
MaxResultCount: 5,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 普通工具
|
// 普通工具
|
||||||
var funcDecls []GeminiFunctionDecl
|
var funcDecls []GeminiFunctionDecl
|
||||||
for _, tool := range tools {
|
for _, tool := range tools {
|
||||||
|
if isWebSearchTool(tool) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
// 跳过无效工具名称
|
// 跳过无效工具名称
|
||||||
if strings.TrimSpace(tool.Name) == "" {
|
if strings.TrimSpace(tool.Name) == "" {
|
||||||
log.Printf("Warning: skipping tool with empty name")
|
log.Printf("Warning: skipping tool with empty name")
|
||||||
@@ -531,7 +611,20 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(funcDecls) == 0 {
|
if len(funcDecls) == 0 {
|
||||||
return nil
|
if !hasWebSearch {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Web Search 工具映射
|
||||||
|
return []GeminiToolDeclaration{{
|
||||||
|
GoogleSearch: &GeminiGoogleSearch{
|
||||||
|
EnhancedContent: &GeminiEnhancedContent{
|
||||||
|
ImageSearch: &GeminiImageSearch{
|
||||||
|
MaxResultCount: 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
return []GeminiToolDeclaration{{
|
return []GeminiToolDeclaration{{
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package antigravity
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
|
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
|
||||||
@@ -63,6 +64,12 @@ func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID,
|
|||||||
p.processPart(&part)
|
p.processPart(&part)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(geminiResp.Candidates) > 0 {
|
||||||
|
if grounding := geminiResp.Candidates[0].GroundingMetadata; grounding != nil {
|
||||||
|
p.processGrounding(grounding)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 刷新剩余内容
|
// 刷新剩余内容
|
||||||
p.flushThinking()
|
p.flushThinking()
|
||||||
p.flushText()
|
p.flushText()
|
||||||
@@ -190,6 +197,18 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *NonStreamingProcessor) processGrounding(grounding *GeminiGroundingMetadata) {
|
||||||
|
groundingText := buildGroundingText(grounding)
|
||||||
|
if groundingText == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.flushThinking()
|
||||||
|
p.flushText()
|
||||||
|
p.textBuilder += groundingText
|
||||||
|
p.flushText()
|
||||||
|
}
|
||||||
|
|
||||||
// flushText 刷新 text builder
|
// flushText 刷新 text builder
|
||||||
func (p *NonStreamingProcessor) flushText() {
|
func (p *NonStreamingProcessor) flushText() {
|
||||||
if p.textBuilder == "" {
|
if p.textBuilder == "" {
|
||||||
@@ -262,6 +281,44 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildGroundingText(grounding *GeminiGroundingMetadata) string {
|
||||||
|
if grounding == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var builder strings.Builder
|
||||||
|
|
||||||
|
if len(grounding.WebSearchQueries) > 0 {
|
||||||
|
_, _ = builder.WriteString("\n\n---\nWeb search queries: ")
|
||||||
|
_, _ = builder.WriteString(strings.Join(grounding.WebSearchQueries, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(grounding.GroundingChunks) > 0 {
|
||||||
|
var links []string
|
||||||
|
for i, chunk := range grounding.GroundingChunks {
|
||||||
|
if chunk.Web == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
title := strings.TrimSpace(chunk.Web.Title)
|
||||||
|
if title == "" {
|
||||||
|
title = "Source"
|
||||||
|
}
|
||||||
|
uri := strings.TrimSpace(chunk.Web.URI)
|
||||||
|
if uri == "" {
|
||||||
|
uri = "#"
|
||||||
|
}
|
||||||
|
links = append(links, fmt.Sprintf("[%d] [%s](%s)", i+1, title, uri))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(links) > 0 {
|
||||||
|
_, _ = builder.WriteString("\n\nSources:\n")
|
||||||
|
_, _ = builder.WriteString(strings.Join(links, "\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
// generateRandomID 生成随机 ID
|
// generateRandomID 生成随机 ID
|
||||||
func generateRandomID() string {
|
func generateRandomID() string {
|
||||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ type StreamingProcessor struct {
|
|||||||
pendingSignature string
|
pendingSignature string
|
||||||
trailingSignature string
|
trailingSignature string
|
||||||
originalModel string
|
originalModel string
|
||||||
|
webSearchQueries []string
|
||||||
|
groundingChunks []GeminiGroundingChunk
|
||||||
|
|
||||||
// 累计 usage
|
// 累计 usage
|
||||||
inputTokens int
|
inputTokens int
|
||||||
@@ -93,6 +95,10 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(geminiResp.Candidates) > 0 {
|
||||||
|
p.captureGrounding(geminiResp.Candidates[0].GroundingMetadata)
|
||||||
|
}
|
||||||
|
|
||||||
// 检查是否结束
|
// 检查是否结束
|
||||||
if len(geminiResp.Candidates) > 0 {
|
if len(geminiResp.Candidates) > 0 {
|
||||||
finishReason := geminiResp.Candidates[0].FinishReason
|
finishReason := geminiResp.Candidates[0].FinishReason
|
||||||
@@ -200,6 +206,20 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
|
|||||||
return result.Bytes()
|
return result.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *StreamingProcessor) captureGrounding(grounding *GeminiGroundingMetadata) {
|
||||||
|
if grounding == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(grounding.WebSearchQueries) > 0 && len(p.webSearchQueries) == 0 {
|
||||||
|
p.webSearchQueries = append([]string(nil), grounding.WebSearchQueries...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(grounding.GroundingChunks) > 0 && len(p.groundingChunks) == 0 {
|
||||||
|
p.groundingChunks = append([]GeminiGroundingChunk(nil), grounding.GroundingChunks...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// processThinking 处理 thinking
|
// processThinking 处理 thinking
|
||||||
func (p *StreamingProcessor) processThinking(text, signature string) []byte {
|
func (p *StreamingProcessor) processThinking(text, signature string) []byte {
|
||||||
var result bytes.Buffer
|
var result bytes.Buffer
|
||||||
@@ -417,6 +437,23 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
|||||||
p.trailingSignature = ""
|
p.trailingSignature = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(p.webSearchQueries) > 0 || len(p.groundingChunks) > 0 {
|
||||||
|
groundingText := buildGroundingText(&GeminiGroundingMetadata{
|
||||||
|
WebSearchQueries: p.webSearchQueries,
|
||||||
|
GroundingChunks: p.groundingChunks,
|
||||||
|
})
|
||||||
|
if groundingText != "" {
|
||||||
|
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "",
|
||||||
|
}))
|
||||||
|
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
|
||||||
|
"text": groundingText,
|
||||||
|
}))
|
||||||
|
_, _ = result.Write(p.endBlock())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 确定 stop_reason
|
// 确定 stop_reason
|
||||||
stopReason := "end_turn"
|
stopReason := "end_turn"
|
||||||
if p.usedTool {
|
if p.usedTool {
|
||||||
|
|||||||
@@ -543,6 +543,15 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||||
|
_, err := r.client.Account.Update().
|
||||||
|
Where(dbaccount.IDEQ(id)).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
SetErrorMessage("").
|
||||||
|
Save(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||||
_, err := r.client.AccountGroup.Create().
|
_, err := r.client.AccountGroup.Create().
|
||||||
SetAccountID(accountID).
|
SetAccountID(accountID).
|
||||||
|
|||||||
@@ -744,6 +744,10 @@ func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg strin
|
|||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ClearError(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ type AccountRepository interface {
|
|||||||
UpdateLastUsed(ctx context.Context, id int64) error
|
UpdateLastUsed(ctx context.Context, id int64) error
|
||||||
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||||
SetError(ctx context.Context, id int64, errorMsg string) error
|
SetError(ctx context.Context, id int64, errorMsg string) error
|
||||||
|
ClearError(ctx context.Context, id int64) error
|
||||||
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
||||||
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
|
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
|
||||||
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
||||||
|
|||||||
@@ -99,6 +99,10 @@ func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg strin
|
|||||||
panic("unexpected SetError call")
|
panic("unexpected SetError call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStub) ClearError(ctx context.Context, id int64) error {
|
||||||
|
panic("unexpected ClearError call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||||
panic("unexpected SetSchedulable call")
|
panic("unexpected SetSchedulable call")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ type AdminService interface {
|
|||||||
DeleteAccount(ctx context.Context, id int64) error
|
DeleteAccount(ctx context.Context, id int64) error
|
||||||
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
|
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
|
||||||
ClearAccountError(ctx context.Context, id int64) (*Account, error)
|
ClearAccountError(ctx context.Context, id int64) (*Account, error)
|
||||||
|
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
||||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
||||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||||
|
|
||||||
@@ -1101,6 +1102,10 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
|
|||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
|
return s.accountRepo.SetError(ctx, id, errorMsg)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
|
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
|
||||||
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
|
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -82,13 +82,14 @@ type AntigravityExchangeCodeInput struct {
|
|||||||
|
|
||||||
// AntigravityTokenInfo token 信息
|
// AntigravityTokenInfo token 信息
|
||||||
type AntigravityTokenInfo struct {
|
type AntigravityTokenInfo struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
ExpiresAt int64 `json:"expires_at"`
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
TokenType string `json:"token_type"`
|
TokenType string `json:"token_type"`
|
||||||
Email string `json:"email,omitempty"`
|
Email string `json:"email,omitempty"`
|
||||||
ProjectID string `json:"project_id,omitempty"`
|
ProjectID string `json:"project_id,omitempty"`
|
||||||
|
ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExchangeCode 用 authorization code 交换 token
|
// ExchangeCode 用 authorization code 交换 token
|
||||||
@@ -149,12 +150,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
|||||||
result.ProjectID = loadResp.CloudAICompanionProject
|
result.ProjectID = loadResp.CloudAICompanionProject
|
||||||
}
|
}
|
||||||
|
|
||||||
// 兜底:随机生成 project_id
|
|
||||||
if result.ProjectID == "" {
|
|
||||||
result.ProjectID = antigravity.GenerateMockProjectID()
|
|
||||||
fmt.Printf("[AntigravityOAuth] 使用随机生成的 project_id: %s\n", result.ProjectID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,16 +231,24 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保留原有的 project_id 和 email
|
// 保留原有的 email
|
||||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
|
||||||
if existingProjectID != "" {
|
|
||||||
tokenInfo.ProjectID = existingProjectID
|
|
||||||
}
|
|
||||||
existingEmail := strings.TrimSpace(account.GetCredential("email"))
|
existingEmail := strings.TrimSpace(account.GetCredential("email"))
|
||||||
if existingEmail != "" {
|
if existingEmail != "" {
|
||||||
tokenInfo.Email = existingEmail
|
tokenInfo.Email = existingEmail
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 每次刷新都调用 LoadCodeAssist 获取 project_id
|
||||||
|
client := antigravity.NewClient(proxyURL)
|
||||||
|
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
|
||||||
|
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
|
||||||
|
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
|
||||||
|
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
|
tokenInfo.ProjectID = existingProjectID
|
||||||
|
tokenInfo.ProjectIDMissing = true
|
||||||
|
} else {
|
||||||
|
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
|
||||||
|
}
|
||||||
|
|
||||||
return tokenInfo, nil
|
return tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -31,11 +31,6 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
|||||||
accessToken := account.GetCredential("access_token")
|
accessToken := account.GetCredential("access_token")
|
||||||
projectID := account.GetCredential("project_id")
|
projectID := account.GetCredential("project_id")
|
||||||
|
|
||||||
// 如果没有 project_id,生成一个随机的
|
|
||||||
if projectID == "" {
|
|
||||||
projectID = antigravity.GenerateMockProjectID()
|
|
||||||
}
|
|
||||||
|
|
||||||
client := antigravity.NewClient(proxyURL)
|
client := antigravity.NewClient(proxyURL)
|
||||||
|
|
||||||
// 调用 API 获取配额
|
// 调用 API 获取配额
|
||||||
|
|||||||
186
backend/internal/service/antigravity_rate_limit_test.go
Normal file
186
backend/internal/service/antigravity_rate_limit_test.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type stubAntigravityUpstream struct {
|
||||||
|
firstBase string
|
||||||
|
secondBase string
|
||||||
|
calls []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||||
|
url := req.URL.String()
|
||||||
|
s.calls = append(s.calls, url)
|
||||||
|
if strings.HasPrefix(url, s.firstBase) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Resource has been exhausted"}}`)),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type scopeLimitCall struct {
|
||||||
|
accountID int64
|
||||||
|
scope AntigravityQuotaScope
|
||||||
|
resetAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type rateLimitCall struct {
|
||||||
|
accountID int64
|
||||||
|
resetAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubAntigravityAccountRepo struct {
|
||||||
|
AccountRepository
|
||||||
|
scopeCalls []scopeLimitCall
|
||||||
|
rateCalls []rateLimitCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||||
|
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
|
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
||||||
|
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||||
|
oldAvailability := antigravity.DefaultURLAvailability
|
||||||
|
defer func() {
|
||||||
|
antigravity.BaseURLs = oldBaseURLs
|
||||||
|
antigravity.DefaultURLAvailability = oldAvailability
|
||||||
|
}()
|
||||||
|
|
||||||
|
base1 := "https://ag-1.test"
|
||||||
|
base2 := "https://ag-2.test"
|
||||||
|
antigravity.BaseURLs = []string{base1, base2}
|
||||||
|
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||||
|
|
||||||
|
upstream := &stubAntigravityUpstream{firstBase: base1, secondBase: base2}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "acc-1",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Schedulable: true,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
var handleErrorCalled bool
|
||||||
|
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
|
||||||
|
prefix: "[test]",
|
||||||
|
ctx: context.Background(),
|
||||||
|
account: account,
|
||||||
|
proxyURL: "",
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
quotaScope: AntigravityQuotaScopeClaude,
|
||||||
|
httpUpstream: upstream,
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||||
|
handleErrorCalled = true
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.resp)
|
||||||
|
defer func() { _ = result.resp.Body.Close() }()
|
||||||
|
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||||
|
require.False(t, handleErrorCalled)
|
||||||
|
require.Len(t, upstream.calls, 2)
|
||||||
|
require.True(t, strings.HasPrefix(upstream.calls[0], base1))
|
||||||
|
require.True(t, strings.HasPrefix(upstream.calls[1], base2))
|
||||||
|
|
||||||
|
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||||
|
require.NotEmpty(t, available)
|
||||||
|
require.Equal(t, base2, available[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) {
|
||||||
|
t.Setenv(antigravityScopeRateLimitEnv, "true")
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||||
|
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
|
||||||
|
|
||||||
|
body := buildGeminiRateLimitBody("3s")
|
||||||
|
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
|
||||||
|
|
||||||
|
require.Len(t, repo.scopeCalls, 1)
|
||||||
|
require.Empty(t, repo.rateCalls)
|
||||||
|
call := repo.scopeCalls[0]
|
||||||
|
require.Equal(t, account.ID, call.accountID)
|
||||||
|
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
|
||||||
|
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) {
|
||||||
|
t.Setenv(antigravityScopeRateLimitEnv, "false")
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||||
|
account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity}
|
||||||
|
|
||||||
|
body := buildGeminiRateLimitBody("2s")
|
||||||
|
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
|
||||||
|
|
||||||
|
require.Len(t, repo.rateCalls, 1)
|
||||||
|
require.Empty(t, repo.scopeCalls)
|
||||||
|
call := repo.rateCalls[0]
|
||||||
|
require.Equal(t, account.ID, call.accountID)
|
||||||
|
require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
future := now.Add(10 * time.Minute)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "acc",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
account.RateLimitResetAt = &future
|
||||||
|
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
||||||
|
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||||
|
|
||||||
|
account.RateLimitResetAt = nil
|
||||||
|
account.Extra = map[string]any{
|
||||||
|
antigravityQuotaScopesKey: map[string]any{
|
||||||
|
"claude": map[string]any{
|
||||||
|
"rate_limit_reset_at": future.Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
||||||
|
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildGeminiRateLimitBody(delay string) []byte {
|
||||||
|
return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay))
|
||||||
|
}
|
||||||
@@ -61,5 +61,10 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
|
||||||
|
if tokenInfo.ProjectIDMissing {
|
||||||
|
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity")
|
||||||
|
}
|
||||||
|
|
||||||
return newCredentials, nil
|
return newCredentials, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up
|
|||||||
func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error {
|
func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForPlatform) ClearError(ctx context.Context, id int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
mathrand "math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -918,7 +919,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ============ Layer 3: 兜底排队 ============
|
// ============ Layer 3: 兜底排队 ============
|
||||||
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
|
s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode)
|
||||||
for _, acc := range candidates {
|
for _, acc := range candidates {
|
||||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||||
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
|
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
|
||||||
@@ -1318,6 +1319,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sortCandidatesForFallback 根据配置选择排序策略
|
||||||
|
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
|
||||||
|
func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
|
||||||
|
if mode == "random" {
|
||||||
|
// 先按优先级排序,然后在同优先级内随机打乱
|
||||||
|
sortAccountsByPriorityOnly(accounts, preferOAuth)
|
||||||
|
shuffleWithinPriority(accounts)
|
||||||
|
} else {
|
||||||
|
// 默认按最后使用时间排序
|
||||||
|
sortAccountsByPriorityAndLastUsed(accounts, preferOAuth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortAccountsByPriorityOnly 仅按优先级排序
|
||||||
|
func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) {
|
||||||
|
sort.SliceStable(accounts, func(i, j int) bool {
|
||||||
|
a, b := accounts[i], accounts[j]
|
||||||
|
if a.Priority != b.Priority {
|
||||||
|
return a.Priority < b.Priority
|
||||||
|
}
|
||||||
|
if preferOAuth && a.Type != b.Type {
|
||||||
|
return a.Type == AccountTypeOAuth
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// shuffleWithinPriority 在同优先级内随机打乱顺序
|
||||||
|
func shuffleWithinPriority(accounts []*Account) {
|
||||||
|
if len(accounts) <= 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
|
||||||
|
start := 0
|
||||||
|
for start < len(accounts) {
|
||||||
|
priority := accounts[start].Priority
|
||||||
|
end := start + 1
|
||||||
|
for end < len(accounts) && accounts[end].Priority == priority {
|
||||||
|
end++
|
||||||
|
}
|
||||||
|
// 对 [start, end) 范围内的账户随机打乱
|
||||||
|
if end-start > 1 {
|
||||||
|
r.Shuffle(end-start, func(i, j int) {
|
||||||
|
accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
start = end
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||||
preferOAuth := platform == PlatformGemini
|
preferOAuth := platform == PlatformGemini
|
||||||
|
|||||||
@@ -88,6 +88,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda
|
|||||||
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
|
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForGemini) ClearError(ctx context.Context, id int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -166,11 +166,25 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
|||||||
|
|
||||||
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
||||||
newCredentials, err := refresher.Refresh(ctx, account)
|
newCredentials, err := refresher.Refresh(ctx, account)
|
||||||
if err == nil {
|
|
||||||
// 刷新成功,更新账号credentials
|
// 如果有新凭证,先更新(即使有错误也要保存 token)
|
||||||
|
if newCredentials != nil {
|
||||||
account.Credentials = newCredentials
|
account.Credentials = newCredentials
|
||||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
|
||||||
return fmt.Errorf("failed to save credentials: %w", err)
|
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
|
||||||
|
if account.Platform == PlatformAntigravity &&
|
||||||
|
account.Status == StatusError &&
|
||||||
|
strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||||
|
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
|
||||||
|
log.Printf("[TokenRefresh] Failed to clear error status for account %d: %v", account.ID, clearErr)
|
||||||
|
} else {
|
||||||
|
log.Printf("[TokenRefresh] Account %d: cleared missing_project_id error", account.ID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
||||||
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
||||||
@@ -230,6 +244,7 @@ func isNonRetryableRefreshError(err error) bool {
|
|||||||
"invalid_client", // 客户端配置错误
|
"invalid_client", // 客户端配置错误
|
||||||
"unauthorized_client", // 客户端未授权
|
"unauthorized_client", // 客户端未授权
|
||||||
"access_denied", // 访问被拒绝
|
"access_denied", // 访问被拒绝
|
||||||
|
"missing_project_id", // 缺少 project_id
|
||||||
}
|
}
|
||||||
for _, needle := range nonRetryable {
|
for _, needle := range nonRetryable {
|
||||||
if strings.Contains(msg, needle) {
|
if strings.Contains(msg, needle) {
|
||||||
|
|||||||
@@ -21,8 +21,20 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Right: Language + Subscriptions + Balance + User Dropdown -->
|
<!-- Right: Docs + Language + Subscriptions + Balance + User Dropdown -->
|
||||||
<div class="flex items-center gap-3">
|
<div class="flex items-center gap-3">
|
||||||
|
<!-- Docs Link -->
|
||||||
|
<a
|
||||||
|
v-if="docUrl"
|
||||||
|
:href="docUrl"
|
||||||
|
target="_blank"
|
||||||
|
rel="noopener noreferrer"
|
||||||
|
class="flex items-center gap-1.5 rounded-lg px-2.5 py-1.5 text-sm font-medium text-gray-600 transition-colors hover:bg-gray-100 hover:text-gray-900 dark:text-dark-400 dark:hover:bg-dark-800 dark:hover:text-white"
|
||||||
|
>
|
||||||
|
<Icon name="book" size="sm" />
|
||||||
|
<span class="hidden sm:inline">{{ t('nav.docs') }}</span>
|
||||||
|
</a>
|
||||||
|
|
||||||
<!-- Language Switcher -->
|
<!-- Language Switcher -->
|
||||||
<LocaleSwitcher />
|
<LocaleSwitcher />
|
||||||
|
|
||||||
@@ -211,6 +223,7 @@ const user = computed(() => authStore.user)
|
|||||||
const dropdownOpen = ref(false)
|
const dropdownOpen = ref(false)
|
||||||
const dropdownRef = ref<HTMLElement | null>(null)
|
const dropdownRef = ref<HTMLElement | null>(null)
|
||||||
const contactInfo = computed(() => appStore.contactInfo)
|
const contactInfo = computed(() => appStore.contactInfo)
|
||||||
|
const docUrl = computed(() => appStore.docUrl)
|
||||||
|
|
||||||
// 只在标准模式的管理员下显示新手引导按钮
|
// 只在标准模式的管理员下显示新手引导按钮
|
||||||
const showOnboardingButton = computed(() => {
|
const showOnboardingButton = computed(() => {
|
||||||
|
|||||||
@@ -196,7 +196,8 @@ export default {
|
|||||||
expand: 'Expand',
|
expand: 'Expand',
|
||||||
logout: 'Logout',
|
logout: 'Logout',
|
||||||
github: 'GitHub',
|
github: 'GitHub',
|
||||||
mySubscriptions: 'My Subscriptions'
|
mySubscriptions: 'My Subscriptions',
|
||||||
|
docs: 'Docs'
|
||||||
},
|
},
|
||||||
|
|
||||||
// Auth
|
// Auth
|
||||||
|
|||||||
@@ -193,7 +193,8 @@ export default {
|
|||||||
expand: '展开',
|
expand: '展开',
|
||||||
logout: '退出登录',
|
logout: '退出登录',
|
||||||
github: 'GitHub',
|
github: 'GitHub',
|
||||||
mySubscriptions: '我的订阅'
|
mySubscriptions: '我的订阅',
|
||||||
|
docs: '文档'
|
||||||
},
|
},
|
||||||
|
|
||||||
// Auth
|
// Auth
|
||||||
|
|||||||
Reference in New Issue
Block a user