394 lines
13 KiB
Go
394 lines
13 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"crypto/subtle"
|
||
"log/slog"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||
)
|
||
|
||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||
type OpenAIOAuthService struct {
|
||
sessionStore *openai.SessionStore
|
||
proxyRepo ProxyRepository
|
||
oauthClient OpenAIOAuthClient
|
||
privacyClientFactory PrivacyClientFactory // 用于调用 chatgpt.com/backend-api(ImpersonateChrome)
|
||
}
|
||
|
||
// NewOpenAIOAuthService creates a new OpenAI OAuth service
|
||
func NewOpenAIOAuthService(proxyRepo ProxyRepository, oauthClient OpenAIOAuthClient) *OpenAIOAuthService {
|
||
return &OpenAIOAuthService{
|
||
sessionStore: openai.NewSessionStore(),
|
||
proxyRepo: proxyRepo,
|
||
oauthClient: oauthClient,
|
||
}
|
||
}
|
||
|
||
// SetPrivacyClientFactory 注入 ImpersonateChrome 客户端工厂,
|
||
// 用于调用 chatgpt.com/backend-api 获取账号信息(plan_type 等)。
|
||
func (s *OpenAIOAuthService) SetPrivacyClientFactory(factory PrivacyClientFactory) {
|
||
s.privacyClientFactory = factory
|
||
}
|
||
|
||
// OpenAIAuthURLResult contains the authorization URL and session info
|
||
type OpenAIAuthURLResult struct {
|
||
AuthURL string `json:"auth_url"`
|
||
SessionID string `json:"session_id"`
|
||
}
|
||
|
||
// GenerateAuthURL generates an OpenAI OAuth authorization URL
|
||
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, platform string) (*OpenAIAuthURLResult, error) {
|
||
// Generate PKCE values
|
||
state, err := openai.GenerateState()
|
||
if err != nil {
|
||
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_STATE_FAILED", "failed to generate state: %v", err)
|
||
}
|
||
|
||
codeVerifier, err := openai.GenerateCodeVerifier()
|
||
if err != nil {
|
||
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_VERIFIER_FAILED", "failed to generate code verifier: %v", err)
|
||
}
|
||
|
||
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
|
||
|
||
// Generate session ID
|
||
sessionID, err := openai.GenerateSessionID()
|
||
if err != nil {
|
||
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_SESSION_FAILED", "failed to generate session ID: %v", err)
|
||
}
|
||
|
||
// Get proxy URL if specified
|
||
var proxyURL string
|
||
if proxyID != nil {
|
||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||
if err != nil {
|
||
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
||
}
|
||
if proxy != nil {
|
||
proxyURL = proxy.URL()
|
||
}
|
||
}
|
||
|
||
// Use default redirect URI if not specified
|
||
if redirectURI == "" {
|
||
redirectURI = openai.DefaultRedirectURI
|
||
}
|
||
normalizedPlatform := normalizeOpenAIOAuthPlatform(platform)
|
||
clientID, _ := openai.OAuthClientConfigByPlatform(normalizedPlatform)
|
||
|
||
// Store session
|
||
session := &openai.OAuthSession{
|
||
State: state,
|
||
CodeVerifier: codeVerifier,
|
||
ClientID: clientID,
|
||
RedirectURI: redirectURI,
|
||
ProxyURL: proxyURL,
|
||
CreatedAt: time.Now(),
|
||
}
|
||
s.sessionStore.Set(sessionID, session)
|
||
|
||
// Build authorization URL
|
||
authURL := openai.BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, normalizedPlatform)
|
||
|
||
return &OpenAIAuthURLResult{
|
||
AuthURL: authURL,
|
||
SessionID: sessionID,
|
||
}, nil
|
||
}
|
||
|
||
// OpenAIExchangeCodeInput represents the input for code exchange
|
||
type OpenAIExchangeCodeInput struct {
|
||
SessionID string
|
||
Code string
|
||
State string
|
||
RedirectURI string
|
||
ProxyID *int64
|
||
}
|
||
|
||
// OpenAITokenInfo represents the token information for OpenAI
|
||
type OpenAITokenInfo struct {
|
||
AccessToken string `json:"access_token"`
|
||
RefreshToken string `json:"refresh_token"`
|
||
IDToken string `json:"id_token,omitempty"`
|
||
ExpiresIn int64 `json:"expires_in"`
|
||
ExpiresAt int64 `json:"expires_at"`
|
||
ClientID string `json:"client_id,omitempty"`
|
||
Email string `json:"email,omitempty"`
|
||
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
|
||
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
|
||
OrganizationID string `json:"organization_id,omitempty"`
|
||
PlanType string `json:"plan_type,omitempty"`
|
||
SubscriptionExpiresAt string `json:"subscription_expires_at,omitempty"`
|
||
PrivacyMode string `json:"privacy_mode,omitempty"`
|
||
}
|
||
|
||
// ExchangeCode exchanges authorization code for tokens
|
||
func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExchangeCodeInput) (*OpenAITokenInfo, error) {
|
||
// Get session
|
||
session, ok := s.sessionStore.Get(input.SessionID)
|
||
if !ok {
|
||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
|
||
}
|
||
if input.State == "" {
|
||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required")
|
||
}
|
||
if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 {
|
||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state")
|
||
}
|
||
|
||
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
|
||
proxyURL := session.ProxyURL
|
||
if input.ProxyID != nil {
|
||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||
if err != nil {
|
||
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
||
}
|
||
if proxy != nil {
|
||
proxyURL = proxy.URL()
|
||
}
|
||
}
|
||
|
||
// Use redirect URI from session or input
|
||
redirectURI := session.RedirectURI
|
||
if input.RedirectURI != "" {
|
||
redirectURI = input.RedirectURI
|
||
}
|
||
clientID := strings.TrimSpace(session.ClientID)
|
||
if clientID == "" {
|
||
clientID = openai.ClientID
|
||
}
|
||
|
||
// Exchange code for token
|
||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL, clientID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Parse ID token to get user info
|
||
var userInfo *openai.UserInfo
|
||
if tokenResp.IDToken != "" {
|
||
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
|
||
if parseErr != nil {
|
||
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
|
||
} else {
|
||
userInfo = claims.GetUserInfo()
|
||
}
|
||
}
|
||
|
||
// Delete session after successful exchange
|
||
s.sessionStore.Delete(input.SessionID)
|
||
|
||
tokenInfo := &OpenAITokenInfo{
|
||
AccessToken: tokenResp.AccessToken,
|
||
RefreshToken: tokenResp.RefreshToken,
|
||
IDToken: tokenResp.IDToken,
|
||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||
ClientID: clientID,
|
||
}
|
||
|
||
if userInfo != nil {
|
||
tokenInfo.Email = userInfo.Email
|
||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||
tokenInfo.PlanType = userInfo.PlanType
|
||
}
|
||
|
||
s.enrichTokenInfo(ctx, tokenInfo, proxyURL)
|
||
|
||
return tokenInfo, nil
|
||
}
|
||
|
||
// RefreshToken refreshes an OpenAI OAuth token
|
||
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
|
||
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
||
}
|
||
|
||
// RefreshTokenWithClientID refreshes an OpenAI OAuth token with optional client_id.
|
||
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
|
||
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Parse ID token to get user info
|
||
var userInfo *openai.UserInfo
|
||
if tokenResp.IDToken != "" {
|
||
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
|
||
if parseErr != nil {
|
||
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
|
||
} else {
|
||
userInfo = claims.GetUserInfo()
|
||
}
|
||
}
|
||
|
||
tokenInfo := &OpenAITokenInfo{
|
||
AccessToken: tokenResp.AccessToken,
|
||
RefreshToken: tokenResp.RefreshToken,
|
||
IDToken: tokenResp.IDToken,
|
||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||
}
|
||
if trimmed := strings.TrimSpace(clientID); trimmed != "" {
|
||
tokenInfo.ClientID = trimmed
|
||
}
|
||
|
||
if userInfo != nil {
|
||
tokenInfo.Email = userInfo.Email
|
||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||
tokenInfo.PlanType = userInfo.PlanType
|
||
}
|
||
|
||
s.enrichTokenInfo(ctx, tokenInfo, proxyURL)
|
||
|
||
return tokenInfo, nil
|
||
}
|
||
|
||
// enrichTokenInfo 通过 ChatGPT backend-api 补全 tokenInfo 并设置隐私(best-effort)。
|
||
// 从 accounts/check 获取最新 plan_type、subscription_expires_at、email,
|
||
// 然后尝试关闭训练数据共享。适用于所有获取/刷新 token 的路径。
|
||
func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *OpenAITokenInfo, proxyURL string) {
|
||
if tokenInfo.AccessToken == "" || s.privacyClientFactory == nil {
|
||
return
|
||
}
|
||
|
||
// 从 access_token JWT 中提取 orgID(poid),用于匹配正确的账号
|
||
orgID := tokenInfo.OrganizationID
|
||
if orgID == "" {
|
||
if atClaims, err := openai.DecodeIDToken(tokenInfo.AccessToken); err == nil && atClaims.OpenAIAuth != nil {
|
||
orgID = atClaims.OpenAIAuth.POID
|
||
}
|
||
}
|
||
if info := fetchChatGPTAccountInfo(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, orgID); info != nil {
|
||
if info.PlanType != "" {
|
||
tokenInfo.PlanType = info.PlanType
|
||
}
|
||
if info.SubscriptionExpiresAt != "" {
|
||
tokenInfo.SubscriptionExpiresAt = info.SubscriptionExpiresAt
|
||
}
|
||
if tokenInfo.Email == "" && info.Email != "" {
|
||
tokenInfo.Email = info.Email
|
||
}
|
||
}
|
||
|
||
// 尝试设置隐私(关闭训练数据共享),best-effort
|
||
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
|
||
}
|
||
|
||
// RefreshAccountToken refreshes token for an OpenAI OAuth account
|
||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||
if account.Platform != PlatformOpenAI {
|
||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
|
||
}
|
||
if account.Type != AccountTypeOAuth {
|
||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
|
||
}
|
||
|
||
refreshToken := account.GetCredential("refresh_token")
|
||
if refreshToken == "" {
|
||
accessToken := account.GetCredential("access_token")
|
||
if accessToken != "" {
|
||
tokenInfo := &OpenAITokenInfo{
|
||
AccessToken: accessToken,
|
||
RefreshToken: "",
|
||
IDToken: account.GetCredential("id_token"),
|
||
ClientID: account.GetCredential("client_id"),
|
||
Email: account.GetCredential("email"),
|
||
ChatGPTAccountID: account.GetCredential("chatgpt_account_id"),
|
||
ChatGPTUserID: account.GetCredential("chatgpt_user_id"),
|
||
OrganizationID: account.GetCredential("organization_id"),
|
||
PlanType: account.GetCredential("plan_type"),
|
||
}
|
||
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
|
||
tokenInfo.ExpiresAt = expiresAt.Unix()
|
||
tokenInfo.ExpiresIn = int64(time.Until(*expiresAt).Seconds())
|
||
}
|
||
return tokenInfo, nil
|
||
}
|
||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
|
||
}
|
||
|
||
var proxyURL string
|
||
if account.ProxyID != nil {
|
||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||
if err == nil && proxy != nil {
|
||
proxyURL = proxy.URL()
|
||
}
|
||
}
|
||
|
||
clientID := account.GetCredential("client_id")
|
||
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||
}
|
||
|
||
// BuildAccountCredentials builds credentials map from token info
|
||
func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) map[string]any {
|
||
expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339)
|
||
|
||
creds := map[string]any{
|
||
"access_token": tokenInfo.AccessToken,
|
||
"expires_at": expiresAt,
|
||
}
|
||
// 仅在刷新响应返回了新的 refresh_token 时才更新,防止用空值覆盖已有令牌
|
||
if strings.TrimSpace(tokenInfo.RefreshToken) != "" {
|
||
creds["refresh_token"] = tokenInfo.RefreshToken
|
||
}
|
||
|
||
if tokenInfo.IDToken != "" {
|
||
creds["id_token"] = tokenInfo.IDToken
|
||
}
|
||
if tokenInfo.Email != "" {
|
||
creds["email"] = tokenInfo.Email
|
||
}
|
||
if tokenInfo.ChatGPTAccountID != "" {
|
||
creds["chatgpt_account_id"] = tokenInfo.ChatGPTAccountID
|
||
}
|
||
if tokenInfo.ChatGPTUserID != "" {
|
||
creds["chatgpt_user_id"] = tokenInfo.ChatGPTUserID
|
||
}
|
||
if tokenInfo.OrganizationID != "" {
|
||
creds["organization_id"] = tokenInfo.OrganizationID
|
||
}
|
||
if tokenInfo.PlanType != "" {
|
||
creds["plan_type"] = tokenInfo.PlanType
|
||
}
|
||
if tokenInfo.SubscriptionExpiresAt != "" {
|
||
creds["subscription_expires_at"] = tokenInfo.SubscriptionExpiresAt
|
||
}
|
||
if strings.TrimSpace(tokenInfo.ClientID) != "" {
|
||
creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
|
||
}
|
||
|
||
return creds
|
||
}
|
||
|
||
// Stop stops the session store cleanup goroutine
|
||
func (s *OpenAIOAuthService) Stop() {
|
||
s.sessionStore.Stop()
|
||
}
|
||
|
||
func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
|
||
if proxyID == nil {
|
||
return "", nil
|
||
}
|
||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||
if err != nil {
|
||
return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
||
}
|
||
if proxy == nil {
|
||
return "", nil
|
||
}
|
||
return proxy.URL(), nil
|
||
}
|
||
|
||
func normalizeOpenAIOAuthPlatform(platform string) string {
|
||
return openai.OAuthPlatformOpenAI
|
||
}
|