feat: 优化 OAuth 账号导入流程
This commit is contained in:
@@ -52,3 +52,47 @@ func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccess
|
||||
require.Equal(t, "client-id-1", info.ClientID)
|
||||
require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh")
|
||||
}
|
||||
|
||||
func TestOpenAITokenRefresher_NeedsRefresh_SkipsAccountWithoutRefreshToken(t *testing.T) {
|
||||
refresher := NewOpenAITokenRefresher(nil, nil)
|
||||
expiresAt := time.Now().Add(time.Minute).UTC().Format(time.RFC3339)
|
||||
|
||||
withoutRT := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "access-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
require.False(t, refresher.NeedsRefresh(withoutRT, 5*time.Minute))
|
||||
|
||||
withRT := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "access-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
require.True(t, refresher.NeedsRefresh(withRT, 5*time.Minute))
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_NoRefreshTokenExpiredAccessTokenReturnsError(t *testing.T) {
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
expiresAt := time.Now().Add(-time.Minute).UTC().Format(time.RFC3339)
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "expired-access-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Empty(t, token)
|
||||
require.Contains(t, err.Error(), "refresh_token is missing")
|
||||
}
|
||||
|
||||
@@ -152,6 +152,12 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
// 2) Refresh if needed (pre-expiry skew).
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||
if needsRefresh && strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" {
|
||||
if expiresAt != nil && !time.Now().Before(*expiresAt) {
|
||||
return "", errors.New("openai access_token expired and refresh_token is missing")
|
||||
}
|
||||
needsRefresh = false
|
||||
}
|
||||
refreshFailed := false
|
||||
|
||||
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||
|
||||
@@ -424,8 +424,9 @@ func TestOpenAITokenProvider_CacheGetError(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
"access_token": "fallback-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -650,8 +651,9 @@ func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
"access_token": "fallback-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -819,8 +821,9 @@ func TestOpenAITokenProvider_Real_LockRace_PollingHitsCache(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
"access_token": "fallback-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -848,8 +851,9 @@ func TestOpenAITokenProvider_Real_LockRace_ContextCanceled(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
"access_token": "fallback-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -875,8 +879,9 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockWaitHitAndSnapshot(t *testing.T)
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
"access_token": "fallback-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
@@ -911,8 +916,9 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
"access_token": "fallback-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -95,6 +96,9 @@ func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
||||
// NeedsRefresh 检查token是否需要刷新
|
||||
// expires_at 缺失且处于限流状态时需要刷新,防止限流期间 token 静默过期
|
||||
func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
|
||||
if strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" {
|
||||
return false
|
||||
}
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return account.IsRateLimited()
|
||||
|
||||
Reference in New Issue
Block a user