diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 782aa95c..7ddf4355 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -3,6 +3,7 @@ package service import ( "bufio" "bytes" + "context" "crypto/rand" "encoding/hex" "encoding/json" @@ -16,6 +17,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -38,19 +40,30 @@ type TestEvent struct { // AccountTestService handles account testing operations type AccountTestService struct { - accountRepo AccountRepository - oauthService *OAuthService - openaiOAuthService *OpenAIOAuthService - httpUpstream HTTPUpstream + accountRepo AccountRepository + oauthService *OAuthService + openaiOAuthService *OpenAIOAuthService + geminiOAuthService *GeminiOAuthService + geminiTokenProvider *GeminiTokenProvider + httpUpstream HTTPUpstream } // NewAccountTestService creates a new AccountTestService -func NewAccountTestService(accountRepo AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream HTTPUpstream) *AccountTestService { +func NewAccountTestService( + accountRepo AccountRepository, + oauthService *OAuthService, + openaiOAuthService *OpenAIOAuthService, + geminiOAuthService *GeminiOAuthService, + geminiTokenProvider *GeminiTokenProvider, + httpUpstream HTTPUpstream, +) *AccountTestService { return &AccountTestService{ - accountRepo: accountRepo, - oauthService: oauthService, - openaiOAuthService: openaiOAuthService, - httpUpstream: httpUpstream, + accountRepo: accountRepo, + oauthService: oauthService, + openaiOAuthService: openaiOAuthService, + geminiOAuthService: geminiOAuthService, + geminiTokenProvider: geminiTokenProvider, + httpUpstream: httpUpstream, } } @@ -123,6 +136,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int return s.testOpenAIAccountConnection(c, account, modelID) } + if account.IsGemini() { + return s.testGeminiAccountConnection(c, account, modelID) + } + return s.testClaudeAccountConnection(c, account, modelID) } @@ -368,6 +385,247 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account return s.processOpenAIStream(c, resp.Body) } +// testGeminiAccountConnection tests a Gemini account's connection +func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *model.Account, modelID string) error { + ctx := c.Request.Context() + + // Determine the model to use + testModelID := modelID + if testModelID == "" { + testModelID = geminicli.DefaultTestModel + } + + // For API Key accounts with model mapping, map the model + if account.Type == model.AccountTypeApiKey { + mapping := account.GetModelMapping() + if len(mapping) > 0 { + if mappedModel, exists := mapping[testModelID]; exists { + testModelID = mappedModel + } + } + } + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + // Create test payload (Gemini format) + payload := createGeminiTestPayload() + + // Build request based on account type + var req *http.Request + var err error + + switch account.Type { + case model.AccountTypeApiKey: + req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload) + case model.AccountTypeOAuth: + req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload) + default: + return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) + } + + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build request: %s", err.Error())) + } + + // Send test_start event + s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) + + // Get proxy and execute request + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.Do(req, proxyURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) + } + + // Process SSE stream + return s.processGeminiStream(c, resp.Body) +} + +// buildGeminiAPIKeyRequest builds request for Gemini API Key accounts +func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) { + apiKey := account.GetCredential("api_key") + if strings.TrimSpace(apiKey) == "" { + return nil, fmt.Errorf("No API key available") + } + + baseURL := account.GetCredential("base_url") + if baseURL == "" { + baseURL = geminicli.AIStudioBaseURL + } + + // Use streamGenerateContent for real-time feedback + fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", + strings.TrimRight(baseURL, "/"), modelID) + + req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-goog-api-key", apiKey) + + return req, nil +} + +// buildGeminiOAuthRequest builds request for Gemini OAuth accounts +func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) { + if s.geminiTokenProvider == nil { + return nil, fmt.Errorf("Gemini token provider not configured") + } + + // Get access token (auto-refreshes if needed) + accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("Failed to get access token: %w", err) + } + + projectID := strings.TrimSpace(account.GetCredential("project_id")) + if projectID == "" { + // AI Studio OAuth mode (no project_id): call generativelanguage API directly with Bearer token. + baseURL := account.GetCredential("base_url") + if strings.TrimSpace(baseURL) == "" { + baseURL = geminicli.AIStudioBaseURL + } + fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + return req, nil + } + + // Wrap payload in Code Assist format + var inner map[string]any + if err := json.Unmarshal(payload, &inner); err != nil { + return nil, err + } + + wrapped := map[string]any{ + "model": modelID, + "project": projectID, + "request": inner, + } + wrappedBytes, _ := json.Marshal(wrapped) + + fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL) + + req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) + + return req, nil +} + +// createGeminiTestPayload creates a minimal test payload for Gemini API +func createGeminiTestPayload() []byte { + payload := map[string]any{ + "contents": []map[string]any{ + { + "role": "user", + "parts": []map[string]any{ + {"text": "hi"}, + }, + }, + }, + "systemInstruction": map[string]any{ + "parts": []map[string]any{ + {"text": "You are a helpful AI assistant."}, + }, + }, + } + bytes, _ := json.Marshal(payload) + return bytes +} + +// processGeminiStream processes SSE stream from Gemini API +func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) error { + reader := bufio.NewReader(body) + + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error())) + } + + line = strings.TrimSpace(line) + if line == "" || !strings.HasPrefix(line, "data: ") { + continue + } + + jsonStr := strings.TrimPrefix(line, "data: ") + if jsonStr == "[DONE]" { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + var data map[string]any + if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { + continue + } + + // Extract text from candidates[0].content.parts[].text + if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 { + if candidate, ok := candidates[0].(map[string]any); ok { + // Check for completion + if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + // Extract content + if content, ok := candidate["content"].(map[string]any); ok { + if parts, ok := content["parts"].([]any); ok { + for _, part := range parts { + if partMap, ok := part.(map[string]any); ok { + if text, ok := partMap["text"].(string); ok && text != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: text}) + } + } + } + } + } + } + } + + // Handle errors + if errData, ok := data["error"].(map[string]any); ok { + errorMsg := "Unknown error" + if msg, ok := errData["message"].(string); ok { + errorMsg = msg + } + return s.sendErrorAndEnd(c, errorMsg) + } + } +} + // createOpenAITestPayload creates a test payload for OpenAI Responses API func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any { payload := map[string]any{ diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 5b5ad7c8..3c3066d9 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -317,8 +317,17 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int selected = acc } else if acc.Priority == selected.Priority { // 优先级相同时,选最久未用的 - if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) { + switch { + case acc.LastUsedAt == nil && selected.LastUsedAt != nil: selected = acc + case acc.LastUsedAt != nil && selected.LastUsedAt == nil: + // keep selected (never used is preferred) + case acc.LastUsedAt == nil && selected.LastUsedAt == nil: + // keep selected (both never used) + default: + if acc.LastUsedAt.Before(*selected.LastUsedAt) { + selected = acc + } } } } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 49fe7135..3412270c 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -19,6 +19,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/service/ports" @@ -57,6 +58,11 @@ func NewGeminiMessagesCompatService( } } +// GetTokenProvider returns the token provider for OAuth accounts +func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider { + return s.tokenProvider +} + func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { cacheKey := "gemini:" + sessionHash if sessionHash != "" { @@ -94,8 +100,20 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, if acc.Priority < selected.Priority { selected = acc } else if acc.Priority == selected.Priority { - if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) { + switch { + case acc.LastUsedAt == nil && selected.LastUsedAt != nil: selected = acc + case acc.LastUsedAt != nil && selected.LastUsedAt == nil: + // keep selected (never used is preferred) + case acc.LastUsedAt == nil && selected.LastUsedAt == nil: + // Prefer OAuth accounts when both are unused (more compatible for Code Assist flows). + if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth { + selected = acc + } + default: + if acc.LastUsedAt.Before(*selected.LastUsedAt) { + selected = acc + } } } } @@ -114,6 +132,96 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, return selected, nil } +// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against +// generativelanguage.googleapis.com (e.g. GET /v1beta/models). +// +// Preference order: +// 1) API key accounts (AI Studio) +// 2) OAuth accounts without project_id (AI Studio OAuth) +// 3) OAuth accounts explicitly marked as ai_studio +// 4) Any remaining Gemini accounts (fallback) +func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*model.Account, error) { + var accounts []model.Account + var err error + if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini) + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini) + } + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + if len(accounts) == 0 { + return nil, errors.New("no available Gemini accounts") + } + + rank := func(a *model.Account) int { + if a == nil { + return 999 + } + switch a.Type { + case model.AccountTypeApiKey: + if strings.TrimSpace(a.GetCredential("api_key")) != "" { + return 0 + } + return 9 + case model.AccountTypeOAuth: + if strings.TrimSpace(a.GetCredential("project_id")) == "" { + return 1 + } + if strings.TrimSpace(a.GetCredential("oauth_type")) == "ai_studio" { + return 2 + } + // Code Assist OAuth tokens often lack AI Studio scopes for models listing. + return 3 + default: + return 10 + } + } + + var selected *model.Account + for i := range accounts { + acc := &accounts[i] + if selected == nil { + selected = acc + continue + } + + r1, r2 := rank(acc), rank(selected) + if r1 < r2 { + selected = acc + continue + } + if r1 > r2 { + continue + } + + if acc.Priority < selected.Priority { + selected = acc + } else if acc.Priority == selected.Priority { + switch { + case acc.LastUsedAt == nil && selected.LastUsedAt != nil: + selected = acc + case acc.LastUsedAt != nil && selected.LastUsedAt == nil: + // keep selected + case acc.LastUsedAt == nil && selected.LastUsedAt == nil: + if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth { + selected = acc + } + default: + if acc.LastUsedAt.Before(*selected.LastUsedAt) { + selected = acc + } + } + } + } + + if selected == nil { + return nil, errors.New("no available Gemini accounts") + } + return selected, nil +} + func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) { startTime := time.Now() @@ -146,6 +254,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex var requestIDHeader string var buildReq func(ctx context.Context) (*http.Request, string, error) + useUpstreamStream := req.Stream + if account.Type == model.AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" { + // Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate. + useUpstreamStream = true + } switch account.Type { case model.AccountTypeApiKey: @@ -190,38 +303,61 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex } projectID := strings.TrimSpace(account.GetCredential("project_id")) - if projectID == "" { - return nil, "", errors.New("missing project_id in account credentials") - } action := "generateContent" - if req.Stream { + if useUpstreamStream { action = "streamGenerateContent" } - fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action) - if req.Stream { - fullURL += "?alt=sse" - } - wrapped := map[string]any{ - "model": mappedModel, - "project": projectID, - } - var inner any - if err := json.Unmarshal(geminiReq, &inner); err != nil { - return nil, "", fmt.Errorf("failed to parse gemini request: %w", err) - } - wrapped["request"] = inner - wrappedBytes, _ := json.Marshal(wrapped) + // Two modes for OAuth: + // 1. With project_id -> Code Assist API (wrapped request) + // 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token) + if projectID != "" { + // Mode 1: Code Assist API + fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action) + if useUpstreamStream { + fullURL += "?alt=sse" + } - upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes)) - if err != nil { - return nil, "", err + wrapped := map[string]any{ + "model": mappedModel, + "project": projectID, + } + var inner any + if err := json.Unmarshal(geminiReq, &inner); err != nil { + return nil, "", fmt.Errorf("failed to parse gemini request: %w", err) + } + wrapped["request"] = inner + wrappedBytes, _ := json.Marshal(wrapped) + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) + return upstreamReq, "x-request-id", nil + } else { + // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) + baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") + if baseURL == "" { + baseURL = geminicli.AIStudioBaseURL + } + + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action) + if useUpstreamStream { + fullURL += "?alt=sse" + } + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + return upstreamReq, "x-request-id", nil } - upstreamReq.Header.Set("Content-Type", "application/json") - upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) - upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) - return upstreamReq, "x-request-id", nil } requestIDHeader = "x-request-id" @@ -301,9 +437,22 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex usage = streamRes.usage firstTokenMs = streamRes.firstTokenMs } else { - usage, err = s.handleNonStreamingResponse(c, resp, originalModel) - if err != nil { - return nil, err + if useUpstreamStream { + collected, usageObj, err := collectGeminiSSE(resp.Body, true) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream") + } + claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel) + c.JSON(http.StatusOK, claudeResp) + usage = usageObj2 + if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) { + usage = usageObj + } + } else { + usage, err = s.handleNonStreamingResponse(c, resp, originalModel) + if err != nil { + return nil, err + } } } @@ -317,6 +466,291 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex }, nil } +func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *model.Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) { + startTime := time.Now() + + if strings.TrimSpace(originalModel) == "" { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL") + } + if strings.TrimSpace(action) == "" { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL") + } + if len(body) == 0 { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") + } + + switch action { + case "generateContent", "streamGenerateContent", "countTokens": + // ok + default: + return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action) + } + + mappedModel := originalModel + if account.Type == model.AccountTypeApiKey { + mappedModel = account.GetMappedModel(originalModel) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + useUpstreamStream := stream + upstreamAction := action + if account.Type == model.AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" { + // Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate. + useUpstreamStream = true + upstreamAction = "streamGenerateContent" + } + forceAIStudio := action == "countTokens" + + var requestIDHeader string + var buildReq func(ctx context.Context) (*http.Request, string, error) + + switch account.Type { + case model.AccountTypeApiKey: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + apiKey := account.GetCredential("api_key") + if strings.TrimSpace(apiKey) == "" { + return nil, "", errors.New("Gemini api_key not configured") + } + + baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") + if baseURL == "" { + baseURL = geminicli.AIStudioBaseURL + } + + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction) + if useUpstreamStream { + fullURL += "?alt=sse" + } + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("x-goog-api-key", apiKey) + return upstreamReq, "x-request-id", nil + } + requestIDHeader = "x-request-id" + + case model.AccountTypeOAuth: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + if s.tokenProvider == nil { + return nil, "", errors.New("Gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, "", err + } + + projectID := strings.TrimSpace(account.GetCredential("project_id")) + + // Two modes for OAuth: + // 1. With project_id -> Code Assist API (wrapped request) + // 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token) + if projectID != "" && !forceAIStudio { + // Mode 1: Code Assist API + fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction) + if useUpstreamStream { + fullURL += "?alt=sse" + } + + wrapped := map[string]any{ + "model": mappedModel, + "project": projectID, + } + var inner any + if err := json.Unmarshal(body, &inner); err != nil { + return nil, "", fmt.Errorf("failed to parse gemini request: %w", err) + } + wrapped["request"] = inner + wrappedBytes, _ := json.Marshal(wrapped) + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) + return upstreamReq, "x-request-id", nil + } else { + // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) + baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") + if baseURL == "" { + baseURL = geminicli.AIStudioBaseURL + } + + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction) + if useUpstreamStream { + fullURL += "?alt=sse" + } + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + return upstreamReq, "x-request-id", nil + } + } + requestIDHeader = "x-request-id" + + default: + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type) + } + + var resp *http.Response + for attempt := 1; attempt <= geminiMaxRetries; attempt++ { + upstreamReq, idHeader, err := buildReq(ctx) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + // Local build error: don't retry. + if strings.Contains(err.Error(), "missing project_id") { + return nil, s.writeGoogleError(c, http.StatusBadRequest, err.Error()) + } + return nil, s.writeGoogleError(c, http.StatusBadGateway, err.Error()) + } + requestIDHeader = idHeader + + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) + if err != nil { + if attempt < geminiMaxRetries { + log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) + sleepGeminiBackoff(attempt) + continue + } + if action == "countTokens" { + estimated := estimateGeminiCountTokens(body) + c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) + return &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{}, + Model: originalModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, + }, nil + } + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") + } + + if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + if resp.StatusCode == 429 { + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + if attempt < geminiMaxRetries { + log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) + sleepGeminiBackoff(attempt) + continue + } + if action == "countTokens" { + estimated := estimateGeminiCountTokens(body) + c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) + return &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{}, + Model: originalModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, + }, nil + } + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") + } + + break + } + defer func() { _ = resp.Body.Close() }() + + requestID := resp.Header.Get(requestIDHeader) + if requestID == "" { + requestID = resp.Header.Get("x-goog-request-id") + } + if requestID != "" { + c.Header("x-request-id", requestID) + } + + isOAuth := account.Type == model.AccountTypeOAuth + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + + // Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens. + // This avoids Gemini SDKs failing hard during preflight token counting. + if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) { + estimated := estimateGeminiCountTokens(body) + c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) + return &ForwardResult{ + RequestID: requestID, + Usage: ClaudeUsage{}, + Model: originalModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, + }, nil + } + + respBody = unwrapIfNeeded(isOAuth, respBody) + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, respBody) + return nil, fmt.Errorf("gemini upstream error: %d", resp.StatusCode) + } + + var usage *ClaudeUsage + var firstTokenMs *int + + if stream { + streamRes, err := s.handleNativeStreamingResponse(c, resp, startTime, isOAuth) + if err != nil { + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } else { + if useUpstreamStream { + collected, usageObj, err := collectGeminiSSE(resp.Body, isOAuth) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Failed to read upstream stream") + } + b, _ := json.Marshal(collected) + c.Data(http.StatusOK, "application/json", b) + usage = usageObj + } else { + usageResp, err := s.handleNativeNonStreamingResponse(c, resp, isOAuth) + if err != nil { + return nil, err + } + usage = usageResp + } + } + + if usage == nil { + usage = &ClaudeUsage{} + } + + return &ForwardResult{ + RequestID: requestID, + Usage: *usage, + Model: originalModel, + Stream: stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *model.Account, statusCode int) bool { switch statusCode { case 429, 500, 502, 503, 504, 529: @@ -590,22 +1024,29 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re openBlockIndex := -1 openBlockType := "" seenText := "" + openToolIndex := -1 + openToolID := "" + openToolName := "" + seenToolJSON := "" reader := bufio.NewReader(resp.Body) for { line, err := reader.ReadString('\n') - if err != nil { - if errors.Is(err, io.EOF) { - break - } + if err != nil && !errors.Is(err, io.EOF) { return nil, fmt.Errorf("stream read error: %w", err) } if !strings.HasPrefix(line, "data:") { + if errors.Is(err, io.EOF) { + break + } continue } payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) if payload == "" || payload == "[DONE]" { + if errors.Is(err, io.EOF) { + break + } continue } @@ -670,7 +1111,7 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re name = "tool" } - // Close any open block before tool_use. + // Close any open text block before tool_use. if openBlockIndex >= 0 { writeSSE(c.Writer, "content_block_stop", map[string]any{ "type": "content_block_stop", @@ -680,40 +1121,63 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re openBlockType = "" } - toolID := "toolu_" + randomHex(8) - toolIndex := nextBlockIndex - nextBlockIndex++ - sawToolUse = true + // If we receive streamed tool args in pieces, keep a single tool block open and emit deltas. + if openToolIndex >= 0 && openToolName != name { + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": openToolIndex, + }) + openToolIndex = -1 + openToolID = "" + openToolName = "" + seenToolJSON = "" + } - writeSSE(c.Writer, "content_block_start", map[string]any{ - "type": "content_block_start", - "index": toolIndex, - "content_block": map[string]any{ - "type": "tool_use", - "id": toolID, - "name": name, - "input": map[string]any{}, - }, - }) + if openToolIndex < 0 { + openToolID = "toolu_" + randomHex(8) + openToolIndex = nextBlockIndex + openToolName = name + nextBlockIndex++ + sawToolUse = true - argsJSON := "{}" - if args != nil { - if b, err := json.Marshal(args); err == nil { - argsJSON = string(b) + writeSSE(c.Writer, "content_block_start", map[string]any{ + "type": "content_block_start", + "index": openToolIndex, + "content_block": map[string]any{ + "type": "tool_use", + "id": openToolID, + "name": name, + "input": map[string]any{}, + }, + }) + } + + argsJSONText := "{}" + switch v := args.(type) { + case nil: + // keep default "{}" + case string: + if strings.TrimSpace(v) != "" { + argsJSONText = v + } + default: + if b, err := json.Marshal(args); err == nil && len(b) > 0 { + argsJSONText = string(b) } } - writeSSE(c.Writer, "content_block_delta", map[string]any{ - "type": "content_block_delta", - "index": toolIndex, - "delta": map[string]any{ - "type": "input_json_delta", - "partial_json": argsJSON, - }, - }) - writeSSE(c.Writer, "content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": toolIndex, - }) + + delta, newSeen := computeGeminiTextDelta(seenToolJSON, argsJSONText) + seenToolJSON = newSeen + if delta != "" { + writeSSE(c.Writer, "content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": openToolIndex, + "delta": map[string]any{ + "type": "input_json_delta", + "partial_json": delta, + }, + }) + } flusher.Flush() } } @@ -721,6 +1185,11 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re if u := extractGeminiUsage(geminiResp); u != nil { usage = *u } + + // Process the final unterminated line at EOF as well. + if errors.Is(err, io.EOF) { + break + } } if openBlockIndex >= 0 { @@ -729,6 +1198,12 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re "index": openBlockIndex, }) } + if openToolIndex >= 0 { + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": openToolIndex, + }) + } stopReason := mapGeminiFinishReasonToClaudeStopReason(finishReason) if sawToolUse { @@ -779,6 +1254,369 @@ func (s *GeminiMessagesCompatService) writeClaudeError(c *gin.Context, status in return fmt.Errorf("%s", message) } +func (s *GeminiMessagesCompatService) writeGoogleError(c *gin.Context, status int, message string) error { + c.JSON(status, gin.H{ + "error": gin.H{ + "code": status, + "message": message, + "status": googleapi.HTTPStatusToGoogleStatus(status), + }, + }) + return fmt.Errorf("%s", message) +} + +func unwrapIfNeeded(isOAuth bool, raw []byte) []byte { + if !isOAuth { + return raw + } + inner, err := unwrapGeminiResponse(raw) + if err != nil { + return raw + } + b, err := json.Marshal(inner) + if err != nil { + return raw + } + return b +} + +func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsage, error) { + reader := bufio.NewReader(body) + + var last map[string]any + var lastWithParts map[string]any + usage := &ClaudeUsage{} + + for { + line, err := reader.ReadString('\n') + if len(line) > 0 { + trimmed := strings.TrimRight(line, "\r\n") + if strings.HasPrefix(trimmed, "data:") { + payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + switch payload { + case "", "[DONE]": + if payload == "[DONE]" { + return pickGeminiCollectResult(last, lastWithParts), usage, nil + } + default: + var parsed map[string]any + if isOAuth { + inner, err := unwrapGeminiResponse([]byte(payload)) + if err == nil && inner != nil { + parsed = inner + } + } else { + _ = json.Unmarshal([]byte(payload), &parsed) + } + if parsed != nil { + last = parsed + if u := extractGeminiUsage(parsed); u != nil { + usage = u + } + if parts := extractGeminiParts(parsed); len(parts) > 0 { + lastWithParts = parsed + } + } + } + } + } + + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, nil, err + } + } + + return pickGeminiCollectResult(last, lastWithParts), usage, nil +} + +func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any { + if lastWithParts != nil { + return lastWithParts + } + if last != nil { + return last + } + return map[string]any{} +} + +type geminiNativeStreamResult struct { + usage *ClaudeUsage + firstTokenMs *int +} + +func isGeminiInsufficientScope(headers http.Header, body []byte) bool { + if strings.Contains(strings.ToLower(headers.Get("Www-Authenticate")), "insufficient_scope") { + return true + } + lower := strings.ToLower(string(body)) + return strings.Contains(lower, "insufficient authentication scopes") || strings.Contains(lower, "access_token_scope_insufficient") +} + +func estimateGeminiCountTokens(reqBody []byte) int { + var obj map[string]any + if err := json.Unmarshal(reqBody, &obj); err != nil { + return 0 + } + + var texts []string + + // systemInstruction.parts[].text + if si, ok := obj["systemInstruction"].(map[string]any); ok { + if parts, ok := si["parts"].([]any); ok { + for _, p := range parts { + if pm, ok := p.(map[string]any); ok { + if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" { + texts = append(texts, t) + } + } + } + } + } + + // contents[].parts[].text + if contents, ok := obj["contents"].([]any); ok { + for _, c := range contents { + cm, ok := c.(map[string]any) + if !ok { + continue + } + parts, ok := cm["parts"].([]any) + if !ok { + continue + } + for _, p := range parts { + pm, ok := p.(map[string]any) + if !ok { + continue + } + if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" { + texts = append(texts, t) + } + } + } + } + + total := 0 + for _, t := range texts { + total += estimateTokensForText(t) + } + if total < 0 { + return 0 + } + return total +} + +func estimateTokensForText(s string) int { + s = strings.TrimSpace(s) + if s == "" { + return 0 + } + runes := []rune(s) + if len(runes) == 0 { + return 0 + } + ascii := 0 + for _, r := range runes { + if r <= 0x7f { + ascii++ + } + } + asciiRatio := float64(ascii) / float64(len(runes)) + if asciiRatio >= 0.8 { + // Roughly 4 chars per token for English-like text. + return (len(runes) + 3) / 4 + } + // For CJK-heavy text, approximate 1 rune per token. + return len(runes) +} + +type UpstreamHTTPResult struct { + StatusCode int + Headers http.Header + Body []byte +} + +func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var parsed map[string]any + if isOAuth { + parsed, err = unwrapGeminiResponse(respBody) + if err == nil && parsed != nil { + respBody, _ = json.Marshal(parsed) + } + } else { + _ = json.Unmarshal(respBody, &parsed) + } + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, respBody) + + if parsed != nil { + if u := extractGeminiUsage(parsed); u != nil { + return u, nil + } + } + return &ClaudeUsage{}, nil +} + +func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) { + c.Status(resp.StatusCode) + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "text/event-stream; charset=utf-8" + } + c.Header("Content-Type", contentType) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + reader := bufio.NewReader(resp.Body) + usage := &ClaudeUsage{} + var firstTokenMs *int + + for { + line, err := reader.ReadString('\n') + if len(line) > 0 { + trimmed := strings.TrimRight(line, "\r\n") + if strings.HasPrefix(trimmed, "data:") { + payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + // Keepalive / done markers + if payload == "" || payload == "[DONE]" { + _, _ = io.WriteString(c.Writer, line) + flusher.Flush() + } else { + var rawToWrite string + rawToWrite = payload + + var parsed map[string]any + if isOAuth { + inner, err := unwrapGeminiResponse([]byte(payload)) + if err == nil && inner != nil { + parsed = inner + if b, err := json.Marshal(inner); err == nil { + rawToWrite = string(b) + } + } + } else { + _ = json.Unmarshal([]byte(payload), &parsed) + } + + if parsed != nil { + if u := extractGeminiUsage(parsed); u != nil { + usage = u + } + } + + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + if isOAuth { + // SSE format requires double newline (\n\n) to separate events + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", rawToWrite) + } else { + // Pass-through for AI Studio responses. + _, _ = io.WriteString(c.Writer, line) + } + flusher.Flush() + } + } else { + _, _ = io.WriteString(c.Writer, line) + flusher.Flush() + } + } + + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, err + } + } + + return &geminiNativeStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil +} + +// ForwardAIStudioGET forwards a GET request to AI Studio (generativelanguage.googleapis.com) for +// endpoints like /v1beta/models and /v1beta/models/{model}. +// +// This is used to support Gemini SDKs that call models listing endpoints before generation. +func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *model.Account, path string) (*UpstreamHTTPResult, error) { + if account == nil { + return nil, errors.New("account is nil") + } + path = strings.TrimSpace(path) + if path == "" || !strings.HasPrefix(path, "/") { + return nil, errors.New("invalid path") + } + + baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") + if baseURL == "" { + baseURL = geminicli.AIStudioBaseURL + } + fullURL := strings.TrimRight(baseURL, "/") + path + + var proxyURL string + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + return nil, err + } + + switch account.Type { + case model.AccountTypeApiKey: + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if apiKey == "" { + return nil, errors.New("Gemini api_key not configured") + } + req.Header.Set("x-goog-api-key", apiKey) + case model.AccountTypeOAuth: + if s.tokenProvider == nil { + return nil, errors.New("Gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + default: + return nil, fmt.Errorf("unsupported account type: %s", account.Type) + } + + resp, err := s.httpUpstream.Do(req, proxyURL) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) + return &UpstreamHTTPResult{ + StatusCode: resp.StatusCode, + Headers: resp.Header.Clone(), + Body: body, + }, nil +} + func unwrapGeminiResponse(raw []byte) (map[string]any, error) { var outer map[string]any if err := json.Unmarshal(raw, &outer); err != nil { diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index b3dc3f09..5b418130 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -2,8 +2,12 @@ package service import ( "context" + "encoding/json" "errors" "fmt" + "io" + "net/http" + "net/url" "strconv" "strings" "time" @@ -43,7 +47,7 @@ type GeminiAuthURLResult struct { State string `json:"state"` } -func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*GeminiAuthURLResult, error) { +func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) { state, err := geminicli.GenerateState() if err != nil { return nil, fmt.Errorf("failed to generate state: %w", err) @@ -66,22 +70,38 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 } } - session := &geminicli.OAuthSession{ - State: state, - CodeVerifier: codeVerifier, - ProxyURL: proxyURL, - RedirectURI: redirectURI, - CreatedAt: time.Now(), - } - s.sessionStore.Set(sessionID, session) - + // 两种 OAuth 模式都使用相同的配置,只是 scopes 不同 + // scopes 会在 EffectiveOAuthConfig 中根据 oauthType 自动选择 oauthCfg := geminicli.OAuthConfig{ ClientID: s.cfg.Gemini.OAuth.ClientID, ClientSecret: s.cfg.Gemini.OAuth.ClientSecret, Scopes: s.cfg.Gemini.OAuth.Scopes, } - authURL, err := geminicli.BuildAuthorizationURL(oauthCfg, state, codeChallenge, redirectURI) + session := &geminicli.OAuthSession{ + State: state, + CodeVerifier: codeVerifier, + ProxyURL: proxyURL, + RedirectURI: redirectURI, + ProjectID: strings.TrimSpace(projectID), + OAuthType: oauthType, + CreatedAt: time.Now(), + } + s.sessionStore.Set(sessionID, session) + + effectiveCfg, err := geminicli.EffectiveOAuthConfig(oauthCfg, oauthType) + if err != nil { + return nil, err + } + + // For Code Assist with Gemini CLI credentials, use the CLI's redirect URI + if oauthType == "code_assist" { + redirectURI = geminicli.GeminiCLIRedirectURI + session.RedirectURI = redirectURI + s.sessionStore.Set(sessionID, session) + } + + authURL, err := geminicli.BuildAuthorizationURL(effectiveCfg, state, codeChallenge, redirectURI, session.ProjectID, oauthType) if err != nil { return nil, err } @@ -94,11 +114,11 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 } type GeminiExchangeCodeInput struct { - SessionID string - State string - Code string - RedirectURI string - ProxyID *int64 + SessionID string + State string + Code string + ProxyID *int64 + OAuthType string // "code_assist" 或 "ai_studio" } type GeminiTokenInfo struct { @@ -109,6 +129,7 @@ type GeminiTokenInfo struct { TokenType string `json:"token_type"` Scope string `json:"scope,omitempty"` ProjectID string `json:"project_id,omitempty"` + OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" } func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { @@ -129,19 +150,38 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch } redirectURI := session.RedirectURI - if strings.TrimSpace(input.RedirectURI) != "" { - redirectURI = input.RedirectURI - } tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL) if err != nil { return nil, fmt.Errorf("failed to exchange code: %w", err) } + sessionProjectID := strings.TrimSpace(session.ProjectID) + oauthType := session.OAuthType + if oauthType == "" { + oauthType = "code_assist" // 默认为 code_assist 以兼容旧 session + } s.sessionStore.Delete(input.SessionID) // 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差 expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 - projectID, _ := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + + projectID := sessionProjectID + + // 对于 code_assist 模式,project_id 是必需的 + // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) + if oauthType == "code_assist" { + if projectID == "" { + var err error + projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + if err != nil { + // 记录警告但不阻断流程,允许后续补充 project_id + fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) + } + } + if strings.TrimSpace(projectID) == "" { + return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project") + } + } return &GeminiTokenInfo{ AccessToken: tokenResp.AccessToken, @@ -151,6 +191,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ExpiresAt: expiresAt, Scope: tokenResp.Scope, ProjectID: projectID, + OAuthType: oauthType, }, nil } @@ -223,7 +264,39 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *m } } - return s.RefreshToken(ctx, refreshToken, proxyURL) + tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL) + if err != nil { + return nil, err + } + + // Preserve oauth_type from the account (defaults to code_assist for backward compatibility). + oauthType := strings.TrimSpace(account.GetCredential("oauth_type")) + if oauthType == "" { + oauthType = "code_assist" + } + tokenInfo.OAuthType = oauthType + + // Preserve account's project_id when present. + existingProjectID := strings.TrimSpace(account.GetCredential("project_id")) + if existingProjectID != "" { + tokenInfo.ProjectID = existingProjectID + } + + // For Code Assist, project_id is required. Auto-detect if missing. + // For AI Studio OAuth, project_id is optional and should not block refresh. + if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" { + projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) + if err != nil { + return nil, fmt.Errorf("failed to auto-detect project_id: %w", err) + } + projectID = strings.TrimSpace(projectID) + if projectID == "" { + return nil, fmt.Errorf("failed to auto-detect project_id: empty result") + } + tokenInfo.ProjectID = projectID + } + + return tokenInfo, nil } func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) map[string]any { @@ -243,6 +316,9 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) if tokenInfo.ProjectID != "" { creds["project_id"] = tokenInfo.ProjectID } + if tokenInfo.OAuthType != "" { + creds["oauth_type"] = tokenInfo.OAuthType + } return creds } @@ -255,20 +331,28 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr return "", errors.New("code assist client not configured") } - loadResp, err := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil) - if err == nil && strings.TrimSpace(loadResp.CurrentTier) != "" && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" { + loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil) + if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" { return strings.TrimSpace(loadResp.CloudAICompanionProject), nil } - // pick default tier from allowedTiers, fallback to LEGACY. + // Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID. tierID := "LEGACY" if loadResp != nil { for _, tier := range loadResp.AllowedTiers { if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { - tierID = tier.ID + tierID = strings.TrimSpace(tier.ID) break } } + if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" { + for _, tier := range loadResp.AllowedTiers { + if strings.TrimSpace(tier.ID) != "" { + tierID = strings.TrimSpace(tier.ID) + break + } + } + } } req := &geminicli.OnboardUserRequest{ @@ -284,24 +368,116 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr for attempt := 1; attempt <= maxAttempts; attempt++ { resp, err := s.codeAssist.OnboardUser(ctx, accessToken, proxyURL, req) if err != nil { + // If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects. + fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) + if fbErr == nil && strings.TrimSpace(fallback) != "" { + return strings.TrimSpace(fallback), nil + } return "", err } if resp.Done { - if resp.Response == nil || resp.Response.CloudAICompanionProject == nil { - return "", errors.New("onboardUser completed but no project_id returned") - } - switch v := resp.Response.CloudAICompanionProject.(type) { - case string: - return strings.TrimSpace(v), nil - case map[string]any: - if id, ok := v["id"].(string); ok { - return strings.TrimSpace(id), nil + if resp.Response != nil && resp.Response.CloudAICompanionProject != nil { + switch v := resp.Response.CloudAICompanionProject.(type) { + case string: + return strings.TrimSpace(v), nil + case map[string]any: + if id, ok := v["id"].(string); ok { + return strings.TrimSpace(id), nil + } } } - return "", errors.New("onboardUser returned unsupported project_id format") + + fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) + if fbErr == nil && strings.TrimSpace(fallback) != "" { + return strings.TrimSpace(fallback), nil + } + return "", errors.New("onboardUser completed but no project_id returned") } time.Sleep(2 * time.Second) } + fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) + if fbErr == nil && strings.TrimSpace(fallback) != "" { + return strings.TrimSpace(fallback), nil + } + if loadErr != nil { + return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) + } return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) } + +type googleCloudProject struct { + ProjectID string `json:"projectId"` + DisplayName string `json:"name"` + LifecycleState string `json:"lifecycleState"` +} + +type googleCloudProjectsResponse struct { + Projects []googleCloudProject `json:"projects"` +} + +func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyURL string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil) + if err != nil { + return "", fmt.Errorf("failed to create resource manager request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) + + client := &http.Client{Timeout: 30 * time.Second} + if strings.TrimSpace(proxyURL) != "" { + if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil { + client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)} + } + } + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("resource manager request failed: %w", err) + } + defer resp.Body.Close() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read resource manager response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("resource manager HTTP %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var projectsResp googleCloudProjectsResponse + if err := json.Unmarshal(bodyBytes, &projectsResp); err != nil { + return "", fmt.Errorf("failed to parse resource manager response: %w", err) + } + + active := make([]googleCloudProject, 0, len(projectsResp.Projects)) + for _, p := range projectsResp.Projects { + if p.LifecycleState == "ACTIVE" && strings.TrimSpace(p.ProjectID) != "" { + active = append(active, p) + } + } + if len(active) == 0 { + return "", errors.New("no ACTIVE projects found from resource manager") + } + + // Prefer likely companion projects first. + for _, p := range active { + id := strings.ToLower(strings.TrimSpace(p.ProjectID)) + name := strings.ToLower(strings.TrimSpace(p.DisplayName)) + if strings.Contains(id, "cloud-ai-companion") || strings.Contains(name, "cloud ai companion") || strings.Contains(name, "code assist") { + return strings.TrimSpace(p.ProjectID), nil + } + } + // Then prefer "default". + for _, p := range active { + id := strings.ToLower(strings.TrimSpace(p.ProjectID)) + name := strings.ToLower(strings.TrimSpace(p.DisplayName)) + if strings.Contains(id, "default") || strings.Contains(name, "default") { + return strings.TrimSpace(p.ProjectID), nil + } + } + + return strings.TrimSpace(active[0].ProjectID), nil +} diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 51a2f54a..8ed7133e 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -3,6 +3,7 @@ package service import ( "context" "errors" + "log" "strconv" "strings" "time" @@ -95,6 +96,40 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model return "", errors.New("access_token not found in credentials") } + // project_id is optional now: + // - If present: will use Code Assist API (requires project_id) + // - If absent: will use AI Studio API with OAuth token (like regular API key mode) + // Auto-detect project_id only if explicitly enabled via a credential flag + projectID := strings.TrimSpace(account.GetCredential("project_id")) + autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true" + + if projectID == "" && autoDetectProjectID { + if p.geminiOAuthService == nil { + return accessToken, nil // Fallback to AI Studio API mode + } + + var proxyURL string + if account.ProxyID != nil && p.geminiOAuthService.proxyRepo != nil { + if proxy, err := p.geminiOAuthService.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && proxy != nil { + proxyURL = proxy.URL() + } + } + + detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) + if err != nil { + log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err) + return accessToken, nil + } + detected = strings.TrimSpace(detected) + if detected != "" { + if account.Credentials == nil { + account.Credentials = model.JSONB{} + } + account.Credentials["project_id"] = detected + _ = p.accountRepo.Update(ctx, account) + } + } + // 3) Populate cache with TTL. if p.tokenCache != nil { ttl := 30 * time.Minute diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index ca3c2c36..377b2417 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -166,9 +166,18 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI if acc.Priority < selected.Priority { selected = acc } else if acc.Priority == selected.Priority { - // Same priority, select least recently used - if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) { + switch { + case acc.LastUsedAt == nil && selected.LastUsedAt != nil: selected = acc + case acc.LastUsedAt != nil && selected.LastUsedAt == nil: + // keep selected (never used is preferred) + case acc.LastUsedAt == nil && selected.LastUsedAt == nil: + // keep selected (both never used) + default: + // Same priority, select least recently used + if acc.LastUsedAt.Before(*selected.LastUsedAt) { + selected = acc + } } } }