diff --git a/relay/channel/vertex/service_account.go b/relay/channel/vertex/service_account.go index cc640803..1d41c945 100644 --- a/relay/channel/vertex/service_account.go +++ b/relay/channel/vertex/service_account.go @@ -11,6 +11,7 @@ import ( "net/http" "net/url" relaycommon "one-api/relay/common" + "one-api/service" "strings" "fmt" @@ -45,7 +46,7 @@ func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) { if err != nil { return "", fmt.Errorf("failed to create signed JWT: %w", err) } - newToken, err := exchangeJwtForAccessToken(signedJWT) + newToken, err := exchangeJwtForAccessToken(signedJWT, info) if err != nil { return "", fmt.Errorf("failed to exchange JWT for access token: %w", err) } @@ -96,14 +97,25 @@ func createSignedJWT(email, privateKeyPEM string) (string, error) { return signedToken, nil } -func exchangeJwtForAccessToken(signedJWT string) (string, error) { +func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (string, error) { authURL := "https://www.googleapis.com/oauth2/v4/token" data := url.Values{} data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer") data.Set("assertion", signedJWT) - resp, err := http.PostForm(authURL, data) + var client *http.Client + var err error + if proxyURL, ok := info.ChannelSetting["proxy"]; ok { + client, err = service.NewProxyHttpClient(proxyURL.(string)) + if err != nil { + return "", fmt.Errorf("new proxy http client failed: %w", err) + } + } else { + client = service.GetHttpClient() + } + + resp, err := client.PostForm(authURL, data) if err != nil { return "", err }