Files
sub2api/backend/internal/service/antigravity_oauth_service.go

268 lines
7.1 KiB
Go

package service
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
type AntigravityOAuthService struct {
sessionStore *antigravity.SessionStore
proxyRepo ProxyRepository
}
func NewAntigravityOAuthService(proxyRepo ProxyRepository) *AntigravityOAuthService {
return &AntigravityOAuthService{
sessionStore: antigravity.NewSessionStore(),
proxyRepo: proxyRepo,
}
}
// AntigravityAuthURLResult is the result of generating an authorization URL
type AntigravityAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
}
// GenerateAuthURL 生成 Google OAuth 授权链接
func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*AntigravityAuthURLResult, error) {
state, err := antigravity.GenerateState()
if err != nil {
return nil, fmt.Errorf("生成 state 失败: %w", err)
}
codeVerifier, err := antigravity.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("生成 code_verifier 失败: %w", err)
}
sessionID, err := antigravity.GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("生成 session_id 失败: %w", err)
}
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
session := &antigravity.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
}
s.sessionStore.Set(sessionID, session)
codeChallenge := antigravity.GenerateCodeChallenge(codeVerifier)
authURL := antigravity.BuildAuthorizationURL(state, codeChallenge)
return &AntigravityAuthURLResult{
AuthURL: authURL,
SessionID: sessionID,
State: state,
}, nil
}
// AntigravityExchangeCodeInput 交换 code 的输入
type AntigravityExchangeCodeInput struct {
SessionID string
State string
Code string
ProxyID *int64
}
// AntigravityTokenInfo token 信息
type AntigravityTokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"`
Email string `json:"email,omitempty"`
ProjectID string `json:"project_id,omitempty"`
}
// ExchangeCode 用 authorization code 交换 token
func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *AntigravityExchangeCodeInput) (*AntigravityTokenInfo, error) {
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session 不存在或已过期")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
return nil, fmt.Errorf("state 无效")
}
// 确定代理 URL
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
client := antigravity.NewClient(proxyURL)
// 交换 token
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
if err != nil {
return nil, fmt.Errorf("Token 交换失败: %w", err)
}
// 删除 session
s.sessionStore.Delete(input.SessionID)
// 计算过期时间(减去 5 分钟安全窗口)
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
result := &AntigravityTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
TokenType: tokenResp.TokenType,
}
// 获取用户信息
userInfo, err := client.GetUserInfo(ctx, tokenResp.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
} else {
result.Email = userInfo.Email
}
// 获取 project_id
loadResp, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
result.ProjectID = loadResp.CloudAICompanionProject
}
return result, nil
}
// RefreshToken 刷新 token
func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*AntigravityTokenInfo, error) {
var lastErr error
for attempt := 0; attempt <= 3; attempt++ {
if attempt > 0 {
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
time.Sleep(backoff)
}
client := antigravity.NewClient(proxyURL)
tokenResp, err := client.RefreshToken(ctx, refreshToken)
if err == nil {
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
return &AntigravityTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
TokenType: tokenResp.TokenType,
}, nil
}
if isNonRetryableAntigravityOAuthError(err) {
return nil, err
}
lastErr = err
}
return nil, fmt.Errorf("Token 刷新失败 (重试后): %w", lastErr)
}
func isNonRetryableAntigravityOAuthError(err error) bool {
msg := err.Error()
nonRetryable := []string{
"invalid_grant",
"invalid_client",
"unauthorized_client",
"access_denied",
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {
return true
}
}
return false
}
// RefreshAccountToken 刷新账户的 token
func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*AntigravityTokenInfo, error) {
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("非 Antigravity OAuth 账户")
}
refreshToken := account.GetCredential("refresh_token")
if strings.TrimSpace(refreshToken) == "" {
return nil, fmt.Errorf("无可用的 refresh_token")
}
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
if err != nil {
return nil, err
}
// 保留原有的 project_id 和 email
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
if existingProjectID != "" {
tokenInfo.ProjectID = existingProjectID
}
existingEmail := strings.TrimSpace(account.GetCredential("email"))
if existingEmail != "" {
tokenInfo.Email = existingEmail
}
return tokenInfo, nil
}
// BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.TokenType != "" {
creds["token_type"] = tokenInfo.TokenType
}
if tokenInfo.Email != "" {
creds["email"] = tokenInfo.Email
}
if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID
}
return creds
}
// Stop 停止服务
func (s *AntigravityOAuthService) Stop() {
s.sessionStore.Stop()
}