Address code review comments
This commit is contained in:
		@@ -29,6 +29,7 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	"github.com/golang/glog"
 | 
						"github.com/golang/glog"
 | 
				
			||||||
	"k8s.io/kubernetes/pkg/api"
 | 
						"k8s.io/kubernetes/pkg/api"
 | 
				
			||||||
 | 
						"k8s.io/kubernetes/pkg/kubelet"
 | 
				
			||||||
	"k8s.io/kubernetes/pkg/util"
 | 
						"k8s.io/kubernetes/pkg/util"
 | 
				
			||||||
	"k8s.io/kubernetes/pkg/util/httpstream"
 | 
						"k8s.io/kubernetes/pkg/util/httpstream"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -122,16 +123,13 @@ func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}) (*P
 | 
				
			|||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// The SPDY subprotocol "portforward.k8s.io" is used for port forwarding.
 | 
					 | 
				
			||||||
const PortForwardProtocolV1Name = "portforward.k8s.io"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// ForwardPorts formats and executes a port forwarding request. The connection will remain
 | 
					// ForwardPorts formats and executes a port forwarding request. The connection will remain
 | 
				
			||||||
// open until stopChan is closed.
 | 
					// open until stopChan is closed.
 | 
				
			||||||
func (pf *PortForwarder) ForwardPorts() error {
 | 
					func (pf *PortForwarder) ForwardPorts() error {
 | 
				
			||||||
	defer pf.Close()
 | 
						defer pf.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var err error
 | 
						var err error
 | 
				
			||||||
	pf.streamConn, _, err = pf.dialer.Dial([]string{PortForwardProtocolV1Name})
 | 
						pf.streamConn, _, err = pf.dialer.Dial(kubelet.PortForwardProtocolV1Name)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return fmt.Errorf("error upgrading connection: %s", err)
 | 
							return fmt.Errorf("error upgrading connection: %s", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -44,7 +44,7 @@ type fakeDialer struct {
 | 
				
			|||||||
	negotiatedProtocol string
 | 
						negotiatedProtocol string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (d *fakeDialer) Dial(protocols []string) (httpstream.Connection, string, error) {
 | 
					func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
 | 
				
			||||||
	d.dialed = true
 | 
						d.dialed = true
 | 
				
			||||||
	return d.conn, d.negotiatedProtocol, d.err
 | 
						return d.conn, d.negotiatedProtocol, d.err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -98,7 +98,7 @@ func NewStreamExecutor(upgrader httpstream.UpgradeRoundTripper, fn func(http.Rou
 | 
				
			|||||||
// Dial opens a connection to a remote server and attempts to negotiate a SPDY
 | 
					// Dial opens a connection to a remote server and attempts to negotiate a SPDY
 | 
				
			||||||
// connection. Upon success, it returns the connection and the protocol
 | 
					// connection. Upon success, it returns the connection and the protocol
 | 
				
			||||||
// selected by the server.
 | 
					// selected by the server.
 | 
				
			||||||
func (e *streamExecutor) Dial(protocols []string) (httpstream.Connection, string, error) {
 | 
					func (e *streamExecutor) Dial(protocols ...string) (httpstream.Connection, string, error) {
 | 
				
			||||||
	transport := e.transport
 | 
						transport := e.transport
 | 
				
			||||||
	// TODO consider removing this and reusing client.TransportFor above to get this for free
 | 
						// TODO consider removing this and reusing client.TransportFor above to get this for free
 | 
				
			||||||
	switch {
 | 
						switch {
 | 
				
			||||||
@@ -111,6 +111,9 @@ func (e *streamExecutor) Dial(protocols []string) (httpstream.Connection, string
 | 
				
			|||||||
	case bool(glog.V(6)):
 | 
						case bool(glog.V(6)):
 | 
				
			||||||
		transport = client.NewDebuggingRoundTripper(transport, client.URLTiming)
 | 
							transport = client.NewDebuggingRoundTripper(transport, client.URLTiming)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO the client probably shouldn't be created here, as it doesn't allow
 | 
				
			||||||
 | 
						// flexibility to allow callers to configure it.
 | 
				
			||||||
	client := &http.Client{Transport: transport}
 | 
						client := &http.Client{Transport: transport}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req, err := http.NewRequest(e.method, e.url.String(), nil)
 | 
						req, err := http.NewRequest(e.method, e.url.String(), nil)
 | 
				
			||||||
@@ -158,7 +161,8 @@ type streamProtocolHandler interface {
 | 
				
			|||||||
// Stream opens a protocol streamer to the server and streams until a client closes
 | 
					// Stream opens a protocol streamer to the server and streams until a client closes
 | 
				
			||||||
// the connection or the server disconnects.
 | 
					// the connection or the server disconnects.
 | 
				
			||||||
func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error {
 | 
					func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error {
 | 
				
			||||||
	conn, protocol, err := e.Dial([]string{StreamProtocolV2Name, StreamProtocolV1Name})
 | 
						supportedProtocols := []string{StreamProtocolV2Name, StreamProtocolV1Name}
 | 
				
			||||||
 | 
						conn, protocol, err := e.Dial(supportedProtocols...)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -175,8 +179,7 @@ func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty b
 | 
				
			|||||||
			tty:    tty,
 | 
								tty:    tty,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	case "":
 | 
						case "":
 | 
				
			||||||
		glog.Warning("The server did not negotiate a streaming protocol version. Falling back to unversioned")
 | 
							glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to unversioned")
 | 
				
			||||||
		// TODO restore v1
 | 
					 | 
				
			||||||
		streamer = &streamProtocolV1{
 | 
							streamer = &streamProtocolV1{
 | 
				
			||||||
			stdin:  stdin,
 | 
								stdin:  stdin,
 | 
				
			||||||
			stdout: stdout,
 | 
								stdout: stdout,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -42,10 +42,17 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
						return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
 | 
							protocol, err := httpstream.Handshake(req, w, []string{StreamProtocolV2Name}, StreamProtocolV1Name)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if protocol != StreamProtocolV2Name {
 | 
				
			||||||
 | 
								t.Fatalf("unexpected protocol: %s", protocol)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		streamCh := make(chan httpstream.Stream)
 | 
							streamCh := make(chan httpstream.Stream)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		upgrader := spdy.NewResponseUpgrader()
 | 
							upgrader := spdy.NewResponseUpgrader()
 | 
				
			||||||
		conn, protocol := upgrader.UpgradeResponse(w, req, []string{StreamProtocolV2Name, StreamProtocolV1Name}, func(stream httpstream.Stream) error {
 | 
							conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error {
 | 
				
			||||||
			streamCh <- stream
 | 
								streamCh <- stream
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
@@ -57,7 +64,6 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro
 | 
				
			|||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		defer conn.Close()
 | 
							defer conn.Close()
 | 
				
			||||||
		_ = protocol
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
 | 
							var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
 | 
				
			||||||
		receivedStreams := 0
 | 
							receivedStreams := 0
 | 
				
			||||||
@@ -185,6 +191,7 @@ func TestRequestExecuteRemoteCommand(t *testing.T) {
 | 
				
			|||||||
		url, _ := url.ParseRequestURI(server.URL)
 | 
							url, _ := url.ParseRequestURI(server.URL)
 | 
				
			||||||
		c := client.NewRESTClient(url, "x", nil, -1, -1)
 | 
							c := client.NewRESTClient(url, "x", nil, -1, -1)
 | 
				
			||||||
		req := c.Post().Resource("testing")
 | 
							req := c.Post().Resource("testing")
 | 
				
			||||||
 | 
							req.SetHeader(httpstream.HeaderProtocolVersion, StreamProtocolV2Name)
 | 
				
			||||||
		req.Param("command", "ls")
 | 
							req.Param("command", "ls")
 | 
				
			||||||
		req.Param("command", "/")
 | 
							req.Param("command", "/")
 | 
				
			||||||
		conf := &client.Config{
 | 
							conf := &client.Config{
 | 
				
			||||||
@@ -364,7 +371,7 @@ func TestDial(t *testing.T) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	conn, protocol, err := exec.Dial([]string{"a", "b"})
 | 
						conn, protocol, err := exec.Dial("protocol1")
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -27,6 +27,10 @@ import (
 | 
				
			|||||||
	"k8s.io/kubernetes/pkg/util/httpstream"
 | 
						"k8s.io/kubernetes/pkg/util/httpstream"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// streamProtocolV1 implements the first version of the streaming exec & attach
 | 
				
			||||||
 | 
					// protocol. This version has some bugs, such as not being able to detecte when
 | 
				
			||||||
 | 
					// non-interactive stdin data has ended. See http://issues.k8s.io/13394 and
 | 
				
			||||||
 | 
					// http://issues.k8s.io/13395 for more details.
 | 
				
			||||||
type streamProtocolV1 struct {
 | 
					type streamProtocolV1 struct {
 | 
				
			||||||
	stdin  io.Reader
 | 
						stdin  io.Reader
 | 
				
			||||||
	stdout io.Writer
 | 
						stdout io.Writer
 | 
				
			||||||
@@ -41,8 +45,8 @@ func (e *streamProtocolV1) stream(conn httpstream.Connection) error {
 | 
				
			|||||||
	errorChan := make(chan error)
 | 
						errorChan := make(chan error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	cp := func(s string, dst io.Writer, src io.Reader) {
 | 
						cp := func(s string, dst io.Writer, src io.Reader) {
 | 
				
			||||||
		glog.V(4).Infof("Copying %s", s)
 | 
							glog.V(6).Infof("Copying %s", s)
 | 
				
			||||||
		defer glog.V(4).Infof("Done copying %s", s)
 | 
							defer glog.V(6).Infof("Done copying %s", s)
 | 
				
			||||||
		if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
 | 
							if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
 | 
				
			||||||
			glog.Errorf("Error copying %s: %v", s, err)
 | 
								glog.Errorf("Error copying %s: %v", s, err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -44,7 +44,6 @@ import (
 | 
				
			|||||||
	"k8s.io/kubernetes/pkg/api/validation"
 | 
						"k8s.io/kubernetes/pkg/api/validation"
 | 
				
			||||||
	"k8s.io/kubernetes/pkg/auth/authenticator"
 | 
						"k8s.io/kubernetes/pkg/auth/authenticator"
 | 
				
			||||||
	"k8s.io/kubernetes/pkg/auth/authorizer"
 | 
						"k8s.io/kubernetes/pkg/auth/authorizer"
 | 
				
			||||||
	"k8s.io/kubernetes/pkg/client/unversioned/portforward"
 | 
					 | 
				
			||||||
	"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
 | 
						"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
 | 
				
			||||||
	"k8s.io/kubernetes/pkg/healthz"
 | 
						"k8s.io/kubernetes/pkg/healthz"
 | 
				
			||||||
	"k8s.io/kubernetes/pkg/httplog"
 | 
						"k8s.io/kubernetes/pkg/httplog"
 | 
				
			||||||
@@ -687,10 +686,17 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo
 | 
				
			|||||||
		return streams[0], streams[1], streams[2], streams[3], conn, tty, true
 | 
							return streams[0], streams[1], streams[2], streams[3], conn, tty, true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						supportedStreamProtocols := []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}
 | 
				
			||||||
 | 
						_, err := httpstream.Handshake(request.Request, response.ResponseWriter, supportedStreamProtocols, remotecommand.StreamProtocolV1Name)
 | 
				
			||||||
 | 
						// negotiated protocol isn't used server side at the moment, but could be in the future
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, nil, nil, nil, nil, false, false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	streamCh := make(chan httpstream.Stream)
 | 
						streamCh := make(chan httpstream.Stream)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	upgrader := spdy.NewResponseUpgrader()
 | 
						upgrader := spdy.NewResponseUpgrader()
 | 
				
			||||||
	conn, protocol := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}, func(stream httpstream.Stream) error {
 | 
						conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error {
 | 
				
			||||||
		streamCh <- stream
 | 
							streamCh <- stream
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
@@ -701,9 +707,6 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo
 | 
				
			|||||||
		// if we weren't successful in upgrading.
 | 
							// if we weren't successful in upgrading.
 | 
				
			||||||
		return nil, nil, nil, nil, nil, false, false
 | 
							return nil, nil, nil, nil, nil, false, false
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if len(protocol) == 0 {
 | 
					 | 
				
			||||||
		protocol = remotecommand.StreamProtocolV1Name
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
 | 
						conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -778,24 +781,34 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
 | 
				
			|||||||
	ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), defaultStreamCreationTimeout)
 | 
						ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), defaultStreamCreationTimeout)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The subprotocol "portforward.k8s.io" is used for port forwarding.
 | 
				
			||||||
 | 
					const PortForwardProtocolV1Name = "portforward.k8s.io"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ServePortForward handles a port forwarding request.  A single request is
 | 
					// ServePortForward handles a port forwarding request.  A single request is
 | 
				
			||||||
// kept alive as long as the client is still alive and the connection has not
 | 
					// kept alive as long as the client is still alive and the connection has not
 | 
				
			||||||
// been timed out due to idleness. This function handles multiple forwarded
 | 
					// been timed out due to idleness. This function handles multiple forwarded
 | 
				
			||||||
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
 | 
					// connections; i.e., multiple `curl http://localhost:8888/` requests will be
 | 
				
			||||||
// handled by a single invocation of ServePortForward.
 | 
					// handled by a single invocation of ServePortForward.
 | 
				
			||||||
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
 | 
					func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
 | 
				
			||||||
 | 
						supportedPortForwardProtocols := []string{PortForwardProtocolV1Name}
 | 
				
			||||||
 | 
						_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols, PortForwardProtocolV1Name)
 | 
				
			||||||
 | 
						// negotiated protocol isn't currently used server side, but could be in the future
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							// Handshake writes the error to the client
 | 
				
			||||||
 | 
							util.HandleError(err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	streamChan := make(chan httpstream.Stream, 1)
 | 
						streamChan := make(chan httpstream.Stream, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	glog.V(5).Infof("Upgrading port forward response")
 | 
						glog.V(5).Infof("Upgrading port forward response")
 | 
				
			||||||
	upgrader := spdy.NewResponseUpgrader()
 | 
						upgrader := spdy.NewResponseUpgrader()
 | 
				
			||||||
	conn, protocol := upgrader.UpgradeResponse(w, req, []string{portforward.PortForwardProtocolV1Name}, portForwardStreamReceived(streamChan))
 | 
						conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan))
 | 
				
			||||||
	if conn == nil {
 | 
						if conn == nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	defer conn.Close()
 | 
						defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_ = protocol
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
 | 
						glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
 | 
				
			||||||
	conn.SetIdleTimeout(idleTimeout)
 | 
						conn.SetIdleTimeout(idleTimeout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -17,6 +17,7 @@ limitations under the License.
 | 
				
			|||||||
package httpstream
 | 
					package httpstream
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
@@ -24,9 +25,10 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	HeaderConnection      = "Connection"
 | 
						HeaderConnection               = "Connection"
 | 
				
			||||||
	HeaderUpgrade         = "Upgrade"
 | 
						HeaderUpgrade                  = "Upgrade"
 | 
				
			||||||
	HeaderProtocolVersion = "X-Stream-Protocol-Version"
 | 
						HeaderProtocolVersion          = "X-Stream-Protocol-Version"
 | 
				
			||||||
 | 
						HeaderAcceptedProtocolVersions = "X-Accepted-Stream-Protocol-Versions"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewStreamHandler defines a function that is called when a new Stream is
 | 
					// NewStreamHandler defines a function that is called when a new Stream is
 | 
				
			||||||
@@ -43,7 +45,7 @@ type Dialer interface {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Dial opens a streaming connection to a server using one of the protocols
 | 
						// Dial opens a streaming connection to a server using one of the protocols
 | 
				
			||||||
	// specified (in order of most preferred to least preferred).
 | 
						// specified (in order of most preferred to least preferred).
 | 
				
			||||||
	Dial(protocols []string) (Connection, string, error)
 | 
						Dial(protocols ...string) (Connection, string, error)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade
 | 
					// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade
 | 
				
			||||||
@@ -60,9 +62,9 @@ type UpgradeRoundTripper interface {
 | 
				
			|||||||
// add streaming support to them.
 | 
					// add streaming support to them.
 | 
				
			||||||
type ResponseUpgrader interface {
 | 
					type ResponseUpgrader interface {
 | 
				
			||||||
	// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
 | 
						// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
 | 
				
			||||||
	// streams. newStreamHandler will be called synchronously whenever the
 | 
						// streams. newStreamHandler will be called asynchronously whenever the
 | 
				
			||||||
	// other end of the upgraded connection creates a new stream.
 | 
						// other end of the upgraded connection creates a new stream.
 | 
				
			||||||
	UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler NewStreamHandler) (Connection, string)
 | 
						UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Connection represents an upgraded HTTP connection.
 | 
					// Connection represents an upgraded HTTP connection.
 | 
				
			||||||
@@ -100,3 +102,44 @@ func IsUpgradeRequest(req *http.Request) bool {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func negotiateProtocol(clientProtocols, serverProtocols []string) string {
 | 
				
			||||||
 | 
						for i := range clientProtocols {
 | 
				
			||||||
 | 
							for j := range serverProtocols {
 | 
				
			||||||
 | 
								if clientProtocols[i] == serverProtocols[j] {
 | 
				
			||||||
 | 
									return clientProtocols[i]
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Handshake performs a subprotocol negotiation. If the client did not request
 | 
				
			||||||
 | 
					// a specific subprotocol, defaultProtocol is used. If the client did request a
 | 
				
			||||||
 | 
					// subprotocol, Handshake will select the first common value found in
 | 
				
			||||||
 | 
					// serverProtocols. If a match is found, Handshake adds a response header
 | 
				
			||||||
 | 
					// indicating the chosen subprotocol. If no match is found, HTTP forbidden is
 | 
				
			||||||
 | 
					// returned, along with a response header containing the list of protocols the
 | 
				
			||||||
 | 
					// server can accept.
 | 
				
			||||||
 | 
					func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string, defaultProtocol string) (string, error) {
 | 
				
			||||||
 | 
						clientProtocols := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)]
 | 
				
			||||||
 | 
						if len(clientProtocols) == 0 {
 | 
				
			||||||
 | 
							// Kube 1.0 client that didn't support subprotocol negotiation
 | 
				
			||||||
 | 
							// TODO remove this defaulting logic once Kube 1.0 is no longer supported
 | 
				
			||||||
 | 
							w.Header().Add(HeaderProtocolVersion, defaultProtocol)
 | 
				
			||||||
 | 
							return defaultProtocol, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols)
 | 
				
			||||||
 | 
						if len(negotiatedProtocol) == 0 {
 | 
				
			||||||
 | 
							w.WriteHeader(http.StatusForbidden)
 | 
				
			||||||
 | 
							for i := range serverProtocols {
 | 
				
			||||||
 | 
								w.Header().Add(HeaderAcceptedProtocolVersions, serverProtocols[i])
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: client supports %v, server accepts %v", clientProtocols, serverProtocols)
 | 
				
			||||||
 | 
							return "", fmt.Errorf("unable to upgrade: unable to negotiate protocol: client supports %v, server supports %v", clientProtocols, serverProtocols)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						w.Header().Add(HeaderProtocolVersion, negotiatedProtocol)
 | 
				
			||||||
 | 
						return negotiatedProtocol, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -120,7 +120,7 @@ func TestRoundTripAndNewConnection(t *testing.T) {
 | 
				
			|||||||
			streamCh := make(chan httpstream.Stream)
 | 
								streamCh := make(chan httpstream.Stream)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			responseUpgrader := NewResponseUpgrader()
 | 
								responseUpgrader := NewResponseUpgrader()
 | 
				
			||||||
			spdyConn, _ := responseUpgrader.UpgradeResponse(w, req, []string{"protocol1"}, func(s httpstream.Stream) error {
 | 
								spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream) error {
 | 
				
			||||||
				streamCh <- s
 | 
									streamCh <- s
 | 
				
			||||||
				return nil
 | 
									return nil
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -21,7 +21,7 @@ import (
 | 
				
			|||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/golang/glog"
 | 
						"k8s.io/kubernetes/pkg/util"
 | 
				
			||||||
	"k8s.io/kubernetes/pkg/util/httpstream"
 | 
						"k8s.io/kubernetes/pkg/util/httpstream"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -39,47 +39,23 @@ func NewResponseUpgrader() httpstream.ResponseUpgrader {
 | 
				
			|||||||
	return responseUpgrader{}
 | 
						return responseUpgrader{}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func negotiateProtocol(clientProtocols, serverProtocols []string) string {
 | 
					 | 
				
			||||||
	for i := range clientProtocols {
 | 
					 | 
				
			||||||
		for j := range serverProtocols {
 | 
					 | 
				
			||||||
			if clientProtocols[i] == serverProtocols[j] {
 | 
					 | 
				
			||||||
				return clientProtocols[i]
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return ""
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
 | 
					// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
 | 
				
			||||||
// streams. newStreamHandler will be called synchronously whenever the
 | 
					// streams. newStreamHandler will be called synchronously whenever the
 | 
				
			||||||
// other end of the upgraded connection creates a new stream.
 | 
					// other end of the upgraded connection creates a new stream.
 | 
				
			||||||
func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler httpstream.NewStreamHandler) (httpstream.Connection, string) {
 | 
					func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection {
 | 
				
			||||||
	connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
 | 
						connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
 | 
				
			||||||
	upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade))
 | 
						upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade))
 | 
				
			||||||
	if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
 | 
						if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
 | 
				
			||||||
		w.WriteHeader(http.StatusBadRequest)
 | 
							w.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
		fmt.Fprintf(w, "unable to upgrade: missing upgrade headers in request: %#v", req.Header)
 | 
							fmt.Fprintf(w, "unable to upgrade: missing upgrade headers in request: %#v", req.Header)
 | 
				
			||||||
		return nil, ""
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	hijacker, ok := w.(http.Hijacker)
 | 
						hijacker, ok := w.(http.Hijacker)
 | 
				
			||||||
	if !ok {
 | 
						if !ok {
 | 
				
			||||||
		w.WriteHeader(http.StatusInternalServerError)
 | 
							w.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
		fmt.Fprintf(w, "unable to upgrade: unable to hijack response")
 | 
							fmt.Fprintf(w, "unable to upgrade: unable to hijack response")
 | 
				
			||||||
		return nil, ""
 | 
							return nil
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var negotiatedProtocol string
 | 
					 | 
				
			||||||
	clientProtocols := req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)]
 | 
					 | 
				
			||||||
	if len(clientProtocols) > 0 {
 | 
					 | 
				
			||||||
		negotiatedProtocol = negotiateProtocol(req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)], protocols)
 | 
					 | 
				
			||||||
		if len(negotiatedProtocol) > 0 {
 | 
					 | 
				
			||||||
			w.Header().Add(httpstream.HeaderProtocolVersion, negotiatedProtocol)
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			w.WriteHeader(http.StatusForbidden)
 | 
					 | 
				
			||||||
			fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: server accepts %v", protocols)
 | 
					 | 
				
			||||||
			return nil, ""
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
 | 
						w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
 | 
				
			||||||
@@ -88,15 +64,15 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	conn, _, err := hijacker.Hijack()
 | 
						conn, _, err := hijacker.Hijack()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		glog.Errorf("unable to upgrade: error hijacking response: %v", err)
 | 
							util.HandleError(fmt.Errorf("unable to upgrade: error hijacking response: %v", err))
 | 
				
			||||||
		return nil, ""
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	spdyConn, err := NewServerConnection(conn, newStreamHandler)
 | 
						spdyConn, err := NewServerConnection(conn, newStreamHandler)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		glog.Errorf("unable to upgrade: error creating SPDY server connection: %v", err)
 | 
							util.HandleError(fmt.Errorf("unable to upgrade: error creating SPDY server connection: %v", err))
 | 
				
			||||||
		return nil, ""
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return spdyConn, negotiatedProtocol
 | 
						return spdyConn
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -53,8 +53,7 @@ func TestUpgradeResponse(t *testing.T) {
 | 
				
			|||||||
	for i, testCase := range testCases {
 | 
						for i, testCase := range testCases {
 | 
				
			||||||
		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
							server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
			upgrader := NewResponseUpgrader()
 | 
								upgrader := NewResponseUpgrader()
 | 
				
			||||||
			conn, protocol := upgrader.UpgradeResponse(w, req, []string{"protocol1"}, nil)
 | 
								conn := upgrader.UpgradeResponse(w, req, nil)
 | 
				
			||||||
			_ = protocol
 | 
					 | 
				
			||||||
			haveErr := conn == nil
 | 
								haveErr := conn == nil
 | 
				
			||||||
			if e, a := testCase.shouldError, haveErr; e != a {
 | 
								if e, a := testCase.shouldError, haveErr; e != a {
 | 
				
			||||||
				t.Fatalf("%d: expected shouldErr=%t, got %t", i, testCase.shouldError, haveErr)
 | 
									t.Fatalf("%d: expected shouldErr=%t, got %t", i, testCase.shouldError, haveErr)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user