feat(gemini): 优化 OAuth 和配额展示
主要改进: - 修复 google_one OAuth scopes 配置问题 - 添加 Gemini 账号配额展示组件 - 优化 Code Assist 类型检测逻辑 - 添加 OAuth 测试用例
This commit is contained in:
@@ -18,7 +18,6 @@ func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *Gemi
|
||||
return &GeminiOAuthHandler{geminiOAuthService: geminiOAuthService}
|
||||
}
|
||||
|
||||
// GetCapabilities retrieves OAuth configuration capabilities.
|
||||
// GET /api/v1/admin/gemini/oauth/capabilities
|
||||
func (h *GeminiOAuthHandler) GetCapabilities(c *gin.Context) {
|
||||
cfg := h.geminiOAuthService.GetOAuthConfig()
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// LoadCodeAssistRequest matches done-hub's internal Code Assist call.
|
||||
type LoadCodeAssistRequest struct {
|
||||
Metadata LoadCodeAssistMetadata `json:"metadata"`
|
||||
@@ -11,12 +16,51 @@ type LoadCodeAssistMetadata struct {
|
||||
PluginType string `json:"pluginType"`
|
||||
}
|
||||
|
||||
type TierInfo struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON supports both legacy string tiers and object tiers.
|
||||
func (t *TierInfo) UnmarshalJSON(data []byte) error {
|
||||
data = bytes.TrimSpace(data)
|
||||
if len(data) == 0 || string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
if data[0] == '"' {
|
||||
var id string
|
||||
if err := json.Unmarshal(data, &id); err != nil {
|
||||
return err
|
||||
}
|
||||
t.ID = id
|
||||
return nil
|
||||
}
|
||||
type alias TierInfo
|
||||
var decoded alias
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
return err
|
||||
}
|
||||
*t = TierInfo(decoded)
|
||||
return nil
|
||||
}
|
||||
|
||||
type LoadCodeAssistResponse struct {
|
||||
CurrentTier string `json:"currentTier,omitempty"`
|
||||
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
||||
PaidTier *TierInfo `json:"paidTier,omitempty"`
|
||||
CloudAICompanionProject string `json:"cloudaicompanionProject,omitempty"`
|
||||
AllowedTiers []AllowedTier `json:"allowedTiers,omitempty"`
|
||||
}
|
||||
|
||||
// GetTier extracts tier ID, prioritizing paidTier over currentTier
|
||||
func (r *LoadCodeAssistResponse) GetTier() string {
|
||||
if r.PaidTier != nil && r.PaidTier.ID != "" {
|
||||
return r.PaidTier.ID
|
||||
}
|
||||
if r.CurrentTier != nil {
|
||||
return r.CurrentTier.ID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type AllowedTier struct {
|
||||
ID string `json:"id"`
|
||||
IsDefault bool `json:"isDefault,omitempty"`
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
// Package geminicli provides OAuth authentication and API client functionality
|
||||
// for Google's Gemini AI services, supporting both AI Studio and Code Assist endpoints.
|
||||
package geminicli
|
||||
|
||||
import "time"
|
||||
@@ -29,7 +27,9 @@ const (
|
||||
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
|
||||
|
||||
// DefaultScopes for Google One (personal Google accounts with Gemini access)
|
||||
// Includes generative-language for Gemini API access and drive.readonly for storage tier detection
|
||||
// Only used when a custom OAuth client is configured. When using the built-in Gemini CLI client,
|
||||
// Google One uses DefaultCodeAssistScopes (same as code_assist) because the built-in client
|
||||
// cannot request restricted scopes like generative-language.retriever or drive.readonly.
|
||||
DefaultGoogleOneScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
|
||||
|
||||
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
|
||||
|
||||
@@ -181,19 +181,23 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
|
||||
effective.Scopes = DefaultAIStudioScopes
|
||||
}
|
||||
case "google_one":
|
||||
// Google One accounts need generative-language scope for Gemini API access
|
||||
// and drive.readonly scope for storage tier detection
|
||||
effective.Scopes = DefaultGoogleOneScopes
|
||||
// Google One uses built-in Gemini CLI client (same as code_assist)
|
||||
// Built-in client can't request restricted scopes like generative-language.retriever
|
||||
if isBuiltinClient {
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
} else {
|
||||
effective.Scopes = DefaultGoogleOneScopes
|
||||
}
|
||||
default:
|
||||
// Default to Code Assist scopes
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
}
|
||||
} else if oauthType == "ai_studio" && isBuiltinClient {
|
||||
} else if (oauthType == "ai_studio" || oauthType == "google_one") && isBuiltinClient {
|
||||
// If user overrides scopes while still using the built-in client, strip restricted scopes.
|
||||
parts := strings.Fields(effective.Scopes)
|
||||
filtered := make([]string, 0, len(parts))
|
||||
for _, s := range parts {
|
||||
if strings.Contains(s, "generative-language") {
|
||||
if hasRestrictedScope(s) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, s)
|
||||
@@ -219,6 +223,11 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
|
||||
return effective, nil
|
||||
}
|
||||
|
||||
func hasRestrictedScope(scope string) bool {
|
||||
return strings.HasPrefix(scope, "https://www.googleapis.com/auth/generative-language") ||
|
||||
strings.HasPrefix(scope, "https://www.googleapis.com/auth/drive")
|
||||
}
|
||||
|
||||
func BuildAuthorizationURL(cfg OAuthConfig, state, codeChallenge, redirectURI, projectID, oauthType string) (string, error) {
|
||||
effectiveCfg, err := EffectiveOAuthConfig(cfg, oauthType)
|
||||
if err != nil {
|
||||
|
||||
113
backend/internal/pkg/geminicli/oauth_test.go
Normal file
113
backend/internal/pkg/geminicli/oauth_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input OAuthConfig
|
||||
oauthType string
|
||||
wantClientID string
|
||||
wantScopes string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Google One with built-in client (empty config)",
|
||||
input: OAuthConfig{},
|
||||
oauthType: "google_one",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
wantScopes: DefaultCodeAssistScopes,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with custom client",
|
||||
input: OAuthConfig{
|
||||
ClientID: "custom-client-id",
|
||||
ClientSecret: "custom-client-secret",
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: "custom-client-id",
|
||||
wantScopes: DefaultGoogleOneScopes,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with built-in client and custom scopes (should filter restricted scopes)",
|
||||
input: OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
wantScopes: "https://www.googleapis.com/auth/cloud-platform",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with built-in client and only restricted scopes (should fallback to default)",
|
||||
input: OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
wantScopes: DefaultCodeAssistScopes,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Code Assist with built-in client",
|
||||
input: OAuthConfig{},
|
||||
oauthType: "code_assist",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
wantScopes: DefaultCodeAssistScopes,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := EffectiveOAuthConfig(tt.input, tt.oauthType)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("EffectiveOAuthConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if got.ClientID != tt.wantClientID {
|
||||
t.Errorf("EffectiveOAuthConfig() ClientID = %v, want %v", got.ClientID, tt.wantClientID)
|
||||
}
|
||||
if got.Scopes != tt.wantScopes {
|
||||
t.Errorf("EffectiveOAuthConfig() Scopes = %v, want %v", got.Scopes, tt.wantScopes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
|
||||
// Test that Google One with built-in client filters out restricted scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile",
|
||||
}, "google_one")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
|
||||
// Should only contain cloud-platform, userinfo.email, and userinfo.profile
|
||||
// Should NOT contain generative-language or drive scopes
|
||||
if strings.Contains(cfg.Scopes, "generative-language") {
|
||||
t.Errorf("Scopes should not contain generative-language when using built-in client, got: %v", cfg.Scopes)
|
||||
}
|
||||
if strings.Contains(cfg.Scopes, "drive") {
|
||||
t.Errorf("Scopes should not contain drive when using built-in client, got: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "cloud-platform") {
|
||||
t.Errorf("Scopes should contain cloud-platform, got: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "userinfo.email") {
|
||||
t.Errorf("Scopes should contain userinfo.email, got: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "userinfo.profile") {
|
||||
t.Errorf("Scopes should contain userinfo.profile, got: %v", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
@@ -273,7 +273,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
return 999
|
||||
}
|
||||
switch a.Type {
|
||||
case AccountTypeAPIKey:
|
||||
case AccountTypeApiKey:
|
||||
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
|
||||
return 0
|
||||
}
|
||||
@@ -351,7 +351,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
originalModel := req.Model
|
||||
mappedModel := req.Model
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mappedModel = account.GetMappedModel(req.Model)
|
||||
}
|
||||
|
||||
@@ -374,7 +374,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeAPIKey:
|
||||
case AccountTypeApiKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -539,7 +539,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
tempMatched := false
|
||||
if s.rateLimitService != nil {
|
||||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
||||
}
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
if tempMatched {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
@@ -614,7 +621,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
|
||||
@@ -636,7 +643,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
var buildReq func(ctx context.Context) (*http.Request, string, error)
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeAPIKey:
|
||||
case AccountTypeApiKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -825,6 +832,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
tempMatched := false
|
||||
if s.rateLimitService != nil {
|
||||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
||||
}
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
|
||||
@@ -842,6 +853,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}, nil
|
||||
}
|
||||
|
||||
if tempMatched {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
@@ -1758,7 +1772,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeAPIKey:
|
||||
case AccountTypeApiKey:
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if apiKey == "" {
|
||||
return nil, errors.New("gemini api_key not configured")
|
||||
@@ -2177,10 +2191,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
|
||||
parts := make([]any, 0)
|
||||
switch content := mm["content"].(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(content) != "" {
|
||||
parts = append(parts, map[string]any{"text": content})
|
||||
}
|
||||
// 字符串形式的 content,保留所有内容(包括空白)
|
||||
parts = append(parts, map[string]any{"text": content})
|
||||
case []any:
|
||||
// 如果只有一个 block,不过滤空白(让上游 API 报错)
|
||||
singleBlock := len(content) == 1
|
||||
|
||||
for _, block := range content {
|
||||
bm, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
@@ -2189,8 +2205,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
|
||||
bt, _ := bm["type"].(string)
|
||||
switch bt {
|
||||
case "text":
|
||||
if text, ok := bm["text"].(string); ok && strings.TrimSpace(text) != "" {
|
||||
parts = append(parts, map[string]any{"text": text})
|
||||
if text, ok := bm["text"].(string); ok {
|
||||
// 单个 block 时保留所有内容(包括空白)
|
||||
// 多个 blocks 时过滤掉空白
|
||||
if singleBlock || strings.TrimSpace(text) != "" {
|
||||
parts = append(parts, map[string]any{"text": text})
|
||||
}
|
||||
}
|
||||
case "tool_use":
|
||||
id, _ := bm["id"].(string)
|
||||
|
||||
@@ -251,8 +251,20 @@ func inferGoogleOneTier(storageBytes int64) string {
|
||||
return TierGoogleOneUnknown
|
||||
}
|
||||
|
||||
// FetchGoogleOneTier fetches Google One tier from Drive API
|
||||
// fetchGoogleOneTier fetches Google One tier from Drive API or LoadCodeAssist API
|
||||
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
|
||||
// First try LoadCodeAssist API (works for accounts with GCP projects)
|
||||
if s.codeAssist != nil {
|
||||
loadResp, err := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
|
||||
if err == nil && loadResp != nil {
|
||||
if tier := loadResp.GetTier(); tier != "" {
|
||||
fmt.Printf("[GeminiOAuth] Got tier from LoadCodeAssist: %s\n", tier)
|
||||
return tier, nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to Drive API (requires drive.readonly scope)
|
||||
driveClient := geminicli.NewDriveClient()
|
||||
|
||||
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
|
||||
@@ -422,12 +434,15 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
}
|
||||
case "google_one":
|
||||
// Attempt to fetch Drive storage tier
|
||||
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
|
||||
var storageInfo *geminicli.DriveStorageInfo
|
||||
var err error
|
||||
tierID, storageInfo, err = s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
// Log warning but don't block - use fallback
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
|
||||
tierID = TierGoogleOneUnknown
|
||||
}
|
||||
fmt.Printf("[GeminiOAuth] Google One tierID after fetch: %s\n", tierID)
|
||||
|
||||
// Store Drive info in extra field for caching
|
||||
if storageInfo != nil {
|
||||
@@ -452,7 +467,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
}
|
||||
// ai_studio 模式不设置 tierID,保持为空
|
||||
|
||||
return &GeminiTokenInfo{
|
||||
result := &GeminiTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
@@ -462,7 +477,9 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
ProjectID: projectID,
|
||||
TierID: tierID,
|
||||
OAuthType: oauthType,
|
||||
}, nil
|
||||
}
|
||||
fmt.Printf("[GeminiOAuth] ExchangeCode returning tierID: %s\n", result.TierID)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*GeminiTokenInfo, error) {
|
||||
@@ -669,6 +686,9 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
|
||||
// Validate tier_id before storing
|
||||
if err := validateTierID(tokenInfo.TierID); err == nil {
|
||||
creds["tier_id"] = tokenInfo.TierID
|
||||
fmt.Printf("[GeminiOAuth] Storing tier_id: %s\n", tokenInfo.TierID)
|
||||
} else {
|
||||
fmt.Printf("[GeminiOAuth] Invalid tier_id %s: %v\n", tokenInfo.TierID, err)
|
||||
}
|
||||
// Silently skip invalid tier_id (don't block account creation)
|
||||
}
|
||||
@@ -698,7 +718,13 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
|
||||
tierID := "LEGACY"
|
||||
if loadResp != nil {
|
||||
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
|
||||
// First try to get tier from currentTier/paidTier fields
|
||||
if tier := loadResp.GetTier(); tier != "" {
|
||||
tierID = tier
|
||||
} else {
|
||||
// Fallback to extracting from allowedTiers
|
||||
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
|
||||
}
|
||||
}
|
||||
|
||||
// If LoadCodeAssist returned a project, use it
|
||||
|
||||
Reference in New Issue
Block a user