Merge branch 'main' of https://github.com/mt21625457/aicodex2api
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@@ -18,12 +19,23 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
TierAIPremium = "AI_PREMIUM"
|
||||
TierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
|
||||
TierGoogleOneBasic = "GOOGLE_ONE_BASIC"
|
||||
TierFree = "FREE"
|
||||
TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
|
||||
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
|
||||
// Canonical tier IDs used by sub2api (2026-aligned).
|
||||
GeminiTierGoogleOneFree = "google_one_free"
|
||||
GeminiTierGoogleAIPro = "google_ai_pro"
|
||||
GeminiTierGoogleAIUltra = "google_ai_ultra"
|
||||
GeminiTierGCPStandard = "gcp_standard"
|
||||
GeminiTierGCPEnterprise = "gcp_enterprise"
|
||||
GeminiTierAIStudioFree = "aistudio_free"
|
||||
GeminiTierAIStudioPaid = "aistudio_paid"
|
||||
GeminiTierGoogleOneUnknown = "google_one_unknown"
|
||||
|
||||
// Legacy/compat tier IDs that may exist in historical data or upstream responses.
|
||||
legacyTierAIPremium = "AI_PREMIUM"
|
||||
legacyTierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
|
||||
legacyTierGoogleOneBasic = "GOOGLE_ONE_BASIC"
|
||||
legacyTierFree = "FREE"
|
||||
legacyTierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
|
||||
legacyTierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -84,7 +96,7 @@ type GeminiAuthURLResult struct {
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) {
|
||||
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType, tierID string) (*GeminiAuthURLResult, error) {
|
||||
state, err := geminicli.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
@@ -109,14 +121,14 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
|
||||
// OAuth client selection:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
|
||||
// - google_one: same as code_assist, uses built-in client for personal Google accounts.
|
||||
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client.
|
||||
// - ai_studio: requires a user-provided OAuth client.
|
||||
oauthCfg := geminicli.OAuthConfig{
|
||||
ClientID: s.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: s.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
if oauthType == "code_assist" {
|
||||
oauthCfg.ClientID = ""
|
||||
oauthCfg.ClientSecret = ""
|
||||
}
|
||||
@@ -127,6 +139,7 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
ProxyURL: proxyURL,
|
||||
RedirectURI: redirectURI,
|
||||
ProjectID: strings.TrimSpace(projectID),
|
||||
TierID: canonicalGeminiTierIDForOAuthType(oauthType, tierID),
|
||||
OAuthType: oauthType,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
@@ -146,9 +159,9 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
}
|
||||
|
||||
// Redirect URI strategy:
|
||||
// - code_assist: use Gemini CLI redirect URI (codeassist.google.com/authcode)
|
||||
// - ai_studio: use localhost callback for manual copy/paste flow
|
||||
if oauthType == "code_assist" {
|
||||
// - built-in Gemini CLI OAuth client: use upstream redirect URI (codeassist.google.com/authcode)
|
||||
// - custom OAuth client: use localhost callback for manual copy/paste flow
|
||||
if isBuiltinClient {
|
||||
redirectURI = geminicli.GeminiCLIRedirectURI
|
||||
} else {
|
||||
redirectURI = geminicli.AIStudioOAuthRedirectURI
|
||||
@@ -174,6 +187,9 @@ type GeminiExchangeCodeInput struct {
|
||||
Code string
|
||||
ProxyID *int64
|
||||
OAuthType string // "code_assist" 或 "ai_studio"
|
||||
// TierID is a user-selected tier to be used when auto detection is unavailable or fails.
|
||||
// If empty, the service will fall back to the tier stored in the OAuth session (if any).
|
||||
TierID string
|
||||
}
|
||||
|
||||
type GeminiTokenInfo struct {
|
||||
@@ -185,7 +201,7 @@ type GeminiTokenInfo struct {
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
|
||||
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
|
||||
TierID string `json:"tier_id,omitempty"` // Canonical tier id (e.g. google_one_free, gcp_standard, aistudio_free)
|
||||
Extra map[string]any `json:"extra,omitempty"` // Drive metadata
|
||||
}
|
||||
|
||||
@@ -204,6 +220,90 @@ func validateTierID(tierID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func canonicalGeminiTierID(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
lower := strings.ToLower(raw)
|
||||
switch lower {
|
||||
case GeminiTierGoogleOneFree,
|
||||
GeminiTierGoogleAIPro,
|
||||
GeminiTierGoogleAIUltra,
|
||||
GeminiTierGCPStandard,
|
||||
GeminiTierGCPEnterprise,
|
||||
GeminiTierAIStudioFree,
|
||||
GeminiTierAIStudioPaid,
|
||||
GeminiTierGoogleOneUnknown:
|
||||
return lower
|
||||
}
|
||||
|
||||
upper := strings.ToUpper(raw)
|
||||
switch upper {
|
||||
// Google One legacy tiers
|
||||
case legacyTierAIPremium:
|
||||
return GeminiTierGoogleAIPro
|
||||
case legacyTierGoogleOneUnlimited:
|
||||
return GeminiTierGoogleAIUltra
|
||||
case legacyTierFree, legacyTierGoogleOneBasic, legacyTierGoogleOneStandard:
|
||||
return GeminiTierGoogleOneFree
|
||||
case legacyTierGoogleOneUnknown:
|
||||
return GeminiTierGoogleOneUnknown
|
||||
|
||||
// Code Assist legacy tiers
|
||||
case "STANDARD", "PRO", "LEGACY":
|
||||
return GeminiTierGCPStandard
|
||||
case "ENTERPRISE", "ULTRA":
|
||||
return GeminiTierGCPEnterprise
|
||||
}
|
||||
|
||||
// Some Code Assist responses use kebab-case tier identifiers.
|
||||
switch lower {
|
||||
case "standard-tier", "pro-tier":
|
||||
return GeminiTierGCPStandard
|
||||
case "ultra-tier":
|
||||
return GeminiTierGCPEnterprise
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func canonicalGeminiTierIDForOAuthType(oauthType, tierID string) string {
|
||||
oauthType = strings.ToLower(strings.TrimSpace(oauthType))
|
||||
canonical := canonicalGeminiTierID(tierID)
|
||||
if canonical == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch oauthType {
|
||||
case "google_one":
|
||||
switch canonical {
|
||||
case GeminiTierGoogleOneFree, GeminiTierGoogleAIPro, GeminiTierGoogleAIUltra:
|
||||
return canonical
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
case "code_assist":
|
||||
switch canonical {
|
||||
case GeminiTierGCPStandard, GeminiTierGCPEnterprise:
|
||||
return canonical
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
case "ai_studio":
|
||||
switch canonical {
|
||||
case GeminiTierAIStudioFree, GeminiTierAIStudioPaid:
|
||||
return canonical
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
default:
|
||||
// Unknown oauth type: accept canonical tier.
|
||||
return canonical
|
||||
}
|
||||
}
|
||||
|
||||
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
|
||||
// Prioritizes IsDefault tier, falls back to first non-empty tier
|
||||
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
|
||||
@@ -229,45 +329,61 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string
|
||||
|
||||
// inferGoogleOneTier infers Google One tier from Drive storage limit
|
||||
func inferGoogleOneTier(storageBytes int64) string {
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB))
|
||||
|
||||
if storageBytes <= 0 {
|
||||
return TierGoogleOneUnknown
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN")
|
||||
return GeminiTierGoogleOneUnknown
|
||||
}
|
||||
|
||||
if storageBytes > StorageTierUnlimited {
|
||||
return TierGoogleOneUnlimited
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited)
|
||||
return GeminiTierGoogleAIUltra
|
||||
}
|
||||
if storageBytes >= StorageTierAIPremium {
|
||||
return TierAIPremium
|
||||
}
|
||||
if storageBytes >= StorageTierStandard {
|
||||
return TierGoogleOneStandard
|
||||
}
|
||||
if storageBytes >= StorageTierBasic {
|
||||
return TierGoogleOneBasic
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium)
|
||||
return GeminiTierGoogleAIPro
|
||||
}
|
||||
if storageBytes >= StorageTierFree {
|
||||
return TierFree
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree)
|
||||
return GeminiTierGoogleOneFree
|
||||
}
|
||||
return TierGoogleOneUnknown
|
||||
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree)
|
||||
return GeminiTierGoogleOneUnknown
|
||||
}
|
||||
|
||||
// fetchGoogleOneTier fetches Google One tier from Drive API
|
||||
// FetchGoogleOneTier fetches Google One tier from Drive API.
|
||||
// Note: LoadCodeAssist API is NOT called for Google One accounts because:
|
||||
// 1. It's designed for GCP IAM (enterprise), not personal Google accounts
|
||||
// 2. Personal accounts will get 403/404 from cloudaicompanion.googleapis.com
|
||||
// 3. Google consumer (Google One) and enterprise (GCP) systems are physically isolated
|
||||
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
|
||||
log.Printf("[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)")
|
||||
|
||||
// Use Drive API to infer tier from storage quota (requires drive.readonly scope)
|
||||
log.Printf("[GeminiOAuth] Calling Drive API for storage quota...")
|
||||
driveClient := geminicli.NewDriveClient()
|
||||
|
||||
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
// Check if it's a 403 (scope not granted)
|
||||
if strings.Contains(err.Error(), "status 403") {
|
||||
fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err)
|
||||
return TierGoogleOneUnknown, nil, err
|
||||
log.Printf("[GeminiOAuth] Drive API scope not available (403): %v", err)
|
||||
return GeminiTierGoogleOneUnknown, nil, err
|
||||
}
|
||||
// Other errors
|
||||
fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err)
|
||||
return TierGoogleOneUnknown, nil, err
|
||||
log.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v", err)
|
||||
return GeminiTierGoogleOneUnknown, nil, err
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
|
||||
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
|
||||
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
|
||||
|
||||
tierID := inferGoogleOneTier(storageInfo.Limit)
|
||||
log.Printf("[GeminiOAuth] Inferred tier from storage: %s", tierID)
|
||||
|
||||
return tierID, storageInfo, nil
|
||||
}
|
||||
|
||||
@@ -326,11 +442,16 @@ func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
|
||||
log.Printf("[GeminiOAuth] ========== ExchangeCode START ==========")
|
||||
log.Printf("[GeminiOAuth] SessionID: %s", input.SessionID)
|
||||
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
log.Printf("[GeminiOAuth] ERROR: Session not found or expired")
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
if strings.TrimSpace(input.State) == "" || input.State != session.State {
|
||||
log.Printf("[GeminiOAuth] ERROR: Invalid state")
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
}
|
||||
|
||||
@@ -341,6 +462,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiOAuth] ProxyURL: %s", proxyURL)
|
||||
|
||||
redirectURI := session.RedirectURI
|
||||
|
||||
@@ -349,6 +471,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
|
||||
log.Printf("[GeminiOAuth] Project ID from session: %s", session.ProjectID)
|
||||
|
||||
// If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured.
|
||||
if oauthType == "ai_studio" {
|
||||
@@ -374,8 +498,13 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[GeminiOAuth] ERROR: Failed to exchange code: %v", err)
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Token exchange successful")
|
||||
log.Printf("[GeminiOAuth] Token scope: %s", tokenResp.Scope)
|
||||
log.Printf("[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn)
|
||||
|
||||
sessionProjectID := strings.TrimSpace(session.ProjectID)
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
@@ -391,43 +520,91 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
|
||||
projectID := sessionProjectID
|
||||
var tierID string
|
||||
fallbackTierID := canonicalGeminiTierIDForOAuthType(oauthType, input.TierID)
|
||||
if fallbackTierID == "" {
|
||||
fallbackTierID = canonicalGeminiTierIDForOAuthType(oauthType, session.TierID)
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] ========== Account Type Detection START ==========")
|
||||
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
|
||||
|
||||
// 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API
|
||||
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别
|
||||
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
|
||||
switch oauthType {
|
||||
case "code_assist":
|
||||
log.Printf("[GeminiOAuth] Processing code_assist OAuth type")
|
||||
if projectID == "" {
|
||||
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
|
||||
var err error
|
||||
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
// 记录警告但不阻断流程,允许后续补充 project_id
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
|
||||
log.Printf("[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err)
|
||||
} else {
|
||||
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID)
|
||||
// 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID
|
||||
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
|
||||
log.Printf("[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err)
|
||||
} else {
|
||||
tierID = fetchedTierID
|
||||
log.Printf("[GeminiOAuth] Successfully fetched tier_id: %s", tierID)
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
log.Printf("[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth")
|
||||
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
|
||||
}
|
||||
// tierID 缺失时使用默认值
|
||||
// Prefer auto-detected tier; fall back to user-selected tier.
|
||||
tierID = canonicalGeminiTierIDForOAuthType(oauthType, tierID)
|
||||
if tierID == "" {
|
||||
tierID = "LEGACY"
|
||||
if fallbackTierID != "" {
|
||||
tierID = fallbackTierID
|
||||
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
|
||||
} else {
|
||||
tierID = GeminiTierGCPStandard
|
||||
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID)
|
||||
|
||||
case "google_one":
|
||||
log.Printf("[GeminiOAuth] Processing google_one OAuth type")
|
||||
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
|
||||
// Attempt to fetch Drive storage tier
|
||||
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
|
||||
var storageInfo *geminicli.DriveStorageInfo
|
||||
var err error
|
||||
tierID, storageInfo, err = s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
// Log warning but don't block - use fallback
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
|
||||
tierID = TierGoogleOneUnknown
|
||||
log.Printf("[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err)
|
||||
tierID = ""
|
||||
} else {
|
||||
log.Printf("[GeminiOAuth] Successfully fetched Drive tier: %s", tierID)
|
||||
if storageInfo != nil {
|
||||
log.Printf("[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
|
||||
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
|
||||
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
|
||||
}
|
||||
}
|
||||
tierID = canonicalGeminiTierIDForOAuthType(oauthType, tierID)
|
||||
if tierID == "" || tierID == GeminiTierGoogleOneUnknown {
|
||||
if fallbackTierID != "" {
|
||||
tierID = fallbackTierID
|
||||
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
|
||||
} else {
|
||||
tierID = GeminiTierGoogleOneFree
|
||||
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
|
||||
}
|
||||
}
|
||||
fmt.Printf("[GeminiOAuth] Google One tierID after normalization: %s\n", tierID)
|
||||
|
||||
// Store Drive info in extra field for caching
|
||||
if storageInfo != nil {
|
||||
@@ -447,12 +624,25 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
log.Printf("[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========")
|
||||
return tokenInfo, nil
|
||||
}
|
||||
}
|
||||
// ai_studio 模式不设置 tierID,保持为空
|
||||
|
||||
return &GeminiTokenInfo{
|
||||
case "ai_studio":
|
||||
// No automatic tier detection for AI Studio OAuth; rely on user selection.
|
||||
if fallbackTierID != "" {
|
||||
tierID = fallbackTierID
|
||||
} else {
|
||||
tierID = GeminiTierAIStudioFree
|
||||
}
|
||||
|
||||
default:
|
||||
log.Printf("[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType)
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] ========== Account Type Detection END ==========")
|
||||
|
||||
result := &GeminiTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
@@ -462,7 +652,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
ProjectID: projectID,
|
||||
TierID: tierID,
|
||||
OAuthType: oauthType,
|
||||
}, nil
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID)
|
||||
log.Printf("[GeminiOAuth] ========== ExchangeCode END ==========")
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*GeminiTokenInfo, error) {
|
||||
@@ -558,6 +751,17 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
// Backward compatibility for google_one:
|
||||
// - New behavior: when a custom OAuth client is configured, google_one will use it.
|
||||
// - Old behavior: google_one always used the built-in Gemini CLI OAuth client.
|
||||
// If an existing account was authorized with the built-in client, refreshing with the custom client
|
||||
// will fail with "unauthorized_client". Retry with the built-in client (code_assist path forces it).
|
||||
if err != nil && oauthType == "google_one" && strings.Contains(err.Error(), "unauthorized_client") && s.GetOAuthConfig().AIStudioOAuthEnabled {
|
||||
if alt, altErr := s.RefreshToken(ctx, "code_assist", refreshToken, proxyURL); altErr == nil {
|
||||
tokenInfo = alt
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Provide a more actionable error for common OAuth client mismatch issues.
|
||||
if strings.Contains(err.Error(), "unauthorized_client") {
|
||||
@@ -583,13 +787,14 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
case "code_assist":
|
||||
// 先设置默认值或保留旧值,确保 tier_id 始终有值
|
||||
if existingTierID != "" {
|
||||
tokenInfo.TierID = existingTierID
|
||||
} else {
|
||||
tokenInfo.TierID = "LEGACY" // 默认值
|
||||
tokenInfo.TierID = canonicalGeminiTierIDForOAuthType(oauthType, existingTierID)
|
||||
}
|
||||
if tokenInfo.TierID == "" {
|
||||
tokenInfo.TierID = GeminiTierGCPStandard
|
||||
}
|
||||
|
||||
// 尝试自动探测 project_id 和 tier_id
|
||||
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
|
||||
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || tokenInfo.TierID == ""
|
||||
if needDetect {
|
||||
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
@@ -598,9 +803,10 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
// 只有当原来没有 tier_id 且探测成功时才更新
|
||||
if existingTierID == "" && tierID != "" {
|
||||
tokenInfo.TierID = tierID
|
||||
if tierID != "" {
|
||||
if canonical := canonicalGeminiTierIDForOAuthType(oauthType, tierID); canonical != "" {
|
||||
tokenInfo.TierID = canonical
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -609,6 +815,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
|
||||
}
|
||||
case "google_one":
|
||||
canonicalExistingTier := canonicalGeminiTierIDForOAuthType(oauthType, existingTierID)
|
||||
// Check if tier cache is stale (> 24 hours)
|
||||
needsRefresh := true
|
||||
if account.Extra != nil {
|
||||
@@ -617,30 +824,37 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
if time.Since(updatedAt) <= 24*time.Hour {
|
||||
needsRefresh = false
|
||||
// Use cached tier
|
||||
if existingTierID != "" {
|
||||
tokenInfo.TierID = existingTierID
|
||||
}
|
||||
tokenInfo.TierID = canonicalExistingTier
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tokenInfo.TierID == "" {
|
||||
tokenInfo.TierID = canonicalExistingTier
|
||||
}
|
||||
|
||||
if needsRefresh {
|
||||
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err == nil && storageInfo != nil {
|
||||
tokenInfo.TierID = tierID
|
||||
tokenInfo.Extra = map[string]any{
|
||||
"drive_storage_limit": storageInfo.Limit,
|
||||
"drive_storage_usage": storageInfo.Usage,
|
||||
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||
if err == nil {
|
||||
if canonical := canonicalGeminiTierIDForOAuthType(oauthType, tierID); canonical != "" && canonical != GeminiTierGoogleOneUnknown {
|
||||
tokenInfo.TierID = canonical
|
||||
}
|
||||
if storageInfo != nil {
|
||||
tokenInfo.Extra = map[string]any{
|
||||
"drive_storage_limit": storageInfo.Limit,
|
||||
"drive_storage_usage": storageInfo.Usage,
|
||||
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tokenInfo.TierID == "" || tokenInfo.TierID == GeminiTierGoogleOneUnknown {
|
||||
if canonicalExistingTier != "" {
|
||||
tokenInfo.TierID = canonicalExistingTier
|
||||
} else {
|
||||
// Fallback to cached or unknown
|
||||
if existingTierID != "" {
|
||||
tokenInfo.TierID = existingTierID
|
||||
} else {
|
||||
tokenInfo.TierID = TierGoogleOneUnknown
|
||||
}
|
||||
tokenInfo.TierID = GeminiTierGoogleOneFree
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -669,6 +883,9 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
|
||||
// Validate tier_id before storing
|
||||
if err := validateTierID(tokenInfo.TierID); err == nil {
|
||||
creds["tier_id"] = tokenInfo.TierID
|
||||
fmt.Printf("[GeminiOAuth] Storing tier_id: %s\n", tokenInfo.TierID)
|
||||
} else {
|
||||
fmt.Printf("[GeminiOAuth] Invalid tier_id %s: %v\n", tokenInfo.TierID, err)
|
||||
}
|
||||
// Silently skip invalid tier_id (don't block account creation)
|
||||
}
|
||||
@@ -698,7 +915,13 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
|
||||
tierID := "LEGACY"
|
||||
if loadResp != nil {
|
||||
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
|
||||
// First try to get tier from currentTier/paidTier fields
|
||||
if tier := loadResp.GetTier(); tier != "" {
|
||||
tierID = tier
|
||||
} else {
|
||||
// Fallback to extracting from allowedTiers
|
||||
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
|
||||
}
|
||||
}
|
||||
|
||||
// If LoadCodeAssist returned a project, use it
|
||||
|
||||
Reference in New Issue
Block a user