Files
sub2api/backend/internal/service/antigravity_oauth_service.go
huangzhenpc d274c8cb14
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled
feat: 品牌重命名 Sub2API -> TianShuAPI
- 前端: 所有界面显示、i18n 文本、组件中的品牌名称
- 后端: 服务层、设置默认值、邮件模板、安装向导
- 数据库: 迁移脚本注释
- 保持功能完全一致,仅更改品牌名称

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-04 17:50:29 +08:00

277 lines
7.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
type AntigravityOAuthService struct {
sessionStore *antigravity.SessionStore
proxyRepo ProxyRepository
}
func NewAntigravityOAuthService(proxyRepo ProxyRepository) *AntigravityOAuthService {
return &AntigravityOAuthService{
sessionStore: antigravity.NewSessionStore(),
proxyRepo: proxyRepo,
}
}
// AntigravityAuthURLResult is the result of generating an authorization URL
type AntigravityAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
}
// GenerateAuthURL 生成 Google OAuth 授权链接
func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*AntigravityAuthURLResult, error) {
state, err := antigravity.GenerateState()
if err != nil {
return nil, fmt.Errorf("生成 state 失败: %w", err)
}
codeVerifier, err := antigravity.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("生成 code_verifier 失败: %w", err)
}
sessionID, err := antigravity.GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("生成 session_id 失败: %w", err)
}
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
session := &antigravity.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
}
s.sessionStore.Set(sessionID, session)
codeChallenge := antigravity.GenerateCodeChallenge(codeVerifier)
authURL := antigravity.BuildAuthorizationURL(state, codeChallenge)
return &AntigravityAuthURLResult{
AuthURL: authURL,
SessionID: sessionID,
State: state,
}, nil
}
// AntigravityExchangeCodeInput 交换 code 的输入
type AntigravityExchangeCodeInput struct {
SessionID string
State string
Code string
ProxyID *int64
}
// AntigravityTokenInfo token 信息
type AntigravityTokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"`
Email string `json:"email,omitempty"`
ProjectID string `json:"project_id,omitempty"`
}
// ExchangeCode 用 authorization code 交换 token
func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *AntigravityExchangeCodeInput) (*AntigravityTokenInfo, error) {
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session 不存在或已过期")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
return nil, fmt.Errorf("state 无效")
}
// 确定代理 URL
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
client := antigravity.NewClient(proxyURL)
// 交换 token
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
if err != nil {
return nil, fmt.Errorf("token 交换失败: %w", err)
}
// 删除 session
s.sessionStore.Delete(input.SessionID)
// 计算过期时间(减去 5 分钟安全窗口)
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
result := &AntigravityTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
TokenType: tokenResp.TokenType,
}
// 获取用户信息
userInfo, err := client.GetUserInfo(ctx, tokenResp.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
} else {
result.Email = userInfo.Email
}
// 获取 project_id部分账户类型可能没有
loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
result.ProjectID = loadResp.CloudAICompanionProject
}
// 兜底:随机生成 project_id
if result.ProjectID == "" {
result.ProjectID = antigravity.GenerateMockProjectID()
fmt.Printf("[AntigravityOAuth] 使用随机生成的 project_id: %s\n", result.ProjectID)
}
return result, nil
}
// RefreshToken 刷新 token
func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*AntigravityTokenInfo, error) {
var lastErr error
for attempt := 0; attempt <= 3; attempt++ {
if attempt > 0 {
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
time.Sleep(backoff)
}
client := antigravity.NewClient(proxyURL)
tokenResp, err := client.RefreshToken(ctx, refreshToken)
if err == nil {
now := time.Now()
expiresAt := now.Unix() + tokenResp.ExpiresIn - 300
fmt.Printf("[AntigravityOAuth] Token refreshed: expires_in=%d, expires_at=%d (%s)\n",
tokenResp.ExpiresIn, expiresAt, time.Unix(expiresAt, 0).Format("2006-01-02 15:04:05"))
return &AntigravityTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
TokenType: tokenResp.TokenType,
}, nil
}
if isNonRetryableAntigravityOAuthError(err) {
return nil, err
}
lastErr = err
}
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
}
func isNonRetryableAntigravityOAuthError(err error) bool {
msg := err.Error()
nonRetryable := []string{
"invalid_grant",
"invalid_client",
"unauthorized_client",
"access_denied",
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {
return true
}
}
return false
}
// RefreshAccountToken 刷新账户的 token
func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*AntigravityTokenInfo, error) {
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("非 Antigravity OAuth 账户")
}
refreshToken := account.GetCredential("refresh_token")
if strings.TrimSpace(refreshToken) == "" {
return nil, fmt.Errorf("无可用的 refresh_token")
}
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
if err != nil {
return nil, err
}
// 保留原有的 project_id 和 email
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
if existingProjectID != "" {
tokenInfo.ProjectID = existingProjectID
}
existingEmail := strings.TrimSpace(account.GetCredential("email"))
if existingEmail != "" {
tokenInfo.Email = existingEmail
}
return tokenInfo, nil
}
// BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.TokenType != "" {
creds["token_type"] = tokenInfo.TokenType
}
if tokenInfo.Email != "" {
creds["email"] = tokenInfo.Email
}
if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID
}
return creds
}
// Stop 停止服务
func (s *AntigravityOAuthService) Stop() {
s.sessionStore.Stop()
}