fix: respect proxy settings for outbound clients (#43)
This commit is contained in:
@@ -34,6 +34,8 @@ func buildAuthTransport(proxyURL string) *http.Transport {
|
||||
t.Proxy = http.ProxyURL(u)
|
||||
t.ForceAttemptHTTP2 = false
|
||||
}
|
||||
} else {
|
||||
t.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
52
auth/http_client_test.go
Normal file
52
auth/http_client_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildAuthTransportUsesExplicitProxyURL(t *testing.T) {
|
||||
transport := buildAuthTransport("http://proxy.local:8080")
|
||||
req := &http.Request{URL: mustParseURL(t, "https://oidc.us-east-1.amazonaws.com")}
|
||||
|
||||
got, err := transport.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected proxy error: %v", err)
|
||||
}
|
||||
assertProxyURL(t, got, "http://proxy.local:8080")
|
||||
}
|
||||
|
||||
func TestBuildAuthTransportFallsBackToEnvironmentProxy(t *testing.T) {
|
||||
t.Setenv("HTTPS_PROXY", "http://env-proxy.local:2323")
|
||||
t.Setenv("NO_PROXY", "")
|
||||
t.Setenv("no_proxy", "")
|
||||
|
||||
transport := buildAuthTransport("")
|
||||
req := &http.Request{URL: mustParseURL(t, "https://oidc.us-east-1.amazonaws.com")}
|
||||
|
||||
got, err := transport.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected proxy error: %v", err)
|
||||
}
|
||||
assertProxyURL(t, got, "http://env-proxy.local:2323")
|
||||
}
|
||||
|
||||
func mustParseURL(t *testing.T, raw string) *url.URL {
|
||||
t.Helper()
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid test URL: %v", err)
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
|
||||
func assertProxyURL(t *testing.T, got *url.URL, want string) {
|
||||
t.Helper()
|
||||
if got == nil {
|
||||
t.Fatalf("expected proxy URL %q, got nil", want)
|
||||
}
|
||||
if got.String() != want {
|
||||
t.Fatalf("expected proxy URL %q, got %q", want, got.String())
|
||||
}
|
||||
}
|
||||
@@ -43,6 +43,7 @@ var kiroEndpoints = []kiroEndpoint{
|
||||
|
||||
// 全局 HTTP 客户端,支持运行时更换(代理重配置)
|
||||
var kiroHttpStore atomic.Pointer[http.Client]
|
||||
var kiroRestHttpStore atomic.Pointer[http.Client]
|
||||
|
||||
func init() {
|
||||
InitKiroHttpClient("")
|
||||
@@ -63,6 +64,8 @@ func buildKiroTransport(proxyURL string) *http.Transport {
|
||||
// 代理不支持 HTTP/2 协议升级
|
||||
t.ForceAttemptHTTP2 = false
|
||||
}
|
||||
} else {
|
||||
t.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
return t
|
||||
}
|
||||
@@ -74,6 +77,12 @@ func InitKiroHttpClient(proxyURL string) {
|
||||
Transport: buildKiroTransport(proxyURL),
|
||||
}
|
||||
kiroHttpStore.Store(client)
|
||||
|
||||
restClient := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: buildKiroTransport(proxyURL),
|
||||
}
|
||||
kiroRestHttpStore.Store(restClient)
|
||||
}
|
||||
|
||||
// ==================== 请求结构 ====================
|
||||
|
||||
@@ -25,8 +25,7 @@ func GetUsageLimits(account *config.Account) (*UsageLimitsResponse, error) {
|
||||
|
||||
setKiroHeaders(req, account)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := kiroRestHttpStore.Load().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -57,8 +56,7 @@ func GetUserInfo(account *config.Account) (*UserInfoResponse, error) {
|
||||
setKiroHeaders(req, account)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := kiroRestHttpStore.Load().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -87,8 +85,7 @@ func ListAvailableModels(account *config.Account) ([]ModelInfo, error) {
|
||||
|
||||
setKiroHeaders(req, account)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := kiroRestHttpStore.Load().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package proxy
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNormalizeChunkBasicProgression(t *testing.T) {
|
||||
prev := ""
|
||||
@@ -35,3 +40,63 @@ func TestNormalizeChunkOverlapDelta(t *testing.T) {
|
||||
t.Fatalf("expected overlap suffix delta, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKiroTransportUsesExplicitProxyURL(t *testing.T) {
|
||||
transport := buildKiroTransport("http://proxy.local:8080")
|
||||
req := &http.Request{URL: mustParseURL(t, "https://q.us-east-1.amazonaws.com")}
|
||||
|
||||
got, err := transport.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected proxy error: %v", err)
|
||||
}
|
||||
assertProxyURL(t, got, "http://proxy.local:8080")
|
||||
}
|
||||
|
||||
func TestBuildKiroTransportFallsBackToEnvironmentProxy(t *testing.T) {
|
||||
t.Setenv("HTTPS_PROXY", "http://env-proxy.local:2323")
|
||||
t.Setenv("NO_PROXY", "")
|
||||
t.Setenv("no_proxy", "")
|
||||
|
||||
transport := buildKiroTransport("")
|
||||
req := &http.Request{URL: mustParseURL(t, "https://q.us-east-1.amazonaws.com")}
|
||||
|
||||
got, err := transport.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected proxy error: %v", err)
|
||||
}
|
||||
assertProxyURL(t, got, "http://env-proxy.local:2323")
|
||||
}
|
||||
|
||||
func TestInitKiroHttpClientKeepsShortRestTimeout(t *testing.T) {
|
||||
InitKiroHttpClient("")
|
||||
t.Cleanup(func() { InitKiroHttpClient("") })
|
||||
|
||||
streamClient := kiroHttpStore.Load()
|
||||
restClient := kiroRestHttpStore.Load()
|
||||
|
||||
if streamClient.Timeout != 5*time.Minute {
|
||||
t.Fatalf("expected streaming timeout to be 5m, got %s", streamClient.Timeout)
|
||||
}
|
||||
if restClient.Timeout != 30*time.Second {
|
||||
t.Fatalf("expected REST timeout to stay 30s, got %s", restClient.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseURL(t *testing.T, raw string) *url.URL {
|
||||
t.Helper()
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid test URL: %v", err)
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
|
||||
func assertProxyURL(t *testing.T, got *url.URL, want string) {
|
||||
t.Helper()
|
||||
if got == nil {
|
||||
t.Fatalf("expected proxy URL %q, got nil", want)
|
||||
}
|
||||
if got.String() != want {
|
||||
t.Fatalf("expected proxy URL %q, got %q", want, got.String())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user