feat(sync): full code sync from release
This commit is contained in:
@@ -5,8 +5,12 @@ import (
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -16,6 +20,13 @@ import (
|
||||
|
||||
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
|
||||
|
||||
var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`)
|
||||
|
||||
type soraSessionChunk struct {
|
||||
index int
|
||||
value string
|
||||
}
|
||||
|
||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||
type OpenAIOAuthService struct {
|
||||
sessionStore *openai.SessionStore
|
||||
@@ -39,7 +50,7 @@ type OpenAIAuthURLResult struct {
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates an OpenAI OAuth authorization URL
|
||||
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) {
|
||||
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, platform string) (*OpenAIAuthURLResult, error) {
|
||||
// Generate PKCE values
|
||||
state, err := openai.GenerateState()
|
||||
if err != nil {
|
||||
@@ -75,11 +86,14 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
}
|
||||
normalizedPlatform := normalizeOpenAIOAuthPlatform(platform)
|
||||
clientID, _ := openai.OAuthClientConfigByPlatform(normalizedPlatform)
|
||||
|
||||
// Store session
|
||||
session := &openai.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
@@ -87,7 +101,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
// Build authorization URL
|
||||
authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI)
|
||||
authURL := openai.BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, normalizedPlatform)
|
||||
|
||||
return &OpenAIAuthURLResult{
|
||||
AuthURL: authURL,
|
||||
@@ -111,6 +125,7 @@ type OpenAITokenInfo struct {
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
|
||||
@@ -148,9 +163,13 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
if input.RedirectURI != "" {
|
||||
redirectURI = input.RedirectURI
|
||||
}
|
||||
clientID := strings.TrimSpace(session.ClientID)
|
||||
if clientID == "" {
|
||||
clientID = openai.ClientID
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -158,8 +177,10 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if parseErr != nil {
|
||||
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
|
||||
} else {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
@@ -173,6 +194,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
IDToken: tokenResp.IDToken,
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
ClientID: clientID,
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
@@ -200,8 +222,10 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if parseErr != nil {
|
||||
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
|
||||
} else {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
@@ -213,6 +237,9 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
}
|
||||
if trimmed := strings.TrimSpace(clientID); trimmed != "" {
|
||||
tokenInfo.ClientID = trimmed
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
tokenInfo.Email = userInfo.Email
|
||||
@@ -226,6 +253,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
|
||||
|
||||
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
|
||||
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
|
||||
sessionToken = normalizeSoraSessionTokenInput(sessionToken)
|
||||
if strings.TrimSpace(sessionToken) == "" {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
|
||||
}
|
||||
@@ -287,10 +315,141 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi
|
||||
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
|
||||
ExpiresIn: expiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
ClientID: openai.SoraClientID,
|
||||
Email: strings.TrimSpace(sessionResp.User.Email),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeSoraSessionTokenInput(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1)
|
||||
if len(matches) == 0 {
|
||||
return sanitizeSessionToken(trimmed)
|
||||
}
|
||||
|
||||
chunkMatches := make([]soraSessionChunk, 0, len(matches))
|
||||
singleValues := make([]string, 0, len(matches))
|
||||
|
||||
for _, match := range matches {
|
||||
if len(match) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
value := sanitizeSessionToken(match[2])
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.TrimSpace(match[1]) == "" {
|
||||
singleValues = append(singleValues, value)
|
||||
continue
|
||||
}
|
||||
|
||||
idx, err := strconv.Atoi(strings.TrimSpace(match[1]))
|
||||
if err != nil || idx < 0 {
|
||||
continue
|
||||
}
|
||||
chunkMatches = append(chunkMatches, soraSessionChunk{
|
||||
index: idx,
|
||||
value: value,
|
||||
})
|
||||
}
|
||||
|
||||
if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" {
|
||||
return merged
|
||||
}
|
||||
|
||||
if len(singleValues) > 0 {
|
||||
return singleValues[len(singleValues)-1]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string {
|
||||
if len(chunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
byIndex := make(map[int]string, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
byIndex[chunk.index] = chunk.value
|
||||
}
|
||||
|
||||
if _, ok := byIndex[0]; !ok {
|
||||
return ""
|
||||
}
|
||||
if requireComplete {
|
||||
for idx := 0; idx <= requiredMaxIndex; idx++ {
|
||||
if _, ok := byIndex[idx]; !ok {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
orderedIndexes := make([]int, 0, len(byIndex))
|
||||
for idx := range byIndex {
|
||||
orderedIndexes = append(orderedIndexes, idx)
|
||||
}
|
||||
sort.Ints(orderedIndexes)
|
||||
|
||||
var builder strings.Builder
|
||||
for _, idx := range orderedIndexes {
|
||||
if _, err := builder.WriteString(byIndex[idx]); err != nil {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return sanitizeSessionToken(builder.String())
|
||||
}
|
||||
|
||||
func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string {
|
||||
if len(chunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
requiredMaxIndex := 0
|
||||
for _, chunk := range chunks {
|
||||
if chunk.index > requiredMaxIndex {
|
||||
requiredMaxIndex = chunk.index
|
||||
}
|
||||
}
|
||||
|
||||
groupStarts := make([]int, 0, len(chunks))
|
||||
for idx, chunk := range chunks {
|
||||
if chunk.index == 0 {
|
||||
groupStarts = append(groupStarts, idx)
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupStarts) == 0 {
|
||||
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
|
||||
}
|
||||
|
||||
for i := len(groupStarts) - 1; i >= 0; i-- {
|
||||
start := groupStarts[i]
|
||||
end := len(chunks)
|
||||
if i+1 < len(groupStarts) {
|
||||
end = groupStarts[i+1]
|
||||
}
|
||||
if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" {
|
||||
return merged
|
||||
}
|
||||
}
|
||||
|
||||
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
|
||||
}
|
||||
|
||||
func sanitizeSessionToken(raw string) string {
|
||||
token := strings.TrimSpace(raw)
|
||||
token = strings.Trim(token, "\"'`")
|
||||
token = strings.TrimSuffix(token, ";")
|
||||
return strings.TrimSpace(token)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
|
||||
@@ -322,9 +481,12 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
|
||||
expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339)
|
||||
|
||||
creds := map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"refresh_token": tokenInfo.RefreshToken,
|
||||
"expires_at": expiresAt,
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"expires_at": expiresAt,
|
||||
}
|
||||
// 仅在刷新响应返回了新的 refresh_token 时才更新,防止用空值覆盖已有令牌
|
||||
if strings.TrimSpace(tokenInfo.RefreshToken) != "" {
|
||||
creds["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
|
||||
if tokenInfo.IDToken != "" {
|
||||
@@ -342,6 +504,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
|
||||
if tokenInfo.OrganizationID != "" {
|
||||
creds["organization_id"] = tokenInfo.OrganizationID
|
||||
}
|
||||
if strings.TrimSpace(tokenInfo.ClientID) != "" {
|
||||
creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
|
||||
}
|
||||
|
||||
return creds
|
||||
}
|
||||
@@ -377,3 +542,12 @@ func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAIOAuthPlatform(platform string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(platform)) {
|
||||
case PlatformSora:
|
||||
return openai.OAuthPlatformSora
|
||||
default:
|
||||
return openai.OAuthPlatformOpenAI
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user