fix: avoid temp unsched when refresh token is missing
This commit is contained in:
@@ -502,6 +502,25 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
|||||||
|
|
||||||
refreshToken := account.GetCredential("refresh_token")
|
refreshToken := account.GetCredential("refresh_token")
|
||||||
if refreshToken == "" {
|
if refreshToken == "" {
|
||||||
|
accessToken := account.GetCredential("access_token")
|
||||||
|
if accessToken != "" {
|
||||||
|
tokenInfo := &OpenAITokenInfo{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: "",
|
||||||
|
IDToken: account.GetCredential("id_token"),
|
||||||
|
ClientID: account.GetCredential("client_id"),
|
||||||
|
Email: account.GetCredential("email"),
|
||||||
|
ChatGPTAccountID: account.GetCredential("chatgpt_account_id"),
|
||||||
|
ChatGPTUserID: account.GetCredential("chatgpt_user_id"),
|
||||||
|
OrganizationID: account.GetCredential("organization_id"),
|
||||||
|
PlanType: account.GetCredential("plan_type"),
|
||||||
|
}
|
||||||
|
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
|
||||||
|
tokenInfo.ExpiresAt = expiresAt.Unix()
|
||||||
|
tokenInfo.ExpiresIn = int64(time.Until(*expiresAt).Seconds())
|
||||||
|
}
|
||||||
|
return tokenInfo, nil
|
||||||
|
}
|
||||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openaiOAuthClientRefreshStub struct {
|
||||||
|
refreshCalls int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientRefreshStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientRefreshStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
|
atomic.AddInt32(&s.refreshCalls, 1)
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientRefreshStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||||
|
atomic.AddInt32(&s.refreshCalls, 1)
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccessToken(t *testing.T) {
|
||||||
|
client := &openaiOAuthClientRefreshStub{}
|
||||||
|
svc := NewOpenAIOAuthService(nil, client)
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 77,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "existing-access-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
"client_id": "client-id-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := svc.RefreshAccountToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, info)
|
||||||
|
require.Equal(t, "existing-access-token", info.AccessToken)
|
||||||
|
require.Equal(t, "client-id-1", info.ClientID)
|
||||||
|
require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh")
|
||||||
|
}
|
||||||
@@ -430,6 +430,7 @@ func isNonRetryableRefreshError(err error) bool {
|
|||||||
"unauthorized_client", // 客户端未授权
|
"unauthorized_client", // 客户端未授权
|
||||||
"access_denied", // 访问被拒绝
|
"access_denied", // 访问被拒绝
|
||||||
"missing_project_id", // 缺少 project_id
|
"missing_project_id", // 缺少 project_id
|
||||||
|
"no refresh token available",
|
||||||
}
|
}
|
||||||
for _, needle := range nonRetryable {
|
for _, needle := range nonRetryable {
|
||||||
if strings.Contains(msg, needle) {
|
if strings.Contains(msg, needle) {
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ type tokenRefreshAccountRepo struct {
|
|||||||
updateCredentialsCalls int
|
updateCredentialsCalls int
|
||||||
setErrorCalls int
|
setErrorCalls int
|
||||||
clearTempCalls int
|
clearTempCalls int
|
||||||
|
setTempUnschedCalls int
|
||||||
lastAccount *Account
|
lastAccount *Account
|
||||||
updateErr error
|
updateErr error
|
||||||
}
|
}
|
||||||
@@ -58,6 +59,11 @@ func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *tokenRefreshAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
|
r.setTempUnschedCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type tokenCacheInvalidatorStub struct {
|
type tokenCacheInvalidatorStub struct {
|
||||||
calls int
|
calls int
|
||||||
err error
|
err error
|
||||||
@@ -490,6 +496,31 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedule(t *testing.T) {
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 2,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 18,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
err: errors.New("no refresh token available"),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, 0, repo.updateCalls)
|
||||||
|
require.Equal(t, 0, repo.setTempUnschedCalls, "missing refresh token should not mark the account temp unschedulable")
|
||||||
|
require.Equal(t, 1, repo.setErrorCalls, "missing refresh token should be treated as a non-retryable credential state")
|
||||||
|
}
|
||||||
|
|
||||||
// TestIsNonRetryableRefreshError 测试不可重试错误判断
|
// TestIsNonRetryableRefreshError 测试不可重试错误判断
|
||||||
func TestIsNonRetryableRefreshError(t *testing.T) {
|
func TestIsNonRetryableRefreshError(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -503,6 +534,7 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
|
|||||||
{name: "invalid_client", err: errors.New("invalid_client"), expected: true},
|
{name: "invalid_client", err: errors.New("invalid_client"), expected: true},
|
||||||
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
|
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
|
||||||
{name: "access_denied", err: errors.New("access_denied"), expected: true},
|
{name: "access_denied", err: errors.New("access_denied"), expected: true},
|
||||||
|
{name: "no_refresh_token", err: errors.New("no refresh token available"), expected: true},
|
||||||
{name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true},
|
{name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true},
|
||||||
{name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true},
|
{name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true},
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user