diff --git a/pkg/client/unversioned/remotecommand/remotecommand.go b/pkg/client/unversioned/remotecommand/remotecommand.go index f2422b4cf70..a90fab1fe45 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand.go +++ b/pkg/client/unversioned/remotecommand/remotecommand.go @@ -78,7 +78,7 @@ func NewExecutor(config *restclient.Config, method string, url *url.URL) (Stream return nil, err } - upgradeRoundTripper := spdy.NewRoundTripper(tlsConfig) + upgradeRoundTripper := spdy.NewRoundTripper(tlsConfig, true) wrapper, err := restclient.HTTPWrappersForConfig(config, upgradeRoundTripper) if err != nil { return nil, err diff --git a/pkg/kubelet/server/server_test.go b/pkg/kubelet/server/server_test.go index 4039f8f239c..da2be51d08a 100644 --- a/pkg/kubelet/server/server_test.go +++ b/pkg/kubelet/server/server_test.go @@ -1119,7 +1119,7 @@ func TestServeExecInContainerIdleTimeout(t *testing.T) { url := fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?c=ls&c=-a&" + api.ExecStdinParam + "=1" - upgradeRoundTripper := spdy.NewSpdyRoundTripper(nil) + upgradeRoundTripper := spdy.NewSpdyRoundTripper(nil, true) c := &http.Client{Transport: upgradeRoundTripper} resp, err := c.Post(url, "", nil) @@ -1304,7 +1304,7 @@ func testExecAttach(t *testing.T, verb string) { return http.ErrUseLastResponse } } else { - upgradeRoundTripper = spdy.NewRoundTripper(nil) + upgradeRoundTripper = spdy.NewRoundTripper(nil, true) c = &http.Client{Transport: upgradeRoundTripper} } @@ -1442,7 +1442,7 @@ func TestServePortForwardIdleTimeout(t *testing.T) { url := fw.testHTTPServer.URL + "/portForward/" + podNamespace + "/" + podName - upgradeRoundTripper := spdy.NewRoundTripper(nil) + upgradeRoundTripper := spdy.NewRoundTripper(nil, true) c := &http.Client{Transport: upgradeRoundTripper} resp, err := c.Post(url, "", nil) @@ -1552,11 +1552,20 @@ func TestServePortForward(t *testing.T) { url = fmt.Sprintf("%s/portForward/%s/%s", fw.testHTTPServer.URL, podNamespace, podName) } - upgradeRoundTripper := spdy.NewRoundTripper(nil) - c := &http.Client{Transport: upgradeRoundTripper} - // Don't follow redirects, since we want to inspect the redirect response. - c.CheckRedirect = func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse + var ( + upgradeRoundTripper httpstream.UpgradeRoundTripper + c *http.Client + ) + + if len(test.responseLocation) > 0 { + c = &http.Client{} + // Don't follow redirects, since we want to inspect the redirect response. + c.CheckRedirect = func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + } + } else { + upgradeRoundTripper = spdy.NewRoundTripper(nil, true) + c = &http.Client{Transport: upgradeRoundTripper} } resp, err := c.Post(url, "", nil) diff --git a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper.go b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper.go index cf5fbe9be3b..ad300e28b5b 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper.go @@ -18,9 +18,11 @@ package spdy import ( "bufio" + "bytes" "crypto/tls" "encoding/base64" "fmt" + "io" "io/ioutil" "net" "net/http" @@ -33,6 +35,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/serializer" "k8s.io/apimachinery/pkg/util/httpstream" + utilnet "k8s.io/apimachinery/pkg/util/net" "k8s.io/apimachinery/third_party/forked/golang/netutil" ) @@ -59,25 +62,49 @@ type SpdyRoundTripper struct { // proxier knows which proxy to use given a request, defaults to http.ProxyFromEnvironment // Used primarily for mocking the proxy discovery in tests. proxier func(req *http.Request) (*url.URL, error) + + // followRedirects indicates if the round tripper should examine responses for redirects and + // follow them. + followRedirects bool } +var _ utilnet.TLSClientConfigHolder = &SpdyRoundTripper{} +var _ httpstream.UpgradeRoundTripper = &SpdyRoundTripper{} +var _ utilnet.Dialer = &SpdyRoundTripper{} + // NewRoundTripper creates a new SpdyRoundTripper that will use // the specified tlsConfig. -func NewRoundTripper(tlsConfig *tls.Config) httpstream.UpgradeRoundTripper { - return NewSpdyRoundTripper(tlsConfig) +func NewRoundTripper(tlsConfig *tls.Config, followRedirects bool) httpstream.UpgradeRoundTripper { + return NewSpdyRoundTripper(tlsConfig, followRedirects) } // NewSpdyRoundTripper creates a new SpdyRoundTripper that will use // the specified tlsConfig. This function is mostly meant for unit tests. -func NewSpdyRoundTripper(tlsConfig *tls.Config) *SpdyRoundTripper { - return &SpdyRoundTripper{tlsConfig: tlsConfig} +func NewSpdyRoundTripper(tlsConfig *tls.Config, followRedirects bool) *SpdyRoundTripper { + return &SpdyRoundTripper{tlsConfig: tlsConfig, followRedirects: followRedirects} } -// implements pkg/util/net.TLSClientConfigHolder for proper TLS checking during proxying with a spdy roundtripper +// TLSClientConfig implements pkg/util/net.TLSClientConfigHolder for proper TLS checking during +// proxying with a spdy roundtripper. func (s *SpdyRoundTripper) TLSClientConfig() *tls.Config { return s.tlsConfig } +// Dial implements k8s.io/apimachinery/pkg/util/net.Dialer. +func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) { + conn, err := s.dial(req) + if err != nil { + return nil, err + } + + if err := req.Write(conn); err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} + // dial dials the host specified by req, using TLS if appropriate, optionally // using a proxy server if one is configured via environment variables. func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) { @@ -213,24 +240,39 @@ func (s *SpdyRoundTripper) proxyAuth(proxyURL *url.URL) string { // clients may call SpdyRoundTripper.Connection() to retrieve the upgraded // connection. func (s *SpdyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - // TODO what's the best way to clone the request? - r := *req - req = &r - req.Header.Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade) - req.Header.Add(httpstream.HeaderUpgrade, HeaderSpdy31) + header := utilnet.CloneHeader(req.Header) + header.Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade) + header.Add(httpstream.HeaderUpgrade, HeaderSpdy31) - conn, err := s.dial(req) + var ( + conn net.Conn + rawResponse []byte + err error + ) + + if s.followRedirects { + conn, rawResponse, err = utilnet.ConnectWithRedirects(req.Method, req.URL, header, req.Body, s) + } else { + clone := utilnet.CloneRequest(req) + clone.Header = header + conn, err = s.Dial(clone) + } if err != nil { return nil, err } - err = req.Write(conn) - if err != nil { - return nil, err - } + responseReader := bufio.NewReader( + io.MultiReader( + bytes.NewBuffer(rawResponse), + conn, + ), + ) - resp, err := http.ReadResponse(bufio.NewReader(conn), req) + resp, err := http.ReadResponse(responseReader, nil) if err != nil { + if conn != nil { + conn.Close() + } return nil, err } diff --git a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper_test.go b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper_test.go index 7cbc7c8808e..887adbe8f64 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper_test.go @@ -20,359 +20,463 @@ import ( "crypto/tls" "crypto/x509" "encoding/base64" + "fmt" "io" "net/http" "net/http/httptest" "net/url" + "strings" + "sync/atomic" "testing" "github.com/elazarl/goproxy" + "k8s.io/apimachinery/pkg/util/httpstream" ) // be sure to unset environment variable https_proxy (if exported) before testing, otherwise the testing will fail unexpectedly. func TestRoundTripAndNewConnection(t *testing.T) { - localhostPool := x509.NewCertPool() - if !localhostPool.AppendCertsFromPEM(localhostCert) { - t.Errorf("error setting up localhostCert pool") - } + for _, redirect := range []bool{false, true} { + t.Run(fmt.Sprintf("redirect = %t", redirect), func(t *testing.T) { + localhostPool := x509.NewCertPool() + if !localhostPool.AppendCertsFromPEM(localhostCert) { + t.Errorf("error setting up localhostCert pool") + } - httpsServerInvalidHostname := func(h http.Handler) *httptest.Server { - cert, err := tls.X509KeyPair(exampleCert, exampleKey) - if err != nil { - t.Errorf("https (invalid hostname): proxy_test: %v", err) - } - ts := httptest.NewUnstartedServer(h) - ts.TLS = &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - ts.StartTLS() - return ts - } + httpsServerInvalidHostname := func(h http.Handler) *httptest.Server { + cert, err := tls.X509KeyPair(exampleCert, exampleKey) + if err != nil { + t.Errorf("https (invalid hostname): proxy_test: %v", err) + } + ts := httptest.NewUnstartedServer(h) + ts.TLS = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + ts.StartTLS() + return ts + } - httpsServerValidHostname := func(h http.Handler) *httptest.Server { - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - t.Errorf("https (valid hostname): proxy_test: %v", err) - } - ts := httptest.NewUnstartedServer(h) - ts.TLS = &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - ts.StartTLS() - return ts - } + httpsServerValidHostname := func(h http.Handler) *httptest.Server { + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Errorf("https (valid hostname): proxy_test: %v", err) + } + ts := httptest.NewUnstartedServer(h) + ts.TLS = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + ts.StartTLS() + return ts + } - testCases := map[string]struct { - serverFunc func(http.Handler) *httptest.Server - proxyServerFunc func(http.Handler) *httptest.Server - proxyAuth *url.Userinfo - clientTLS *tls.Config - serverConnectionHeader string - serverUpgradeHeader string - serverStatusCode int - shouldError bool - }{ - "no headers": { - serverFunc: httptest.NewServer, - serverConnectionHeader: "", - serverUpgradeHeader: "", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: true, - }, - "no upgrade header": { - serverFunc: httptest.NewServer, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: true, - }, - "no connection header": { - serverFunc: httptest.NewServer, - serverConnectionHeader: "", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: true, - }, - "no switching protocol status code": { - serverFunc: httptest.NewServer, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusForbidden, - shouldError: true, - }, - "http": { - serverFunc: httptest.NewServer, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: false, - }, - "https (invalid hostname + InsecureSkipVerify)": { - serverFunc: httpsServerInvalidHostname, - clientTLS: &tls.Config{InsecureSkipVerify: true}, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: false, - }, - "https (invalid hostname + hostname verification)": { - serverFunc: httpsServerInvalidHostname, - clientTLS: &tls.Config{InsecureSkipVerify: false}, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: true, - }, - "https (valid hostname + RootCAs)": { - serverFunc: httpsServerValidHostname, - clientTLS: &tls.Config{RootCAs: localhostPool}, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: false, - }, - "proxied http->http": { - serverFunc: httptest.NewServer, - proxyServerFunc: httptest.NewServer, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: false, - }, - "proxied https (invalid hostname + InsecureSkipVerify) -> http": { - serverFunc: httptest.NewServer, - proxyServerFunc: httpsServerInvalidHostname, - clientTLS: &tls.Config{InsecureSkipVerify: true}, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: false, - }, - "proxied https with auth (invalid hostname + InsecureSkipVerify) -> http": { - serverFunc: httptest.NewServer, - proxyServerFunc: httpsServerInvalidHostname, - proxyAuth: url.UserPassword("proxyuser", "proxypasswd"), - clientTLS: &tls.Config{InsecureSkipVerify: true}, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: false, - }, - "proxied https (invalid hostname + hostname verification) -> http": { - serverFunc: httptest.NewServer, - proxyServerFunc: httpsServerInvalidHostname, - clientTLS: &tls.Config{InsecureSkipVerify: false}, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: true, // fails because the client doesn't trust the proxy - }, - "proxied https (valid hostname + RootCAs) -> http": { - serverFunc: httptest.NewServer, - proxyServerFunc: httpsServerValidHostname, - clientTLS: &tls.Config{RootCAs: localhostPool}, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: false, - }, - "proxied https with auth (valid hostname + RootCAs) -> http": { - serverFunc: httptest.NewServer, - proxyServerFunc: httpsServerValidHostname, - proxyAuth: url.UserPassword("proxyuser", "proxypasswd"), - clientTLS: &tls.Config{RootCAs: localhostPool}, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: false, - }, - "proxied https (invalid hostname + InsecureSkipVerify) -> https (invalid hostname)": { - serverFunc: httpsServerInvalidHostname, - proxyServerFunc: httpsServerInvalidHostname, - clientTLS: &tls.Config{InsecureSkipVerify: true}, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: false, // works because the test proxy ignores TLS errors - }, - "proxied https with auth (invalid hostname + InsecureSkipVerify) -> https (invalid hostname)": { - serverFunc: httpsServerInvalidHostname, - proxyServerFunc: httpsServerInvalidHostname, - proxyAuth: url.UserPassword("proxyuser", "proxypasswd"), - clientTLS: &tls.Config{InsecureSkipVerify: true}, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: false, // works because the test proxy ignores TLS errors - }, - "proxied https (invalid hostname + hostname verification) -> https (invalid hostname)": { - serverFunc: httpsServerInvalidHostname, - proxyServerFunc: httpsServerInvalidHostname, - clientTLS: &tls.Config{InsecureSkipVerify: false}, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: true, // fails because the client doesn't trust the proxy - }, - "proxied https (valid hostname + RootCAs) -> https (valid hostname + RootCAs)": { - serverFunc: httpsServerValidHostname, - proxyServerFunc: httpsServerValidHostname, - clientTLS: &tls.Config{RootCAs: localhostPool}, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: false, - }, - "proxied https with auth (valid hostname + RootCAs) -> https (valid hostname + RootCAs)": { - serverFunc: httpsServerValidHostname, - proxyServerFunc: httpsServerValidHostname, - proxyAuth: url.UserPassword("proxyuser", "proxypasswd"), - clientTLS: &tls.Config{RootCAs: localhostPool}, - serverConnectionHeader: "Upgrade", - serverUpgradeHeader: "SPDY/3.1", - serverStatusCode: http.StatusSwitchingProtocols, - shouldError: false, - }, - } + testCases := map[string]struct { + serverFunc func(http.Handler) *httptest.Server + proxyServerFunc func(http.Handler) *httptest.Server + proxyAuth *url.Userinfo + clientTLS *tls.Config + serverConnectionHeader string + serverUpgradeHeader string + serverStatusCode int + shouldError bool + }{ + "no headers": { + serverFunc: httptest.NewServer, + serverConnectionHeader: "", + serverUpgradeHeader: "", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: true, + }, + "no upgrade header": { + serverFunc: httptest.NewServer, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: true, + }, + "no connection header": { + serverFunc: httptest.NewServer, + serverConnectionHeader: "", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: true, + }, + "no switching protocol status code": { + serverFunc: httptest.NewServer, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusForbidden, + shouldError: true, + }, + "http": { + serverFunc: httptest.NewServer, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: false, + }, + "https (invalid hostname + InsecureSkipVerify)": { + serverFunc: httpsServerInvalidHostname, + clientTLS: &tls.Config{InsecureSkipVerify: true}, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: false, + }, + "https (invalid hostname + hostname verification)": { + serverFunc: httpsServerInvalidHostname, + clientTLS: &tls.Config{InsecureSkipVerify: false}, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: true, + }, + "https (valid hostname + RootCAs)": { + serverFunc: httpsServerValidHostname, + clientTLS: &tls.Config{RootCAs: localhostPool}, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: false, + }, + "proxied http->http": { + serverFunc: httptest.NewServer, + proxyServerFunc: httptest.NewServer, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: false, + }, + "proxied https (invalid hostname + InsecureSkipVerify) -> http": { + serverFunc: httptest.NewServer, + proxyServerFunc: httpsServerInvalidHostname, + clientTLS: &tls.Config{InsecureSkipVerify: true}, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: false, + }, + "proxied https with auth (invalid hostname + InsecureSkipVerify) -> http": { + serverFunc: httptest.NewServer, + proxyServerFunc: httpsServerInvalidHostname, + proxyAuth: url.UserPassword("proxyuser", "proxypasswd"), + clientTLS: &tls.Config{InsecureSkipVerify: true}, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: false, + }, + "proxied https (invalid hostname + hostname verification) -> http": { + serverFunc: httptest.NewServer, + proxyServerFunc: httpsServerInvalidHostname, + clientTLS: &tls.Config{InsecureSkipVerify: false}, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: true, // fails because the client doesn't trust the proxy + }, + "proxied https (valid hostname + RootCAs) -> http": { + serverFunc: httptest.NewServer, + proxyServerFunc: httpsServerValidHostname, + clientTLS: &tls.Config{RootCAs: localhostPool}, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: false, + }, + "proxied https with auth (valid hostname + RootCAs) -> http": { + serverFunc: httptest.NewServer, + proxyServerFunc: httpsServerValidHostname, + proxyAuth: url.UserPassword("proxyuser", "proxypasswd"), + clientTLS: &tls.Config{RootCAs: localhostPool}, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: false, + }, + "proxied https (invalid hostname + InsecureSkipVerify) -> https (invalid hostname)": { + serverFunc: httpsServerInvalidHostname, + proxyServerFunc: httpsServerInvalidHostname, + clientTLS: &tls.Config{InsecureSkipVerify: true}, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: false, // works because the test proxy ignores TLS errors + }, + "proxied https with auth (invalid hostname + InsecureSkipVerify) -> https (invalid hostname)": { + serverFunc: httpsServerInvalidHostname, + proxyServerFunc: httpsServerInvalidHostname, + proxyAuth: url.UserPassword("proxyuser", "proxypasswd"), + clientTLS: &tls.Config{InsecureSkipVerify: true}, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: false, // works because the test proxy ignores TLS errors + }, + "proxied https (invalid hostname + hostname verification) -> https (invalid hostname)": { + serverFunc: httpsServerInvalidHostname, + proxyServerFunc: httpsServerInvalidHostname, + clientTLS: &tls.Config{InsecureSkipVerify: false}, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: true, // fails because the client doesn't trust the proxy + }, + "proxied https (valid hostname + RootCAs) -> https (valid hostname + RootCAs)": { + serverFunc: httpsServerValidHostname, + proxyServerFunc: httpsServerValidHostname, + clientTLS: &tls.Config{RootCAs: localhostPool}, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: false, + }, + "proxied https with auth (valid hostname + RootCAs) -> https (valid hostname + RootCAs)": { + serverFunc: httpsServerValidHostname, + proxyServerFunc: httpsServerValidHostname, + proxyAuth: url.UserPassword("proxyuser", "proxypasswd"), + clientTLS: &tls.Config{RootCAs: localhostPool}, + serverConnectionHeader: "Upgrade", + serverUpgradeHeader: "SPDY/3.1", + serverStatusCode: http.StatusSwitchingProtocols, + shouldError: false, + }, + } - for k, testCase := range testCases { - server := testCase.serverFunc(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if testCase.shouldError { - if e, a := httpstream.HeaderUpgrade, req.Header.Get(httpstream.HeaderConnection); e != a { - t.Fatalf("%s: Expected connection=upgrade header, got '%s", k, a) + for k, testCase := range testCases { + server := testCase.serverFunc(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if testCase.shouldError { + if e, a := httpstream.HeaderUpgrade, req.Header.Get(httpstream.HeaderConnection); e != a { + t.Fatalf("%s: Expected connection=upgrade header, got '%s", k, a) + } + + w.Header().Set(httpstream.HeaderConnection, testCase.serverConnectionHeader) + w.Header().Set(httpstream.HeaderUpgrade, testCase.serverUpgradeHeader) + w.WriteHeader(testCase.serverStatusCode) + + return + } + + streamCh := make(chan httpstream.Stream) + + responseUpgrader := NewResponseUpgrader() + spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream, replySent <-chan struct{}) error { + streamCh <- s + return nil + }) + if spdyConn == nil { + t.Fatalf("%s: unexpected nil spdyConn", k) + } + defer spdyConn.Close() + + stream := <-streamCh + io.Copy(stream, stream) + })) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("%s: Error creating request: %s", k, err) + } + req, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatalf("%s: Error creating request: %s", k, err) } - w.Header().Set(httpstream.HeaderConnection, testCase.serverConnectionHeader) - w.Header().Set(httpstream.HeaderUpgrade, testCase.serverUpgradeHeader) - w.WriteHeader(testCase.serverStatusCode) + spdyTransport := NewSpdyRoundTripper(testCase.clientTLS, redirect) + var proxierCalled bool + var proxyCalledWithHost string + var proxyCalledWithAuth bool + var proxyCalledWithAuthHeader string + if testCase.proxyServerFunc != nil { + proxyHandler := goproxy.NewProxyHttpServer() + + proxyHandler.OnRequest().HandleConnectFunc(func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) { + proxyCalledWithHost = host + + proxyAuthHeaderName := "Proxy-Authorization" + _, proxyCalledWithAuth = ctx.Req.Header[proxyAuthHeaderName] + proxyCalledWithAuthHeader = ctx.Req.Header.Get(proxyAuthHeaderName) + return goproxy.OkConnect, host + }) + + proxy := testCase.proxyServerFunc(proxyHandler) + + spdyTransport.proxier = func(proxierReq *http.Request) (*url.URL, error) { + proxierCalled = true + proxyURL, err := url.Parse(proxy.URL) + if err != nil { + return nil, err + } + proxyURL.User = testCase.proxyAuth + return proxyURL, nil + } + defer proxy.Close() + } + + client := &http.Client{Transport: spdyTransport} + + resp, err := client.Do(req) + var conn httpstream.Connection + if err == nil { + conn, err = spdyTransport.NewConnection(resp) + } + haveErr := err != nil + if e, a := testCase.shouldError, haveErr; e != a { + t.Fatalf("%s: shouldError=%t, got %t: %v", k, e, a, err) + } + if testCase.shouldError { + continue + } + defer conn.Close() + + if resp.StatusCode != http.StatusSwitchingProtocols { + t.Fatalf("%s: expected http 101 switching protocols, got %d", k, resp.StatusCode) + } + + stream, err := conn.CreateStream(http.Header{}) + if err != nil { + t.Fatalf("%s: error creating client stream: %s", k, err) + } + + n, err := stream.Write([]byte("hello")) + if err != nil { + t.Fatalf("%s: error writing to stream: %s", k, err) + } + if n != 5 { + t.Fatalf("%s: Expected to write 5 bytes, but actually wrote %d", k, n) + } + + b := make([]byte, 5) + n, err = stream.Read(b) + if err != nil { + t.Fatalf("%s: error reading from stream: %s", k, err) + } + if n != 5 { + t.Fatalf("%s: Expected to read 5 bytes, but actually read %d", k, n) + } + if e, a := "hello", string(b[0:n]); e != a { + t.Fatalf("%s: expected '%s', got '%s'", k, e, a) + } + + if testCase.proxyServerFunc != nil { + if !proxierCalled { + t.Fatalf("%s: Expected to use a proxy but proxier in SpdyRoundTripper wasn't called", k) + } + if proxyCalledWithHost != serverURL.Host { + t.Fatalf("%s: Expected to see a call to the proxy for backend %q, got %q", k, serverURL.Host, proxyCalledWithHost) + } + } + + var expectedProxyAuth string + if testCase.proxyAuth != nil { + encodedCredentials := base64.StdEncoding.EncodeToString([]byte(testCase.proxyAuth.String())) + expectedProxyAuth = "Basic " + encodedCredentials + } + if len(expectedProxyAuth) == 0 && proxyCalledWithAuth { + t.Fatalf("%s: Proxy authorization unexpected, got %q", k, proxyCalledWithAuthHeader) + } + if proxyCalledWithAuthHeader != expectedProxyAuth { + t.Fatalf("%s: Expected to see a call to the proxy with credentials %q, got %q", k, testCase.proxyAuth, proxyCalledWithAuthHeader) + } + } + }) + } +} + +func TestRoundTripRedirects(t *testing.T) { + tests := []struct { + redirects int32 + expectSuccess bool + }{ + {0, true}, + {1, true}, + {10, true}, + {11, false}, + } + for _, test := range tests { + t.Run(fmt.Sprintf("with %d redirects", test.redirects), func(t *testing.T) { + var redirects int32 = 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if redirects < test.redirects { + redirects = atomic.AddInt32(&redirects, 1) + http.Redirect(w, req, "redirect", http.StatusFound) + return + } + streamCh := make(chan httpstream.Stream) + + responseUpgrader := NewResponseUpgrader() + spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream, replySent <-chan struct{}) error { + streamCh <- s + return nil + }) + if spdyConn == nil { + t.Fatalf("unexpected nil spdyConn") + } + defer spdyConn.Close() + + stream := <-streamCh + io.Copy(stream, stream) + })) + defer server.Close() + + req, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatalf("Error creating request: %s", err) + } + + spdyTransport := NewSpdyRoundTripper(nil, true) + client := &http.Client{Transport: spdyTransport} + + resp, err := client.Do(req) + if test.expectSuccess { + if err != nil { + t.Fatalf("error calling Do: %v", err) + } + } else { + if err == nil { + t.Fatalf("expecting an error") + } else if !strings.Contains(err.Error(), "too many redirects") { + t.Fatalf("expecting too many redirects, got %v", err) + } return } - streamCh := make(chan httpstream.Stream) - - responseUpgrader := NewResponseUpgrader() - spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream, replySent <-chan struct{}) error { - streamCh <- s - return nil - }) - if spdyConn == nil { - t.Fatalf("%s: unexpected nil spdyConn", k) + conn, err := spdyTransport.NewConnection(resp) + if err != nil { + t.Fatalf("error calling NewConnection: %v", err) } - defer spdyConn.Close() + defer conn.Close() - stream := <-streamCh - io.Copy(stream, stream) - })) - defer server.Close() - - serverURL, err := url.Parse(server.URL) - if err != nil { - t.Fatalf("%s: Error creating request: %s", k, err) - } - req, err := http.NewRequest("GET", server.URL, nil) - if err != nil { - t.Fatalf("%s: Error creating request: %s", k, err) - } - - spdyTransport := NewSpdyRoundTripper(testCase.clientTLS) - - var proxierCalled bool - var proxyCalledWithHost string - var proxyCalledWithAuth bool - var proxyCalledWithAuthHeader string - if testCase.proxyServerFunc != nil { - proxyHandler := goproxy.NewProxyHttpServer() - - proxyHandler.OnRequest().HandleConnectFunc(func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) { - proxyCalledWithHost = host - - proxyAuthHeaderName := "Proxy-Authorization" - _, proxyCalledWithAuth = ctx.Req.Header[proxyAuthHeaderName] - proxyCalledWithAuthHeader = ctx.Req.Header.Get(proxyAuthHeaderName) - return goproxy.OkConnect, host - }) - - proxy := testCase.proxyServerFunc(proxyHandler) - - spdyTransport.proxier = func(proxierReq *http.Request) (*url.URL, error) { - proxierCalled = true - proxyURL, err := url.Parse(proxy.URL) - if err != nil { - return nil, err - } - proxyURL.User = testCase.proxyAuth - return proxyURL, nil + if resp.StatusCode != http.StatusSwitchingProtocols { + t.Fatalf("expected http 101 switching protocols, got %d", resp.StatusCode) } - defer proxy.Close() - } - client := &http.Client{Transport: spdyTransport} - - resp, err := client.Do(req) - var conn httpstream.Connection - if err == nil { - conn, err = spdyTransport.NewConnection(resp) - } - haveErr := err != nil - if e, a := testCase.shouldError, haveErr; e != a { - t.Fatalf("%s: shouldError=%t, got %t: %v", k, e, a, err) - } - if testCase.shouldError { - continue - } - defer conn.Close() - - if resp.StatusCode != http.StatusSwitchingProtocols { - t.Fatalf("%s: expected http 101 switching protocols, got %d", k, resp.StatusCode) - } - - stream, err := conn.CreateStream(http.Header{}) - if err != nil { - t.Fatalf("%s: error creating client stream: %s", k, err) - } - - n, err := stream.Write([]byte("hello")) - if err != nil { - t.Fatalf("%s: error writing to stream: %s", k, err) - } - if n != 5 { - t.Fatalf("%s: Expected to write 5 bytes, but actually wrote %d", k, n) - } - - b := make([]byte, 5) - n, err = stream.Read(b) - if err != nil { - t.Fatalf("%s: error reading from stream: %s", k, err) - } - if n != 5 { - t.Fatalf("%s: Expected to read 5 bytes, but actually read %d", k, n) - } - if e, a := "hello", string(b[0:n]); e != a { - t.Fatalf("%s: expected '%s', got '%s'", k, e, a) - } - - if testCase.proxyServerFunc != nil { - if !proxierCalled { - t.Fatalf("%s: Expected to use a proxy but proxier in SpdyRoundTripper wasn't called", k) + stream, err := conn.CreateStream(http.Header{}) + if err != nil { + t.Fatalf("error creating client stream: %s", err) } - if proxyCalledWithHost != serverURL.Host { - t.Fatalf("%s: Expected to see a call to the proxy for backend %q, got %q", k, serverURL.Host, proxyCalledWithHost) - } - } - var expectedProxyAuth string - if testCase.proxyAuth != nil { - encodedCredentials := base64.StdEncoding.EncodeToString([]byte(testCase.proxyAuth.String())) - expectedProxyAuth = "Basic " + encodedCredentials - } - if len(expectedProxyAuth) == 0 && proxyCalledWithAuth { - t.Fatalf("%s: Proxy authorization unexpected, got %q", k, proxyCalledWithAuthHeader) - } - if proxyCalledWithAuthHeader != expectedProxyAuth { - t.Fatalf("%s: Expected to see a call to the proxy with credentials %q, got %q", k, testCase.proxyAuth, proxyCalledWithAuthHeader) - } + n, err := stream.Write([]byte("hello")) + if err != nil { + t.Fatalf("error writing to stream: %s", err) + } + if n != 5 { + t.Fatalf("Expected to write 5 bytes, but actually wrote %d", n) + } + + b := make([]byte, 5) + n, err = stream.Read(b) + if err != nil { + t.Fatalf("error reading from stream: %s", err) + } + if n != 5 { + t.Fatalf("Expected to read 5 bytes, but actually read %d", n) + } + if e, a := "hello", string(b[0:n]); e != a { + t.Fatalf("expected '%s', got '%s'", e, a) + } + }) } } diff --git a/staging/src/k8s.io/apimachinery/pkg/util/net/http.go b/staging/src/k8s.io/apimachinery/pkg/util/net/http.go index c32082e9315..52e22ca721b 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/net/http.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/net/http.go @@ -17,6 +17,8 @@ limitations under the License. package net import ( + "bufio" + "bytes" "crypto/tls" "fmt" "io" @@ -95,7 +97,7 @@ type RoundTripperWrapper interface { type DialFunc func(net, addr string) (net.Conn, error) -func Dialer(transport http.RoundTripper) (DialFunc, error) { +func DialerFor(transport http.RoundTripper) (DialFunc, error) { if transport == nil { return nil, nil } @@ -104,7 +106,7 @@ func Dialer(transport http.RoundTripper) (DialFunc, error) { case *http.Transport: return transport.Dial, nil case RoundTripperWrapper: - return Dialer(transport.WrappedRoundTripper()) + return DialerFor(transport.WrappedRoundTripper()) default: return nil, fmt.Errorf("unknown transport type: %v", transport) } @@ -267,3 +269,126 @@ func NewProxierWithNoProxyCIDR(delegate func(req *http.Request) (*url.URL, error return delegate(req) } } + +// Dialer dials a host and writes a request to it. +type Dialer interface { + // Dial connects to the host specified by req's URL, writes the request to the connection, and + // returns the opened net.Conn. + Dial(req *http.Request) (net.Conn, error) +} + +// ConnectWithRedirects uses dialer to send req, following up to 10 redirects (relative to +// originalLocation). It returns the opened net.Conn and the raw response bytes. +func ConnectWithRedirects(originalMethod string, originalLocation *url.URL, header http.Header, originalBody io.Reader, dialer Dialer) (net.Conn, []byte, error) { + const ( + maxRedirects = 10 + maxResponseSize = 16384 // play it safe to allow the potential for lots of / large headers + ) + + var ( + location = originalLocation + method = originalMethod + intermediateConn net.Conn + rawResponse = bytes.NewBuffer(make([]byte, 0, 256)) + body = originalBody + ) + + defer func() { + if intermediateConn != nil { + intermediateConn.Close() + } + }() + +redirectLoop: + for redirects := 0; ; redirects++ { + if redirects > maxRedirects { + return nil, nil, fmt.Errorf("too many redirects (%d)", redirects) + } + + req, err := http.NewRequest(method, location.String(), body) + if err != nil { + return nil, nil, err + } + + req.Header = header + + intermediateConn, err = dialer.Dial(req) + if err != nil { + return nil, nil, err + } + + // Peek at the backend response. + rawResponse.Reset() + respReader := bufio.NewReader(io.TeeReader( + io.LimitReader(intermediateConn, maxResponseSize), // Don't read more than maxResponseSize bytes. + rawResponse)) // Save the raw response. + resp, err := http.ReadResponse(respReader, nil) + if err != nil { + // Unable to read the backend response; let the client handle it. + glog.Warningf("Error reading backend response: %v", err) + break redirectLoop + } + + switch resp.StatusCode { + case http.StatusFound: + // Redirect, continue. + default: + // Don't redirect. + break redirectLoop + } + + // Redirected requests switch to "GET" according to the HTTP spec: + // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3 + method = "GET" + // don't send a body when following redirects + body = nil + + resp.Body.Close() // not used + + // Reset the connection. + intermediateConn.Close() + intermediateConn = nil + + // Prepare to follow the redirect. + redirectStr := resp.Header.Get("Location") + if redirectStr == "" { + return nil, nil, fmt.Errorf("%d response missing Location header", resp.StatusCode) + } + // We have to parse relative to the current location, NOT originalLocation. For example, + // if we request http://foo.com/a and get back "http://bar.com/b", the result should be + // http://bar.com/b. If we then make that request and get back a redirect to "/c", the result + // should be http://bar.com/c, not http://foo.com/c. + location, err = location.Parse(redirectStr) + if err != nil { + return nil, nil, fmt.Errorf("malformed Location header: %v", err) + } + } + + connToReturn := intermediateConn + intermediateConn = nil // Don't close the connection when we return it. + return connToReturn, rawResponse.Bytes(), nil +} + +// CloneRequest creates a shallow copy of the request along with a deep copy of the Headers. +func CloneRequest(req *http.Request) *http.Request { + r := new(http.Request) + + // shallow clone + *r = *req + + // deep copy headers + r.Header = CloneHeader(req.Header) + + return r +} + +// CloneHeader creates a deep copy of an http.Header. +func CloneHeader(in http.Header) http.Header { + out := make(http.Header, len(in)) + for key, values := range in { + newValues := make([]string, len(values)) + copy(newValues, values) + out[key] = newValues + } + return out +} diff --git a/staging/src/k8s.io/apiserver/pkg/registry/generic/rest/proxy.go b/staging/src/k8s.io/apiserver/pkg/registry/generic/rest/proxy.go index b38e1468efe..8286597c62a 100644 --- a/staging/src/k8s.io/apiserver/pkg/registry/generic/rest/proxy.go +++ b/staging/src/k8s.io/apiserver/pkg/registry/generic/rest/proxy.go @@ -17,8 +17,6 @@ limitations under the License. package rest import ( - "bufio" - "bytes" "fmt" "io" "net" @@ -146,10 +144,13 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R rawResponse []byte err error ) + if h.InterceptRedirects && utilfeature.DefaultFeatureGate.Enabled(genericfeatures.StreamingProxyRedirects) { - backendConn, rawResponse, err = h.connectBackendWithRedirects(req) + backendConn, rawResponse, err = utilnet.ConnectWithRedirects(req.Method, h.Location, req.Header, req.Body, h) } else { - backendConn, err = h.connectBackend(req.Method, h.Location, req.Header, req.Body) + clone := utilnet.CloneRequest(req) + clone.URL = h.Location + backendConn, err = h.Dial(clone) } if err != nil { h.Responder.Error(err) @@ -214,112 +215,22 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R return true } -// connectBackend dials the backend at location and forwards a copy of the client request. -func (h *UpgradeAwareProxyHandler) connectBackend(method string, location *url.URL, header http.Header, body io.Reader) (conn net.Conn, err error) { - defer func() { - if err != nil && conn != nil { - conn.Close() - conn = nil - } - }() - - beReq, err := http.NewRequest(method, location.String(), body) +// Dial dials the backend at req.URL and writes req to it. +func (h *UpgradeAwareProxyHandler) Dial(req *http.Request) (net.Conn, error) { + conn, err := proxy.DialURL(req.URL, h.Transport) if err != nil { - return nil, err - } - beReq.Header = header - - conn, err = proxy.DialURL(location, h.Transport) - if err != nil { - return conn, fmt.Errorf("error dialing backend: %v", err) + return nil, fmt.Errorf("error dialing backend: %v", err) } - if err = beReq.Write(conn); err != nil { - return conn, fmt.Errorf("error sending request: %v", err) + if err = req.Write(conn); err != nil { + conn.Close() + return nil, fmt.Errorf("error sending request: %v", err) } return conn, err } -// connectBackendWithRedirects dials the backend and forwards a copy of the client request. If the -// client responds with a redirect, it is followed. The raw response bytes are returned, and should -// be forwarded back to the client. -func (h *UpgradeAwareProxyHandler) connectBackendWithRedirects(req *http.Request) (net.Conn, []byte, error) { - const ( - maxRedirects = 10 - maxResponseSize = 4096 - ) - var ( - initialReq = req - rawResponse = bytes.NewBuffer(make([]byte, 0, 256)) - location = h.Location - intermediateConn net.Conn - err error - ) - defer func() { - if intermediateConn != nil { - intermediateConn.Close() - } - }() - -redirectLoop: - for redirects := 0; ; redirects++ { - if redirects > maxRedirects { - return nil, nil, fmt.Errorf("too many redirects (%d)", redirects) - } - - if redirects == 0 { - intermediateConn, err = h.connectBackend(req.Method, location, req.Header, req.Body) - } else { - // Redirected requests switch to "GET" according to the HTTP spec: - // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3 - intermediateConn, err = h.connectBackend("GET", location, initialReq.Header, nil) - } - - if err != nil { - return nil, nil, err - } - - // Peek at the backend response. - rawResponse.Reset() - respReader := bufio.NewReader(io.TeeReader( - io.LimitReader(intermediateConn, maxResponseSize), // Don't read more than maxResponseSize bytes. - rawResponse)) // Save the raw response. - resp, err := http.ReadResponse(respReader, req) - if err != nil { - // Unable to read the backend response; let the client handle it. - glog.Warningf("Error reading backend response: %v", err) - break redirectLoop - } - resp.Body.Close() // Unused. - - switch resp.StatusCode { - case http.StatusFound: - // Redirect, continue. - default: - // Don't redirect. - break redirectLoop - } - - // Reset the connection. - intermediateConn.Close() - intermediateConn = nil - - // Prepare to follow the redirect. - redirectStr := resp.Header.Get("Location") - if redirectStr == "" { - return nil, nil, fmt.Errorf("%d response missing Location header", resp.StatusCode) - } - location, err = h.Location.Parse(redirectStr) - if err != nil { - return nil, nil, fmt.Errorf("malformed Location header: %v", err) - } - } - - backendConn := intermediateConn - intermediateConn = nil // Don't close the connection when we return it. - return backendConn, rawResponse.Bytes(), nil -} +var _ utilnet.Dialer = &UpgradeAwareProxyHandler{} func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL, internalTransport http.RoundTripper) http.RoundTripper { scheme := url.Scheme diff --git a/staging/src/k8s.io/apiserver/pkg/registry/generic/rest/proxy_test.go b/staging/src/k8s.io/apiserver/pkg/registry/generic/rest/proxy_test.go index a43279fb951..96ebed4d0b7 100644 --- a/staging/src/k8s.io/apiserver/pkg/registry/generic/rest/proxy_test.go +++ b/staging/src/k8s.io/apiserver/pkg/registry/generic/rest/proxy_test.go @@ -432,6 +432,7 @@ func TestProxyUpgrade(t *testing.T) { Location: serverURL, Transport: tc.ProxyTransport, InterceptRedirects: redirect, + Responder: &noErrorsAllowed{t: t}, } proxy := httptest.NewServer(proxyHandler) defer proxy.Close() @@ -459,6 +460,14 @@ func TestProxyUpgrade(t *testing.T) { } } +type noErrorsAllowed struct { + t *testing.T +} + +func (r *noErrorsAllowed) Error(err error) { + r.t.Error(err) +} + func TestProxyUpgradeErrorResponse(t *testing.T) { var ( responder *fakeResponder diff --git a/staging/src/k8s.io/apiserver/pkg/util/proxy/dial.go b/staging/src/k8s.io/apiserver/pkg/util/proxy/dial.go index 55ca0e32d53..3cb890dd03a 100644 --- a/staging/src/k8s.io/apiserver/pkg/util/proxy/dial.go +++ b/staging/src/k8s.io/apiserver/pkg/util/proxy/dial.go @@ -32,7 +32,7 @@ import ( func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) { dialAddr := netutil.CanonicalAddr(url) - dialer, _ := utilnet.Dialer(transport) + dialer, _ := utilnet.DialerFor(transport) switch url.Scheme { case "http": diff --git a/staging/src/k8s.io/apiserver/pkg/util/proxy/dial_test.go b/staging/src/k8s.io/apiserver/pkg/util/proxy/dial_test.go index ee143b1e24e..f268ecd80ec 100644 --- a/staging/src/k8s.io/apiserver/pkg/util/proxy/dial_test.go +++ b/staging/src/k8s.io/apiserver/pkg/util/proxy/dial_test.go @@ -102,7 +102,7 @@ func TestDialURL(t *testing.T) { TLSClientConfig: tlsConfigCopy, } - extractedDial, err := utilnet.Dialer(transport) + extractedDial, err := utilnet.DialerFor(transport) if err != nil { t.Fatal(err) } diff --git a/staging/src/k8s.io/client-go/transport/round_trippers.go b/staging/src/k8s.io/client-go/transport/round_trippers.go index a6f396fbb0a..c728b18775f 100644 --- a/staging/src/k8s.io/client-go/transport/round_trippers.go +++ b/staging/src/k8s.io/client-go/transport/round_trippers.go @@ -23,6 +23,8 @@ import ( "time" "github.com/golang/glog" + + utilnet "k8s.io/apimachinery/pkg/util/net" ) // HTTPWrappersForConfig wraps a round tripper with any relevant layered @@ -105,7 +107,7 @@ func NewAuthProxyRoundTripper(username string, groups []string, extra map[string } func (rt *authProxyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - req = cloneRequest(req) + req = utilnet.CloneRequest(req) SetAuthProxyHeaders(req, rt.username, rt.groups, rt.extra) return rt.rt.RoundTrip(req) @@ -155,7 +157,7 @@ func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, e if len(req.Header.Get("User-Agent")) != 0 { return rt.rt.RoundTrip(req) } - req = cloneRequest(req) + req = utilnet.CloneRequest(req) req.Header.Set("User-Agent", rt.agent) return rt.rt.RoundTrip(req) } @@ -186,7 +188,7 @@ func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, e if len(req.Header.Get("Authorization")) != 0 { return rt.rt.RoundTrip(req) } - req = cloneRequest(req) + req = utilnet.CloneRequest(req) req.SetBasicAuth(rt.username, rt.password) return rt.rt.RoundTrip(req) } @@ -236,7 +238,7 @@ func (rt *impersonatingRoundTripper) RoundTrip(req *http.Request) (*http.Respons if len(req.Header.Get(ImpersonateUserHeader)) != 0 { return rt.delegate.RoundTrip(req) } - req = cloneRequest(req) + req = utilnet.CloneRequest(req) req.Header.Set(ImpersonateUserHeader, rt.impersonate.UserName) for _, group := range rt.impersonate.Groups { @@ -277,7 +279,7 @@ func (rt *bearerAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, return rt.rt.RoundTrip(req) } - req = cloneRequest(req) + req = utilnet.CloneRequest(req) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", rt.bearer)) return rt.rt.RoundTrip(req) } @@ -292,20 +294,6 @@ func (rt *bearerAuthRoundTripper) CancelRequest(req *http.Request) { func (rt *bearerAuthRoundTripper) WrappedRoundTripper() http.RoundTripper { return rt.rt } -// cloneRequest returns a clone of the provided *http.Request. -// The clone is a shallow copy of the struct and its Header map. -func cloneRequest(r *http.Request) *http.Request { - // shallow copy of the struct - r2 := new(http.Request) - *r2 = *r - // deep copy of the Header - r2.Header = make(http.Header) - for k, s := range r.Header { - r2.Header[k] = s - } - return r2 -} - // requestInfo keeps track of information about a request/response combination type requestInfo struct { RequestHeaders http.Header diff --git a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy.go b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy.go index b925a4ef0a9..904f17147cd 100644 --- a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy.go +++ b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy.go @@ -25,7 +25,9 @@ import ( "k8s.io/apimachinery/pkg/util/httpstream/spdy" "k8s.io/apiserver/pkg/endpoints/handlers/responsewriters" genericapirequest "k8s.io/apiserver/pkg/endpoints/request" + genericfeatures "k8s.io/apiserver/pkg/features" genericrest "k8s.io/apiserver/pkg/registry/generic/rest" + utilfeature "k8s.io/apiserver/pkg/util/feature" restclient "k8s.io/client-go/rest" "k8s.io/client-go/transport" @@ -147,7 +149,8 @@ func maybeWrapForConnectionUpgrades(restConfig *restclient.Config, rt http.Round if err != nil { return nil, true, err } - upgradeRoundTripper := spdy.NewRoundTripper(tlsConfig) + followRedirects := utilfeature.DefaultFeatureGate.Enabled(genericfeatures.StreamingProxyRedirects) + upgradeRoundTripper := spdy.NewRoundTripper(tlsConfig, followRedirects) wrappedRT, err := restclient.HTTPWrappersForConfig(restConfig, upgradeRoundTripper) if err != nil { return nil, true, err