feat(backend): 实现 Gemini AI Studio OAuth 和消息兼容服务
- gemini_oauth_service.go: 新增 AI Studio OAuth 类型支持 - gemini_token_provider.go: Token 提供器增强 - gemini_messages_compat_service.go: 支持 AI Studio 端点 - account_test_service.go: Gemini 账户可用性检测 - gateway_service.go: 网关服务适配 - openai_gateway_service.go: OpenAI 兼容层调整
This commit is contained in:
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -38,19 +40,30 @@ type TestEvent struct {
|
||||
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
accountRepo AccountRepository
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
httpUpstream HTTPUpstream
|
||||
accountRepo AccountRepository
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
geminiOAuthService *GeminiOAuthService
|
||||
geminiTokenProvider *GeminiTokenProvider
|
||||
httpUpstream HTTPUpstream
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(accountRepo AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream HTTPUpstream) *AccountTestService {
|
||||
func NewAccountTestService(
|
||||
accountRepo AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
geminiTokenProvider *GeminiTokenProvider,
|
||||
httpUpstream HTTPUpstream,
|
||||
) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
httpUpstream: httpUpstream,
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
geminiOAuthService: geminiOAuthService,
|
||||
geminiTokenProvider: geminiTokenProvider,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,6 +136,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
return s.testOpenAIAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
if account.IsGemini() {
|
||||
return s.testGeminiAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
@@ -368,6 +385,247 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
return s.processOpenAIStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// testGeminiAccountConnection tests a Gemini account's connection
|
||||
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Determine the model to use
|
||||
testModelID := modelID
|
||||
if testModelID == "" {
|
||||
testModelID = geminicli.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create test payload (Gemini format)
|
||||
payload := createGeminiTestPayload()
|
||||
|
||||
// Build request based on account type
|
||||
var req *http.Request
|
||||
var err error
|
||||
|
||||
switch account.Type {
|
||||
case model.AccountTypeApiKey:
|
||||
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
|
||||
case model.AccountTypeOAuth:
|
||||
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
|
||||
default:
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build request: %s", err.Error()))
|
||||
}
|
||||
|
||||
// Send test_start event
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
// Get proxy and execute request
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Process SSE stream
|
||||
return s.processGeminiStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// buildGeminiAPIKeyRequest builds request for Gemini API Key accounts
|
||||
func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
return nil, fmt.Errorf("No API key available")
|
||||
}
|
||||
|
||||
baseURL := account.GetCredential("base_url")
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
|
||||
// Use streamGenerateContent for real-time feedback
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
|
||||
strings.TrimRight(baseURL, "/"), modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// buildGeminiOAuthRequest builds request for Gemini OAuth accounts
|
||||
func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) {
|
||||
if s.geminiTokenProvider == nil {
|
||||
return nil, fmt.Errorf("Gemini token provider not configured")
|
||||
}
|
||||
|
||||
// Get access token (auto-refreshes if needed)
|
||||
accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to get access token: %w", err)
|
||||
}
|
||||
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID == "" {
|
||||
// AI Studio OAuth mode (no project_id): call generativelanguage API directly with Bearer token.
|
||||
baseURL := account.GetCredential("base_url")
|
||||
if strings.TrimSpace(baseURL) == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Wrap payload in Code Assist format
|
||||
var inner map[string]any
|
||||
if err := json.Unmarshal(payload, &inner); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wrapped := map[string]any{
|
||||
"model": modelID,
|
||||
"project": projectID,
|
||||
"request": inner,
|
||||
}
|
||||
wrappedBytes, _ := json.Marshal(wrapped)
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// createGeminiTestPayload creates a minimal test payload for Gemini API
|
||||
func createGeminiTestPayload() []byte {
|
||||
payload := map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"parts": []map[string]any{
|
||||
{"text": "hi"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"systemInstruction": map[string]any{
|
||||
"parts": []map[string]any{
|
||||
{"text": "You are a helpful AI assistant."},
|
||||
},
|
||||
},
|
||||
}
|
||||
bytes, _ := json.Marshal(payload)
|
||||
return bytes
|
||||
}
|
||||
|
||||
// processGeminiStream processes SSE stream from Gemini API
|
||||
func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract text from candidates[0].content.parts[].text
|
||||
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if candidate, ok := candidates[0].(map[string]any); ok {
|
||||
// Check for completion
|
||||
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract content
|
||||
if content, ok := candidate["content"].(map[string]any); ok {
|
||||
if parts, ok := content["parts"].([]any); ok {
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if text, ok := partMap["text"].(string); ok && text != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle errors
|
||||
if errData, ok := data["error"].(map[string]any); ok {
|
||||
errorMsg := "Unknown error"
|
||||
if msg, ok := errData["message"].(string); ok {
|
||||
errorMsg = msg
|
||||
}
|
||||
return s.sendErrorAndEnd(c, errorMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createOpenAITestPayload creates a test payload for OpenAI Responses API
|
||||
func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any {
|
||||
payload := map[string]any{
|
||||
|
||||
@@ -317,8 +317,17 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
// 优先级相同时,选最久未用的
|
||||
if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) {
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
// keep selected (both never used)
|
||||
default:
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,8 +2,12 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -43,7 +47,7 @@ type GeminiAuthURLResult struct {
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*GeminiAuthURLResult, error) {
|
||||
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) {
|
||||
state, err := geminicli.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
@@ -66,22 +70,38 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
}
|
||||
}
|
||||
|
||||
session := &geminicli.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ProxyURL: proxyURL,
|
||||
RedirectURI: redirectURI,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
// 两种 OAuth 模式都使用相同的配置,只是 scopes 不同
|
||||
// scopes 会在 EffectiveOAuthConfig 中根据 oauthType 自动选择
|
||||
oauthCfg := geminicli.OAuthConfig{
|
||||
ClientID: s.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: s.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
|
||||
authURL, err := geminicli.BuildAuthorizationURL(oauthCfg, state, codeChallenge, redirectURI)
|
||||
session := &geminicli.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ProxyURL: proxyURL,
|
||||
RedirectURI: redirectURI,
|
||||
ProjectID: strings.TrimSpace(projectID),
|
||||
OAuthType: oauthType,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
effectiveCfg, err := geminicli.EffectiveOAuthConfig(oauthCfg, oauthType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// For Code Assist with Gemini CLI credentials, use the CLI's redirect URI
|
||||
if oauthType == "code_assist" {
|
||||
redirectURI = geminicli.GeminiCLIRedirectURI
|
||||
session.RedirectURI = redirectURI
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
}
|
||||
|
||||
authURL, err := geminicli.BuildAuthorizationURL(effectiveCfg, state, codeChallenge, redirectURI, session.ProjectID, oauthType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -94,11 +114,11 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
}
|
||||
|
||||
type GeminiExchangeCodeInput struct {
|
||||
SessionID string
|
||||
State string
|
||||
Code string
|
||||
RedirectURI string
|
||||
ProxyID *int64
|
||||
SessionID string
|
||||
State string
|
||||
Code string
|
||||
ProxyID *int64
|
||||
OAuthType string // "code_assist" 或 "ai_studio"
|
||||
}
|
||||
|
||||
type GeminiTokenInfo struct {
|
||||
@@ -109,6 +129,7 @@ type GeminiTokenInfo struct {
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
|
||||
@@ -129,19 +150,38 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
}
|
||||
|
||||
redirectURI := session.RedirectURI
|
||||
if strings.TrimSpace(input.RedirectURI) != "" {
|
||||
redirectURI = input.RedirectURI
|
||||
}
|
||||
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
sessionProjectID := strings.TrimSpace(session.ProjectID)
|
||||
oauthType := session.OAuthType
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist" // 默认为 code_assist 以兼容旧 session
|
||||
}
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
|
||||
projectID, _ := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
|
||||
projectID := sessionProjectID
|
||||
|
||||
// 对于 code_assist 模式,project_id 是必需的
|
||||
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
|
||||
if oauthType == "code_assist" {
|
||||
if projectID == "" {
|
||||
var err error
|
||||
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
// 记录警告但不阻断流程,允许后续补充 project_id
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
|
||||
}
|
||||
}
|
||||
|
||||
return &GeminiTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
@@ -151,6 +191,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: tokenResp.Scope,
|
||||
ProjectID: projectID,
|
||||
OAuthType: oauthType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -223,7 +264,39 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *m
|
||||
}
|
||||
}
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Preserve oauth_type from the account (defaults to code_assist for backward compatibility).
|
||||
oauthType := strings.TrimSpace(account.GetCredential("oauth_type"))
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
tokenInfo.OAuthType = oauthType
|
||||
|
||||
// Preserve account's project_id when present.
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if existingProjectID != "" {
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
}
|
||||
|
||||
// For Code Assist, project_id is required. Auto-detect if missing.
|
||||
// For AI Studio OAuth, project_id is optional and should not block refresh.
|
||||
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
|
||||
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
|
||||
}
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID == "" {
|
||||
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
|
||||
}
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) map[string]any {
|
||||
@@ -243,6 +316,9 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
|
||||
if tokenInfo.ProjectID != "" {
|
||||
creds["project_id"] = tokenInfo.ProjectID
|
||||
}
|
||||
if tokenInfo.OAuthType != "" {
|
||||
creds["oauth_type"] = tokenInfo.OAuthType
|
||||
}
|
||||
return creds
|
||||
}
|
||||
|
||||
@@ -255,20 +331,28 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
return "", errors.New("code assist client not configured")
|
||||
}
|
||||
|
||||
loadResp, err := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
|
||||
if err == nil && strings.TrimSpace(loadResp.CurrentTier) != "" && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
|
||||
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
|
||||
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
|
||||
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
|
||||
}
|
||||
|
||||
// pick default tier from allowedTiers, fallback to LEGACY.
|
||||
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
|
||||
tierID := "LEGACY"
|
||||
if loadResp != nil {
|
||||
for _, tier := range loadResp.AllowedTiers {
|
||||
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
|
||||
tierID = tier.ID
|
||||
tierID = strings.TrimSpace(tier.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
|
||||
for _, tier := range loadResp.AllowedTiers {
|
||||
if strings.TrimSpace(tier.ID) != "" {
|
||||
tierID = strings.TrimSpace(tier.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req := &geminicli.OnboardUserRequest{
|
||||
@@ -284,24 +368,116 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
resp, err := s.codeAssist.OnboardUser(ctx, accessToken, proxyURL, req)
|
||||
if err != nil {
|
||||
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if resp.Done {
|
||||
if resp.Response == nil || resp.Response.CloudAICompanionProject == nil {
|
||||
return "", errors.New("onboardUser completed but no project_id returned")
|
||||
}
|
||||
switch v := resp.Response.CloudAICompanionProject.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v), nil
|
||||
case map[string]any:
|
||||
if id, ok := v["id"].(string); ok {
|
||||
return strings.TrimSpace(id), nil
|
||||
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
|
||||
switch v := resp.Response.CloudAICompanionProject.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v), nil
|
||||
case map[string]any:
|
||||
if id, ok := v["id"].(string); ok {
|
||||
return strings.TrimSpace(id), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", errors.New("onboardUser returned unsupported project_id format")
|
||||
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), nil
|
||||
}
|
||||
return "", errors.New("onboardUser completed but no project_id returned")
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), nil
|
||||
}
|
||||
if loadErr != nil {
|
||||
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
|
||||
}
|
||||
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
|
||||
}
|
||||
|
||||
type googleCloudProject struct {
|
||||
ProjectID string `json:"projectId"`
|
||||
DisplayName string `json:"name"`
|
||||
LifecycleState string `json:"lifecycleState"`
|
||||
}
|
||||
|
||||
type googleCloudProjectsResponse struct {
|
||||
Projects []googleCloudProject `json:"projects"`
|
||||
}
|
||||
|
||||
func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyURL string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create resource manager request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil {
|
||||
client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)}
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resource manager request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read resource manager response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("resource manager HTTP %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var projectsResp googleCloudProjectsResponse
|
||||
if err := json.Unmarshal(bodyBytes, &projectsResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse resource manager response: %w", err)
|
||||
}
|
||||
|
||||
active := make([]googleCloudProject, 0, len(projectsResp.Projects))
|
||||
for _, p := range projectsResp.Projects {
|
||||
if p.LifecycleState == "ACTIVE" && strings.TrimSpace(p.ProjectID) != "" {
|
||||
active = append(active, p)
|
||||
}
|
||||
}
|
||||
if len(active) == 0 {
|
||||
return "", errors.New("no ACTIVE projects found from resource manager")
|
||||
}
|
||||
|
||||
// Prefer likely companion projects first.
|
||||
for _, p := range active {
|
||||
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
|
||||
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
|
||||
if strings.Contains(id, "cloud-ai-companion") || strings.Contains(name, "cloud ai companion") || strings.Contains(name, "code assist") {
|
||||
return strings.TrimSpace(p.ProjectID), nil
|
||||
}
|
||||
}
|
||||
// Then prefer "default".
|
||||
for _, p := range active {
|
||||
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
|
||||
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
|
||||
if strings.Contains(id, "default") || strings.Contains(name, "default") {
|
||||
return strings.TrimSpace(p.ProjectID), nil
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(active[0].ProjectID), nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -95,6 +96,40 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// project_id is optional now:
|
||||
// - If present: will use Code Assist API (requires project_id)
|
||||
// - If absent: will use AI Studio API with OAuth token (like regular API key mode)
|
||||
// Auto-detect project_id only if explicitly enabled via a credential flag
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true"
|
||||
|
||||
if projectID == "" && autoDetectProjectID {
|
||||
if p.geminiOAuthService == nil {
|
||||
return accessToken, nil // Fallback to AI Studio API mode
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil && p.geminiOAuthService.proxyRepo != nil {
|
||||
if proxy, err := p.geminiOAuthService.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
|
||||
return accessToken, nil
|
||||
}
|
||||
detected = strings.TrimSpace(detected)
|
||||
if detected != "" {
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = model.JSONB{}
|
||||
}
|
||||
account.Credentials["project_id"] = detected
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
}
|
||||
}
|
||||
|
||||
// 3) Populate cache with TTL.
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
|
||||
@@ -166,9 +166,18 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
// Same priority, select least recently used
|
||||
if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) {
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
// keep selected (both never used)
|
||||
default:
|
||||
// Same priority, select least recently used
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user