为共享 req 客户端增加 HTTP/2 选项与缓存隔离 OpenAI OAuth 超时提升到 120s,并按协议控制强制 新增客户端池与 OAuth 客户端单测覆盖 修复 usage cleanup 相关 errcheck/ineffassign/staticcheck 并统一格式 测试: make test
257 lines
7.4 KiB
Go
257 lines
7.4 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"testing"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/stretchr/testify/suite"
|
|
)
|
|
|
|
type OpenAIOAuthServiceSuite struct {
|
|
suite.Suite
|
|
ctx context.Context
|
|
srv *httptest.Server
|
|
svc *openaiOAuthService
|
|
received chan url.Values
|
|
}
|
|
|
|
func (s *OpenAIOAuthServiceSuite) SetupTest() {
|
|
s.ctx = context.Background()
|
|
s.received = make(chan url.Values, 1)
|
|
}
|
|
|
|
func (s *OpenAIOAuthServiceSuite) TearDownTest() {
|
|
if s.srv != nil {
|
|
s.srv.Close()
|
|
s.srv = nil
|
|
}
|
|
}
|
|
|
|
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
|
|
s.srv = newLocalTestServer(s.T(), handler)
|
|
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
|
|
}
|
|
|
|
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
|
|
errCh := make(chan string, 1)
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
errCh <- "method mismatch"
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if err := r.ParseForm(); err != nil {
|
|
errCh <- "ParseForm failed"
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if got := r.PostForm.Get("grant_type"); got != "authorization_code" {
|
|
errCh <- "grant_type mismatch"
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
|
|
errCh <- "client_id mismatch"
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if got := r.PostForm.Get("code"); got != "code" {
|
|
errCh <- "code mismatch"
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if got := r.PostForm.Get("redirect_uri"); got != openai.DefaultRedirectURI {
|
|
errCh <- "redirect_uri mismatch"
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if got := r.PostForm.Get("code_verifier"); got != "ver" {
|
|
errCh <- "code_verifier mismatch"
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
|
|
}))
|
|
|
|
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
|
|
require.NoError(s.T(), err, "ExchangeCode")
|
|
select {
|
|
case msg := <-errCh:
|
|
require.Fail(s.T(), msg)
|
|
default:
|
|
}
|
|
require.Equal(s.T(), "at", resp.AccessToken)
|
|
require.Equal(s.T(), "rt", resp.RefreshToken)
|
|
}
|
|
|
|
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
|
|
errCh := make(chan string, 1)
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
errCh <- "ParseForm failed"
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if got := r.PostForm.Get("grant_type"); got != "refresh_token" {
|
|
errCh <- "grant_type mismatch"
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if got := r.PostForm.Get("refresh_token"); got != "rt" {
|
|
errCh <- "refresh_token mismatch"
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
|
|
errCh <- "client_id mismatch"
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if got := r.PostForm.Get("scope"); got != openai.RefreshScopes {
|
|
errCh <- "scope mismatch"
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = io.WriteString(w, `{"access_token":"at2","refresh_token":"rt2","token_type":"bearer","expires_in":3600}`)
|
|
}))
|
|
|
|
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
|
require.NoError(s.T(), err, "RefreshToken")
|
|
select {
|
|
case msg := <-errCh:
|
|
require.Fail(s.T(), msg)
|
|
default:
|
|
}
|
|
require.Equal(s.T(), "at2", resp.AccessToken)
|
|
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
|
}
|
|
|
|
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
_, _ = io.WriteString(w, "bad")
|
|
}))
|
|
|
|
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
|
require.Error(s.T(), err)
|
|
require.ErrorContains(s.T(), err, "status 400")
|
|
require.ErrorContains(s.T(), err, "bad")
|
|
}
|
|
|
|
func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
|
s.srv.Close()
|
|
|
|
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
|
require.Error(s.T(), err)
|
|
require.ErrorContains(s.T(), err, "request failed")
|
|
}
|
|
|
|
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
|
|
started := make(chan struct{})
|
|
block := make(chan struct{})
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
close(started)
|
|
<-block
|
|
}))
|
|
|
|
ctx, cancel := context.WithCancel(s.ctx)
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
|
done <- err
|
|
}()
|
|
|
|
<-started
|
|
cancel()
|
|
close(block)
|
|
|
|
err := <-done
|
|
require.Error(s.T(), err)
|
|
}
|
|
|
|
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
|
|
want := "http://localhost:9999/cb"
|
|
errCh := make(chan string, 1)
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_ = r.ParseForm()
|
|
if got := r.PostForm.Get("redirect_uri"); got != want {
|
|
errCh <- "redirect_uri mismatch"
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
|
}))
|
|
|
|
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
|
|
require.NoError(s.T(), err, "ExchangeCode")
|
|
select {
|
|
case msg := <-errCh:
|
|
require.Fail(s.T(), msg)
|
|
default:
|
|
}
|
|
}
|
|
|
|
func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_ = r.ParseForm()
|
|
s.received <- r.PostForm
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
|
}))
|
|
s.svc.tokenURL = s.srv.URL + "?x=1"
|
|
|
|
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
|
require.NoError(s.T(), err, "ExchangeCode")
|
|
select {
|
|
case <-s.received:
|
|
default:
|
|
require.Fail(s.T(), "expected server to receive request")
|
|
}
|
|
}
|
|
|
|
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = io.WriteString(w, "not-valid-json")
|
|
}))
|
|
|
|
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
|
require.Error(s.T(), err, "expected error for invalid JSON response")
|
|
}
|
|
|
|
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
_, _ = io.WriteString(w, "unauthorized")
|
|
}))
|
|
|
|
_, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
|
require.Error(s.T(), err, "expected error for non-2xx status")
|
|
require.ErrorContains(s.T(), err, "status 401")
|
|
}
|
|
|
|
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
|
|
client := NewOpenAIOAuthClient()
|
|
svc, ok := client.(*openaiOAuthService)
|
|
require.True(t, ok)
|
|
require.Equal(t, openai.TokenURL, svc.tokenURL)
|
|
}
|
|
|
|
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
|
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
|
}
|