256 lines
7.1 KiB
Go
256 lines
7.1 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
|
)
|
|
|
|
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
|
type OpenAIOAuthService struct {
|
|
sessionStore *openai.SessionStore
|
|
proxyRepo ProxyRepository
|
|
oauthClient OpenAIOAuthClient
|
|
}
|
|
|
|
// NewOpenAIOAuthService creates a new OpenAI OAuth service
|
|
func NewOpenAIOAuthService(proxyRepo ProxyRepository, oauthClient OpenAIOAuthClient) *OpenAIOAuthService {
|
|
return &OpenAIOAuthService{
|
|
sessionStore: openai.NewSessionStore(),
|
|
proxyRepo: proxyRepo,
|
|
oauthClient: oauthClient,
|
|
}
|
|
}
|
|
|
|
// OpenAIAuthURLResult contains the authorization URL and session info
|
|
type OpenAIAuthURLResult struct {
|
|
AuthURL string `json:"auth_url"`
|
|
SessionID string `json:"session_id"`
|
|
}
|
|
|
|
// GenerateAuthURL generates an OpenAI OAuth authorization URL
|
|
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) {
|
|
// Generate PKCE values
|
|
state, err := openai.GenerateState()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate state: %w", err)
|
|
}
|
|
|
|
codeVerifier, err := openai.GenerateCodeVerifier()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
|
}
|
|
|
|
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
|
|
|
|
// Generate session ID
|
|
sessionID, err := openai.GenerateSessionID()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate session ID: %w", err)
|
|
}
|
|
|
|
// Get proxy URL if specified
|
|
var proxyURL string
|
|
if proxyID != nil {
|
|
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
|
if err == nil && proxy != nil {
|
|
proxyURL = proxy.URL()
|
|
}
|
|
}
|
|
|
|
// Use default redirect URI if not specified
|
|
if redirectURI == "" {
|
|
redirectURI = openai.DefaultRedirectURI
|
|
}
|
|
|
|
// Store session
|
|
session := &openai.OAuthSession{
|
|
State: state,
|
|
CodeVerifier: codeVerifier,
|
|
RedirectURI: redirectURI,
|
|
ProxyURL: proxyURL,
|
|
CreatedAt: time.Now(),
|
|
}
|
|
s.sessionStore.Set(sessionID, session)
|
|
|
|
// Build authorization URL
|
|
authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI)
|
|
|
|
return &OpenAIAuthURLResult{
|
|
AuthURL: authURL,
|
|
SessionID: sessionID,
|
|
}, nil
|
|
}
|
|
|
|
// OpenAIExchangeCodeInput represents the input for code exchange
|
|
type OpenAIExchangeCodeInput struct {
|
|
SessionID string
|
|
Code string
|
|
RedirectURI string
|
|
ProxyID *int64
|
|
}
|
|
|
|
// OpenAITokenInfo represents the token information for OpenAI
|
|
type OpenAITokenInfo struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
IDToken string `json:"id_token,omitempty"`
|
|
ExpiresIn int64 `json:"expires_in"`
|
|
ExpiresAt int64 `json:"expires_at"`
|
|
Email string `json:"email,omitempty"`
|
|
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
|
|
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
|
|
OrganizationID string `json:"organization_id,omitempty"`
|
|
}
|
|
|
|
// ExchangeCode exchanges authorization code for tokens
|
|
func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExchangeCodeInput) (*OpenAITokenInfo, error) {
|
|
// Get session
|
|
session, ok := s.sessionStore.Get(input.SessionID)
|
|
if !ok {
|
|
return nil, fmt.Errorf("session not found or expired")
|
|
}
|
|
|
|
// Get proxy URL
|
|
proxyURL := session.ProxyURL
|
|
if input.ProxyID != nil {
|
|
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
|
if err == nil && proxy != nil {
|
|
proxyURL = proxy.URL()
|
|
}
|
|
}
|
|
|
|
// Use redirect URI from session or input
|
|
redirectURI := session.RedirectURI
|
|
if input.RedirectURI != "" {
|
|
redirectURI = input.RedirectURI
|
|
}
|
|
|
|
// Exchange code for token
|
|
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
|
}
|
|
|
|
// Parse ID token to get user info
|
|
var userInfo *openai.UserInfo
|
|
if tokenResp.IDToken != "" {
|
|
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
|
if err == nil {
|
|
userInfo = claims.GetUserInfo()
|
|
}
|
|
}
|
|
|
|
// Delete session after successful exchange
|
|
s.sessionStore.Delete(input.SessionID)
|
|
|
|
tokenInfo := &OpenAITokenInfo{
|
|
AccessToken: tokenResp.AccessToken,
|
|
RefreshToken: tokenResp.RefreshToken,
|
|
IDToken: tokenResp.IDToken,
|
|
ExpiresIn: int64(tokenResp.ExpiresIn),
|
|
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
|
}
|
|
|
|
if userInfo != nil {
|
|
tokenInfo.Email = userInfo.Email
|
|
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
|
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
|
tokenInfo.OrganizationID = userInfo.OrganizationID
|
|
}
|
|
|
|
return tokenInfo, nil
|
|
}
|
|
|
|
// RefreshToken refreshes an OpenAI OAuth token
|
|
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
|
|
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse ID token to get user info
|
|
var userInfo *openai.UserInfo
|
|
if tokenResp.IDToken != "" {
|
|
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
|
if err == nil {
|
|
userInfo = claims.GetUserInfo()
|
|
}
|
|
}
|
|
|
|
tokenInfo := &OpenAITokenInfo{
|
|
AccessToken: tokenResp.AccessToken,
|
|
RefreshToken: tokenResp.RefreshToken,
|
|
IDToken: tokenResp.IDToken,
|
|
ExpiresIn: int64(tokenResp.ExpiresIn),
|
|
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
|
}
|
|
|
|
if userInfo != nil {
|
|
tokenInfo.Email = userInfo.Email
|
|
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
|
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
|
tokenInfo.OrganizationID = userInfo.OrganizationID
|
|
}
|
|
|
|
return tokenInfo, nil
|
|
}
|
|
|
|
// RefreshAccountToken refreshes token for an OpenAI account
|
|
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
|
if !account.IsOpenAI() {
|
|
return nil, fmt.Errorf("account is not an OpenAI account")
|
|
}
|
|
|
|
refreshToken := account.GetOpenAIRefreshToken()
|
|
if refreshToken == "" {
|
|
return nil, fmt.Errorf("no refresh token available")
|
|
}
|
|
|
|
var proxyURL string
|
|
if account.ProxyID != nil {
|
|
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
|
if err == nil && proxy != nil {
|
|
proxyURL = proxy.URL()
|
|
}
|
|
}
|
|
|
|
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
|
}
|
|
|
|
// BuildAccountCredentials builds credentials map from token info
|
|
func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) map[string]any {
|
|
expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339)
|
|
|
|
creds := map[string]any{
|
|
"access_token": tokenInfo.AccessToken,
|
|
"refresh_token": tokenInfo.RefreshToken,
|
|
"expires_at": expiresAt,
|
|
}
|
|
|
|
if tokenInfo.IDToken != "" {
|
|
creds["id_token"] = tokenInfo.IDToken
|
|
}
|
|
if tokenInfo.Email != "" {
|
|
creds["email"] = tokenInfo.Email
|
|
}
|
|
if tokenInfo.ChatGPTAccountID != "" {
|
|
creds["chatgpt_account_id"] = tokenInfo.ChatGPTAccountID
|
|
}
|
|
if tokenInfo.ChatGPTUserID != "" {
|
|
creds["chatgpt_user_id"] = tokenInfo.ChatGPTUserID
|
|
}
|
|
if tokenInfo.OrganizationID != "" {
|
|
creds["organization_id"] = tokenInfo.OrganizationID
|
|
}
|
|
|
|
return creds
|
|
}
|
|
|
|
// Stop stops the session store cleanup goroutine
|
|
func (s *OpenAIOAuthService) Stop() {
|
|
s.sessionStore.Stop()
|
|
}
|