Files
sub2api/backend/internal/service/gemini_oauth_service.go
IanShaw027 48764e15a5 test(gemini): 添加 Drive API 和 OAuth 服务单元测试
- 新增 drive_client_test.go:Drive API 客户端单元测试
- 新增 gemini_oauth_service_test.go:OAuth 服务单元测试
- 重构 account_handler.go:改进 RefreshTier API 实现
- 优化 drive_client.go:增强错误处理和重试逻辑
- 完善 repository 和 service 层:支持批量 tier 刷新
- 更新迁移文件编号:017 -> 024(避免冲突)
2026-01-01 15:07:16 +08:00

836 lines
27 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
const (
TierAIPremium = "AI_PREMIUM"
TierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
TierGoogleOneBasic = "GOOGLE_ONE_BASIC"
TierFree = "FREE"
TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
)
const (
GB = 1024 * 1024 * 1024
TB = 1024 * GB
StorageTierUnlimited = 100 * TB // 100TB
StorageTierAIPremium = 2 * TB // 2TB
StorageTierStandard = 200 * GB // 200GB
StorageTierBasic = 100 * GB // 100GB
StorageTierFree = 15 * GB // 15GB
)
type GeminiOAuthService struct {
sessionStore *geminicli.SessionStore
proxyRepo ProxyRepository
oauthClient GeminiOAuthClient
codeAssist GeminiCliCodeAssistClient
cfg *config.Config
}
type GeminiOAuthCapabilities struct {
AIStudioOAuthEnabled bool `json:"ai_studio_oauth_enabled"`
RequiredRedirectURIs []string `json:"required_redirect_uris"`
}
func NewGeminiOAuthService(
proxyRepo ProxyRepository,
oauthClient GeminiOAuthClient,
codeAssist GeminiCliCodeAssistClient,
cfg *config.Config,
) *GeminiOAuthService {
return &GeminiOAuthService{
sessionStore: geminicli.NewSessionStore(),
proxyRepo: proxyRepo,
oauthClient: oauthClient,
codeAssist: codeAssist,
cfg: cfg,
}
}
func (s *GeminiOAuthService) GetOAuthConfig() *GeminiOAuthCapabilities {
// AI Studio OAuth is only enabled when the operator configures a custom OAuth client.
clientID := strings.TrimSpace(s.cfg.Gemini.OAuth.ClientID)
clientSecret := strings.TrimSpace(s.cfg.Gemini.OAuth.ClientSecret)
enabled := clientID != "" && clientSecret != "" &&
(clientID != geminicli.GeminiCLIOAuthClientID || clientSecret != geminicli.GeminiCLIOAuthClientSecret)
return &GeminiOAuthCapabilities{
AIStudioOAuthEnabled: enabled,
RequiredRedirectURIs: []string{geminicli.AIStudioOAuthRedirectURI},
}
}
type GeminiAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
}
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) {
state, err := geminicli.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}
codeVerifier, err := geminicli.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
}
codeChallenge := geminicli.GenerateCodeChallenge(codeVerifier)
sessionID, err := geminicli.GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
}
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
// - google_one: same as code_assist, uses built-in client for personal Google accounts.
// - ai_studio: requires a user-provided OAuth client.
oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" || oauthType == "google_one" {
oauthCfg.ClientID = ""
oauthCfg.ClientSecret = ""
}
session := &geminicli.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
RedirectURI: redirectURI,
ProjectID: strings.TrimSpace(projectID),
OAuthType: oauthType,
CreatedAt: time.Now(),
}
s.sessionStore.Set(sessionID, session)
effectiveCfg, err := geminicli.EffectiveOAuthConfig(oauthCfg, oauthType)
if err != nil {
return nil, err
}
isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID &&
effectiveCfg.ClientSecret == geminicli.GeminiCLIOAuthClientSecret
// AI Studio OAuth requires a user-provided OAuth client (built-in Gemini CLI client is scope-restricted).
if oauthType == "ai_studio" && isBuiltinClient {
return nil, fmt.Errorf("AI Studio OAuth requires a custom OAuth Client (GEMINI_OAUTH_CLIENT_ID / GEMINI_OAUTH_CLIENT_SECRET). If you don't want to configure an OAuth client, please use an AI Studio API Key account instead")
}
// Redirect URI strategy:
// - code_assist: use Gemini CLI redirect URI (codeassist.google.com/authcode)
// - ai_studio: use localhost callback for manual copy/paste flow
if oauthType == "code_assist" {
redirectURI = geminicli.GeminiCLIRedirectURI
} else {
redirectURI = geminicli.AIStudioOAuthRedirectURI
}
session.RedirectURI = redirectURI
s.sessionStore.Set(sessionID, session)
authURL, err := geminicli.BuildAuthorizationURL(effectiveCfg, state, codeChallenge, redirectURI, session.ProjectID, oauthType)
if err != nil {
return nil, err
}
return &GeminiAuthURLResult{
AuthURL: authURL,
SessionID: sessionID,
State: state,
}, nil
}
type GeminiExchangeCodeInput struct {
SessionID string
State string
Code string
ProxyID *int64
OAuthType string // "code_assist" 或 "ai_studio"
}
type GeminiTokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
Extra map[string]any `json:"extra,omitempty"` // Drive metadata
}
// validateTierID validates tier_id format and length
func validateTierID(tierID string) error {
if tierID == "" {
return nil // Empty is allowed
}
if len(tierID) > 64 {
return fmt.Errorf("tier_id exceeds maximum length of 64 characters")
}
// Allow alphanumeric, underscore, hyphen, and slash (for tier paths)
if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) {
return fmt.Errorf("tier_id contains invalid characters")
}
return nil
}
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
// Prioritizes IsDefault tier, falls back to first non-empty tier
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
tierID := "LEGACY"
// First pass: look for default tier
for _, tier := range allowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
// Second pass: if still LEGACY, take first non-empty tier
if tierID == "LEGACY" {
for _, tier := range allowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
return tierID
}
// inferGoogleOneTier infers Google One tier from Drive storage limit
func inferGoogleOneTier(storageBytes int64) string {
if storageBytes <= 0 {
return TierGoogleOneUnknown
}
if storageBytes > StorageTierUnlimited {
return TierGoogleOneUnlimited
}
if storageBytes >= StorageTierAIPremium {
return TierAIPremium
}
if storageBytes >= StorageTierStandard {
return TierGoogleOneStandard
}
if storageBytes >= StorageTierBasic {
return TierGoogleOneBasic
}
if storageBytes >= StorageTierFree {
return TierFree
}
return TierGoogleOneUnknown
}
// fetchGoogleOneTier fetches Google One tier from Drive API
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
driveClient := geminicli.NewDriveClient()
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
if err != nil {
// Check if it's a 403 (scope not granted)
if strings.Contains(err.Error(), "status 403") {
fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err)
return TierGoogleOneUnknown, nil, err
}
// Other errors
fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err)
return TierGoogleOneUnknown, nil, err
}
tierID := inferGoogleOneTier(storageInfo.Limit)
return tierID, storageInfo, nil
}
// RefreshAccountGoogleOneTier 刷新单个账号的 Google One Tier
func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
ctx context.Context,
account *Account,
) (tierID string, extra map[string]any, credentials map[string]any, err error) {
if account == nil {
return "", nil, nil, fmt.Errorf("account is nil")
}
// 验证账号类型
oauthType, ok := account.Credentials["oauth_type"].(string)
if !ok || oauthType != "google_one" {
return "", nil, nil, fmt.Errorf("not a google_one OAuth account")
}
// 获取 access_token
accessToken, ok := account.Credentials["access_token"].(string)
if !ok || accessToken == "" {
return "", nil, nil, fmt.Errorf("missing access_token")
}
// 获取 proxy URL
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 调用 Drive API
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, accessToken, proxyURL)
if err != nil {
return "", nil, nil, err
}
// 构建 extra 数据(保留原有 extra 字段)
extra = make(map[string]any)
for k, v := range account.Extra {
extra[k] = v
}
if storageInfo != nil {
extra["drive_storage_limit"] = storageInfo.Limit
extra["drive_storage_usage"] = storageInfo.Usage
extra["drive_tier_updated_at"] = time.Now().Format(time.RFC3339)
}
// 构建 credentials 数据
credentials = make(map[string]any)
for k, v := range account.Credentials {
credentials[k] = v
}
credentials["tier_id"] = tierID
return tierID, extra, credentials, nil
}
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session not found or expired")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
return nil, fmt.Errorf("invalid state")
}
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
redirectURI := session.RedirectURI
// Resolve oauth_type early (defaults to code_assist for backward compatibility).
oauthType := session.OAuthType
if oauthType == "" {
oauthType = "code_assist"
}
// If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured.
if oauthType == "ai_studio" {
effectiveCfg, err := geminicli.EffectiveOAuthConfig(geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}, "ai_studio")
if err != nil {
return nil, err
}
isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID &&
effectiveCfg.ClientSecret == geminicli.GeminiCLIOAuthClientSecret
if isBuiltinClient {
return nil, fmt.Errorf("AI Studio OAuth requires a custom OAuth Client. Please use an AI Studio API Key account, or configure GEMINI_OAUTH_CLIENT_ID / GEMINI_OAUTH_CLIENT_SECRET and re-authorize")
}
}
// code_assist always uses the built-in client and its fixed redirect URI.
if oauthType == "code_assist" {
redirectURI = geminicli.GeminiCLIRedirectURI
}
tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
sessionProjectID := strings.TrimSpace(session.ProjectID)
s.sessionStore.Delete(input.SessionID)
// 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差)
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
const safetyWindow = 300 // 5 minutes
const minTTL = 30 // minimum 30 seconds
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
minExpiresAt := time.Now().Unix() + minTTL
if expiresAt < minExpiresAt {
expiresAt = minExpiresAt
}
projectID := sessionProjectID
var tierID string
// 对于 code_assist 模式project_id 是必需的,需要调用 Code Assist API
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id配额由 Google 网关自动识别
// 对于 ai_studio 模式project_id 是可选的(不影响使用 AI Studio API
switch oauthType {
case "code_assist":
if projectID == "" {
var err error
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
}
} else {
// 用户手动填了 project_id仍需调用 LoadCodeAssist 获取 tierID
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
} else {
tierID = fetchedTierID
}
}
if strings.TrimSpace(projectID) == "" {
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
}
// tierID 缺失时使用默认值
if tierID == "" {
tierID = "LEGACY"
}
case "google_one":
// Attempt to fetch Drive storage tier
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// Log warning but don't block - use fallback
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
tierID = TierGoogleOneUnknown
}
// Store Drive info in extra field for caching
if storageInfo != nil {
tokenInfo := &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType,
Extra: map[string]any{
"drive_storage_limit": storageInfo.Limit,
"drive_storage_usage": storageInfo.Usage,
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
},
}
return tokenInfo, nil
}
}
// ai_studio 模式不设置 tierID保持为空
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType,
}, nil
}
func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*GeminiTokenInfo, error) {
var lastErr error
for attempt := 0; attempt <= 3; attempt++ {
if attempt > 0 {
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
time.Sleep(backoff)
}
tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
if err == nil {
// 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差)
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
const safetyWindow = 300 // 5 minutes
const minTTL = 30 // minimum 30 seconds
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
minExpiresAt := time.Now().Unix() + minTTL
if expiresAt < minExpiresAt {
expiresAt = minExpiresAt
}
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
}, nil
}
if isNonRetryableGeminiOAuthError(err) {
return nil, err
}
lastErr = err
}
return nil, fmt.Errorf("token refresh failed after retries: %w", lastErr)
}
func isNonRetryableGeminiOAuthError(err error) bool {
msg := err.Error()
nonRetryable := []string{
"invalid_grant",
"invalid_client",
"unauthorized_client",
"access_denied",
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {
return true
}
}
return false
}
func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*GeminiTokenInfo, error) {
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("account is not a Gemini OAuth account")
}
refreshToken := account.GetCredential("refresh_token")
if strings.TrimSpace(refreshToken) == "" {
return nil, fmt.Errorf("no refresh token available")
}
// Preserve oauth_type from the account (defaults to code_assist for backward compatibility).
oauthType := strings.TrimSpace(account.GetCredential("oauth_type"))
if oauthType == "" {
oauthType = "code_assist"
}
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
tokenInfo, err := s.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
// Backward compatibility:
// Older versions could refresh Code Assist tokens using a user-provided OAuth client when configured.
// If the refresh token was originally issued to that custom client, forcing the built-in client will
// fail with "unauthorized_client". In that case, retry with the custom client (ai_studio path) when available.
if err != nil && oauthType == "code_assist" && strings.Contains(err.Error(), "unauthorized_client") && s.GetOAuthConfig().AIStudioOAuthEnabled {
if alt, altErr := s.RefreshToken(ctx, "ai_studio", refreshToken, proxyURL); altErr == nil {
tokenInfo = alt
err = nil
}
}
if err != nil {
// Provide a more actionable error for common OAuth client mismatch issues.
if strings.Contains(err.Error(), "unauthorized_client") {
return nil, fmt.Errorf("%w (OAuth client mismatch: the refresh_token is bound to the OAuth client used during authorization; please re-authorize this account or restore the original GEMINI_OAUTH_CLIENT_ID/SECRET)", err)
}
return nil, err
}
tokenInfo.OAuthType = oauthType
// Preserve account's project_id when present.
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
if existingProjectID != "" {
tokenInfo.ProjectID = existingProjectID
}
// 尝试从账号凭证获取 tierID向后兼容
existingTierID := strings.TrimSpace(account.GetCredential("tier_id"))
// For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh.
switch oauthType {
case "code_assist":
// 先设置默认值或保留旧值,确保 tier_id 始终有值
if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = "LEGACY" // 默认值
}
// 尝试自动探测 project_id 和 tier_id
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
if needDetect {
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: failed to auto-detect project/tier: %v\n", err)
} else {
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
tokenInfo.ProjectID = projectID
}
// 只有当原来没有 tier_id 且探测成功时才更新
if existingTierID == "" && tierID != "" {
tokenInfo.TierID = tierID
}
}
}
if strings.TrimSpace(tokenInfo.ProjectID) == "" {
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
}
case "google_one":
// Check if tier cache is stale (> 24 hours)
needsRefresh := true
if account.Extra != nil {
if updatedAtStr, ok := account.Extra["drive_tier_updated_at"].(string); ok {
if updatedAt, err := time.Parse(time.RFC3339, updatedAtStr); err == nil {
if time.Since(updatedAt) <= 24*time.Hour {
needsRefresh = false
// Use cached tier
if existingTierID != "" {
tokenInfo.TierID = existingTierID
}
}
}
}
}
if needsRefresh {
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL)
if err == nil && storageInfo != nil {
tokenInfo.TierID = tierID
tokenInfo.Extra = map[string]any{
"drive_storage_limit": storageInfo.Limit,
"drive_storage_usage": storageInfo.Usage,
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
}
} else {
// Fallback to cached or unknown
if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = TierGoogleOneUnknown
}
}
}
}
return tokenInfo, nil
}
func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) map[string]any {
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.TokenType != "" {
creds["token_type"] = tokenInfo.TokenType
}
if tokenInfo.Scope != "" {
creds["scope"] = tokenInfo.Scope
}
if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID
}
if tokenInfo.TierID != "" {
// Validate tier_id before storing
if err := validateTierID(tokenInfo.TierID); err == nil {
creds["tier_id"] = tokenInfo.TierID
}
// Silently skip invalid tier_id (don't block account creation)
}
if tokenInfo.OAuthType != "" {
creds["oauth_type"] = tokenInfo.OAuthType
}
// Store extra metadata (Drive info) if present
if len(tokenInfo.Extra) > 0 {
for k, v := range tokenInfo.Extra {
creds[k] = v
}
}
return creds
}
func (s *GeminiOAuthService) Stop() {
s.sessionStore.Stop()
}
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) {
if s.codeAssist == nil {
return "", "", errors.New("code assist client not configured")
}
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
tierID := "LEGACY"
if loadResp != nil {
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
}
// If LoadCodeAssist returned a project, use it
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
}
req := &geminicli.OnboardUserRequest{
TierID: tierID,
Metadata: geminicli.LoadCodeAssistMetadata{
IDEType: "ANTIGRAVITY",
Platform: "PLATFORM_UNSPECIFIED",
PluginType: "GEMINI",
},
}
maxAttempts := 5
for attempt := 1; attempt <= maxAttempts; attempt++ {
resp, err := s.codeAssist.OnboardUser(ctx, accessToken, proxyURL, req)
if err != nil {
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), tierID, nil
}
return "", tierID, err
}
if resp.Done {
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
switch v := resp.Response.CloudAICompanionProject.(type) {
case string:
return strings.TrimSpace(v), tierID, nil
case map[string]any:
if id, ok := v["id"].(string); ok {
return strings.TrimSpace(id), tierID, nil
}
}
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), tierID, nil
}
return "", tierID, errors.New("onboardUser completed but no project_id returned")
}
time.Sleep(2 * time.Second)
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), tierID, nil
}
if loadErr != nil {
return "", tierID, fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
}
return "", tierID, fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
}
type googleCloudProject struct {
ProjectID string `json:"projectId"`
DisplayName string `json:"name"`
LifecycleState string `json:"lifecycleState"`
}
type googleCloudProjectsResponse struct {
Projects []googleCloudProject `json:"projects"`
}
func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyURL string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
if err != nil {
return "", fmt.Errorf("failed to create resource manager request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("resource manager request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read resource manager response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("resource manager HTTP %d: %s", resp.StatusCode, string(bodyBytes))
}
var projectsResp googleCloudProjectsResponse
if err := json.Unmarshal(bodyBytes, &projectsResp); err != nil {
return "", fmt.Errorf("failed to parse resource manager response: %w", err)
}
active := make([]googleCloudProject, 0, len(projectsResp.Projects))
for _, p := range projectsResp.Projects {
if p.LifecycleState == "ACTIVE" && strings.TrimSpace(p.ProjectID) != "" {
active = append(active, p)
}
}
if len(active) == 0 {
return "", errors.New("no ACTIVE projects found from resource manager")
}
// Prefer likely companion projects first.
for _, p := range active {
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
if strings.Contains(id, "cloud-ai-companion") || strings.Contains(name, "cloud ai companion") || strings.Contains(name, "code assist") {
return strings.TrimSpace(p.ProjectID), nil
}
}
// Then prefer "default".
for _, p := range active {
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
if strings.Contains(id, "default") || strings.Contains(name, "default") {
return strings.TrimSpace(p.ProjectID), nil
}
}
return strings.TrimSpace(active[0].ProjectID), nil
}