feat(backend): 实现 Gemini AI Studio OAuth 和消息兼容服务
- gemini_oauth_service.go: 新增 AI Studio OAuth 类型支持 - gemini_token_provider.go: Token 提供器增强 - gemini_messages_compat_service.go: 支持 AI Studio 端点 - account_test_service.go: Gemini 账户可用性检测 - gateway_service.go: 网关服务适配 - openai_gateway_service.go: OpenAI 兼容层调整
This commit is contained in:
@@ -2,8 +2,12 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -43,7 +47,7 @@ type GeminiAuthURLResult struct {
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*GeminiAuthURLResult, error) {
|
||||
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) {
|
||||
state, err := geminicli.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
@@ -66,22 +70,38 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
}
|
||||
}
|
||||
|
||||
session := &geminicli.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ProxyURL: proxyURL,
|
||||
RedirectURI: redirectURI,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
// 两种 OAuth 模式都使用相同的配置,只是 scopes 不同
|
||||
// scopes 会在 EffectiveOAuthConfig 中根据 oauthType 自动选择
|
||||
oauthCfg := geminicli.OAuthConfig{
|
||||
ClientID: s.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: s.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
|
||||
authURL, err := geminicli.BuildAuthorizationURL(oauthCfg, state, codeChallenge, redirectURI)
|
||||
session := &geminicli.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ProxyURL: proxyURL,
|
||||
RedirectURI: redirectURI,
|
||||
ProjectID: strings.TrimSpace(projectID),
|
||||
OAuthType: oauthType,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
effectiveCfg, err := geminicli.EffectiveOAuthConfig(oauthCfg, oauthType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// For Code Assist with Gemini CLI credentials, use the CLI's redirect URI
|
||||
if oauthType == "code_assist" {
|
||||
redirectURI = geminicli.GeminiCLIRedirectURI
|
||||
session.RedirectURI = redirectURI
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
}
|
||||
|
||||
authURL, err := geminicli.BuildAuthorizationURL(effectiveCfg, state, codeChallenge, redirectURI, session.ProjectID, oauthType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -94,11 +114,11 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
}
|
||||
|
||||
type GeminiExchangeCodeInput struct {
|
||||
SessionID string
|
||||
State string
|
||||
Code string
|
||||
RedirectURI string
|
||||
ProxyID *int64
|
||||
SessionID string
|
||||
State string
|
||||
Code string
|
||||
ProxyID *int64
|
||||
OAuthType string // "code_assist" 或 "ai_studio"
|
||||
}
|
||||
|
||||
type GeminiTokenInfo struct {
|
||||
@@ -109,6 +129,7 @@ type GeminiTokenInfo struct {
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
|
||||
@@ -129,19 +150,38 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
}
|
||||
|
||||
redirectURI := session.RedirectURI
|
||||
if strings.TrimSpace(input.RedirectURI) != "" {
|
||||
redirectURI = input.RedirectURI
|
||||
}
|
||||
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
sessionProjectID := strings.TrimSpace(session.ProjectID)
|
||||
oauthType := session.OAuthType
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist" // 默认为 code_assist 以兼容旧 session
|
||||
}
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
|
||||
projectID, _ := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
|
||||
projectID := sessionProjectID
|
||||
|
||||
// 对于 code_assist 模式,project_id 是必需的
|
||||
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
|
||||
if oauthType == "code_assist" {
|
||||
if projectID == "" {
|
||||
var err error
|
||||
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
// 记录警告但不阻断流程,允许后续补充 project_id
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
|
||||
}
|
||||
}
|
||||
|
||||
return &GeminiTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
@@ -151,6 +191,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: tokenResp.Scope,
|
||||
ProjectID: projectID,
|
||||
OAuthType: oauthType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -223,7 +264,39 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *m
|
||||
}
|
||||
}
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Preserve oauth_type from the account (defaults to code_assist for backward compatibility).
|
||||
oauthType := strings.TrimSpace(account.GetCredential("oauth_type"))
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
tokenInfo.OAuthType = oauthType
|
||||
|
||||
// Preserve account's project_id when present.
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if existingProjectID != "" {
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
}
|
||||
|
||||
// For Code Assist, project_id is required. Auto-detect if missing.
|
||||
// For AI Studio OAuth, project_id is optional and should not block refresh.
|
||||
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
|
||||
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
|
||||
}
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID == "" {
|
||||
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
|
||||
}
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) map[string]any {
|
||||
@@ -243,6 +316,9 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
|
||||
if tokenInfo.ProjectID != "" {
|
||||
creds["project_id"] = tokenInfo.ProjectID
|
||||
}
|
||||
if tokenInfo.OAuthType != "" {
|
||||
creds["oauth_type"] = tokenInfo.OAuthType
|
||||
}
|
||||
return creds
|
||||
}
|
||||
|
||||
@@ -255,20 +331,28 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
return "", errors.New("code assist client not configured")
|
||||
}
|
||||
|
||||
loadResp, err := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
|
||||
if err == nil && strings.TrimSpace(loadResp.CurrentTier) != "" && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
|
||||
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
|
||||
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
|
||||
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
|
||||
}
|
||||
|
||||
// pick default tier from allowedTiers, fallback to LEGACY.
|
||||
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
|
||||
tierID := "LEGACY"
|
||||
if loadResp != nil {
|
||||
for _, tier := range loadResp.AllowedTiers {
|
||||
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
|
||||
tierID = tier.ID
|
||||
tierID = strings.TrimSpace(tier.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
|
||||
for _, tier := range loadResp.AllowedTiers {
|
||||
if strings.TrimSpace(tier.ID) != "" {
|
||||
tierID = strings.TrimSpace(tier.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req := &geminicli.OnboardUserRequest{
|
||||
@@ -284,24 +368,116 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
resp, err := s.codeAssist.OnboardUser(ctx, accessToken, proxyURL, req)
|
||||
if err != nil {
|
||||
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if resp.Done {
|
||||
if resp.Response == nil || resp.Response.CloudAICompanionProject == nil {
|
||||
return "", errors.New("onboardUser completed but no project_id returned")
|
||||
}
|
||||
switch v := resp.Response.CloudAICompanionProject.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v), nil
|
||||
case map[string]any:
|
||||
if id, ok := v["id"].(string); ok {
|
||||
return strings.TrimSpace(id), nil
|
||||
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
|
||||
switch v := resp.Response.CloudAICompanionProject.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v), nil
|
||||
case map[string]any:
|
||||
if id, ok := v["id"].(string); ok {
|
||||
return strings.TrimSpace(id), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", errors.New("onboardUser returned unsupported project_id format")
|
||||
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), nil
|
||||
}
|
||||
return "", errors.New("onboardUser completed but no project_id returned")
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), nil
|
||||
}
|
||||
if loadErr != nil {
|
||||
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
|
||||
}
|
||||
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
|
||||
}
|
||||
|
||||
type googleCloudProject struct {
|
||||
ProjectID string `json:"projectId"`
|
||||
DisplayName string `json:"name"`
|
||||
LifecycleState string `json:"lifecycleState"`
|
||||
}
|
||||
|
||||
type googleCloudProjectsResponse struct {
|
||||
Projects []googleCloudProject `json:"projects"`
|
||||
}
|
||||
|
||||
func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyURL string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create resource manager request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil {
|
||||
client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)}
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resource manager request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read resource manager response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("resource manager HTTP %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var projectsResp googleCloudProjectsResponse
|
||||
if err := json.Unmarshal(bodyBytes, &projectsResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse resource manager response: %w", err)
|
||||
}
|
||||
|
||||
active := make([]googleCloudProject, 0, len(projectsResp.Projects))
|
||||
for _, p := range projectsResp.Projects {
|
||||
if p.LifecycleState == "ACTIVE" && strings.TrimSpace(p.ProjectID) != "" {
|
||||
active = append(active, p)
|
||||
}
|
||||
}
|
||||
if len(active) == 0 {
|
||||
return "", errors.New("no ACTIVE projects found from resource manager")
|
||||
}
|
||||
|
||||
// Prefer likely companion projects first.
|
||||
for _, p := range active {
|
||||
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
|
||||
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
|
||||
if strings.Contains(id, "cloud-ai-companion") || strings.Contains(name, "cloud ai companion") || strings.Contains(name, "code assist") {
|
||||
return strings.TrimSpace(p.ProjectID), nil
|
||||
}
|
||||
}
|
||||
// Then prefer "default".
|
||||
for _, p := range active {
|
||||
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
|
||||
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
|
||||
if strings.Contains(id, "default") || strings.Contains(name, "default") {
|
||||
return strings.TrimSpace(p.ProjectID), nil
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(active[0].ProjectID), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user