- gofmt: 修复代码格式问题 - errcheck: 处理 WriteString 和 Close 返回值 - staticcheck: 错误信息改为小写开头 - staticcheck: 移除无效的 nil 检查 - staticcheck: 使用 append 替换循环 - staticcheck: 使用无条件的 TrimPrefix - ineffassign: 移除无效赋值 - unused: 移除未使用的 geminiOAuthService 字段 - 重新生成 wire_gen.go
198 lines
4.8 KiB
Go
198 lines
4.8 KiB
Go
package geminicli
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type OAuthConfig struct {
|
|
ClientID string
|
|
ClientSecret string
|
|
Scopes string
|
|
}
|
|
|
|
type OAuthSession struct {
|
|
State string `json:"state"`
|
|
CodeVerifier string `json:"code_verifier"`
|
|
ProxyURL string `json:"proxy_url,omitempty"`
|
|
RedirectURI string `json:"redirect_uri"`
|
|
ProjectID string `json:"project_id,omitempty"`
|
|
OAuthType string `json:"oauth_type"` // "code_assist" 或 "ai_studio"
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
type SessionStore struct {
|
|
mu sync.RWMutex
|
|
sessions map[string]*OAuthSession
|
|
stopCh chan struct{}
|
|
}
|
|
|
|
func NewSessionStore() *SessionStore {
|
|
store := &SessionStore{
|
|
sessions: make(map[string]*OAuthSession),
|
|
stopCh: make(chan struct{}),
|
|
}
|
|
go store.cleanup()
|
|
return store
|
|
}
|
|
|
|
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.sessions[sessionID] = session
|
|
}
|
|
|
|
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
session, ok := s.sessions[sessionID]
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
if time.Since(session.CreatedAt) > SessionTTL {
|
|
return nil, false
|
|
}
|
|
return session, true
|
|
}
|
|
|
|
func (s *SessionStore) Delete(sessionID string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
delete(s.sessions, sessionID)
|
|
}
|
|
|
|
func (s *SessionStore) Stop() {
|
|
select {
|
|
case <-s.stopCh:
|
|
return
|
|
default:
|
|
close(s.stopCh)
|
|
}
|
|
}
|
|
|
|
func (s *SessionStore) cleanup() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-s.stopCh:
|
|
return
|
|
case <-ticker.C:
|
|
s.mu.Lock()
|
|
for id, session := range s.sessions {
|
|
if time.Since(session.CreatedAt) > SessionTTL {
|
|
delete(s.sessions, id)
|
|
}
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
func GenerateRandomBytes(n int) ([]byte, error) {
|
|
b := make([]byte, n)
|
|
_, err := rand.Read(b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func GenerateState() (string, error) {
|
|
bytes, err := GenerateRandomBytes(32)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return base64URLEncode(bytes), nil
|
|
}
|
|
|
|
func GenerateSessionID() (string, error) {
|
|
bytes, err := GenerateRandomBytes(16)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(bytes), nil
|
|
}
|
|
|
|
// GenerateCodeVerifier returns an RFC 7636 compatible code verifier (43+ chars).
|
|
func GenerateCodeVerifier() (string, error) {
|
|
bytes, err := GenerateRandomBytes(32)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return base64URLEncode(bytes), nil
|
|
}
|
|
|
|
func GenerateCodeChallenge(verifier string) string {
|
|
hash := sha256.Sum256([]byte(verifier))
|
|
return base64URLEncode(hash[:])
|
|
}
|
|
|
|
func base64URLEncode(data []byte) string {
|
|
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
|
|
}
|
|
|
|
// EffectiveOAuthConfig returns the effective OAuth configuration.
|
|
// oauthType: "code_assist" or "ai_studio" (defaults to "code_assist" if empty)
|
|
// Returns error if ClientID or ClientSecret is not configured.
|
|
// Configure via GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET environment variables.
|
|
func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error) {
|
|
effective := OAuthConfig{
|
|
ClientID: strings.TrimSpace(cfg.ClientID),
|
|
ClientSecret: strings.TrimSpace(cfg.ClientSecret),
|
|
Scopes: strings.TrimSpace(cfg.Scopes),
|
|
}
|
|
|
|
// Require OAuth credentials to be configured
|
|
if effective.ClientID == "" || effective.ClientSecret == "" {
|
|
return OAuthConfig{}, fmt.Errorf("gemini OAuth credentials not configured, set GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET environment variables")
|
|
}
|
|
|
|
if effective.Scopes == "" {
|
|
// Use different default scopes based on OAuth type
|
|
if oauthType == "ai_studio" {
|
|
effective.Scopes = DefaultAIStudioScopes
|
|
} else {
|
|
// Default to Code Assist scopes
|
|
effective.Scopes = DefaultCodeAssistScopes
|
|
}
|
|
}
|
|
|
|
return effective, nil
|
|
}
|
|
|
|
func BuildAuthorizationURL(cfg OAuthConfig, state, codeChallenge, redirectURI, projectID, oauthType string) (string, error) {
|
|
effectiveCfg, err := EffectiveOAuthConfig(cfg, oauthType)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
redirectURI = strings.TrimSpace(redirectURI)
|
|
if redirectURI == "" {
|
|
return "", fmt.Errorf("redirect_uri is required")
|
|
}
|
|
|
|
params := url.Values{}
|
|
params.Set("response_type", "code")
|
|
params.Set("client_id", effectiveCfg.ClientID)
|
|
params.Set("redirect_uri", redirectURI)
|
|
params.Set("scope", effectiveCfg.Scopes)
|
|
params.Set("state", state)
|
|
params.Set("code_challenge", codeChallenge)
|
|
params.Set("code_challenge_method", "S256")
|
|
params.Set("access_type", "offline")
|
|
params.Set("prompt", "consent")
|
|
params.Set("include_granted_scopes", "true")
|
|
if strings.TrimSpace(projectID) != "" {
|
|
params.Set("project_id", strings.TrimSpace(projectID))
|
|
}
|
|
|
|
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()), nil
|
|
}
|