fix: 恢复 Google One 功能兼容性
恢复 main 分支的 gemini_oauth_service.go 以保持与 Google One 功能的兼容性。 变更: - 添加 Google One tier 常量定义 - 添加存储空间 tier 阈值常量 - 支持 google_one OAuth 类型 - 包含 RefreshAccountGoogleOneTier 等 Google One 相关方法 原因: - atomic-scheduling 恢复时使用了旧版本的文件 - 需要保持与 main 分支 Google One 功能(PR #118)的兼容性 - 避免编译错误(handler 代码依赖这些方法)
This commit is contained in:
@@ -17,6 +17,26 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
)
|
||||
|
||||
const (
|
||||
TierAIPremium = "AI_PREMIUM"
|
||||
TierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
|
||||
TierGoogleOneBasic = "GOOGLE_ONE_BASIC"
|
||||
TierFree = "FREE"
|
||||
TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
|
||||
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
|
||||
)
|
||||
|
||||
const (
|
||||
GB = 1024 * 1024 * 1024
|
||||
TB = 1024 * GB
|
||||
|
||||
StorageTierUnlimited = 100 * TB // 100TB
|
||||
StorageTierAIPremium = 2 * TB // 2TB
|
||||
StorageTierStandard = 200 * GB // 200GB
|
||||
StorageTierBasic = 100 * GB // 100GB
|
||||
StorageTierFree = 15 * GB // 15GB
|
||||
)
|
||||
|
||||
type GeminiOAuthService struct {
|
||||
sessionStore *geminicli.SessionStore
|
||||
proxyRepo ProxyRepository
|
||||
@@ -89,13 +109,14 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
|
||||
// OAuth client selection:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
|
||||
// - google_one: same as code_assist, uses built-in client for personal Google accounts.
|
||||
// - ai_studio: requires a user-provided OAuth client.
|
||||
oauthCfg := geminicli.OAuthConfig{
|
||||
ClientID: s.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: s.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" {
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
oauthCfg.ClientID = ""
|
||||
oauthCfg.ClientSecret = ""
|
||||
}
|
||||
@@ -156,15 +177,16 @@ type GeminiExchangeCodeInput struct {
|
||||
}
|
||||
|
||||
type GeminiTokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
|
||||
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
|
||||
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
|
||||
Extra map[string]any `json:"extra,omitempty"` // Drive metadata
|
||||
}
|
||||
|
||||
// validateTierID validates tier_id format and length
|
||||
@@ -205,6 +227,104 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string
|
||||
return tierID
|
||||
}
|
||||
|
||||
// inferGoogleOneTier infers Google One tier from Drive storage limit
|
||||
func inferGoogleOneTier(storageBytes int64) string {
|
||||
if storageBytes <= 0 {
|
||||
return TierGoogleOneUnknown
|
||||
}
|
||||
|
||||
if storageBytes > StorageTierUnlimited {
|
||||
return TierGoogleOneUnlimited
|
||||
}
|
||||
if storageBytes >= StorageTierAIPremium {
|
||||
return TierAIPremium
|
||||
}
|
||||
if storageBytes >= StorageTierStandard {
|
||||
return TierGoogleOneStandard
|
||||
}
|
||||
if storageBytes >= StorageTierBasic {
|
||||
return TierGoogleOneBasic
|
||||
}
|
||||
if storageBytes >= StorageTierFree {
|
||||
return TierFree
|
||||
}
|
||||
return TierGoogleOneUnknown
|
||||
}
|
||||
|
||||
// fetchGoogleOneTier fetches Google One tier from Drive API
|
||||
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
|
||||
driveClient := geminicli.NewDriveClient()
|
||||
|
||||
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
// Check if it's a 403 (scope not granted)
|
||||
if strings.Contains(err.Error(), "status 403") {
|
||||
fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err)
|
||||
return TierGoogleOneUnknown, nil, err
|
||||
}
|
||||
// Other errors
|
||||
fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err)
|
||||
return TierGoogleOneUnknown, nil, err
|
||||
}
|
||||
|
||||
tierID := inferGoogleOneTier(storageInfo.Limit)
|
||||
return tierID, storageInfo, nil
|
||||
}
|
||||
|
||||
// RefreshAccountGoogleOneTier 刷新单个账号的 Google One Tier
|
||||
func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
) (tierID string, extra map[string]any, credentials map[string]any, err error) {
|
||||
if account == nil {
|
||||
return "", nil, nil, fmt.Errorf("account is nil")
|
||||
}
|
||||
|
||||
// 验证账号类型
|
||||
oauthType, ok := account.Credentials["oauth_type"].(string)
|
||||
if !ok || oauthType != "google_one" {
|
||||
return "", nil, nil, fmt.Errorf("not a google_one OAuth account")
|
||||
}
|
||||
|
||||
// 获取 access_token
|
||||
accessToken, ok := account.Credentials["access_token"].(string)
|
||||
if !ok || accessToken == "" {
|
||||
return "", nil, nil, fmt.Errorf("missing access_token")
|
||||
}
|
||||
|
||||
// 获取 proxy URL
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 调用 Drive API
|
||||
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
// 构建 extra 数据(保留原有 extra 字段)
|
||||
extra = make(map[string]any)
|
||||
for k, v := range account.Extra {
|
||||
extra[k] = v
|
||||
}
|
||||
if storageInfo != nil {
|
||||
extra["drive_storage_limit"] = storageInfo.Limit
|
||||
extra["drive_storage_usage"] = storageInfo.Usage
|
||||
extra["drive_tier_updated_at"] = time.Now().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// 构建 credentials 数据
|
||||
credentials = make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
credentials[k] = v
|
||||
}
|
||||
credentials["tier_id"] = tierID
|
||||
|
||||
return tierID, extra, credentials, nil
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
@@ -259,15 +379,24 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
sessionProjectID := strings.TrimSpace(session.ProjectID)
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
|
||||
// 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差)
|
||||
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
|
||||
const safetyWindow = 300 // 5 minutes
|
||||
const minTTL = 30 // minimum 30 seconds
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
|
||||
minExpiresAt := time.Now().Unix() + minTTL
|
||||
if expiresAt < minExpiresAt {
|
||||
expiresAt = minExpiresAt
|
||||
}
|
||||
|
||||
projectID := sessionProjectID
|
||||
var tierID string
|
||||
|
||||
// 对于 code_assist 模式,project_id 是必需的
|
||||
// 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API
|
||||
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别
|
||||
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
|
||||
if oauthType == "code_assist" {
|
||||
switch oauthType {
|
||||
case "code_assist":
|
||||
if projectID == "" {
|
||||
var err error
|
||||
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
@@ -275,11 +404,53 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
// 记录警告但不阻断流程,允许后续补充 project_id
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
// 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID
|
||||
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
|
||||
} else {
|
||||
tierID = fetchedTierID
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
|
||||
}
|
||||
// tierID 缺失时使用默认值
|
||||
if tierID == "" {
|
||||
tierID = "LEGACY"
|
||||
}
|
||||
case "google_one":
|
||||
// Attempt to fetch Drive storage tier
|
||||
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
// Log warning but don't block - use fallback
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
|
||||
tierID = TierGoogleOneUnknown
|
||||
}
|
||||
|
||||
// Store Drive info in extra field for caching
|
||||
if storageInfo != nil {
|
||||
tokenInfo := &GeminiTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: tokenResp.Scope,
|
||||
ProjectID: projectID,
|
||||
TierID: tierID,
|
||||
OAuthType: oauthType,
|
||||
Extra: map[string]any{
|
||||
"drive_storage_limit": storageInfo.Limit,
|
||||
"drive_storage_usage": storageInfo.Usage,
|
||||
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
return tokenInfo, nil
|
||||
}
|
||||
}
|
||||
// ai_studio 模式不设置 tierID,保持为空
|
||||
|
||||
return &GeminiTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
@@ -308,8 +479,15 @@ func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refres
|
||||
|
||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
|
||||
if err == nil {
|
||||
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
|
||||
// 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差)
|
||||
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
|
||||
const safetyWindow = 300 // 5 minutes
|
||||
const minTTL = 30 // minimum 30 seconds
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
|
||||
minExpiresAt := time.Now().Unix() + minTTL
|
||||
if expiresAt < minExpiresAt {
|
||||
expiresAt = minExpiresAt
|
||||
}
|
||||
return &GeminiTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
@@ -396,19 +574,75 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
}
|
||||
|
||||
// 尝试从账号凭证获取 tierID(向后兼容)
|
||||
existingTierID := strings.TrimSpace(account.GetCredential("tier_id"))
|
||||
|
||||
// For Code Assist, project_id is required. Auto-detect if missing.
|
||||
// For AI Studio OAuth, project_id is optional and should not block refresh.
|
||||
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
|
||||
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
|
||||
switch oauthType {
|
||||
case "code_assist":
|
||||
// 先设置默认值或保留旧值,确保 tier_id 始终有值
|
||||
if existingTierID != "" {
|
||||
tokenInfo.TierID = existingTierID
|
||||
} else {
|
||||
tokenInfo.TierID = "LEGACY" // 默认值
|
||||
}
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID == "" {
|
||||
|
||||
// 尝试自动探测 project_id 和 tier_id
|
||||
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
|
||||
if needDetect {
|
||||
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
fmt.Printf("[GeminiOAuth] Warning: failed to auto-detect project/tier: %v\n", err)
|
||||
} else {
|
||||
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
// 只有当原来没有 tier_id 且探测成功时才更新
|
||||
if existingTierID == "" && tierID != "" {
|
||||
tokenInfo.TierID = tierID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(tokenInfo.ProjectID) == "" {
|
||||
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
|
||||
}
|
||||
tokenInfo.ProjectID = projectID
|
||||
tokenInfo.TierID = tierID
|
||||
case "google_one":
|
||||
// Check if tier cache is stale (> 24 hours)
|
||||
needsRefresh := true
|
||||
if account.Extra != nil {
|
||||
if updatedAtStr, ok := account.Extra["drive_tier_updated_at"].(string); ok {
|
||||
if updatedAt, err := time.Parse(time.RFC3339, updatedAtStr); err == nil {
|
||||
if time.Since(updatedAt) <= 24*time.Hour {
|
||||
needsRefresh = false
|
||||
// Use cached tier
|
||||
if existingTierID != "" {
|
||||
tokenInfo.TierID = existingTierID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if needsRefresh {
|
||||
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err == nil && storageInfo != nil {
|
||||
tokenInfo.TierID = tierID
|
||||
tokenInfo.Extra = map[string]any{
|
||||
"drive_storage_limit": storageInfo.Limit,
|
||||
"drive_storage_usage": storageInfo.Usage,
|
||||
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
} else {
|
||||
// Fallback to cached or unknown
|
||||
if existingTierID != "" {
|
||||
tokenInfo.TierID = existingTierID
|
||||
} else {
|
||||
tokenInfo.TierID = TierGoogleOneUnknown
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
@@ -441,6 +675,12 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
|
||||
if tokenInfo.OAuthType != "" {
|
||||
creds["oauth_type"] = tokenInfo.OAuthType
|
||||
}
|
||||
// Store extra metadata (Drive info) if present
|
||||
if len(tokenInfo.Extra) > 0 {
|
||||
for k, v := range tokenInfo.Extra {
|
||||
creds[k] = v
|
||||
}
|
||||
}
|
||||
return creds
|
||||
}
|
||||
|
||||
@@ -466,9 +706,6 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
|
||||
}
|
||||
|
||||
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
|
||||
// (tierID already extracted above, reuse it)
|
||||
|
||||
req := &geminicli.OnboardUserRequest{
|
||||
TierID: tierID,
|
||||
Metadata: geminicli.LoadCodeAssistMetadata{
|
||||
@@ -487,7 +724,7 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), tierID, nil
|
||||
}
|
||||
return "", "", err
|
||||
return "", tierID, err
|
||||
}
|
||||
if resp.Done {
|
||||
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
|
||||
@@ -505,7 +742,7 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), tierID, nil
|
||||
}
|
||||
return "", "", errors.New("onboardUser completed but no project_id returned")
|
||||
return "", tierID, errors.New("onboardUser completed but no project_id returned")
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
@@ -515,9 +752,9 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
return strings.TrimSpace(fallback), tierID, nil
|
||||
}
|
||||
if loadErr != nil {
|
||||
return "", "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
|
||||
return "", tierID, fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
|
||||
}
|
||||
return "", "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
|
||||
return "", tierID, fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
|
||||
}
|
||||
|
||||
type googleCloudProject struct {
|
||||
|
||||
Reference in New Issue
Block a user