diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index 0a1266d9..0f004b01 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -502,6 +502,25 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A refreshToken := account.GetCredential("refresh_token") if refreshToken == "" { + accessToken := account.GetCredential("access_token") + if accessToken != "" { + tokenInfo := &OpenAITokenInfo{ + AccessToken: accessToken, + RefreshToken: "", + IDToken: account.GetCredential("id_token"), + ClientID: account.GetCredential("client_id"), + Email: account.GetCredential("email"), + ChatGPTAccountID: account.GetCredential("chatgpt_account_id"), + ChatGPTUserID: account.GetCredential("chatgpt_user_id"), + OrganizationID: account.GetCredential("organization_id"), + PlanType: account.GetCredential("plan_type"), + } + if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil { + tokenInfo.ExpiresAt = expiresAt.Unix() + tokenInfo.ExpiresIn = int64(time.Until(*expiresAt).Seconds()) + } + return tokenInfo, nil + } return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") } diff --git a/backend/internal/service/openai_oauth_service_refresh_test.go b/backend/internal/service/openai_oauth_service_refresh_test.go new file mode 100644 index 00000000..a31eb8cb --- /dev/null +++ b/backend/internal/service/openai_oauth_service_refresh_test.go @@ -0,0 +1,54 @@ +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientRefreshStub struct { + refreshCalls int32 +} + +func (s *openaiOAuthClientRefreshStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientRefreshStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + atomic.AddInt32(&s.refreshCalls, 1) + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientRefreshStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + atomic.AddInt32(&s.refreshCalls, 1) + return nil, errors.New("not implemented") +} + +func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccessToken(t *testing.T) { + client := &openaiOAuthClientRefreshStub{} + svc := NewOpenAIOAuthService(nil, client) + + expiresAt := time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339) + account := &Account{ + ID: 77, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "existing-access-token", + "expires_at": expiresAt, + "client_id": "client-id-1", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "existing-access-token", info.AccessToken) + require.Equal(t, "client-id-1", info.ClientID) + require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh") +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index eb3e5592..f521c972 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -430,6 +430,7 @@ func isNonRetryableRefreshError(err error) bool { "unauthorized_client", // 客户端未授权 "access_denied", // 访问被拒绝 "missing_project_id", // 缺少 project_id + "no refresh token available", } for _, needle := range nonRetryable { if strings.Contains(msg, needle) { diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go index 60ba4a96..2179a85e 100644 --- a/backend/internal/service/token_refresh_service_test.go +++ b/backend/internal/service/token_refresh_service_test.go @@ -19,6 +19,7 @@ type tokenRefreshAccountRepo struct { updateCredentialsCalls int setErrorCalls int clearTempCalls int + setTempUnschedCalls int lastAccount *Account updateErr error } @@ -58,6 +59,11 @@ func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id return nil } +func (r *tokenRefreshAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + r.setTempUnschedCalls++ + return nil +} + type tokenCacheInvalidatorStub struct { calls int err error @@ -490,6 +496,31 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t } } +func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedule(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 2, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil) + account := &Account{ + ID: 18, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + err: errors.New("no refresh token available"), + } + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.Error(t, err) + require.Equal(t, 0, repo.updateCalls) + require.Equal(t, 0, repo.setTempUnschedCalls, "missing refresh token should not mark the account temp unschedulable") + require.Equal(t, 1, repo.setErrorCalls, "missing refresh token should be treated as a non-retryable credential state") +} + // TestIsNonRetryableRefreshError 测试不可重试错误判断 func TestIsNonRetryableRefreshError(t *testing.T) { tests := []struct { @@ -503,6 +534,7 @@ func TestIsNonRetryableRefreshError(t *testing.T) { {name: "invalid_client", err: errors.New("invalid_client"), expected: true}, {name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true}, {name: "access_denied", err: errors.New("access_denied"), expected: true}, + {name: "no_refresh_token", err: errors.New("no refresh token available"), expected: true}, {name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true}, {name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true}, }