From c75c6b6858f47e5052bdf522c644020b0bd7755e Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 26 Feb 2026 10:53:04 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=B0=86=20DriveClient=20=E6=B3=A8?= =?UTF-8?q?=E5=85=A5=20GeminiOAuthService=EF=BC=8C=E6=B6=88=E9=99=A4?= =?UTF-8?q?=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95=E4=B8=AD=E7=9A=84=E7=9C=9F?= =?UTF-8?q?=E5=AE=9E=20HTTP=20=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FetchGoogleOneTier 原先在方法内部直接创建 DriveClient 实例, 导致单元测试中对 googleapis.com 发起真实 HTTP 请求,在 CI 环境 产生 401 错误。 将 DriveClient 作为依赖注入到 GeminiOAuthService,遵循项目 端口与适配器架构规范: - 新增 repository/gemini_drive_client.go 作为 Provider - 注册到 repository Wire ProviderSet - 测试中使用 mockDriveClient 替代真实调用 --- backend/cmd/server/wire_gen.go | 3 +- .../repository/gemini_drive_client.go | 9 +++ backend/internal/repository/wire.go | 1 + .../internal/service/gemini_oauth_service.go | 6 +- .../service/gemini_oauth_service_test.go | 56 +++++++++++-------- 5 files changed, 50 insertions(+), 25 deletions(-) create mode 100644 backend/internal/repository/gemini_drive_client.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 287f8176..888de4d3 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -113,7 +113,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() - geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig) + driveClient := repository.NewGeminiDriveClient() + geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig) antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) tempUnschedCache := repository.NewTempUnschedCache(redisClient) diff --git a/backend/internal/repository/gemini_drive_client.go b/backend/internal/repository/gemini_drive_client.go new file mode 100644 index 00000000..2e383595 --- /dev/null +++ b/backend/internal/repository/gemini_drive_client.go @@ -0,0 +1,9 @@ +package repository + +import "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + +// NewGeminiDriveClient creates a concrete DriveClient for Google Drive API operations. +// Returned as geminicli.DriveClient interface for DI (Strategy A). +func NewGeminiDriveClient() geminicli.DriveClient { + return geminicli.NewDriveClient() +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 0878c43d..eb8ce3fb 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -106,6 +106,7 @@ var ProviderSet = wire.NewSet( NewOpenAIOAuthClient, NewGeminiOAuthClient, NewGeminiCliCodeAssistClient, + NewGeminiDriveClient, ProvideEnt, ProvideSQLDB, diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 0b9734f6..e866bdc3 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -54,6 +54,7 @@ type GeminiOAuthService struct { proxyRepo ProxyRepository oauthClient GeminiOAuthClient codeAssist GeminiCliCodeAssistClient + driveClient geminicli.DriveClient cfg *config.Config } @@ -66,6 +67,7 @@ func NewGeminiOAuthService( proxyRepo ProxyRepository, oauthClient GeminiOAuthClient, codeAssist GeminiCliCodeAssistClient, + driveClient geminicli.DriveClient, cfg *config.Config, ) *GeminiOAuthService { return &GeminiOAuthService{ @@ -73,6 +75,7 @@ func NewGeminiOAuthService( proxyRepo: proxyRepo, oauthClient: oauthClient, codeAssist: codeAssist, + driveClient: driveClient, cfg: cfg, } } @@ -362,9 +365,8 @@ func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken // Use Drive API to infer tier from storage quota (requires drive.readonly scope) logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Calling Drive API for storage quota...") - driveClient := geminicli.NewDriveClient() - storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL) + storageInfo, err := s.driveClient.GetStorageQuota(ctx, accessToken, proxyURL) if err != nil { // Check if it's a 403 (scope not granted) if strings.Contains(err.Error(), "status 403") { diff --git a/backend/internal/service/gemini_oauth_service_test.go b/backend/internal/service/gemini_oauth_service_test.go index c58a5930..397b581d 100644 --- a/backend/internal/service/gemini_oauth_service_test.go +++ b/backend/internal/service/gemini_oauth_service_test.go @@ -101,7 +101,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg) + svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg) got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "") if tt.wantErrSubstr != "" { if err == nil { @@ -487,7 +487,7 @@ func TestIsNonRetryableGeminiOAuthError(t *testing.T) { func TestGeminiOAuthService_BuildAccountCredentials(t *testing.T) { t.Parallel() - svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) defer svc.Stop() t.Run("完整字段", func(t *testing.T) { @@ -687,7 +687,7 @@ func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg) + svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg) defer svc.Stop() result := svc.GetOAuthConfig() @@ -709,7 +709,7 @@ func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) { func TestGeminiOAuthService_Stop_NoPanic(t *testing.T) { t.Parallel() - svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) // 调用 Stop 不应 panic svc.Stop() @@ -806,6 +806,18 @@ func (m *mockGeminiProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, panic("not impl") } +// mockDriveClient implements geminicli.DriveClient for tests. +type mockDriveClient struct { + getStorageQuotaFunc func(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error) +} + +func (m *mockDriveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error) { + if m.getStorageQuotaFunc != nil { + return m.getStorageQuotaFunc(ctx, accessToken, proxyURL) + } + return nil, fmt.Errorf("drive API not available in test") +} + // ===================== // 新增测试:GeminiOAuthService.RefreshToken(含重试逻辑) // ===================== @@ -825,7 +837,7 @@ func TestGeminiOAuthService_RefreshToken_Success(t *testing.T) { }, } - svc := NewGeminiOAuthService(nil, client, nil, &config.Config{}) + svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{}) defer svc.Stop() info, err := svc.RefreshToken(context.Background(), "code_assist", "old-refresh", "") @@ -852,7 +864,7 @@ func TestGeminiOAuthService_RefreshToken_NonRetryableError(t *testing.T) { }, } - svc := NewGeminiOAuthService(nil, client, nil, &config.Config{}) + svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{}) defer svc.Stop() _, err := svc.RefreshToken(context.Background(), "code_assist", "revoked-token", "") @@ -881,7 +893,7 @@ func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) { }, } - svc := NewGeminiOAuthService(nil, client, nil, &config.Config{}) + svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{}) defer svc.Stop() info, err := svc.RefreshToken(context.Background(), "code_assist", "rt", "") @@ -903,7 +915,7 @@ func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) { func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) { t.Parallel() - svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) defer svc.Stop() account := &Account{ @@ -923,7 +935,7 @@ func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) { func TestGeminiOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) { t.Parallel() - svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) defer svc.Stop() account := &Account{ @@ -958,7 +970,7 @@ func TestGeminiOAuthService_RefreshAccountToken_AIStudio(t *testing.T) { }, } - svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) defer svc.Stop() account := &Account{ @@ -997,7 +1009,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_WithProjectID(t *test }, } - svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) defer svc.Stop() account := &Account{ @@ -1042,7 +1054,7 @@ func TestGeminiOAuthService_RefreshAccountToken_DefaultOAuthType(t *testing.T) { }, } - svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) defer svc.Stop() // 无 oauth_type 凭据的旧账号 @@ -1090,7 +1102,7 @@ func TestGeminiOAuthService_RefreshAccountToken_WithProxy(t *testing.T) { }, } - svc := NewGeminiOAuthService(proxyRepo, client, nil, &config.Config{}) + svc := NewGeminiOAuthService(proxyRepo, client, nil, nil, &config.Config{}) defer svc.Stop() proxyID := int64(5) @@ -1132,7 +1144,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_AutoDetec }, } - svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{}) + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{}) defer svc.Stop() account := &Account{ @@ -1181,7 +1193,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_FailsEmpt }, } - svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{}) + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{}) defer svc.Stop() account := &Account{ @@ -1214,7 +1226,7 @@ func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_FreshCache(t *testing. }, } - svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) defer svc.Stop() account := &Account{ @@ -1254,7 +1266,7 @@ func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_NoTierID_DefaultsFree( }, } - svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &mockDriveClient{}, &config.Config{}) defer svc.Stop() account := &Account{ @@ -1308,7 +1320,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_Fallback(t *t }, } - svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, cfg) + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, cfg) defer svc.Stop() account := &Account{ @@ -1341,7 +1353,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t } // 无自定义 OAuth 客户端,无法 fallback - svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) defer svc.Stop() account := &Account{ @@ -1370,7 +1382,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) { t.Parallel() - svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) defer svc.Stop() _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ @@ -1389,7 +1401,7 @@ func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) { func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) { t.Parallel() - svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) defer svc.Stop() // 手动创建 session(必须设置 CreatedAt,否则会因 TTL 过期被拒绝) @@ -1416,7 +1428,7 @@ func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) { func TestGeminiOAuthService_ExchangeCode_EmptyState(t *testing.T) { t.Parallel() - svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) defer svc.Stop() svc.sessionStore.Set("test-session", &geminicli.OAuthSession{