feat(gemini): 优化 OAuth 和配额展示

主要改进:
- 修复 google_one OAuth scopes 配置问题
- 添加 Gemini 账号配额展示组件
- 优化 Code Assist 类型检测逻辑
- 添加 OAuth 测试用例
This commit is contained in:
ianshaw
2026-01-03 06:32:04 -08:00
parent 26438f7232
commit 26106eb0ac
11 changed files with 363 additions and 109 deletions

View File

@@ -273,7 +273,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
return 999
}
switch a.Type {
case AccountTypeAPIKey:
case AccountTypeApiKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0
}
@@ -351,7 +351,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model
mappedModel := req.Model
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeApiKey {
mappedModel = account.GetMappedModel(req.Model)
}
@@ -374,7 +374,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
switch account.Type {
case AccountTypeAPIKey:
case AccountTypeApiKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -539,7 +539,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if tempMatched {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
@@ -614,7 +621,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
mappedModel := originalModel
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeApiKey {
mappedModel = account.GetMappedModel(originalModel)
}
@@ -636,7 +643,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type {
case AccountTypeAPIKey:
case AccountTypeApiKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -825,6 +832,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
@@ -842,6 +853,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil
}
if tempMatched {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
@@ -1758,7 +1772,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
}
switch account.Type {
case AccountTypeAPIKey:
case AccountTypeApiKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" {
return nil, errors.New("gemini api_key not configured")
@@ -2177,10 +2191,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
parts := make([]any, 0)
switch content := mm["content"].(type) {
case string:
if strings.TrimSpace(content) != "" {
parts = append(parts, map[string]any{"text": content})
}
// 字符串形式的 content保留所有内容包括空白
parts = append(parts, map[string]any{"text": content})
case []any:
// 如果只有一个 block不过滤空白让上游 API 报错)
singleBlock := len(content) == 1
for _, block := range content {
bm, ok := block.(map[string]any)
if !ok {
@@ -2189,8 +2205,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
bt, _ := bm["type"].(string)
switch bt {
case "text":
if text, ok := bm["text"].(string); ok && strings.TrimSpace(text) != "" {
parts = append(parts, map[string]any{"text": text})
if text, ok := bm["text"].(string); ok {
// 单个 block 时保留所有内容(包括空白)
// 多个 blocks 时过滤掉空白
if singleBlock || strings.TrimSpace(text) != "" {
parts = append(parts, map[string]any{"text": text})
}
}
case "tool_use":
id, _ := bm["id"].(string)

View File

@@ -251,8 +251,20 @@ func inferGoogleOneTier(storageBytes int64) string {
return TierGoogleOneUnknown
}
// FetchGoogleOneTier fetches Google One tier from Drive API
// fetchGoogleOneTier fetches Google One tier from Drive API or LoadCodeAssist API
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
// First try LoadCodeAssist API (works for accounts with GCP projects)
if s.codeAssist != nil {
loadResp, err := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
if err == nil && loadResp != nil {
if tier := loadResp.GetTier(); tier != "" {
fmt.Printf("[GeminiOAuth] Got tier from LoadCodeAssist: %s\n", tier)
return tier, nil, nil
}
}
}
// Fallback to Drive API (requires drive.readonly scope)
driveClient := geminicli.NewDriveClient()
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
@@ -422,12 +434,15 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
}
case "google_one":
// Attempt to fetch Drive storage tier
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
var storageInfo *geminicli.DriveStorageInfo
var err error
tierID, storageInfo, err = s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// Log warning but don't block - use fallback
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
tierID = TierGoogleOneUnknown
}
fmt.Printf("[GeminiOAuth] Google One tierID after fetch: %s\n", tierID)
// Store Drive info in extra field for caching
if storageInfo != nil {
@@ -452,7 +467,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
}
// ai_studio 模式不设置 tierID保持为空
return &GeminiTokenInfo{
result := &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
@@ -462,7 +477,9 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType,
}, nil
}
fmt.Printf("[GeminiOAuth] ExchangeCode returning tierID: %s\n", result.TierID)
return result, nil
}
func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*GeminiTokenInfo, error) {
@@ -669,6 +686,9 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
// Validate tier_id before storing
if err := validateTierID(tokenInfo.TierID); err == nil {
creds["tier_id"] = tokenInfo.TierID
fmt.Printf("[GeminiOAuth] Storing tier_id: %s\n", tokenInfo.TierID)
} else {
fmt.Printf("[GeminiOAuth] Invalid tier_id %s: %v\n", tokenInfo.TierID, err)
}
// Silently skip invalid tier_id (don't block account creation)
}
@@ -698,7 +718,13 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
tierID := "LEGACY"
if loadResp != nil {
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
// First try to get tier from currentTier/paidTier fields
if tier := loadResp.GetTier(); tier != "" {
tierID = tier
} else {
// Fallback to extracting from allowedTiers
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
}
}
// If LoadCodeAssist returned a project, use it