fix: 将 DriveClient 注入 GeminiOAuthService,消除单元测试中的真实 HTTP 调用

FetchGoogleOneTier 原先在方法内部直接创建 DriveClient 实例,
导致单元测试中对 googleapis.com 发起真实 HTTP 请求,在 CI 环境
产生 401 错误。

将 DriveClient 作为依赖注入到 GeminiOAuthService,遵循项目
端口与适配器架构规范:
- 新增 repository/gemini_drive_client.go 作为 Provider
- 注册到 repository Wire ProviderSet
- 测试中使用 mockDriveClient 替代真实调用
This commit is contained in:
shaw
2026-02-26 10:53:04 +08:00
parent de61745bb2
commit c75c6b6858
5 changed files with 50 additions and 25 deletions

View File

@@ -113,7 +113,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
driveClient := repository.NewGeminiDriveClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient)

View File

@@ -0,0 +1,9 @@
package repository
import "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
// NewGeminiDriveClient creates a concrete DriveClient for Google Drive API operations.
// Returned as geminicli.DriveClient interface for DI (Strategy A).
func NewGeminiDriveClient() geminicli.DriveClient {
return geminicli.NewDriveClient()
}

View File

@@ -106,6 +106,7 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthClient,
NewGeminiOAuthClient,
NewGeminiCliCodeAssistClient,
NewGeminiDriveClient,
ProvideEnt,
ProvideSQLDB,

View File

@@ -54,6 +54,7 @@ type GeminiOAuthService struct {
proxyRepo ProxyRepository
oauthClient GeminiOAuthClient
codeAssist GeminiCliCodeAssistClient
driveClient geminicli.DriveClient
cfg *config.Config
}
@@ -66,6 +67,7 @@ func NewGeminiOAuthService(
proxyRepo ProxyRepository,
oauthClient GeminiOAuthClient,
codeAssist GeminiCliCodeAssistClient,
driveClient geminicli.DriveClient,
cfg *config.Config,
) *GeminiOAuthService {
return &GeminiOAuthService{
@@ -73,6 +75,7 @@ func NewGeminiOAuthService(
proxyRepo: proxyRepo,
oauthClient: oauthClient,
codeAssist: codeAssist,
driveClient: driveClient,
cfg: cfg,
}
}
@@ -362,9 +365,8 @@ func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken
// Use Drive API to infer tier from storage quota (requires drive.readonly scope)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Calling Drive API for storage quota...")
driveClient := geminicli.NewDriveClient()
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
storageInfo, err := s.driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
if err != nil {
// Check if it's a 403 (scope not granted)
if strings.Contains(err.Error(), "status 403") {

View File

@@ -101,7 +101,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg)
svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg)
got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "")
if tt.wantErrSubstr != "" {
if err == nil {
@@ -487,7 +487,7 @@ func TestIsNonRetryableGeminiOAuthError(t *testing.T) {
func TestGeminiOAuthService_BuildAccountCredentials(t *testing.T) {
t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop()
t.Run("完整字段", func(t *testing.T) {
@@ -687,7 +687,7 @@ func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg)
svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg)
defer svc.Stop()
result := svc.GetOAuthConfig()
@@ -709,7 +709,7 @@ func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) {
func TestGeminiOAuthService_Stop_NoPanic(t *testing.T) {
t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
// 调用 Stop 不应 panic
svc.Stop()
@@ -806,6 +806,18 @@ func (m *mockGeminiProxyRepo) ListAccountSummariesByProxyID(ctx context.Context,
panic("not impl")
}
// mockDriveClient implements geminicli.DriveClient for tests.
type mockDriveClient struct {
getStorageQuotaFunc func(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error)
}
func (m *mockDriveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error) {
if m.getStorageQuotaFunc != nil {
return m.getStorageQuotaFunc(ctx, accessToken, proxyURL)
}
return nil, fmt.Errorf("drive API not available in test")
}
// =====================
// 新增测试GeminiOAuthService.RefreshToken含重试逻辑
// =====================
@@ -825,7 +837,7 @@ func TestGeminiOAuthService_RefreshToken_Success(t *testing.T) {
},
}
svc := NewGeminiOAuthService(nil, client, nil, &config.Config{})
svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{})
defer svc.Stop()
info, err := svc.RefreshToken(context.Background(), "code_assist", "old-refresh", "")
@@ -852,7 +864,7 @@ func TestGeminiOAuthService_RefreshToken_NonRetryableError(t *testing.T) {
},
}
svc := NewGeminiOAuthService(nil, client, nil, &config.Config{})
svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{})
defer svc.Stop()
_, err := svc.RefreshToken(context.Background(), "code_assist", "revoked-token", "")
@@ -881,7 +893,7 @@ func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) {
},
}
svc := NewGeminiOAuthService(nil, client, nil, &config.Config{})
svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{})
defer svc.Stop()
info, err := svc.RefreshToken(context.Background(), "code_assist", "rt", "")
@@ -903,7 +915,7 @@ func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) {
func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) {
t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -923,7 +935,7 @@ func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) {
func TestGeminiOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) {
t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -958,7 +970,7 @@ func TestGeminiOAuthService_RefreshAccountToken_AIStudio(t *testing.T) {
},
}
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -997,7 +1009,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_WithProjectID(t *test
},
}
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -1042,7 +1054,7 @@ func TestGeminiOAuthService_RefreshAccountToken_DefaultOAuthType(t *testing.T) {
},
}
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop()
// 无 oauth_type 凭据的旧账号
@@ -1090,7 +1102,7 @@ func TestGeminiOAuthService_RefreshAccountToken_WithProxy(t *testing.T) {
},
}
svc := NewGeminiOAuthService(proxyRepo, client, nil, &config.Config{})
svc := NewGeminiOAuthService(proxyRepo, client, nil, nil, &config.Config{})
defer svc.Stop()
proxyID := int64(5)
@@ -1132,7 +1144,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_AutoDetec
},
}
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{})
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -1181,7 +1193,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_FailsEmpt
},
}
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{})
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -1214,7 +1226,7 @@ func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_FreshCache(t *testing.
},
}
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -1254,7 +1266,7 @@ func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_NoTierID_DefaultsFree(
},
}
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &mockDriveClient{}, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -1308,7 +1320,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_Fallback(t *t
},
}
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, cfg)
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, cfg)
defer svc.Stop()
account := &Account{
@@ -1341,7 +1353,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t
}
// 无自定义 OAuth 客户端,无法 fallback
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -1370,7 +1382,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t
func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) {
t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop()
_, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{
@@ -1389,7 +1401,7 @@ func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) {
func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) {
t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop()
// 手动创建 session必须设置 CreatedAt否则会因 TTL 过期被拒绝)
@@ -1416,7 +1428,7 @@ func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) {
func TestGeminiOAuthService_ExchangeCode_EmptyState(t *testing.T) {
t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop()
svc.sessionStore.Set("test-session", &geminicli.OAuthSession{