package repository import ( "context" "io" "net/http" "net/http/httptest" "net/url" "testing" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) type OpenAIOAuthServiceSuite struct { suite.Suite ctx context.Context srv *httptest.Server svc *openaiOAuthService received chan url.Values } func (s *OpenAIOAuthServiceSuite) SetupTest() { s.ctx = context.Background() s.received = make(chan url.Values, 1) } func (s *OpenAIOAuthServiceSuite) TearDownTest() { if s.srv != nil { s.srv.Close() s.srv = nil } } func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) { s.srv = newLocalTestServer(s.T(), handler) s.svc = &openaiOAuthService{tokenURL: s.srv.URL} } func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() { errCh := make(chan string, 1) s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { errCh <- "method mismatch" w.WriteHeader(http.StatusBadRequest) return } if err := r.ParseForm(); err != nil { errCh <- "ParseForm failed" w.WriteHeader(http.StatusBadRequest) return } if got := r.PostForm.Get("grant_type"); got != "authorization_code" { errCh <- "grant_type mismatch" w.WriteHeader(http.StatusBadRequest) return } if got := r.PostForm.Get("client_id"); got != openai.ClientID { errCh <- "client_id mismatch" w.WriteHeader(http.StatusBadRequest) return } if got := r.PostForm.Get("code"); got != "code" { errCh <- "code mismatch" w.WriteHeader(http.StatusBadRequest) return } if got := r.PostForm.Get("redirect_uri"); got != openai.DefaultRedirectURI { errCh <- "redirect_uri mismatch" w.WriteHeader(http.StatusBadRequest) return } if got := r.PostForm.Get("code_verifier"); got != "ver" { errCh <- "code_verifier mismatch" w.WriteHeader(http.StatusBadRequest) return } w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`) })) resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "") require.NoError(s.T(), err, "ExchangeCode") select { case msg := <-errCh: require.Fail(s.T(), msg) default: } require.Equal(s.T(), "at", resp.AccessToken) require.Equal(s.T(), "rt", resp.RefreshToken) } func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() { errCh := make(chan string, 1) s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { errCh <- "ParseForm failed" w.WriteHeader(http.StatusBadRequest) return } if got := r.PostForm.Get("grant_type"); got != "refresh_token" { errCh <- "grant_type mismatch" w.WriteHeader(http.StatusBadRequest) return } if got := r.PostForm.Get("refresh_token"); got != "rt" { errCh <- "refresh_token mismatch" w.WriteHeader(http.StatusBadRequest) return } if got := r.PostForm.Get("client_id"); got != openai.ClientID { errCh <- "client_id mismatch" w.WriteHeader(http.StatusBadRequest) return } if got := r.PostForm.Get("scope"); got != openai.RefreshScopes { errCh <- "scope mismatch" w.WriteHeader(http.StatusBadRequest) return } w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, `{"access_token":"at2","refresh_token":"rt2","token_type":"bearer","expires_in":3600}`) })) resp, err := s.svc.RefreshToken(s.ctx, "rt", "") require.NoError(s.T(), err, "RefreshToken") select { case msg := <-errCh: require.Fail(s.T(), msg) default: } require.Equal(s.T(), "at2", resp.AccessToken) require.Equal(s.T(), "rt2", resp.RefreshToken) } func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) _, _ = io.WriteString(w, "bad") })) _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") require.Error(s.T(), err) require.ErrorContains(s.T(), err, "status 400") require.ErrorContains(s.T(), err, "bad") } func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) s.srv.Close() _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") require.Error(s.T(), err) require.ErrorContains(s.T(), err, "request failed") } func (s *OpenAIOAuthServiceSuite) TestContextCancel() { started := make(chan struct{}) block := make(chan struct{}) s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { close(started) <-block })) ctx, cancel := context.WithCancel(s.ctx) done := make(chan error, 1) go func() { _, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "") done <- err }() <-started cancel() close(block) err := <-done require.Error(s.T(), err) } func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() { want := "http://localhost:9999/cb" errCh := make(chan string, 1) s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = r.ParseForm() if got := r.PostForm.Get("redirect_uri"); got != want { errCh <- "redirect_uri mismatch" w.WriteHeader(http.StatusBadRequest) return } w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`) })) _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "") require.NoError(s.T(), err, "ExchangeCode") select { case msg := <-errCh: require.Fail(s.T(), msg) default: } } func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = r.ParseForm() s.received <- r.PostForm w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`) })) s.svc.tokenURL = s.srv.URL + "?x=1" _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") require.NoError(s.T(), err, "ExchangeCode") select { case <-s.received: default: require.Fail(s.T(), "expected server to receive request") } } func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = io.WriteString(w, "not-valid-json") })) _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") require.Error(s.T(), err, "expected error for invalid JSON response") } func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) _, _ = io.WriteString(w, "unauthorized") })) _, err := s.svc.RefreshToken(s.ctx, "rt", "") require.Error(s.T(), err, "expected error for non-2xx status") require.ErrorContains(s.T(), err, "status 401") } func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) { client := NewOpenAIOAuthClient() svc, ok := client.(*openaiOAuthService) require.True(t, ok) require.Equal(t, openai.TokenURL, svc.tokenURL) } func TestOpenAIOAuthServiceSuite(t *testing.T) { suite.Run(t, new(OpenAIOAuthServiceSuite)) }