Files
sub2api/backend/internal/service/gemini_oauth_service.go
ianshaw b2d71da2a2 feat(backend): 实现 Gemini AI Studio OAuth 和消息兼容服务
- gemini_oauth_service.go: 新增 AI Studio OAuth 类型支持
- gemini_token_provider.go: Token 提供器增强
- gemini_messages_compat_service.go: 支持 AI Studio 端点
- account_test_service.go: Gemini 账户可用性检测
- gateway_service.go: 网关服务适配
- openai_gateway_service.go: OpenAI 兼容层调整
2025-12-26 00:11:03 -08:00

484 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
)
type GeminiOAuthService struct {
sessionStore *geminicli.SessionStore
proxyRepo ports.ProxyRepository
oauthClient ports.GeminiOAuthClient
codeAssist ports.GeminiCliCodeAssistClient
cfg *config.Config
}
func NewGeminiOAuthService(
proxyRepo ports.ProxyRepository,
oauthClient ports.GeminiOAuthClient,
codeAssist ports.GeminiCliCodeAssistClient,
cfg *config.Config,
) *GeminiOAuthService {
return &GeminiOAuthService{
sessionStore: geminicli.NewSessionStore(),
proxyRepo: proxyRepo,
oauthClient: oauthClient,
codeAssist: codeAssist,
cfg: cfg,
}
}
type GeminiAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
}
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) {
state, err := geminicli.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}
codeVerifier, err := geminicli.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
}
codeChallenge := geminicli.GenerateCodeChallenge(codeVerifier)
sessionID, err := geminicli.GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
}
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// 两种 OAuth 模式都使用相同的配置,只是 scopes 不同
// scopes 会在 EffectiveOAuthConfig 中根据 oauthType 自动选择
oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}
session := &geminicli.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
RedirectURI: redirectURI,
ProjectID: strings.TrimSpace(projectID),
OAuthType: oauthType,
CreatedAt: time.Now(),
}
s.sessionStore.Set(sessionID, session)
effectiveCfg, err := geminicli.EffectiveOAuthConfig(oauthCfg, oauthType)
if err != nil {
return nil, err
}
// For Code Assist with Gemini CLI credentials, use the CLI's redirect URI
if oauthType == "code_assist" {
redirectURI = geminicli.GeminiCLIRedirectURI
session.RedirectURI = redirectURI
s.sessionStore.Set(sessionID, session)
}
authURL, err := geminicli.BuildAuthorizationURL(effectiveCfg, state, codeChallenge, redirectURI, session.ProjectID, oauthType)
if err != nil {
return nil, err
}
return &GeminiAuthURLResult{
AuthURL: authURL,
SessionID: sessionID,
State: state,
}, nil
}
type GeminiExchangeCodeInput struct {
SessionID string
State string
Code string
ProxyID *int64
OAuthType string // "code_assist" 或 "ai_studio"
}
type GeminiTokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
}
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session not found or expired")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
return nil, fmt.Errorf("invalid state")
}
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
redirectURI := session.RedirectURI
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
sessionProjectID := strings.TrimSpace(session.ProjectID)
oauthType := session.OAuthType
if oauthType == "" {
oauthType = "code_assist" // 默认为 code_assist 以兼容旧 session
}
s.sessionStore.Delete(input.SessionID)
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
projectID := sessionProjectID
// 对于 code_assist 模式project_id 是必需的
// 对于 ai_studio 模式project_id 是可选的(不影响使用 AI Studio API
if oauthType == "code_assist" {
if projectID == "" {
var err error
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
}
}
if strings.TrimSpace(projectID) == "" {
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
}
}
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
ProjectID: projectID,
OAuthType: oauthType,
}, nil
}
func (s *GeminiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*GeminiTokenInfo, error) {
var lastErr error
for attempt := 0; attempt <= 3; attempt++ {
if attempt > 0 {
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
time.Sleep(backoff)
}
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
if err == nil {
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
}, nil
}
if isNonRetryableGeminiOAuthError(err) {
return nil, err
}
lastErr = err
}
return nil, fmt.Errorf("token refresh failed after retries: %w", lastErr)
}
func isNonRetryableGeminiOAuthError(err error) bool {
msg := err.Error()
nonRetryable := []string{
"invalid_grant",
"invalid_client",
"unauthorized_client",
"access_denied",
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {
return true
}
}
return false
}
func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*GeminiTokenInfo, error) {
if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth {
return nil, fmt.Errorf("account is not a Gemini OAuth account")
}
refreshToken := account.GetCredential("refresh_token")
if strings.TrimSpace(refreshToken) == "" {
return nil, fmt.Errorf("no refresh token available")
}
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
if err != nil {
return nil, err
}
// Preserve oauth_type from the account (defaults to code_assist for backward compatibility).
oauthType := strings.TrimSpace(account.GetCredential("oauth_type"))
if oauthType == "" {
oauthType = "code_assist"
}
tokenInfo.OAuthType = oauthType
// Preserve account's project_id when present.
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
if existingProjectID != "" {
tokenInfo.ProjectID = existingProjectID
}
// For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh.
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
}
projectID = strings.TrimSpace(projectID)
if projectID == "" {
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
}
tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
}
func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) map[string]any {
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.TokenType != "" {
creds["token_type"] = tokenInfo.TokenType
}
if tokenInfo.Scope != "" {
creds["scope"] = tokenInfo.Scope
}
if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID
}
if tokenInfo.OAuthType != "" {
creds["oauth_type"] = tokenInfo.OAuthType
}
return creds
}
func (s *GeminiOAuthService) Stop() {
s.sessionStore.Stop()
}
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) {
if s.codeAssist == nil {
return "", errors.New("code assist client not configured")
}
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
}
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
tierID := "LEGACY"
if loadResp != nil {
for _, tier := range loadResp.AllowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
for _, tier := range loadResp.AllowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
}
req := &geminicli.OnboardUserRequest{
TierID: tierID,
Metadata: geminicli.LoadCodeAssistMetadata{
IDEType: "ANTIGRAVITY",
Platform: "PLATFORM_UNSPECIFIED",
PluginType: "GEMINI",
},
}
maxAttempts := 5
for attempt := 1; attempt <= maxAttempts; attempt++ {
resp, err := s.codeAssist.OnboardUser(ctx, accessToken, proxyURL, req)
if err != nil {
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
}
return "", err
}
if resp.Done {
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
switch v := resp.Response.CloudAICompanionProject.(type) {
case string:
return strings.TrimSpace(v), nil
case map[string]any:
if id, ok := v["id"].(string); ok {
return strings.TrimSpace(id), nil
}
}
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
}
return "", errors.New("onboardUser completed but no project_id returned")
}
time.Sleep(2 * time.Second)
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
}
if loadErr != nil {
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
}
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
}
type googleCloudProject struct {
ProjectID string `json:"projectId"`
DisplayName string `json:"name"`
LifecycleState string `json:"lifecycleState"`
}
type googleCloudProjectsResponse struct {
Projects []googleCloudProject `json:"projects"`
}
func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyURL string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
if err != nil {
return "", fmt.Errorf("failed to create resource manager request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client := &http.Client{Timeout: 30 * time.Second}
if strings.TrimSpace(proxyURL) != "" {
if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil {
client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)}
}
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("resource manager request failed: %w", err)
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read resource manager response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("resource manager HTTP %d: %s", resp.StatusCode, string(bodyBytes))
}
var projectsResp googleCloudProjectsResponse
if err := json.Unmarshal(bodyBytes, &projectsResp); err != nil {
return "", fmt.Errorf("failed to parse resource manager response: %w", err)
}
active := make([]googleCloudProject, 0, len(projectsResp.Projects))
for _, p := range projectsResp.Projects {
if p.LifecycleState == "ACTIVE" && strings.TrimSpace(p.ProjectID) != "" {
active = append(active, p)
}
}
if len(active) == 0 {
return "", errors.New("no ACTIVE projects found from resource manager")
}
// Prefer likely companion projects first.
for _, p := range active {
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
if strings.Contains(id, "cloud-ai-companion") || strings.Contains(name, "cloud ai companion") || strings.Contains(name, "code assist") {
return strings.TrimSpace(p.ProjectID), nil
}
}
// Then prefer "default".
for _, p := range active {
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
if strings.Contains(id, "default") || strings.Contains(name, "default") {
return strings.TrimSpace(p.ProjectID), nil
}
}
return strings.TrimSpace(active[0].ProjectID), nil
}