From 7568dc8500e81985da447ca1d5dde98442a67695 Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 1 Jan 2026 19:23:27 -0800 Subject: [PATCH] =?UTF-8?q?Reapply=20"feat(gateway):=20=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E8=B4=9F=E8=BD=BD=E6=84=9F=E7=9F=A5=E7=9A=84=E8=B4=A6=E5=8F=B7?= =?UTF-8?q?=E8=B0=83=E5=BA=A6=E4=BC=98=E5=8C=96=20(#114)"=20(#117)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit c5c12d4c8b44cbfecf2ee22ae3fd7810f724c638. --- .../antigravity/request_transformer_test.go | 6 +- .../internal/service/gemini_oauth_service.go | 299 ++---------------- 2 files changed, 34 insertions(+), 271 deletions(-) diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index ba07893f..56eebad0 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -96,7 +96,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) { { Type: "custom", Name: "mcp_tool", - Custom: &CustomToolSpec{ + Custom: &ClaudeCustomToolSpec{ Description: "MCP tool description", InputSchema: map[string]any{ "type": "object", @@ -121,7 +121,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) { { Type: "custom", Name: "custom_tool", - Custom: &CustomToolSpec{ + Custom: &ClaudeCustomToolSpec{ Description: "Custom tool", InputSchema: map[string]any{"type": "object"}, }, @@ -148,7 +148,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) { { Type: "custom", Name: "invalid_custom", - Custom: &CustomToolSpec{ + Custom: &ClaudeCustomToolSpec{ Description: "Invalid", // InputSchema 为 nil }, diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index e9ccae34..221bd0f2 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -17,26 +17,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" ) -const ( - TierAIPremium = "AI_PREMIUM" - TierGoogleOneStandard = "GOOGLE_ONE_STANDARD" - TierGoogleOneBasic = "GOOGLE_ONE_BASIC" - TierFree = "FREE" - TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN" - TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED" -) - -const ( - GB = 1024 * 1024 * 1024 - TB = 1024 * GB - - StorageTierUnlimited = 100 * TB // 100TB - StorageTierAIPremium = 2 * TB // 2TB - StorageTierStandard = 200 * GB // 200GB - StorageTierBasic = 100 * GB // 100GB - StorageTierFree = 15 * GB // 15GB -) - type GeminiOAuthService struct { sessionStore *geminicli.SessionStore proxyRepo ProxyRepository @@ -109,14 +89,13 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 // OAuth client selection: // - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret. - // - google_one: same as code_assist, uses built-in client for personal Google accounts. // - ai_studio: requires a user-provided OAuth client. oauthCfg := geminicli.OAuthConfig{ ClientID: s.cfg.Gemini.OAuth.ClientID, ClientSecret: s.cfg.Gemini.OAuth.ClientSecret, Scopes: s.cfg.Gemini.OAuth.Scopes, } - if oauthType == "code_assist" || oauthType == "google_one" { + if oauthType == "code_assist" { oauthCfg.ClientID = "" oauthCfg.ClientSecret = "" } @@ -177,16 +156,15 @@ type GeminiExchangeCodeInput struct { } type GeminiTokenInfo struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int64 `json:"expires_in"` - ExpiresAt int64 `json:"expires_at"` - TokenType string `json:"token_type"` - Scope string `json:"scope,omitempty"` - ProjectID string `json:"project_id,omitempty"` - OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" - TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA - Extra map[string]any `json:"extra,omitempty"` // Drive metadata + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + ExpiresAt int64 `json:"expires_at"` + TokenType string `json:"token_type"` + Scope string `json:"scope,omitempty"` + ProjectID string `json:"project_id,omitempty"` + OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" + TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA } // validateTierID validates tier_id format and length @@ -227,104 +205,6 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string return tierID } -// inferGoogleOneTier infers Google One tier from Drive storage limit -func inferGoogleOneTier(storageBytes int64) string { - if storageBytes <= 0 { - return TierGoogleOneUnknown - } - - if storageBytes > StorageTierUnlimited { - return TierGoogleOneUnlimited - } - if storageBytes >= StorageTierAIPremium { - return TierAIPremium - } - if storageBytes >= StorageTierStandard { - return TierGoogleOneStandard - } - if storageBytes >= StorageTierBasic { - return TierGoogleOneBasic - } - if storageBytes >= StorageTierFree { - return TierFree - } - return TierGoogleOneUnknown -} - -// fetchGoogleOneTier fetches Google One tier from Drive API -func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) { - driveClient := geminicli.NewDriveClient() - - storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL) - if err != nil { - // Check if it's a 403 (scope not granted) - if strings.Contains(err.Error(), "status 403") { - fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err) - return TierGoogleOneUnknown, nil, err - } - // Other errors - fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err) - return TierGoogleOneUnknown, nil, err - } - - tierID := inferGoogleOneTier(storageInfo.Limit) - return tierID, storageInfo, nil -} - -// RefreshAccountGoogleOneTier 刷新单个账号的 Google One Tier -func (s *GeminiOAuthService) RefreshAccountGoogleOneTier( - ctx context.Context, - account *Account, -) (tierID string, extra map[string]any, credentials map[string]any, err error) { - if account == nil { - return "", nil, nil, fmt.Errorf("account is nil") - } - - // 验证账号类型 - oauthType, ok := account.Credentials["oauth_type"].(string) - if !ok || oauthType != "google_one" { - return "", nil, nil, fmt.Errorf("not a google_one OAuth account") - } - - // 获取 access_token - accessToken, ok := account.Credentials["access_token"].(string) - if !ok || accessToken == "" { - return "", nil, nil, fmt.Errorf("missing access_token") - } - - // 获取 proxy URL - var proxyURL string - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 调用 Drive API - tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, accessToken, proxyURL) - if err != nil { - return "", nil, nil, err - } - - // 构建 extra 数据(保留原有 extra 字段) - extra = make(map[string]any) - for k, v := range account.Extra { - extra[k] = v - } - if storageInfo != nil { - extra["drive_storage_limit"] = storageInfo.Limit - extra["drive_storage_usage"] = storageInfo.Usage - extra["drive_tier_updated_at"] = time.Now().Format(time.RFC3339) - } - - // 构建 credentials 数据 - credentials = make(map[string]any) - for k, v := range account.Credentials { - credentials[k] = v - } - credentials["tier_id"] = tierID - - return tierID, extra, credentials, nil -} - func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { session, ok := s.sessionStore.Get(input.SessionID) if !ok { @@ -379,24 +259,15 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch sessionProjectID := strings.TrimSpace(session.ProjectID) s.sessionStore.Delete(input.SessionID) - // 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差) - // 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴) - const safetyWindow = 300 // 5 minutes - const minTTL = 30 // minimum 30 seconds - expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow - minExpiresAt := time.Now().Unix() + minTTL - if expiresAt < minExpiresAt { - expiresAt = minExpiresAt - } + // 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差 + expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 projectID := sessionProjectID var tierID string - // 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API - // 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别 + // 对于 code_assist 模式,project_id 是必需的 // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) - switch oauthType { - case "code_assist": + if oauthType == "code_assist" { if projectID == "" { var err error projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) @@ -404,53 +275,11 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch // 记录警告但不阻断流程,允许后续补充 project_id fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) } - } else { - // 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID - _, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) - if err != nil { - fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err) - } else { - tierID = fetchedTierID - } } if strings.TrimSpace(projectID) == "" { return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project") } - // tierID 缺失时使用默认值 - if tierID == "" { - tierID = "LEGACY" - } - case "google_one": - // Attempt to fetch Drive storage tier - tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL) - if err != nil { - // Log warning but don't block - use fallback - fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err) - tierID = TierGoogleOneUnknown - } - - // Store Drive info in extra field for caching - if storageInfo != nil { - tokenInfo := &GeminiTokenInfo{ - AccessToken: tokenResp.AccessToken, - RefreshToken: tokenResp.RefreshToken, - TokenType: tokenResp.TokenType, - ExpiresIn: tokenResp.ExpiresIn, - ExpiresAt: expiresAt, - Scope: tokenResp.Scope, - ProjectID: projectID, - TierID: tierID, - OAuthType: oauthType, - Extra: map[string]any{ - "drive_storage_limit": storageInfo.Limit, - "drive_storage_usage": storageInfo.Usage, - "drive_tier_updated_at": time.Now().Format(time.RFC3339), - }, - } - return tokenInfo, nil - } } - // ai_studio 模式不设置 tierID,保持为空 return &GeminiTokenInfo{ AccessToken: tokenResp.AccessToken, @@ -479,15 +308,8 @@ func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refres tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL) if err == nil { - // 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差) - // 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴) - const safetyWindow = 300 // 5 minutes - const minTTL = 30 // minimum 30 seconds - expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow - minExpiresAt := time.Now().Unix() + minTTL - if expiresAt < minExpiresAt { - expiresAt = minExpiresAt - } + // 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差 + expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 return &GeminiTokenInfo{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, @@ -574,75 +396,19 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A tokenInfo.ProjectID = existingProjectID } - // 尝试从账号凭证获取 tierID(向后兼容) - existingTierID := strings.TrimSpace(account.GetCredential("tier_id")) - // For Code Assist, project_id is required. Auto-detect if missing. // For AI Studio OAuth, project_id is optional and should not block refresh. - switch oauthType { - case "code_assist": - // 先设置默认值或保留旧值,确保 tier_id 始终有值 - if existingTierID != "" { - tokenInfo.TierID = existingTierID - } else { - tokenInfo.TierID = "LEGACY" // 默认值 + if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" { + projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) + if err != nil { + return nil, fmt.Errorf("failed to auto-detect project_id: %w", err) } - - // 尝试自动探测 project_id 和 tier_id - needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == "" - if needDetect { - projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) - if err != nil { - fmt.Printf("[GeminiOAuth] Warning: failed to auto-detect project/tier: %v\n", err) - } else { - if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" { - tokenInfo.ProjectID = projectID - } - // 只有当原来没有 tier_id 且探测成功时才更新 - if existingTierID == "" && tierID != "" { - tokenInfo.TierID = tierID - } - } - } - - if strings.TrimSpace(tokenInfo.ProjectID) == "" { + projectID = strings.TrimSpace(projectID) + if projectID == "" { return nil, fmt.Errorf("failed to auto-detect project_id: empty result") } - case "google_one": - // Check if tier cache is stale (> 24 hours) - needsRefresh := true - if account.Extra != nil { - if updatedAtStr, ok := account.Extra["drive_tier_updated_at"].(string); ok { - if updatedAt, err := time.Parse(time.RFC3339, updatedAtStr); err == nil { - if time.Since(updatedAt) <= 24*time.Hour { - needsRefresh = false - // Use cached tier - if existingTierID != "" { - tokenInfo.TierID = existingTierID - } - } - } - } - } - - if needsRefresh { - tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL) - if err == nil && storageInfo != nil { - tokenInfo.TierID = tierID - tokenInfo.Extra = map[string]any{ - "drive_storage_limit": storageInfo.Limit, - "drive_storage_usage": storageInfo.Usage, - "drive_tier_updated_at": time.Now().Format(time.RFC3339), - } - } else { - // Fallback to cached or unknown - if existingTierID != "" { - tokenInfo.TierID = existingTierID - } else { - tokenInfo.TierID = TierGoogleOneUnknown - } - } - } + tokenInfo.ProjectID = projectID + tokenInfo.TierID = tierID } return tokenInfo, nil @@ -675,12 +441,6 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) if tokenInfo.OAuthType != "" { creds["oauth_type"] = tokenInfo.OAuthType } - // Store extra metadata (Drive info) if present - if len(tokenInfo.Extra) > 0 { - for k, v := range tokenInfo.Extra { - creds[k] = v - } - } return creds } @@ -706,6 +466,9 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil } + // Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID. + // (tierID already extracted above, reuse it) + req := &geminicli.OnboardUserRequest{ TierID: tierID, Metadata: geminicli.LoadCodeAssistMetadata{ @@ -724,7 +487,7 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr if fbErr == nil && strings.TrimSpace(fallback) != "" { return strings.TrimSpace(fallback), tierID, nil } - return "", tierID, err + return "", "", err } if resp.Done { if resp.Response != nil && resp.Response.CloudAICompanionProject != nil { @@ -742,7 +505,7 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr if fbErr == nil && strings.TrimSpace(fallback) != "" { return strings.TrimSpace(fallback), tierID, nil } - return "", tierID, errors.New("onboardUser completed but no project_id returned") + return "", "", errors.New("onboardUser completed but no project_id returned") } time.Sleep(2 * time.Second) } @@ -752,9 +515,9 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr return strings.TrimSpace(fallback), tierID, nil } if loadErr != nil { - return "", tierID, fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) + return "", "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) } - return "", tierID, fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) + return "", "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) } type googleCloudProject struct {