提取 proxyurl.Parse() 公共包,将分散在 6 处的代理 URL 验证逻辑 统一收敛,确保无效代理配置在创建时立即失败,永不静默回退直连。 主要变更: - 新增 proxyurl 包:统一 TrimSpace → url.Parse → Host 校验 → Scheme 白名单 - socks5:// 自动升级为 socks5h://,防止 DNS 泄漏(大小写不敏感) - antigravity: http.ProxyURL → proxyutil.ConfigureTransportProxy 支持 SOCKS5 - openai_oauth: 删除 newOpenAIOAuthHTTPClient,收编至 httpclient.GetClient - 移除未使用的 ProxyStrict 字段(fail-fast 已是全局默认行为) - 补充 15 个 proxyurl 测试 + pricing/usage fail-fast 测试
162 lines
4.7 KiB
Go
162 lines
4.7 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/stretchr/testify/suite"
|
|
)
|
|
|
|
type PricingServiceSuite struct {
|
|
suite.Suite
|
|
ctx context.Context
|
|
srv *httptest.Server
|
|
client *pricingRemoteClient
|
|
}
|
|
|
|
func (s *PricingServiceSuite) SetupTest() {
|
|
s.ctx = context.Background()
|
|
client, ok := NewPricingRemoteClient("", false).(*pricingRemoteClient)
|
|
require.True(s.T(), ok, "type assertion failed")
|
|
s.client = client
|
|
}
|
|
|
|
func (s *PricingServiceSuite) TearDownTest() {
|
|
if s.srv != nil {
|
|
s.srv.Close()
|
|
s.srv = nil
|
|
}
|
|
}
|
|
|
|
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
|
|
s.srv = newLocalTestServer(s.T(), handler)
|
|
}
|
|
|
|
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/ok" {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
}))
|
|
|
|
body, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/ok")
|
|
require.NoError(s.T(), err, "FetchPricingJSON")
|
|
require.Equal(s.T(), `{"ok":true}`, string(body), "body mismatch")
|
|
}
|
|
|
|
func (s *PricingServiceSuite) TestFetchPricingJSON_NonOKStatus() {
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
}))
|
|
|
|
_, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/err")
|
|
require.Error(s.T(), err, "expected error for non-200 status")
|
|
}
|
|
|
|
func (s *PricingServiceSuite) TestFetchHashText_ParsesFields() {
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/hashfile":
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("abc123 model_prices.json\n"))
|
|
case "/hashonly":
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("def456\n"))
|
|
default:
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}
|
|
}))
|
|
|
|
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashfile")
|
|
require.NoError(s.T(), err, "FetchHashText")
|
|
require.Equal(s.T(), "abc123", hash, "hash mismatch")
|
|
|
|
hash2, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashonly")
|
|
require.NoError(s.T(), err, "FetchHashText")
|
|
require.Equal(s.T(), "def456", hash2, "hash mismatch")
|
|
}
|
|
|
|
func (s *PricingServiceSuite) TestFetchHashText_NonOKStatus() {
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}))
|
|
|
|
_, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/nope")
|
|
require.Error(s.T(), err, "expected error for non-200 status")
|
|
}
|
|
|
|
func (s *PricingServiceSuite) TestFetchPricingJSON_InvalidURL() {
|
|
_, err := s.client.FetchPricingJSON(s.ctx, "://invalid-url")
|
|
require.Error(s.T(), err, "expected error for invalid URL")
|
|
}
|
|
|
|
func (s *PricingServiceSuite) TestFetchHashText_EmptyBody() {
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
// empty body
|
|
}))
|
|
|
|
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/empty")
|
|
require.NoError(s.T(), err, "FetchHashText empty body should not error")
|
|
require.Equal(s.T(), "", hash, "expected empty hash")
|
|
}
|
|
|
|
func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(" \n"))
|
|
}))
|
|
|
|
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/ws")
|
|
require.NoError(s.T(), err, "FetchHashText whitespace body should not error")
|
|
require.Equal(s.T(), "", hash, "expected empty hash after trimming")
|
|
}
|
|
|
|
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
|
|
started := make(chan struct{})
|
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
close(started)
|
|
<-r.Context().Done()
|
|
}))
|
|
|
|
ctx, cancel := context.WithCancel(s.ctx)
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
_, err := s.client.FetchPricingJSON(ctx, s.srv.URL+"/block")
|
|
done <- err
|
|
}()
|
|
|
|
<-started
|
|
cancel()
|
|
|
|
err := <-done
|
|
require.Error(s.T(), err)
|
|
}
|
|
|
|
func TestNewPricingRemoteClient_InvalidProxy_NoFallback(t *testing.T) {
|
|
client := NewPricingRemoteClient("://bad", false)
|
|
_, ok := client.(*pricingRemoteClientError)
|
|
require.True(t, ok, "should return error client when proxy is invalid and fallback disabled")
|
|
|
|
_, err := client.FetchPricingJSON(context.Background(), "http://example.com")
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "proxy client init failed")
|
|
}
|
|
|
|
func TestNewPricingRemoteClient_InvalidProxy_WithFallback(t *testing.T) {
|
|
client := NewPricingRemoteClient("://bad", true)
|
|
_, ok := client.(*pricingRemoteClient)
|
|
require.True(t, ok, "should fallback to direct client when allowed")
|
|
}
|
|
|
|
func TestPricingServiceSuite(t *testing.T) {
|
|
suite.Run(t, new(PricingServiceSuite))
|
|
}
|