package service import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/service/ports" ) type GeminiOAuthService struct { sessionStore *geminicli.SessionStore proxyRepo ports.ProxyRepository oauthClient ports.GeminiOAuthClient codeAssist ports.GeminiCliCodeAssistClient cfg *config.Config } func NewGeminiOAuthService( proxyRepo ports.ProxyRepository, oauthClient ports.GeminiOAuthClient, codeAssist ports.GeminiCliCodeAssistClient, cfg *config.Config, ) *GeminiOAuthService { return &GeminiOAuthService{ sessionStore: geminicli.NewSessionStore(), proxyRepo: proxyRepo, oauthClient: oauthClient, codeAssist: codeAssist, cfg: cfg, } } type GeminiAuthURLResult struct { AuthURL string `json:"auth_url"` SessionID string `json:"session_id"` State string `json:"state"` } func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) { state, err := geminicli.GenerateState() if err != nil { return nil, fmt.Errorf("failed to generate state: %w", err) } codeVerifier, err := geminicli.GenerateCodeVerifier() if err != nil { return nil, fmt.Errorf("failed to generate code verifier: %w", err) } codeChallenge := geminicli.GenerateCodeChallenge(codeVerifier) sessionID, err := geminicli.GenerateSessionID() if err != nil { return nil, fmt.Errorf("failed to generate session ID: %w", err) } var proxyURL string if proxyID != nil { proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) if err == nil && proxy != nil { proxyURL = proxy.URL() } } // 两种 OAuth 模式都使用相同的配置,只是 scopes 不同 // scopes 会在 EffectiveOAuthConfig 中根据 oauthType 自动选择 oauthCfg := geminicli.OAuthConfig{ ClientID: s.cfg.Gemini.OAuth.ClientID, ClientSecret: s.cfg.Gemini.OAuth.ClientSecret, Scopes: s.cfg.Gemini.OAuth.Scopes, } session := &geminicli.OAuthSession{ State: state, CodeVerifier: codeVerifier, ProxyURL: proxyURL, RedirectURI: redirectURI, ProjectID: strings.TrimSpace(projectID), OAuthType: oauthType, CreatedAt: time.Now(), } s.sessionStore.Set(sessionID, session) effectiveCfg, err := geminicli.EffectiveOAuthConfig(oauthCfg, oauthType) if err != nil { return nil, err } // For Code Assist with Gemini CLI credentials, use the CLI's redirect URI if oauthType == "code_assist" { redirectURI = geminicli.GeminiCLIRedirectURI session.RedirectURI = redirectURI s.sessionStore.Set(sessionID, session) } authURL, err := geminicli.BuildAuthorizationURL(effectiveCfg, state, codeChallenge, redirectURI, session.ProjectID, oauthType) if err != nil { return nil, err } return &GeminiAuthURLResult{ AuthURL: authURL, SessionID: sessionID, State: state, }, nil } type GeminiExchangeCodeInput struct { SessionID string State string Code string ProxyID *int64 OAuthType string // "code_assist" 或 "ai_studio" } type GeminiTokenInfo struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int64 `json:"expires_in"` ExpiresAt int64 `json:"expires_at"` TokenType string `json:"token_type"` Scope string `json:"scope,omitempty"` ProjectID string `json:"project_id,omitempty"` OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" } func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { session, ok := s.sessionStore.Get(input.SessionID) if !ok { return nil, fmt.Errorf("session not found or expired") } if strings.TrimSpace(input.State) == "" || input.State != session.State { return nil, fmt.Errorf("invalid state") } proxyURL := session.ProxyURL if input.ProxyID != nil { proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID) if err == nil && proxy != nil { proxyURL = proxy.URL() } } redirectURI := session.RedirectURI tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL) if err != nil { return nil, fmt.Errorf("failed to exchange code: %w", err) } sessionProjectID := strings.TrimSpace(session.ProjectID) oauthType := session.OAuthType if oauthType == "" { oauthType = "code_assist" // 默认为 code_assist 以兼容旧 session } s.sessionStore.Delete(input.SessionID) // 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差 expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 projectID := sessionProjectID // 对于 code_assist 模式,project_id 是必需的 // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) if oauthType == "code_assist" { if projectID == "" { var err error projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { // 记录警告但不阻断流程,允许后续补充 project_id fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) } } if strings.TrimSpace(projectID) == "" { return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project") } } return &GeminiTokenInfo{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, TokenType: tokenResp.TokenType, ExpiresIn: tokenResp.ExpiresIn, ExpiresAt: expiresAt, Scope: tokenResp.Scope, ProjectID: projectID, OAuthType: oauthType, }, nil } func (s *GeminiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*GeminiTokenInfo, error) { var lastErr error for attempt := 0; attempt <= 3; attempt++ { if attempt > 0 { backoff := time.Duration(1< 30*time.Second { backoff = 30 * time.Second } time.Sleep(backoff) } tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) if err == nil { // 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差 expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 return &GeminiTokenInfo{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, TokenType: tokenResp.TokenType, ExpiresIn: tokenResp.ExpiresIn, ExpiresAt: expiresAt, Scope: tokenResp.Scope, }, nil } if isNonRetryableGeminiOAuthError(err) { return nil, err } lastErr = err } return nil, fmt.Errorf("token refresh failed after retries: %w", lastErr) } func isNonRetryableGeminiOAuthError(err error) bool { msg := err.Error() nonRetryable := []string{ "invalid_grant", "invalid_client", "unauthorized_client", "access_denied", } for _, needle := range nonRetryable { if strings.Contains(msg, needle) { return true } } return false } func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*GeminiTokenInfo, error) { if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth { return nil, fmt.Errorf("account is not a Gemini OAuth account") } refreshToken := account.GetCredential("refresh_token") if strings.TrimSpace(refreshToken) == "" { return nil, fmt.Errorf("no refresh token available") } var proxyURL string if account.ProxyID != nil { proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID) if err == nil && proxy != nil { proxyURL = proxy.URL() } } tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL) if err != nil { return nil, err } // Preserve oauth_type from the account (defaults to code_assist for backward compatibility). oauthType := strings.TrimSpace(account.GetCredential("oauth_type")) if oauthType == "" { oauthType = "code_assist" } tokenInfo.OAuthType = oauthType // Preserve account's project_id when present. existingProjectID := strings.TrimSpace(account.GetCredential("project_id")) if existingProjectID != "" { tokenInfo.ProjectID = existingProjectID } // For Code Assist, project_id is required. Auto-detect if missing. // For AI Studio OAuth, project_id is optional and should not block refresh. if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" { projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) if err != nil { return nil, fmt.Errorf("failed to auto-detect project_id: %w", err) } projectID = strings.TrimSpace(projectID) if projectID == "" { return nil, fmt.Errorf("failed to auto-detect project_id: empty result") } tokenInfo.ProjectID = projectID } return tokenInfo, nil } func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) map[string]any { creds := map[string]any{ "access_token": tokenInfo.AccessToken, "expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10), } if tokenInfo.RefreshToken != "" { creds["refresh_token"] = tokenInfo.RefreshToken } if tokenInfo.TokenType != "" { creds["token_type"] = tokenInfo.TokenType } if tokenInfo.Scope != "" { creds["scope"] = tokenInfo.Scope } if tokenInfo.ProjectID != "" { creds["project_id"] = tokenInfo.ProjectID } if tokenInfo.OAuthType != "" { creds["oauth_type"] = tokenInfo.OAuthType } return creds } func (s *GeminiOAuthService) Stop() { s.sessionStore.Stop() } func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) { if s.codeAssist == nil { return "", errors.New("code assist client not configured") } loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil) if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" { return strings.TrimSpace(loadResp.CloudAICompanionProject), nil } // Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID. tierID := "LEGACY" if loadResp != nil { for _, tier := range loadResp.AllowedTiers { if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { tierID = strings.TrimSpace(tier.ID) break } } if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" { for _, tier := range loadResp.AllowedTiers { if strings.TrimSpace(tier.ID) != "" { tierID = strings.TrimSpace(tier.ID) break } } } } req := &geminicli.OnboardUserRequest{ TierID: tierID, Metadata: geminicli.LoadCodeAssistMetadata{ IDEType: "ANTIGRAVITY", Platform: "PLATFORM_UNSPECIFIED", PluginType: "GEMINI", }, } maxAttempts := 5 for attempt := 1; attempt <= maxAttempts; attempt++ { resp, err := s.codeAssist.OnboardUser(ctx, accessToken, proxyURL, req) if err != nil { // If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects. fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { return strings.TrimSpace(fallback), nil } return "", err } if resp.Done { if resp.Response != nil && resp.Response.CloudAICompanionProject != nil { switch v := resp.Response.CloudAICompanionProject.(type) { case string: return strings.TrimSpace(v), nil case map[string]any: if id, ok := v["id"].(string); ok { return strings.TrimSpace(id), nil } } } fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { return strings.TrimSpace(fallback), nil } return "", errors.New("onboardUser completed but no project_id returned") } time.Sleep(2 * time.Second) } fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { return strings.TrimSpace(fallback), nil } if loadErr != nil { return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) } return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) } type googleCloudProject struct { ProjectID string `json:"projectId"` DisplayName string `json:"name"` LifecycleState string `json:"lifecycleState"` } type googleCloudProjectsResponse struct { Projects []googleCloudProject `json:"projects"` } func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyURL string) (string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil) if err != nil { return "", fmt.Errorf("failed to create resource manager request: %w", err) } req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) client := &http.Client{Timeout: 30 * time.Second} if strings.TrimSpace(proxyURL) != "" { if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil { client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)} } } resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("resource manager request failed: %w", err) } defer func() { _ = resp.Body.Close() }() bodyBytes, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("failed to read resource manager response: %w", err) } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("resource manager HTTP %d: %s", resp.StatusCode, string(bodyBytes)) } var projectsResp googleCloudProjectsResponse if err := json.Unmarshal(bodyBytes, &projectsResp); err != nil { return "", fmt.Errorf("failed to parse resource manager response: %w", err) } active := make([]googleCloudProject, 0, len(projectsResp.Projects)) for _, p := range projectsResp.Projects { if p.LifecycleState == "ACTIVE" && strings.TrimSpace(p.ProjectID) != "" { active = append(active, p) } } if len(active) == 0 { return "", errors.New("no ACTIVE projects found from resource manager") } // Prefer likely companion projects first. for _, p := range active { id := strings.ToLower(strings.TrimSpace(p.ProjectID)) name := strings.ToLower(strings.TrimSpace(p.DisplayName)) if strings.Contains(id, "cloud-ai-companion") || strings.Contains(name, "cloud ai companion") || strings.Contains(name, "code assist") { return strings.TrimSpace(p.ProjectID), nil } } // Then prefer "default". for _, p := range active { id := strings.ToLower(strings.TrimSpace(p.ProjectID)) name := strings.ToLower(strings.TrimSpace(p.DisplayName)) if strings.Contains(id, "default") || strings.Contains(name, "default") { return strings.TrimSpace(p.ProjectID), nil } } return strings.TrimSpace(active[0].ProjectID), nil }