fix: 将 DriveClient 注入 GeminiOAuthService,消除单元测试中的真实 HTTP 调用

FetchGoogleOneTier 原先在方法内部直接创建 DriveClient 实例,
导致单元测试中对 googleapis.com 发起真实 HTTP 请求,在 CI 环境
产生 401 错误。

将 DriveClient 作为依赖注入到 GeminiOAuthService,遵循项目
端口与适配器架构规范:
- 新增 repository/gemini_drive_client.go 作为 Provider
- 注册到 repository Wire ProviderSet
- 测试中使用 mockDriveClient 替代真实调用
This commit is contained in:
shaw
2026-02-26 10:53:04 +08:00
parent de61745bb2
commit c75c6b6858
5 changed files with 50 additions and 25 deletions

View File

@@ -113,7 +113,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig) driveClient := repository.NewGeminiDriveClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient) tempUnschedCache := repository.NewTempUnschedCache(redisClient)

View File

@@ -0,0 +1,9 @@
package repository
import "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
// NewGeminiDriveClient creates a concrete DriveClient for Google Drive API operations.
// Returned as geminicli.DriveClient interface for DI (Strategy A).
func NewGeminiDriveClient() geminicli.DriveClient {
return geminicli.NewDriveClient()
}

View File

@@ -106,6 +106,7 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthClient, NewOpenAIOAuthClient,
NewGeminiOAuthClient, NewGeminiOAuthClient,
NewGeminiCliCodeAssistClient, NewGeminiCliCodeAssistClient,
NewGeminiDriveClient,
ProvideEnt, ProvideEnt,
ProvideSQLDB, ProvideSQLDB,

View File

