Files
sub2api/backend/internal/service/gemini_oauth_service.go
yangjianbo 25e1632628 fix(安全): 修复上游校验与 URL 清理问题
增加请求阶段 DNS 解析校验,阻断重绑定到私网
补充默认透传 WWW-Authenticate 头,保留认证挑战
前端相对 URL 过滤拒绝 // 协议相对路径

测试: go test ./internal/repository -run TestGitHubReleaseServiceSuite
测试: go test ./internal/repository -run TestTurnstileServiceSuite
测试: go test ./internal/repository -run TestProxyProbeServiceSuite
测试: go test ./internal/repository -run TestClaudeUsageServiceSuite
2026-01-03 10:52:24 +08:00

558 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
type GeminiOAuthService struct {
sessionStore *geminicli.SessionStore
proxyRepo ProxyRepository
oauthClient GeminiOAuthClient
codeAssist GeminiCliCodeAssistClient
cfg *config.Config
}
type GeminiOAuthCapabilities struct {
AIStudioOAuthEnabled bool `json:"ai_studio_oauth_enabled"`
RequiredRedirectURIs []string `json:"required_redirect_uris"`
}
func NewGeminiOAuthService(
proxyRepo ProxyRepository,
oauthClient GeminiOAuthClient,
codeAssist GeminiCliCodeAssistClient,
cfg *config.Config,
) *GeminiOAuthService {
return &GeminiOAuthService{
sessionStore: geminicli.NewSessionStore(),
proxyRepo: proxyRepo,
oauthClient: oauthClient,
codeAssist: codeAssist,
cfg: cfg,
}
}
func (s *GeminiOAuthService) GetOAuthConfig() *GeminiOAuthCapabilities {
// AI Studio OAuth is only enabled when the operator configures a custom OAuth client.
clientID := strings.TrimSpace(s.cfg.Gemini.OAuth.ClientID)
clientSecret := strings.TrimSpace(s.cfg.Gemini.OAuth.ClientSecret)
enabled := clientID != "" && clientSecret != "" &&
(clientID != geminicli.GeminiCLIOAuthClientID || clientSecret != geminicli.GeminiCLIOAuthClientSecret)
return &GeminiOAuthCapabilities{
AIStudioOAuthEnabled: enabled,
RequiredRedirectURIs: []string{geminicli.AIStudioOAuthRedirectURI},
}
}
type GeminiAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
}
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) {
state, err := geminicli.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}
codeVerifier, err := geminicli.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
}
codeChallenge := geminicli.GenerateCodeChallenge(codeVerifier)
sessionID, err := geminicli.GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
}
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
// - ai_studio: requires a user-provided OAuth client.
oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
oauthCfg.ClientID = ""
oauthCfg.ClientSecret = ""
}
session := &geminicli.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
RedirectURI: redirectURI,
ProjectID: strings.TrimSpace(projectID),
OAuthType: oauthType,
CreatedAt: time.Now(),
}
s.sessionStore.Set(sessionID, session)
effectiveCfg, err := geminicli.EffectiveOAuthConfig(oauthCfg, oauthType)
if err != nil {
return nil, err
}
isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID &&
effectiveCfg.ClientSecret == geminicli.GeminiCLIOAuthClientSecret
// AI Studio OAuth requires a user-provided OAuth client (built-in Gemini CLI client is scope-restricted).
if oauthType == "ai_studio" && isBuiltinClient {
return nil, fmt.Errorf("AI Studio OAuth requires a custom OAuth Client (GEMINI_OAUTH_CLIENT_ID / GEMINI_OAUTH_CLIENT_SECRET). If you don't want to configure an OAuth client, please use an AI Studio API Key account instead")
}
// Redirect URI strategy:
// - code_assist: use Gemini CLI redirect URI (codeassist.google.com/authcode)
// - ai_studio: use localhost callback for manual copy/paste flow
if oauthType == "code_assist" {
redirectURI = geminicli.GeminiCLIRedirectURI
} else {
redirectURI = geminicli.AIStudioOAuthRedirectURI
}
session.RedirectURI = redirectURI
s.sessionStore.Set(sessionID, session)
authURL, err := geminicli.BuildAuthorizationURL(effectiveCfg, state, codeChallenge, redirectURI, session.ProjectID, oauthType)
if err != nil {
return nil, err
}
return &GeminiAuthURLResult{
AuthURL: authURL,
SessionID: sessionID,
State: state,
}, nil
}
type GeminiExchangeCodeInput struct {
SessionID string
State string
Code string
ProxyID *int64
OAuthType string // "code_assist" 或 "ai_studio"
}
type GeminiTokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
}
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session not found or expired")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
return nil, fmt.Errorf("invalid state")
}
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
redirectURI := session.RedirectURI
// Resolve oauth_type early (defaults to code_assist for backward compatibility).
oauthType := session.OAuthType
if oauthType == "" {
oauthType = "code_assist"
}
// If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured.
if oauthType == "ai_studio" {
effectiveCfg, err := geminicli.EffectiveOAuthConfig(geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}, "ai_studio")
if err != nil {
return nil, err
}
isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID &&
effectiveCfg.ClientSecret == geminicli.GeminiCLIOAuthClientSecret
if isBuiltinClient {
return nil, fmt.Errorf("AI Studio OAuth requires a custom OAuth Client. Please use an AI Studio API Key account, or configure GEMINI_OAUTH_CLIENT_ID / GEMINI_OAUTH_CLIENT_SECRET and re-authorize")
}
}
// code_assist always uses the built-in client and its fixed redirect URI.
if oauthType == "code_assist" {
redirectURI = geminicli.GeminiCLIRedirectURI
}
tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
sessionProjectID := strings.TrimSpace(session.ProjectID)
s.sessionStore.Delete(input.SessionID)
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
projectID := sessionProjectID
// 对于 code_assist 模式project_id 是必需的
// 对于 ai_studio 模式project_id 是可选的(不影响使用 AI Studio API
if oauthType == "code_assist" {
if projectID == "" {
var err error
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
}
}
if strings.TrimSpace(projectID) == "" {
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
}
}
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
ProjectID: projectID,
OAuthType: oauthType,
}, nil
}
func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*GeminiTokenInfo, error) {
var lastErr error
for attempt := 0; attempt <= 3; attempt++ {
if attempt > 0 {
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
time.Sleep(backoff)
}
tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
if err == nil {
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
}, nil
}
if isNonRetryableGeminiOAuthError(err) {
return nil, err
}
lastErr = err
}
return nil, fmt.Errorf("token refresh failed after retries: %w", lastErr)
}
func isNonRetryableGeminiOAuthError(err error) bool {
msg := err.Error()
nonRetryable := []string{
"invalid_grant",
"invalid_client",
"unauthorized_client",
"access_denied",
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {
return true
}
}
return false
}
func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*GeminiTokenInfo, error) {
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("account is not a Gemini OAuth account")
}
refreshToken := account.GetCredential("refresh_token")
if strings.TrimSpace(refreshToken) == "" {
return nil, fmt.Errorf("no refresh token available")
}
// Preserve oauth_type from the account (defaults to code_assist for backward compatibility).
oauthType := strings.TrimSpace(account.GetCredential("oauth_type"))
if oauthType == "" {
oauthType = "code_assist"
}
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
tokenInfo, err := s.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
// Backward compatibility:
// Older versions could refresh Code Assist tokens using a user-provided OAuth client when configured.
// If the refresh token was originally issued to that custom client, forcing the built-in client will
// fail with "unauthorized_client". In that case, retry with the custom client (ai_studio path) when available.
if err != nil && oauthType == "code_assist" && strings.Contains(err.Error(), "unauthorized_client") && s.GetOAuthConfig().AIStudioOAuthEnabled {
if alt, altErr := s.RefreshToken(ctx, "ai_studio", refreshToken, proxyURL); altErr == nil {
tokenInfo = alt
err = nil
}
}
if err != nil {
// Provide a more actionable error for common OAuth client mismatch issues.
if strings.Contains(err.Error(), "unauthorized_client") {
return nil, fmt.Errorf("%w (OAuth client mismatch: the refresh_token is bound to the OAuth client used during authorization; please re-authorize this account or restore the original GEMINI_OAUTH_CLIENT_ID/SECRET)", err)
}
return nil, err
}
tokenInfo.OAuthType = oauthType
// Preserve account's project_id when present.
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
if existingProjectID != "" {
tokenInfo.ProjectID = existingProjectID
}
// For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh.
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
}
projectID = strings.TrimSpace(projectID)
if projectID == "" {
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
}
tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
}
func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) map[string]any {
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.TokenType != "" {
creds["token_type"] = tokenInfo.TokenType
}
if tokenInfo.Scope != "" {
creds["scope"] = tokenInfo.Scope
}
if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID
}
if tokenInfo.OAuthType != "" {
creds["oauth_type"] = tokenInfo.OAuthType
}
return creds
}
func (s *GeminiOAuthService) Stop() {
s.sessionStore.Stop()
}
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) {
if s.codeAssist == nil {
return "", errors.New("code assist client not configured")
}
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
}
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
tierID := "LEGACY"
if loadResp != nil {
for _, tier := range loadResp.AllowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
for _, tier := range loadResp.AllowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
}
req := &geminicli.OnboardUserRequest{
TierID: tierID,
Metadata: geminicli.LoadCodeAssistMetadata{
IDEType: "ANTIGRAVITY",
Platform: "PLATFORM_UNSPECIFIED",
PluginType: "GEMINI",
},
}
maxAttempts := 5
for attempt := 1; attempt <= maxAttempts; attempt++ {
resp, err := s.codeAssist.OnboardUser(ctx, accessToken, proxyURL, req)
if err != nil {
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
}
return "", err
}
if resp.Done {
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
switch v := resp.Response.CloudAICompanionProject.(type) {
case string:
return strings.TrimSpace(v), nil
case map[string]any:
if id, ok := v["id"].(string); ok {
return strings.TrimSpace(id), nil
}
}
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
}
return "", errors.New("onboardUser completed but no project_id returned")
}
time.Sleep(2 * time.Second)
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
}
if loadErr != nil {
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
}
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
}
type googleCloudProject struct {
ProjectID string `json:"projectId"`
DisplayName string `json:"name"`
LifecycleState string `json:"lifecycleState"`
}
type googleCloudProjectsResponse struct {
Projects []googleCloudProject `json:"projects"`
}
func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyURL string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
if err != nil {
return "", fmt.Errorf("failed to create resource manager request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("resource manager request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read resource manager response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("resource manager HTTP %d: %s", resp.StatusCode, string(bodyBytes))
}
var projectsResp googleCloudProjectsResponse
if err := json.Unmarshal(bodyBytes, &projectsResp); err != nil {
return "", fmt.Errorf("failed to parse resource manager response: %w", err)
}
active := make([]googleCloudProject, 0, len(projectsResp.Projects))
for _, p := range projectsResp.Projects {
if p.LifecycleState == "ACTIVE" && strings.TrimSpace(p.ProjectID) != "" {
active = append(active, p)
}
}
if len(active) == 0 {
return "", errors.New("no ACTIVE projects found from resource manager")
}
// Prefer likely companion projects first.
for _, p := range active {
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
if strings.Contains(id, "cloud-ai-companion") || strings.Contains(name, "cloud ai companion") || strings.Contains(name, "code assist") {
return strings.TrimSpace(p.ProjectID), nil
}
}
// Then prefer "default".
for _, p := range active {
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
if strings.Contains(id, "default") || strings.Contains(name, "default") {
return strings.TrimSpace(p.ProjectID), nil
}
}
return strings.TrimSpace(active[0].ProjectID), nil
}