package service import ( "context" "errors" "sync/atomic" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/stretchr/testify/require" ) type openaiOAuthClientStateStub struct { exchangeCalled int32 } func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { atomic.AddInt32(&s.exchangeCalled, 1) return &openai.TokenResponse{ AccessToken: "at", RefreshToken: "rt", ExpiresIn: 3600, }, nil } func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { return nil, errors.New("not implemented") } func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { return s.RefreshToken(ctx, refreshToken, proxyURL) } func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) { client := &openaiOAuthClientStateStub{} svc := NewOpenAIOAuthService(nil, client) defer svc.Stop() svc.sessionStore.Set("sid", &openai.OAuthSession{ State: "expected-state", CodeVerifier: "verifier", RedirectURI: openai.DefaultRedirectURI, CreatedAt: time.Now(), }) _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ SessionID: "sid", Code: "auth-code", }) require.Error(t, err) require.Contains(t, err.Error(), "oauth state is required") require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) } func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) { client := &openaiOAuthClientStateStub{} svc := NewOpenAIOAuthService(nil, client) defer svc.Stop() svc.sessionStore.Set("sid", &openai.OAuthSession{ State: "expected-state", CodeVerifier: "verifier", RedirectURI: openai.DefaultRedirectURI, CreatedAt: time.Now(), }) _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ SessionID: "sid", Code: "auth-code", State: "wrong-state", }) require.Error(t, err) require.Contains(t, err.Error(), "invalid oauth state") require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) } func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) { client := &openaiOAuthClientStateStub{} svc := NewOpenAIOAuthService(nil, client) defer svc.Stop() svc.sessionStore.Set("sid", &openai.OAuthSession{ State: "expected-state", CodeVerifier: "verifier", RedirectURI: openai.DefaultRedirectURI, CreatedAt: time.Now(), }) info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ SessionID: "sid", Code: "auth-code", State: "expected-state", }) require.NoError(t, err) require.NotNil(t, info) require.Equal(t, "at", info.AccessToken) require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled)) _, ok := svc.sessionStore.Get("sid") require.False(t, ok) }