@@ -54,6 +54,7 @@ type GeminiOAuthService struct {
proxyRepo ProxyRepository proxyRepo ProxyRepository
oauthClient GeminiOAuthClient oauthClient GeminiOAuthClient
codeAssist GeminiCliCodeAssistClient codeAssist GeminiCliCodeAssistClient
driveClient geminicli.DriveClient
cfg *config.Config cfg *config.Config
} }
@@ -66,6 +67,7 @@ func NewGeminiOAuthService(
proxyRepo ProxyRepository, proxyRepo ProxyRepository,
oauthClient GeminiOAuthClient, oauthClient GeminiOAuthClient,
codeAssist GeminiCliCodeAssistClient, codeAssist GeminiCliCodeAssistClient,
driveClient geminicli.DriveClient,
cfg *config.Config, cfg *config.Config,
) *GeminiOAuthService { ) *GeminiOAuthService {
return &GeminiOAuthService{ return &GeminiOAuthService{
@@ -73,6 +75,7 @@ func NewGeminiOAuthService(
proxyRepo: proxyRepo, proxyRepo: proxyRepo,
oauthClient: oauthClient, oauthClient: oauthClient,
codeAssist: codeAssist, codeAssist: codeAssist,
driveClient: driveClient,
cfg: cfg, cfg: cfg,
} }
} }
@@ -362,9 +365,8 @@ func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken
// Use Drive API to infer tier from storage quota (requires drive.readonly scope) // Use Drive API to infer tier from storage quota (requires drive.readonly scope)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Calling Drive API for storage quota...") logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Calling Drive API for storage quota...")
driveClient := geminicli.NewDriveClient()
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL) storageInfo, err := s.driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
if err != nil { if err != nil {
// Check if it's a 403 (scope not granted) // Check if it's a 403 (scope not granted)
if strings.Contains(err.Error(), "status 403") { if strings.Contains(err.Error(), "status 403") {

View File

@@ -101,7 +101,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg) svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg)
got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "") got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "")
if tt.wantErrSubstr != "" { if tt.wantErrSubstr != "" {
if err == nil { if err == nil {
@@ -487,7 +487,7 @@ func TestIsNonRetryableGeminiOAuthError(t *testing.T) {
func TestGeminiOAuthService_BuildAccountCredentials(t *testing.T) { func TestGeminiOAuthService_BuildAccountCredentials(t *testing.T) {
t.Parallel() t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
t.Run("完整字段", func(t *testing.T) { t.Run("完整字段", func(t *testing.T) {
@@ -687,7 +687,7 @@ func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg) svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg)
defer svc.Stop() defer svc.Stop()
result := svc.GetOAuthConfig() result := svc.GetOAuthConfig()
@@ -709,7 +709,7 @@ func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) {
func TestGeminiOAuthService_Stop_NoPanic(t *testing.T) { func TestGeminiOAuthService_Stop_NoPanic(t *testing.T) {
t.Parallel() t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
// 调用 Stop 不应 panic // 调用 Stop 不应 panic
svc.Stop() svc.Stop()
@@ -806,6 +806,18 @@ func (m *mockGeminiProxyRepo) ListAccountSummariesByProxyID(ctx context.Context,
panic("not impl") panic("not impl")
} }
// mockDriveClient implements geminicli.DriveClient for tests.
type mockDriveClient struct {
getStorageQuotaFunc func(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error)
}
func (m *mockDriveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error) {
if m.getStorageQuotaFunc != nil {
return m.getStorageQuotaFunc(ctx, accessToken, proxyURL)
}
return nil, fmt.Errorf("drive API not available in test")
}
// ===================== // =====================
// 新增测试GeminiOAuthService.RefreshToken含重试逻辑 // 新增测试GeminiOAuthService.RefreshToken含重试逻辑
// ===================== // =====================
@@ -825,7 +837,7 @@ func TestGeminiOAuthService_RefreshToken_Success(t *testing.T) {
}, },
} }
svc := NewGeminiOAuthService(nil, client, nil, &config.Config{}) svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
info, err := svc.RefreshToken(context.Background(), "code_assist", "old-refresh", "") info, err := svc.RefreshToken(context.Background(), "code_assist", "old-refresh", "")
@@ -852,7 +864,7 @@ func TestGeminiOAuthService_RefreshToken_NonRetryableError(t *testing.T) {
}, },
} }
svc := NewGeminiOAuthService(nil, client, nil, &config.Config{}) svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
_, err := svc.RefreshToken(context.Background(), "code_assist", "revoked-token", "") _, err := svc.RefreshToken(context.Background(), "code_assist", "revoked-token", "")
@@ -881,7 +893,7 @@ func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) {
}, },
} }
svc := NewGeminiOAuthService(nil, client, nil, &config.Config{}) svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
info, err := svc.RefreshToken(context.Background(), "code_assist", "rt", "") info, err := svc.RefreshToken(context.Background(), "code_assist", "rt", "")
@@ -903,7 +915,7 @@ func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) {
func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) { func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) {
t.Parallel() t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
account := &Account{ account := &Account{
@@ -923,7 +935,7 @@ func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) {
func TestGeminiOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) { func TestGeminiOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) {
t.Parallel() t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
account := &Account{ account := &Account{
@@ -958,7 +970,7 @@ func TestGeminiOAuthService_RefreshAccountToken_AIStudio(t *testing.T) {
}, },
} }
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
account := &Account{ account := &Account{
@@ -997,7 +1009,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_WithProjectID(t *test
}, },
} }
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
account := &Account{ account := &Account{
@@ -1042,7 +1054,7 @@ func TestGeminiOAuthService_RefreshAccountToken_DefaultOAuthType(t *testing.T) {
}, },
} }
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
// 无 oauth_type 凭据的旧账号 // 无 oauth_type 凭据的旧账号
@@ -1090,7 +1102,7 @@ func TestGeminiOAuthService_RefreshAccountToken_WithProxy(t *testing.T) {
}, },
} }
svc := NewGeminiOAuthService(proxyRepo, client, nil, &config.Config{}) svc := NewGeminiOAuthService(proxyRepo, client, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
proxyID := int64(5) proxyID := int64(5)
@@ -1132,7 +1144,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_AutoDetec
}, },
} }
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{}) svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
account := &Account{ account := &Account{
@@ -1181,7 +1193,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_FailsEmpt
}, },
} }
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{}) svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
account := &Account{ account := &Account{
@@ -1214,7 +1226,7 @@ func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_FreshCache(t *testing.
}, },
} }
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
account := &Account{ account := &Account{
@@ -1254,7 +1266,7 @@ func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_NoTierID_DefaultsFree(
}, },
} }
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &mockDriveClient{}, &config.Config{})
defer svc.Stop() defer svc.Stop()
account := &Account{ account := &Account{
@@ -1308,7 +1320,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_Fallback(t *t
}, },
} }
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, cfg) svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, cfg)
defer svc.Stop() defer svc.Stop()
account := &Account{ account := &Account{
@@ -1341,7 +1353,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t
} }
// 无自定义 OAuth 客户端,无法 fallback // 无自定义 OAuth 客户端,无法 fallback
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
account := &Account{ account := &Account{
@@ -1370,7 +1382,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t
func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) { func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) {
t.Parallel() t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
_, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{
@@ -1389,7 +1401,7 @@ func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) {
func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) { func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) {
t.Parallel() t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
// 手动创建 session必须设置 CreatedAt否则会因 TTL 过期被拒绝) // 手动创建 session必须设置 CreatedAt否则会因 TTL 过期被拒绝)
@@ -1416,7 +1428,7 @@ func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) {
func TestGeminiOAuthService_ExchangeCode_EmptyState(t *testing.T) { func TestGeminiOAuthService_ExchangeCode_EmptyState(t *testing.T) {
t.Parallel() t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop() defer svc.Stop()
svc.sessionStore.Set("test-session", &geminicli.OAuthSession{ svc.sessionStore.Set("test-session", &geminicli.OAuthSession{