diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 221bd0f2..e9ccae34 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -17,6 +17,26 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" ) +const ( + TierAIPremium = "AI_PREMIUM" + TierGoogleOneStandard = "GOOGLE_ONE_STANDARD" + TierGoogleOneBasic = "GOOGLE_ONE_BASIC" + TierFree = "FREE" + TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN" + TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED" +) + +const ( + GB = 1024 * 1024 * 1024 + TB = 1024 * GB + + StorageTierUnlimited = 100 * TB // 100TB + StorageTierAIPremium = 2 * TB // 2TB + StorageTierStandard = 200 * GB // 200GB + StorageTierBasic = 100 * GB // 100GB + StorageTierFree = 15 * GB // 15GB +) + type GeminiOAuthService struct { sessionStore *geminicli.SessionStore proxyRepo ProxyRepository @@ -89,13 +109,14 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 // OAuth client selection: // - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret. + // - google_one: same as code_assist, uses built-in client for personal Google accounts. // - ai_studio: requires a user-provided OAuth client. oauthCfg := geminicli.OAuthConfig{ ClientID: s.cfg.Gemini.OAuth.ClientID, ClientSecret: s.cfg.Gemini.OAuth.ClientSecret, Scopes: s.cfg.Gemini.OAuth.Scopes, } - if oauthType == "code_assist" { + if oauthType == "code_assist" || oauthType == "google_one" { oauthCfg.ClientID = "" oauthCfg.ClientSecret = "" } @@ -156,15 +177,16 @@ type GeminiExchangeCodeInput struct { } type GeminiTokenInfo struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int64 `json:"expires_in"` - ExpiresAt int64 `json:"expires_at"` - TokenType string `json:"token_type"` - Scope string `json:"scope,omitempty"` - ProjectID string `json:"project_id,omitempty"` - OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" - TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + ExpiresAt int64 `json:"expires_at"` + TokenType string `json:"token_type"` + Scope string `json:"scope,omitempty"` + ProjectID string `json:"project_id,omitempty"` + OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" + TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA + Extra map[string]any `json:"extra,omitempty"` // Drive metadata } // validateTierID validates tier_id format and length @@ -205,6 +227,104 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string return tierID } +// inferGoogleOneTier infers Google One tier from Drive storage limit +func inferGoogleOneTier(storageBytes int64) string { + if storageBytes <= 0 { + return TierGoogleOneUnknown + } + + if storageBytes > StorageTierUnlimited { + return TierGoogleOneUnlimited + } + if storageBytes >= StorageTierAIPremium { + return TierAIPremium + } + if storageBytes >= StorageTierStandard { + return TierGoogleOneStandard + } + if storageBytes >= StorageTierBasic { + return TierGoogleOneBasic + } + if storageBytes >= StorageTierFree { + return TierFree + } + return TierGoogleOneUnknown +} + +// fetchGoogleOneTier fetches Google One tier from Drive API +func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) { + driveClient := geminicli.NewDriveClient() + + storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL) + if err != nil { + // Check if it's a 403 (scope not granted) + if strings.Contains(err.Error(), "status 403") { + fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err) + return TierGoogleOneUnknown, nil, err + } + // Other errors + fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err) + return TierGoogleOneUnknown, nil, err + } + + tierID := inferGoogleOneTier(storageInfo.Limit) + return tierID, storageInfo, nil +} + +// RefreshAccountGoogleOneTier 刷新单个账号的 Google One Tier +func (s *GeminiOAuthService) RefreshAccountGoogleOneTier( + ctx context.Context, + account *Account, +) (tierID string, extra map[string]any, credentials map[string]any, err error) { + if account == nil { + return "", nil, nil, fmt.Errorf("account is nil") + } + + // 验证账号类型 + oauthType, ok := account.Credentials["oauth_type"].(string) + if !ok || oauthType != "google_one" { + return "", nil, nil, fmt.Errorf("not a google_one OAuth account") + } + + // 获取 access_token + accessToken, ok := account.Credentials["access_token"].(string) + if !ok || accessToken == "" { + return "", nil, nil, fmt.Errorf("missing access_token") + } + + // 获取 proxy URL + var proxyURL string + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 调用 Drive API + tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, accessToken, proxyURL) + if err != nil { + return "", nil, nil, err + } + + // 构建 extra 数据(保留原有 extra 字段) + extra = make(map[string]any) + for k, v := range account.Extra { + extra[k] = v + } + if storageInfo != nil { + extra["drive_storage_limit"] = storageInfo.Limit + extra["drive_storage_usage"] = storageInfo.Usage + extra["drive_tier_updated_at"] = time.Now().Format(time.RFC3339) + } + + // 构建 credentials 数据 + credentials = make(map[string]any) + for k, v := range account.Credentials { + credentials[k] = v + } + credentials["tier_id"] = tierID + + return tierID, extra, credentials, nil +} + func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { session, ok := s.sessionStore.Get(input.SessionID) if !ok { @@ -259,15 +379,24 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch sessionProjectID := strings.TrimSpace(session.ProjectID) s.sessionStore.Delete(input.SessionID) - // 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差 - expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 + // 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差) + // 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴) + const safetyWindow = 300 // 5 minutes + const minTTL = 30 // minimum 30 seconds + expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow + minExpiresAt := time.Now().Unix() + minTTL + if expiresAt < minExpiresAt { + expiresAt = minExpiresAt + } projectID := sessionProjectID var tierID string - // 对于 code_assist 模式,project_id 是必需的 + // 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API + // 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别 // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) - if oauthType == "code_assist" { + switch oauthType { + case "code_assist": if projectID == "" { var err error projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) @@ -275,11 +404,53 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch // 记录警告但不阻断流程,允许后续补充 project_id fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) } + } else { + // 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID + _, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + if err != nil { + fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err) + } else { + tierID = fetchedTierID + } } if strings.TrimSpace(projectID) == "" { return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project") } + // tierID 缺失时使用默认值 + if tierID == "" { + tierID = "LEGACY" + } + case "google_one": + // Attempt to fetch Drive storage tier + tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL) + if err != nil { + // Log warning but don't block - use fallback + fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err) + tierID = TierGoogleOneUnknown + } + + // Store Drive info in extra field for caching + if storageInfo != nil { + tokenInfo := &GeminiTokenInfo{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + TokenType: tokenResp.TokenType, + ExpiresIn: tokenResp.ExpiresIn, + ExpiresAt: expiresAt, + Scope: tokenResp.Scope, + ProjectID: projectID, + TierID: tierID, + OAuthType: oauthType, + Extra: map[string]any{ + "drive_storage_limit": storageInfo.Limit, + "drive_storage_usage": storageInfo.Usage, + "drive_tier_updated_at": time.Now().Format(time.RFC3339), + }, + } + return tokenInfo, nil + } } + // ai_studio 模式不设置 tierID,保持为空 return &GeminiTokenInfo{ AccessToken: tokenResp.AccessToken, @@ -308,8 +479,15 @@ func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refres tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL) if err == nil { - // 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差 - expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 + // 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差) + // 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴) + const safetyWindow = 300 // 5 minutes + const minTTL = 30 // minimum 30 seconds + expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow + minExpiresAt := time.Now().Unix() + minTTL + if expiresAt < minExpiresAt { + expiresAt = minExpiresAt + } return &GeminiTokenInfo{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, @@ -396,19 +574,75 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A tokenInfo.ProjectID = existingProjectID } + // 尝试从账号凭证获取 tierID(向后兼容) + existingTierID := strings.TrimSpace(account.GetCredential("tier_id")) + // For Code Assist, project_id is required. Auto-detect if missing. // For AI Studio OAuth, project_id is optional and should not block refresh. - if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" { - projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) - if err != nil { - return nil, fmt.Errorf("failed to auto-detect project_id: %w", err) + switch oauthType { + case "code_assist": + // 先设置默认值或保留旧值,确保 tier_id 始终有值 + if existingTierID != "" { + tokenInfo.TierID = existingTierID + } else { + tokenInfo.TierID = "LEGACY" // 默认值 } - projectID = strings.TrimSpace(projectID) - if projectID == "" { + + // 尝试自动探测 project_id 和 tier_id + needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == "" + if needDetect { + projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) + if err != nil { + fmt.Printf("[GeminiOAuth] Warning: failed to auto-detect project/tier: %v\n", err) + } else { + if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" { + tokenInfo.ProjectID = projectID + } + // 只有当原来没有 tier_id 且探测成功时才更新 + if existingTierID == "" && tierID != "" { + tokenInfo.TierID = tierID + } + } + } + + if strings.TrimSpace(tokenInfo.ProjectID) == "" { return nil, fmt.Errorf("failed to auto-detect project_id: empty result") } - tokenInfo.ProjectID = projectID - tokenInfo.TierID = tierID + case "google_one": + // Check if tier cache is stale (> 24 hours) + needsRefresh := true + if account.Extra != nil { + if updatedAtStr, ok := account.Extra["drive_tier_updated_at"].(string); ok { + if updatedAt, err := time.Parse(time.RFC3339, updatedAtStr); err == nil { + if time.Since(updatedAt) <= 24*time.Hour { + needsRefresh = false + // Use cached tier + if existingTierID != "" { + tokenInfo.TierID = existingTierID + } + } + } + } + } + + if needsRefresh { + tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL) + if err == nil && storageInfo != nil { + tokenInfo.TierID = tierID + tokenInfo.Extra = map[string]any{ + "drive_storage_limit": storageInfo.Limit, + "drive_storage_usage": storageInfo.Usage, + "drive_tier_updated_at": time.Now().Format(time.RFC3339), + } + } else { + // Fallback to cached or unknown + if existingTierID != "" { + tokenInfo.TierID = existingTierID + } else { + tokenInfo.TierID = TierGoogleOneUnknown + } + } + } } return tokenInfo, nil @@ -441,6 +675,12 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) if tokenInfo.OAuthType != "" { creds["oauth_type"] = tokenInfo.OAuthType } + // Store extra metadata (Drive info) if present + if len(tokenInfo.Extra) > 0 { + for k, v := range tokenInfo.Extra { + creds[k] = v + } + } return creds } @@ -466,9 +706,6 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil } - // Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID. - // (tierID already extracted above, reuse it) - req := &geminicli.OnboardUserRequest{ TierID: tierID, Metadata: geminicli.LoadCodeAssistMetadata{ @@ -487,7 +724,7 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr if fbErr == nil && strings.TrimSpace(fallback) != "" { return strings.TrimSpace(fallback), tierID, nil } - return "", "", err + return "", tierID, err } if resp.Done { if resp.Response != nil && resp.Response.CloudAICompanionProject != nil { @@ -505,7 +742,7 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr if fbErr == nil && strings.TrimSpace(fallback) != "" { return strings.TrimSpace(fallback), tierID, nil } - return "", "", errors.New("onboardUser completed but no project_id returned") + return "", tierID, errors.New("onboardUser completed but no project_id returned") } time.Sleep(2 * time.Second) } @@ -515,9 +752,9 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr return strings.TrimSpace(fallback), tierID, nil } if loadErr != nil { - return "", "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) + return "", tierID, fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) } - return "", "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) + return "", tierID, fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) } type googleCloudProject struct {