Merge pull request #123413 from seans3/tunneling-spdy-websockets
PortForward: Tunnel SPDY through WebSockets
This commit is contained in:
		@@ -619,6 +619,13 @@ const (
 | 
				
			|||||||
	// Enable users to specify when a Pod is ready for scheduling.
 | 
						// Enable users to specify when a Pod is ready for scheduling.
 | 
				
			||||||
	PodSchedulingReadiness featuregate.Feature = "PodSchedulingReadiness"
 | 
						PodSchedulingReadiness featuregate.Feature = "PodSchedulingReadiness"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// owner: @seans3
 | 
				
			||||||
 | 
						// kep: http://kep.k8s.io/4006
 | 
				
			||||||
 | 
						// alpha: v1.30
 | 
				
			||||||
 | 
						//
 | 
				
			||||||
 | 
						// Enables PortForward to be proxied with a websocket client
 | 
				
			||||||
 | 
						PortForwardWebsockets featuregate.Feature = "PortForwardWebsockets"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// owner: @jessfraz
 | 
						// owner: @jessfraz
 | 
				
			||||||
	// alpha: v1.12
 | 
						// alpha: v1.12
 | 
				
			||||||
	//
 | 
						//
 | 
				
			||||||
@@ -1101,6 +1108,8 @@ var defaultKubernetesFeatureGates = map[featuregate.Feature]featuregate.FeatureS
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	PodSchedulingReadiness: {Default: true, PreRelease: featuregate.GA, LockToDefault: true}, // GA in 1.30; remove in 1.32
 | 
						PodSchedulingReadiness: {Default: true, PreRelease: featuregate.GA, LockToDefault: true}, // GA in 1.30; remove in 1.32
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						PortForwardWebsockets: {Default: false, PreRelease: featuregate.Alpha},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ProcMountType: {Default: false, PreRelease: featuregate.Alpha},
 | 
						ProcMountType: {Default: false, PreRelease: featuregate.Alpha},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	QOSReserved: {Default: false, PreRelease: featuregate.Alpha},
 | 
						QOSReserved: {Default: false, PreRelease: featuregate.Alpha},
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -242,7 +242,12 @@ func (r *PortForwardREST) Connect(ctx context.Context, name string, opts runtime
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil
 | 
						handler := newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder)
 | 
				
			||||||
 | 
						if utilfeature.DefaultFeatureGate.Enabled(features.PortForwardWebsockets) {
 | 
				
			||||||
 | 
							tunnelingHandler := translator.NewTunnelingHandler(handler)
 | 
				
			||||||
 | 
							handler = translator.NewTranslatingHandler(handler, tunnelingHandler, wsstream.IsWebSocketRequestWithTunnelingProtocol)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return handler, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder rest.Responder) http.Handler {
 | 
					func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder rest.Responder) http.Handler {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -27,6 +27,7 @@ import (
 | 
				
			|||||||
	"golang.org/x/net/websocket"
 | 
						"golang.org/x/net/websocket"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"k8s.io/apimachinery/pkg/util/httpstream"
 | 
						"k8s.io/apimachinery/pkg/util/httpstream"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/portforward"
 | 
				
			||||||
	"k8s.io/apimachinery/pkg/util/remotecommand"
 | 
						"k8s.io/apimachinery/pkg/util/remotecommand"
 | 
				
			||||||
	"k8s.io/apimachinery/pkg/util/runtime"
 | 
						"k8s.io/apimachinery/pkg/util/runtime"
 | 
				
			||||||
	"k8s.io/klog/v2"
 | 
						"k8s.io/klog/v2"
 | 
				
			||||||
@@ -106,6 +107,23 @@ func IsWebSocketRequestWithStreamCloseProtocol(req *http.Request) bool {
 | 
				
			|||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// IsWebSocketRequestWithTunnelingProtocol returns true if the request contains headers
 | 
				
			||||||
 | 
					// identifying that it is requesting a websocket upgrade with a tunneling protocol;
 | 
				
			||||||
 | 
					// false otherwise.
 | 
				
			||||||
 | 
					func IsWebSocketRequestWithTunnelingProtocol(req *http.Request) bool {
 | 
				
			||||||
 | 
						if !IsWebSocketRequest(req) {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						requestedProtocols := strings.TrimSpace(req.Header.Get(WebSocketProtocolHeader))
 | 
				
			||||||
 | 
						for _, requestedProtocol := range strings.Split(requestedProtocols, ",") {
 | 
				
			||||||
 | 
							if protocolSupportsWebsocketTunneling(strings.TrimSpace(requestedProtocol)) {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the
 | 
					// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the
 | 
				
			||||||
// read and write deadlines are pushed every time a new message is received.
 | 
					// read and write deadlines are pushed every time a new message is received.
 | 
				
			||||||
func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
 | 
					func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
 | 
				
			||||||
@@ -301,6 +319,12 @@ func protocolSupportsStreamClose(protocol string) bool {
 | 
				
			|||||||
	return protocol == remotecommand.StreamProtocolV5Name
 | 
						return protocol == remotecommand.StreamProtocolV5Name
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// protocolSupportsWebsocketTunneling returns true if the passed protocol
 | 
				
			||||||
 | 
					// is a tunneled Kubernetes spdy protocol; false otherwise.
 | 
				
			||||||
 | 
					func protocolSupportsWebsocketTunneling(protocol string) bool {
 | 
				
			||||||
 | 
						return strings.HasPrefix(protocol, portforward.WebsocketsSPDYTunnelingPrefix) && strings.HasSuffix(protocol, portforward.KubernetesSuffix)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// handle implements a websocket handler.
 | 
					// handle implements a websocket handler.
 | 
				
			||||||
func (conn *Conn) handle(ws *websocket.Conn) {
 | 
					func (conn *Conn) handle(ws *websocket.Conn) {
 | 
				
			||||||
	conn.initialize(ws)
 | 
						conn.initialize(ws)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,24 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Copyright 2016 The Kubernetes Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package portforward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						PortForwardV1Name                    = "portforward.k8s.io"
 | 
				
			||||||
 | 
						WebsocketsSPDYTunnelingPrefix        = "SPDY/3.1+"
 | 
				
			||||||
 | 
						KubernetesSuffix                     = ".k8s.io"
 | 
				
			||||||
 | 
						WebsocketsSPDYTunnelingPortForwardV1 = WebsocketsSPDYTunnelingPrefix + PortForwardV1Name
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
@@ -36,6 +36,7 @@ import (
 | 
				
			|||||||
	utilruntime "k8s.io/apimachinery/pkg/util/runtime"
 | 
						utilruntime "k8s.io/apimachinery/pkg/util/runtime"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/mxk/go-flowrate/flowrate"
 | 
						"github.com/mxk/go-flowrate/flowrate"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"k8s.io/klog/v2"
 | 
						"k8s.io/klog/v2"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -336,6 +337,7 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques
 | 
				
			|||||||
		clone.Host = h.Location.Host
 | 
							clone.Host = h.Location.Host
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	clone.URL = &location
 | 
						clone.URL = &location
 | 
				
			||||||
 | 
						klog.V(6).Infof("UpgradeAwareProxy: dialing for SPDY upgrade with headers: %v", clone.Header)
 | 
				
			||||||
	backendConn, err = h.DialForUpgrade(clone)
 | 
						backendConn, err = h.DialForUpgrade(clone)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		klog.V(6).Infof("Proxy connection error: %v", err)
 | 
							klog.V(6).Infof("Proxy connection error: %v", err)
 | 
				
			||||||
@@ -370,13 +372,13 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques
 | 
				
			|||||||
	// hijacking should be the last step in the upgrade.
 | 
						// hijacking should be the last step in the upgrade.
 | 
				
			||||||
	requestHijacker, ok := w.(http.Hijacker)
 | 
						requestHijacker, ok := w.(http.Hijacker)
 | 
				
			||||||
	if !ok {
 | 
						if !ok {
 | 
				
			||||||
		klog.V(6).Infof("Unable to hijack response writer: %T", w)
 | 
							klog.Errorf("Unable to hijack response writer: %T", w)
 | 
				
			||||||
		h.Responder.Error(w, req, fmt.Errorf("request connection cannot be hijacked: %T", w))
 | 
							h.Responder.Error(w, req, fmt.Errorf("request connection cannot be hijacked: %T", w))
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	requestHijackedConn, _, err := requestHijacker.Hijack()
 | 
						requestHijackedConn, _, err := requestHijacker.Hijack()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		klog.V(6).Infof("Unable to hijack response: %v", err)
 | 
							klog.Errorf("Unable to hijack response: %v", err)
 | 
				
			||||||
		h.Responder.Error(w, req, fmt.Errorf("error hijacking connection: %v", err))
 | 
							h.Responder.Error(w, req, fmt.Errorf("error hijacking connection: %v", err))
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -16,6 +16,7 @@ require (
 | 
				
			|||||||
	github.com/google/go-cmp v0.6.0
 | 
						github.com/google/go-cmp v0.6.0
 | 
				
			||||||
	github.com/google/gofuzz v1.2.0
 | 
						github.com/google/gofuzz v1.2.0
 | 
				
			||||||
	github.com/google/uuid v1.3.0
 | 
						github.com/google/uuid v1.3.0
 | 
				
			||||||
 | 
						github.com/gorilla/websocket v1.5.0
 | 
				
			||||||
	github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
 | 
						github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
 | 
				
			||||||
	github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822
 | 
						github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822
 | 
				
			||||||
	github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f
 | 
						github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f
 | 
				
			||||||
@@ -77,7 +78,6 @@ require (
 | 
				
			|||||||
	github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
 | 
						github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
 | 
				
			||||||
	github.com/golang/protobuf v1.5.3 // indirect
 | 
						github.com/golang/protobuf v1.5.3 // indirect
 | 
				
			||||||
	github.com/google/btree v1.0.1 // indirect
 | 
						github.com/google/btree v1.0.1 // indirect
 | 
				
			||||||
	github.com/gorilla/websocket v1.5.0 // indirect
 | 
					 | 
				
			||||||
	github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
 | 
						github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
 | 
				
			||||||
	github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
 | 
						github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
 | 
				
			||||||
	github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect
 | 
						github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										461
									
								
								staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										461
									
								
								staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,461 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Copyright 2024 The Kubernetes Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bufio"
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						gwebsocket "github.com/gorilla/websocket"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/httpstream"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/httpstream/spdy"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
 | 
				
			||||||
 | 
						utilnet "k8s.io/apimachinery/pkg/util/net"
 | 
				
			||||||
 | 
						constants "k8s.io/apimachinery/pkg/util/portforward"
 | 
				
			||||||
 | 
						"k8s.io/client-go/tools/portforward"
 | 
				
			||||||
 | 
						"k8s.io/klog/v2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TunnelingHandler is a handler which tunnels SPDY through WebSockets.
 | 
				
			||||||
 | 
					type TunnelingHandler struct {
 | 
				
			||||||
 | 
						// Used to communicate between upstream SPDY and downstream tunnel.
 | 
				
			||||||
 | 
						upgradeHandler http.Handler
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewTunnelingHandler is used to create the tunnel between an upstream
 | 
				
			||||||
 | 
					// SPDY connection and a downstream tunneling connection through the stored
 | 
				
			||||||
 | 
					// UpgradeAwareProxy.
 | 
				
			||||||
 | 
					func NewTunnelingHandler(upgradeHandler http.Handler) *TunnelingHandler {
 | 
				
			||||||
 | 
						return &TunnelingHandler{upgradeHandler: upgradeHandler}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ServeHTTP uses the upgradeHandler to tunnel between a downstream tunneling
 | 
				
			||||||
 | 
					// connection and an upstream SPDY connection. The tunneling connection is
 | 
				
			||||||
 | 
					// a wrapped WebSockets connection which communicates SPDY framed data. In the
 | 
				
			||||||
 | 
					// case the upstream upgrade fails, we delegate communication to the passed
 | 
				
			||||||
 | 
					// in "w" ResponseWriter.
 | 
				
			||||||
 | 
					func (h *TunnelingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
 | 
						klog.V(4).Infoln("TunnelingHandler ServeHTTP")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						spdyProtocols := spdyProtocolsFromWebsocketProtocols(req)
 | 
				
			||||||
 | 
						if len(spdyProtocols) == 0 {
 | 
				
			||||||
 | 
							http.Error(w, "unable to upgrade: no tunneling spdy protocols provided", http.StatusBadRequest)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						spdyRequest := createSPDYRequest(req, spdyProtocols...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// The fields "w" and "conn" are mutually exclusive. Either a successful upgrade occurs
 | 
				
			||||||
 | 
						// and the "conn" is hijacked and used in the subsequent upgradeHandler, or
 | 
				
			||||||
 | 
						// the upgrade failed, and "w" is the delegate used for the non-upgrade response.
 | 
				
			||||||
 | 
						writer := &tunnelingResponseWriter{
 | 
				
			||||||
 | 
							// "w" is used in the non-upgrade error cases called in the upgradeHandler.
 | 
				
			||||||
 | 
							w: w,
 | 
				
			||||||
 | 
							// "conn" is returned in the successful upgrade case when hijacked in the upgradeHandler.
 | 
				
			||||||
 | 
							conn: &headerInterceptingConn{
 | 
				
			||||||
 | 
								initializableConn: &tunnelingWebsocketUpgraderConn{
 | 
				
			||||||
 | 
									w:   w,
 | 
				
			||||||
 | 
									req: req,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						klog.V(4).Infoln("Tunnel spdy through websockets using the UpgradeAwareProxy")
 | 
				
			||||||
 | 
						h.upgradeHandler.ServeHTTP(writer, spdyRequest)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// createSPDYRequest modifies the passed request to remove
 | 
				
			||||||
 | 
					// WebSockets headers and add SPDY upgrade information, including
 | 
				
			||||||
 | 
					// spdy protocols acceptable to the client.
 | 
				
			||||||
 | 
					func createSPDYRequest(req *http.Request, spdyProtocols ...string) *http.Request {
 | 
				
			||||||
 | 
						clone := utilnet.CloneRequest(req)
 | 
				
			||||||
 | 
						// Clean up the websocket headers from the http request.
 | 
				
			||||||
 | 
						clone.Header.Del(wsstream.WebSocketProtocolHeader)
 | 
				
			||||||
 | 
						clone.Header.Del("Sec-Websocket-Key")
 | 
				
			||||||
 | 
						clone.Header.Del("Sec-Websocket-Version")
 | 
				
			||||||
 | 
						clone.Header.Del(httpstream.HeaderUpgrade)
 | 
				
			||||||
 | 
						// Update the http request for an upstream SPDY upgrade.
 | 
				
			||||||
 | 
						clone.Method = "POST"
 | 
				
			||||||
 | 
						clone.Body = nil // Remove the request body which is unused.
 | 
				
			||||||
 | 
						clone.Header.Set(httpstream.HeaderUpgrade, spdy.HeaderSpdy31)
 | 
				
			||||||
 | 
						clone.Header.Del(httpstream.HeaderProtocolVersion)
 | 
				
			||||||
 | 
						for i := range spdyProtocols {
 | 
				
			||||||
 | 
							clone.Header.Add(httpstream.HeaderProtocolVersion, spdyProtocols[i])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return clone
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// spdyProtocolsFromWebsocketProtocols returns a list of spdy protocols by filtering
 | 
				
			||||||
 | 
					// to Kubernetes websocket subprotocols prefixed with "SPDY/3.1+", then removing the prefix
 | 
				
			||||||
 | 
					func spdyProtocolsFromWebsocketProtocols(req *http.Request) []string {
 | 
				
			||||||
 | 
						var spdyProtocols []string
 | 
				
			||||||
 | 
						for _, protocol := range gwebsocket.Subprotocols(req) {
 | 
				
			||||||
 | 
							if strings.HasPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix) && strings.HasSuffix(protocol, constants.KubernetesSuffix) {
 | 
				
			||||||
 | 
								spdyProtocols = append(spdyProtocols, strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return spdyProtocols
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var _ http.ResponseWriter = &tunnelingResponseWriter{}
 | 
				
			||||||
 | 
					var _ http.Hijacker = &tunnelingResponseWriter{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// tunnelingResponseWriter implements the http.ResponseWriter and http.Hijacker interfaces.
 | 
				
			||||||
 | 
					// Only non-upgrade responses can be written using WriteHeader() and Write().
 | 
				
			||||||
 | 
					// Once Write or WriteHeader is called, Hijack returns an error.
 | 
				
			||||||
 | 
					// Once Hijack is called, Write, WriteHeader, and Hijack return errors.
 | 
				
			||||||
 | 
					type tunnelingResponseWriter struct {
 | 
				
			||||||
 | 
						// w is used to delegate Header(), WriteHeader(), and Write() calls
 | 
				
			||||||
 | 
						w http.ResponseWriter
 | 
				
			||||||
 | 
						// conn is returned from Hijack()
 | 
				
			||||||
 | 
						conn net.Conn
 | 
				
			||||||
 | 
						// mu guards writes
 | 
				
			||||||
 | 
						mu sync.Mutex
 | 
				
			||||||
 | 
						// wrote tracks whether WriteHeader or Write has been called
 | 
				
			||||||
 | 
						written bool
 | 
				
			||||||
 | 
						// hijacked tracks whether Hijack has been called
 | 
				
			||||||
 | 
						hijacked bool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Hijack returns a delegate "net.Conn".
 | 
				
			||||||
 | 
					// An error is returned if Write(), WriteHeader(), or Hijack() was previously called.
 | 
				
			||||||
 | 
					// The returned bufio.ReadWriter is always nil.
 | 
				
			||||||
 | 
					func (w *tunnelingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
 | 
				
			||||||
 | 
						w.mu.Lock()
 | 
				
			||||||
 | 
						defer w.mu.Unlock()
 | 
				
			||||||
 | 
						if w.written {
 | 
				
			||||||
 | 
							klog.Errorf("Hijack called after write")
 | 
				
			||||||
 | 
							return nil, nil, errors.New("connection has already been written to")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if w.hijacked {
 | 
				
			||||||
 | 
							klog.Errorf("Hijack called after hijack")
 | 
				
			||||||
 | 
							return nil, nil, errors.New("connection has already been hijacked")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						w.hijacked = true
 | 
				
			||||||
 | 
						klog.V(6).Infof("Hijack returning websocket tunneling net.Conn")
 | 
				
			||||||
 | 
						return w.conn, nil, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Header is delegated to the stored "http.ResponseWriter".
 | 
				
			||||||
 | 
					func (w *tunnelingResponseWriter) Header() http.Header {
 | 
				
			||||||
 | 
						return w.w.Header()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Write is delegated to the stored "http.ResponseWriter".
 | 
				
			||||||
 | 
					func (w *tunnelingResponseWriter) Write(p []byte) (int, error) {
 | 
				
			||||||
 | 
						w.mu.Lock()
 | 
				
			||||||
 | 
						defer w.mu.Unlock()
 | 
				
			||||||
 | 
						if w.hijacked {
 | 
				
			||||||
 | 
							klog.Errorf("Write called after hijack")
 | 
				
			||||||
 | 
							return 0, http.ErrHijacked
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						w.written = true
 | 
				
			||||||
 | 
						return w.w.Write(p)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WriteHeader is delegated to the stored "http.ResponseWriter".
 | 
				
			||||||
 | 
					func (w *tunnelingResponseWriter) WriteHeader(statusCode int) {
 | 
				
			||||||
 | 
						w.mu.Lock()
 | 
				
			||||||
 | 
						defer w.mu.Unlock()
 | 
				
			||||||
 | 
						if w.written {
 | 
				
			||||||
 | 
							klog.Errorf("WriteHeader called after write")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if w.hijacked {
 | 
				
			||||||
 | 
							klog.Errorf("WriteHeader called after hijack")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						w.written = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if statusCode == http.StatusSwitchingProtocols {
 | 
				
			||||||
 | 
							// 101 upgrade responses must come via the hijacked connection, not WriteHeader
 | 
				
			||||||
 | 
							klog.Errorf("WriteHeader called with 101 upgrade")
 | 
				
			||||||
 | 
							http.Error(w.w, "unexpected upgrade", http.StatusInternalServerError)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// pass through non-upgrade responses we don't need to translate
 | 
				
			||||||
 | 
						w.w.WriteHeader(statusCode)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// headerInterceptingConn wraps the tunneling "net.Conn" to drain the
 | 
				
			||||||
 | 
					// HTTP response status/headers from the upstream SPDY connection, then use
 | 
				
			||||||
 | 
					// that to decide how to initialize the delegate connection for writes.
 | 
				
			||||||
 | 
					type headerInterceptingConn struct {
 | 
				
			||||||
 | 
						// initializableConn is delegated to for all net.Conn methods.
 | 
				
			||||||
 | 
						// initializableConn.Write() is not called until response headers have been read
 | 
				
			||||||
 | 
						// and initializableConn#InitializeWrite() has been called with the result.
 | 
				
			||||||
 | 
						initializableConn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						lock          sync.Mutex
 | 
				
			||||||
 | 
						headerBuffer  []byte
 | 
				
			||||||
 | 
						initialized   bool
 | 
				
			||||||
 | 
						initializeErr error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// initializableConn is a connection that will be initialized before any calls to Write are made
 | 
				
			||||||
 | 
					type initializableConn interface {
 | 
				
			||||||
 | 
						net.Conn
 | 
				
			||||||
 | 
						// InitializeWrite is called when the backend response headers have been read.
 | 
				
			||||||
 | 
						// backendResponse contains the parsed headers.
 | 
				
			||||||
 | 
						// backendResponseBytes are the raw bytes the headers were parsed from.
 | 
				
			||||||
 | 
						InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const maxHeaderBytes = 1 << 20
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// token for normal header / body separation (\r\n\r\n, but go tolerates the leading \r being absent)
 | 
				
			||||||
 | 
					var lfCRLF = []byte("\n\r\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// token for header / body separation without \r (which go tolerates)
 | 
				
			||||||
 | 
					var lfLF = []byte("\n\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Write intercepts to initially swallow the HTTP response, then
 | 
				
			||||||
 | 
					// delegate to the tunneling "net.Conn" once the response has been
 | 
				
			||||||
 | 
					// seen and processed.
 | 
				
			||||||
 | 
					func (h *headerInterceptingConn) Write(b []byte) (int, error) {
 | 
				
			||||||
 | 
						h.lock.Lock()
 | 
				
			||||||
 | 
						defer h.lock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if h.initializeErr != nil {
 | 
				
			||||||
 | 
							return 0, h.initializeErr
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if h.initialized {
 | 
				
			||||||
 | 
							return h.initializableConn.Write(b)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Guard against excessive buffering
 | 
				
			||||||
 | 
						if len(h.headerBuffer)+len(b) > maxHeaderBytes {
 | 
				
			||||||
 | 
							return 0, fmt.Errorf("header size limit exceeded")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Accumulate into headerBuffer
 | 
				
			||||||
 | 
						h.headerBuffer = append(h.headerBuffer, b...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Attempt to parse http response headers
 | 
				
			||||||
 | 
						var headerBytes, bodyBytes []byte
 | 
				
			||||||
 | 
						if i := bytes.Index(h.headerBuffer, lfCRLF); i != -1 {
 | 
				
			||||||
 | 
							// headers terminated with \n\r\n
 | 
				
			||||||
 | 
							headerBytes = h.headerBuffer[0 : i+len(lfCRLF)]
 | 
				
			||||||
 | 
							bodyBytes = h.headerBuffer[i+len(lfCRLF):]
 | 
				
			||||||
 | 
						} else if i := bytes.Index(h.headerBuffer, lfLF); i != -1 {
 | 
				
			||||||
 | 
							// headers terminated with \n\n (which go tolerates)
 | 
				
			||||||
 | 
							headerBytes = h.headerBuffer[0 : i+len(lfLF)]
 | 
				
			||||||
 | 
							bodyBytes = h.headerBuffer[i+len(lfLF):]
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							// don't yet have a complete set of headers yet
 | 
				
			||||||
 | 
							return len(b), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(headerBytes)), nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							klog.Errorf("invalid headers: %v", err)
 | 
				
			||||||
 | 
							h.initializeErr = err
 | 
				
			||||||
 | 
							return len(b), err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resp.Body.Close() //nolint:errcheck
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						h.headerBuffer = nil
 | 
				
			||||||
 | 
						h.initialized = true
 | 
				
			||||||
 | 
						h.initializeErr = h.initializableConn.InitializeWrite(resp, headerBytes)
 | 
				
			||||||
 | 
						if h.initializeErr != nil {
 | 
				
			||||||
 | 
							return len(b), h.initializeErr
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(bodyBytes) > 0 {
 | 
				
			||||||
 | 
							_, err = h.initializableConn.Write(bodyBytes)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return len(b), err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type tunnelingWebsocketUpgraderConn struct {
 | 
				
			||||||
 | 
						// req is the websocket request, used for upgrading
 | 
				
			||||||
 | 
						req *http.Request
 | 
				
			||||||
 | 
						// w is the websocket writer, used for upgrading and writing error responses
 | 
				
			||||||
 | 
						w http.ResponseWriter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// lock guards conn and err
 | 
				
			||||||
 | 
						lock sync.RWMutex
 | 
				
			||||||
 | 
						// if conn is non-nil, InitializeWrite succeeded
 | 
				
			||||||
 | 
						conn net.Conn
 | 
				
			||||||
 | 
						// if err is non-nil, InitializeWrite failed or Close was called before InitializeWrite
 | 
				
			||||||
 | 
						err error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) (err error) {
 | 
				
			||||||
 | 
						// make sure we close a connection we open in error cases
 | 
				
			||||||
 | 
						var conn net.Conn
 | 
				
			||||||
 | 
						defer func() {
 | 
				
			||||||
 | 
							if err != nil && conn != nil {
 | 
				
			||||||
 | 
								conn.Close() //nolint:errcheck
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						u.lock.Lock()
 | 
				
			||||||
 | 
						defer u.lock.Unlock()
 | 
				
			||||||
 | 
						if u.conn != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("InitializeWrite already called")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if u.err != nil {
 | 
				
			||||||
 | 
							return u.err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if backendResponse.StatusCode == http.StatusSwitchingProtocols {
 | 
				
			||||||
 | 
							connectionHeader := strings.ToLower(backendResponse.Header.Get(httpstream.HeaderConnection))
 | 
				
			||||||
 | 
							upgradeHeader := strings.ToLower(backendResponse.Header.Get(httpstream.HeaderUpgrade))
 | 
				
			||||||
 | 
							if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(spdy.HeaderSpdy31)) {
 | 
				
			||||||
 | 
								klog.Errorf("unable to upgrade: missing upgrade headers in response: %#v", backendResponse.Header)
 | 
				
			||||||
 | 
								u.err = fmt.Errorf("unable to upgrade: missing upgrade headers in response")
 | 
				
			||||||
 | 
								http.Error(u.w, u.err.Error(), http.StatusInternalServerError)
 | 
				
			||||||
 | 
								return u.err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Translate the server's chosen SPDY protocol into the tunneled websocket protocol for the handshake
 | 
				
			||||||
 | 
							var serverWebsocketProtocols []string
 | 
				
			||||||
 | 
							if backendSPDYProtocol := strings.TrimSpace(backendResponse.Header.Get(httpstream.HeaderProtocolVersion)); backendSPDYProtocol != "" {
 | 
				
			||||||
 | 
								serverWebsocketProtocols = []string{constants.WebsocketsSPDYTunnelingPrefix + backendSPDYProtocol}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								serverWebsocketProtocols = []string{}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Try to upgrade the websocket connection.
 | 
				
			||||||
 | 
							// Beyond this point, we don't need to write errors to the response.
 | 
				
			||||||
 | 
							var upgrader = gwebsocket.Upgrader{
 | 
				
			||||||
 | 
								CheckOrigin:  func(r *http.Request) bool { return true },
 | 
				
			||||||
 | 
								Subprotocols: serverWebsocketProtocols,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							conn, err := upgrader.Upgrade(u.w, u.req, nil)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								klog.Errorf("error upgrading websocket connection: %v", err)
 | 
				
			||||||
 | 
								u.err = err
 | 
				
			||||||
 | 
								return u.err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							klog.V(4).Infof("websocket connection created: %s", conn.Subprotocol())
 | 
				
			||||||
 | 
							u.conn = portforward.NewTunnelingConnection("server", conn)
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// anything other than an upgrade should pass through the backend response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// try to hijack
 | 
				
			||||||
 | 
						conn, _, err = u.w.(http.Hijacker).Hijack()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							klog.Errorf("Unable to hijack response: %v", err)
 | 
				
			||||||
 | 
							u.err = err
 | 
				
			||||||
 | 
							return u.err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// replay the backend response bytes to the hijacked conn
 | 
				
			||||||
 | 
						conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) //nolint:errcheck
 | 
				
			||||||
 | 
						_, err = conn.Write(backendResponseBytes)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							u.err = err
 | 
				
			||||||
 | 
							return u.err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						u.conn = conn
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (u *tunnelingWebsocketUpgraderConn) Read(b []byte) (n int, err error) {
 | 
				
			||||||
 | 
						u.lock.RLock()
 | 
				
			||||||
 | 
						defer u.lock.RUnlock()
 | 
				
			||||||
 | 
						if u.conn != nil {
 | 
				
			||||||
 | 
							return u.conn.Read(b)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if u.err != nil {
 | 
				
			||||||
 | 
							return 0, u.err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// return empty read without blocking until we are initialized
 | 
				
			||||||
 | 
						return 0, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					func (u *tunnelingWebsocketUpgraderConn) Write(b []byte) (n int, err error) {
 | 
				
			||||||
 | 
						u.lock.RLock()
 | 
				
			||||||
 | 
						defer u.lock.RUnlock()
 | 
				
			||||||
 | 
						if u.conn != nil {
 | 
				
			||||||
 | 
							return u.conn.Write(b)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if u.err != nil {
 | 
				
			||||||
 | 
							return 0, u.err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return 0, fmt.Errorf("Write called before Initialize")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					func (u *tunnelingWebsocketUpgraderConn) Close() error {
 | 
				
			||||||
 | 
						u.lock.Lock()
 | 
				
			||||||
 | 
						defer u.lock.Unlock()
 | 
				
			||||||
 | 
						if u.conn != nil {
 | 
				
			||||||
 | 
							return u.conn.Close()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if u.err != nil {
 | 
				
			||||||
 | 
							return u.err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// record that we closed so we don't write again or try to initialize
 | 
				
			||||||
 | 
						u.err = fmt.Errorf("connection closed")
 | 
				
			||||||
 | 
						// write a response
 | 
				
			||||||
 | 
						http.Error(u.w, u.err.Error(), http.StatusInternalServerError)
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					func (u *tunnelingWebsocketUpgraderConn) LocalAddr() net.Addr {
 | 
				
			||||||
 | 
						u.lock.RLock()
 | 
				
			||||||
 | 
						defer u.lock.RUnlock()
 | 
				
			||||||
 | 
						if u.conn != nil {
 | 
				
			||||||
 | 
							return u.conn.LocalAddr()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return noopAddr{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					func (u *tunnelingWebsocketUpgraderConn) RemoteAddr() net.Addr {
 | 
				
			||||||
 | 
						u.lock.RLock()
 | 
				
			||||||
 | 
						defer u.lock.RUnlock()
 | 
				
			||||||
 | 
						if u.conn != nil {
 | 
				
			||||||
 | 
							return u.conn.RemoteAddr()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return noopAddr{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					func (u *tunnelingWebsocketUpgraderConn) SetDeadline(t time.Time) error {
 | 
				
			||||||
 | 
						u.lock.RLock()
 | 
				
			||||||
 | 
						defer u.lock.RUnlock()
 | 
				
			||||||
 | 
						if u.conn != nil {
 | 
				
			||||||
 | 
							return u.conn.SetDeadline(t)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					func (u *tunnelingWebsocketUpgraderConn) SetReadDeadline(t time.Time) error {
 | 
				
			||||||
 | 
						u.lock.RLock()
 | 
				
			||||||
 | 
						defer u.lock.RUnlock()
 | 
				
			||||||
 | 
						if u.conn != nil {
 | 
				
			||||||
 | 
							return u.conn.SetReadDeadline(t)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					func (u *tunnelingWebsocketUpgraderConn) SetWriteDeadline(t time.Time) error {
 | 
				
			||||||
 | 
						u.lock.RLock()
 | 
				
			||||||
 | 
						defer u.lock.RUnlock()
 | 
				
			||||||
 | 
						if u.conn != nil {
 | 
				
			||||||
 | 
							return u.conn.SetWriteDeadline(t)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type noopAddr struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (n noopAddr) Network() string { return "" }
 | 
				
			||||||
 | 
					func (n noopAddr) String() string  { return "" }
 | 
				
			||||||
							
								
								
									
										364
									
								
								staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										364
									
								
								staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,364 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Copyright 2024 The Kubernetes Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"crypto/rand"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/http/httptest"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/require"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/runtime"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/httpstream"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/httpstream/spdy"
 | 
				
			||||||
 | 
						constants "k8s.io/apimachinery/pkg/util/portforward"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/proxy"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/wait"
 | 
				
			||||||
 | 
						"k8s.io/apiserver/pkg/registry/rest"
 | 
				
			||||||
 | 
						restconfig "k8s.io/client-go/rest"
 | 
				
			||||||
 | 
						"k8s.io/client-go/tools/portforward"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
 | 
				
			||||||
 | 
						// Create fake upstream SPDY server, with channel receiving SPDY streams.
 | 
				
			||||||
 | 
						streamChan := make(chan httpstream.Stream)
 | 
				
			||||||
 | 
						defer close(streamChan)
 | 
				
			||||||
 | 
						stopServerChan := make(chan struct{})
 | 
				
			||||||
 | 
						defer close(stopServerChan)
 | 
				
			||||||
 | 
						// Create fake upstream SPDY server.
 | 
				
			||||||
 | 
						spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
 | 
							_, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name})
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							upgrader := spdy.NewResponseUpgrader()
 | 
				
			||||||
 | 
							conn := upgrader.UpgradeResponse(w, req, justQueueStream(streamChan))
 | 
				
			||||||
 | 
							require.NotNil(t, conn)
 | 
				
			||||||
 | 
							defer conn.Close() //nolint:errcheck
 | 
				
			||||||
 | 
							<-stopServerChan
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
						defer spdyServer.Close()
 | 
				
			||||||
 | 
						// Create UpgradeAwareProxy handler, with url/transport pointing to upstream SPDY. Then
 | 
				
			||||||
 | 
						// create TunnelingHandler by injecting upgrade handler. Create TunnelingServer.
 | 
				
			||||||
 | 
						url, err := url.Parse(spdyServer.URL)
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						transport, err := fakeTransport()
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						upgradeHandler := proxy.NewUpgradeAwareHandler(url, transport, false, true, proxy.NewErrorResponder(&fakeResponder{}))
 | 
				
			||||||
 | 
						tunnelingHandler := NewTunnelingHandler(upgradeHandler)
 | 
				
			||||||
 | 
						tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
 | 
							tunnelingHandler.ServeHTTP(w, req)
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
						defer tunnelingServer.Close()
 | 
				
			||||||
 | 
						// Create SPDY client connection containing a TunnelingConnection by upgrading
 | 
				
			||||||
 | 
						// a request to TunnelingHandler using new portforward version 2.
 | 
				
			||||||
 | 
						tunnelingURL, err := url.Parse(tunnelingServer.URL)
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						dialer, err := portforward.NewSPDYOverWebsocketDialer(tunnelingURL, &restconfig.Config{Host: tunnelingURL.Host})
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name)
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						assert.Equal(t, constants.PortForwardV1Name, protocol)
 | 
				
			||||||
 | 
						defer spdyClient.Close() //nolint:errcheck
 | 
				
			||||||
 | 
						// Create a SPDY client stream, which will queue a SPDY server stream
 | 
				
			||||||
 | 
						// on the stream creation channel. Send random data on the client stream
 | 
				
			||||||
 | 
						// reading off the SPDY server stream, and validating it was tunneled.
 | 
				
			||||||
 | 
						randomSize := 1024 * 1024
 | 
				
			||||||
 | 
						randomData := make([]byte, randomSize)
 | 
				
			||||||
 | 
						_, err = rand.Read(randomData)
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						var actual []byte
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							clientStream, err := spdyClient.CreateStream(http.Header{})
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							_, err = io.Copy(clientStream, bytes.NewReader(randomData))
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							clientStream.Close() //nolint:errcheck
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case serverStream := <-streamChan:
 | 
				
			||||||
 | 
							actual, err = io.ReadAll(serverStream)
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							defer serverStream.Close() //nolint:errcheck
 | 
				
			||||||
 | 
						case <-time.After(wait.ForeverTestTimeout):
 | 
				
			||||||
 | 
							t.Fatalf("timeout waiting for spdy stream to arrive on channel.")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						assert.Equal(t, randomData, actual, "error validating tunneled random data")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var expectedContentLengthHeaders = http.Header{
 | 
				
			||||||
 | 
						"Content-Length": []string{"25"},
 | 
				
			||||||
 | 
						"Date":           []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
 | 
				
			||||||
 | 
						"Split-Point":    []string{"split"},
 | 
				
			||||||
 | 
						"X-App-Protocol": []string{"portforward.k8s.io"},
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const contentLengthHeaders = "HTTP/1.1 400 Error\r\n" +
 | 
				
			||||||
 | 
						"Content-Length: 25\r\n" +
 | 
				
			||||||
 | 
						"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
 | 
				
			||||||
 | 
						"Split-Point: split\r\n" +
 | 
				
			||||||
 | 
						"X-App-Protocol: portforward.k8s.io\r\n" +
 | 
				
			||||||
 | 
						"\r\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const contentLengthBody = "0123456789split0123456789"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var contentLengthHeadersAndBody = contentLengthHeaders + contentLengthBody
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var expectedResponseHeaders = http.Header{
 | 
				
			||||||
 | 
						"Date":           []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
 | 
				
			||||||
 | 
						"Split-Point":    []string{"split"},
 | 
				
			||||||
 | 
						"X-App-Protocol": []string{"portforward.k8s.io"},
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const responseHeaders = "HTTP/1.1 101 Switching Protocols\r\n" +
 | 
				
			||||||
 | 
						"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
 | 
				
			||||||
 | 
						"Split-Point: split\r\n" +
 | 
				
			||||||
 | 
						"X-App-Protocol: portforward.k8s.io\r\n" +
 | 
				
			||||||
 | 
						"\r\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const responseBody = "This is extra split data.\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var responseHeadersAndBody = responseHeaders + responseBody
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const invalidResponseData = "INVALID/1.1 101 Switching Protocols\r\n" +
 | 
				
			||||||
 | 
						"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
 | 
				
			||||||
 | 
						"Split-Point: split\r\n" +
 | 
				
			||||||
 | 
						"X-App-Protocol: portforward.k8s.io\r\n" +
 | 
				
			||||||
 | 
						"\r\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestTunnelingHandler_HeaderInterceptingConn(t *testing.T) {
 | 
				
			||||||
 | 
						// Basic http response is intercepted correctly; no extra data sent to net.Conn.
 | 
				
			||||||
 | 
						t.Run("simple-no-body", func(t *testing.T) {
 | 
				
			||||||
 | 
							testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
 | 
				
			||||||
 | 
							hic := &headerInterceptingConn{initializableConn: testConnConstructor}
 | 
				
			||||||
 | 
							_, err := hic.Write([]byte(responseHeaders))
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							assert.True(t, hic.initialized, "successfully parsed http response headers")
 | 
				
			||||||
 | 
							assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
 | 
				
			||||||
 | 
							assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
 | 
				
			||||||
 | 
							assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol"))
 | 
				
			||||||
 | 
							assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
 | 
				
			||||||
 | 
							assert.Equal(t, "", string(testConnConstructor.mockConn.written))
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Extra data after response headers should be sent to net.Conn.
 | 
				
			||||||
 | 
						t.Run("simple-single-write", func(t *testing.T) {
 | 
				
			||||||
 | 
							testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
 | 
				
			||||||
 | 
							hic := &headerInterceptingConn{initializableConn: testConnConstructor}
 | 
				
			||||||
 | 
							_, err := hic.Write([]byte(responseHeadersAndBody))
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							assert.True(t, hic.initialized)
 | 
				
			||||||
 | 
							assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
 | 
				
			||||||
 | 
							assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
 | 
				
			||||||
 | 
							assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
 | 
				
			||||||
 | 
							assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Partially written headers are buffered and decoded
 | 
				
			||||||
 | 
						t.Run("simple-byte-by-byte", func(t *testing.T) {
 | 
				
			||||||
 | 
							testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
 | 
				
			||||||
 | 
							hic := &headerInterceptingConn{initializableConn: testConnConstructor}
 | 
				
			||||||
 | 
							// write one byte at a time
 | 
				
			||||||
 | 
							for _, b := range []byte(responseHeadersAndBody) {
 | 
				
			||||||
 | 
								_, err := hic.Write([]byte{b})
 | 
				
			||||||
 | 
								require.NoError(t, err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							assert.True(t, hic.initialized)
 | 
				
			||||||
 | 
							assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
 | 
				
			||||||
 | 
							assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
 | 
				
			||||||
 | 
							assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
 | 
				
			||||||
 | 
							assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Writes spanning the header/body breakpoint are buffered and decoded
 | 
				
			||||||
 | 
						t.Run("simple-span-headerbody", func(t *testing.T) {
 | 
				
			||||||
 | 
							testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
 | 
				
			||||||
 | 
							hic := &headerInterceptingConn{initializableConn: testConnConstructor}
 | 
				
			||||||
 | 
							// write one chunk at a time
 | 
				
			||||||
 | 
							for i, chunk := range strings.Split(responseHeadersAndBody, "split") {
 | 
				
			||||||
 | 
								if i > 0 {
 | 
				
			||||||
 | 
									n, err := hic.Write([]byte("split"))
 | 
				
			||||||
 | 
									require.Equal(t, n, len("split"))
 | 
				
			||||||
 | 
									require.NoError(t, err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								n, err := hic.Write([]byte(chunk))
 | 
				
			||||||
 | 
								require.Equal(t, n, len(chunk))
 | 
				
			||||||
 | 
								require.NoError(t, err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							assert.True(t, hic.initialized)
 | 
				
			||||||
 | 
							assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
 | 
				
			||||||
 | 
							assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
 | 
				
			||||||
 | 
							assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
 | 
				
			||||||
 | 
							assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Tolerate header separators of \n instead of \r\n, and extra data after response headers should be sent to net.Conn.
 | 
				
			||||||
 | 
						t.Run("simple-tolerate-lf", func(t *testing.T) {
 | 
				
			||||||
 | 
							testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
 | 
				
			||||||
 | 
							hic := &headerInterceptingConn{initializableConn: testConnConstructor}
 | 
				
			||||||
 | 
							_, err := hic.Write([]byte(strings.ReplaceAll(responseHeadersAndBody, "\r", "")))
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							assert.True(t, hic.initialized)
 | 
				
			||||||
 | 
							assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
 | 
				
			||||||
 | 
							assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
 | 
				
			||||||
 | 
							assert.Equal(t, strings.ReplaceAll(responseHeaders, "\r", ""), string(testConnConstructor.initializeWriteConn.written), "only normalized headers are written in initializeWrite")
 | 
				
			||||||
 | 
							assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Content-Length handling
 | 
				
			||||||
 | 
						t.Run("content-length-body", func(t *testing.T) {
 | 
				
			||||||
 | 
							testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
 | 
				
			||||||
 | 
							hic := &headerInterceptingConn{initializableConn: testConnConstructor}
 | 
				
			||||||
 | 
							_, err := hic.Write([]byte(contentLengthHeadersAndBody))
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							assert.True(t, hic.initialized, "successfully parsed http response headers")
 | 
				
			||||||
 | 
							assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
 | 
				
			||||||
 | 
							assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
 | 
				
			||||||
 | 
							assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
 | 
				
			||||||
 | 
							assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Content-Length separately written headers and body
 | 
				
			||||||
 | 
						t.Run("content-length-headers-body", func(t *testing.T) {
 | 
				
			||||||
 | 
							testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
 | 
				
			||||||
 | 
							hic := &headerInterceptingConn{initializableConn: testConnConstructor}
 | 
				
			||||||
 | 
							_, err := hic.Write([]byte(contentLengthHeaders))
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							_, err = hic.Write([]byte(contentLengthBody))
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							assert.True(t, hic.initialized, "successfully parsed http response headers")
 | 
				
			||||||
 | 
							assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
 | 
				
			||||||
 | 
							assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
 | 
				
			||||||
 | 
							assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
 | 
				
			||||||
 | 
							assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Content-Length accumulating byte-by-byte
 | 
				
			||||||
 | 
						t.Run("content-length-byte-by-byte", func(t *testing.T) {
 | 
				
			||||||
 | 
							testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
 | 
				
			||||||
 | 
							hic := &headerInterceptingConn{initializableConn: testConnConstructor}
 | 
				
			||||||
 | 
							for _, b := range []byte(contentLengthHeadersAndBody) {
 | 
				
			||||||
 | 
								_, err := hic.Write([]byte{b})
 | 
				
			||||||
 | 
								require.NoError(t, err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							assert.True(t, hic.initialized, "successfully parsed http response headers")
 | 
				
			||||||
 | 
							assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
 | 
				
			||||||
 | 
							assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
 | 
				
			||||||
 | 
							assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
 | 
				
			||||||
 | 
							assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Content-Length writes spanning headers / body
 | 
				
			||||||
 | 
						t.Run("content-length-span-headerbody", func(t *testing.T) {
 | 
				
			||||||
 | 
							testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
 | 
				
			||||||
 | 
							hic := &headerInterceptingConn{initializableConn: testConnConstructor}
 | 
				
			||||||
 | 
							// write one chunk at a time
 | 
				
			||||||
 | 
							for i, chunk := range strings.Split(contentLengthHeadersAndBody, "split") {
 | 
				
			||||||
 | 
								if i > 0 {
 | 
				
			||||||
 | 
									n, err := hic.Write([]byte("split"))
 | 
				
			||||||
 | 
									require.Equal(t, n, len("split"))
 | 
				
			||||||
 | 
									require.NoError(t, err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								n, err := hic.Write([]byte(chunk))
 | 
				
			||||||
 | 
								require.Equal(t, n, len(chunk))
 | 
				
			||||||
 | 
								require.NoError(t, err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							assert.True(t, hic.initialized, "successfully parsed http response headers")
 | 
				
			||||||
 | 
							assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
 | 
				
			||||||
 | 
							assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
 | 
				
			||||||
 | 
							assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
 | 
				
			||||||
 | 
							assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Invalid response returns error.
 | 
				
			||||||
 | 
						t.Run("invalid-single-write", func(t *testing.T) {
 | 
				
			||||||
 | 
							testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
 | 
				
			||||||
 | 
							hic := &headerInterceptingConn{initializableConn: testConnConstructor}
 | 
				
			||||||
 | 
							_, err := hic.Write([]byte(invalidResponseData))
 | 
				
			||||||
 | 
							assert.Error(t, err, "expected error from invalid http response")
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Invalid response written byte by byte returns error.
 | 
				
			||||||
 | 
						t.Run("invalid-byte-by-byte", func(t *testing.T) {
 | 
				
			||||||
 | 
							testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
 | 
				
			||||||
 | 
							hic := &headerInterceptingConn{initializableConn: testConnConstructor}
 | 
				
			||||||
 | 
							var err error
 | 
				
			||||||
 | 
							for _, b := range []byte(invalidResponseData) {
 | 
				
			||||||
 | 
								_, err = hic.Write([]byte{b})
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									break
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							assert.Error(t, err, "expected error from invalid http response")
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type mockConnInitializer struct {
 | 
				
			||||||
 | 
						resp                *http.Response
 | 
				
			||||||
 | 
						initializeWriteConn *mockConn
 | 
				
			||||||
 | 
						*mockConn
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error {
 | 
				
			||||||
 | 
						m.resp = backendResponse
 | 
				
			||||||
 | 
						_, err := m.initializeWriteConn.Write(backendResponseBytes)
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// mockConn implements "net.Conn" interface.
 | 
				
			||||||
 | 
					var _ net.Conn = &mockConn{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type mockConn struct {
 | 
				
			||||||
 | 
						written []byte
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (mc *mockConn) Write(p []byte) (int, error) {
 | 
				
			||||||
 | 
						mc.written = append(mc.written, p...)
 | 
				
			||||||
 | 
						return len(p), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (mc *mockConn) Read(p []byte) (int, error)         { return 0, nil }
 | 
				
			||||||
 | 
					func (mc *mockConn) Close() error                       { return nil }
 | 
				
			||||||
 | 
					func (mc *mockConn) LocalAddr() net.Addr                { return &net.TCPAddr{} }
 | 
				
			||||||
 | 
					func (mc *mockConn) RemoteAddr() net.Addr               { return &net.TCPAddr{} }
 | 
				
			||||||
 | 
					func (mc *mockConn) SetDeadline(t time.Time) error      { return nil }
 | 
				
			||||||
 | 
					func (mc *mockConn) SetReadDeadline(t time.Time) error  { return nil }
 | 
				
			||||||
 | 
					func (mc *mockConn) SetWriteDeadline(t time.Time) error { return nil }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// fakeResponder implements "rest.Responder" interface.
 | 
				
			||||||
 | 
					var _ rest.Responder = &fakeResponder{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type fakeResponder struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (fr *fakeResponder) Object(statusCode int, obj runtime.Object) {}
 | 
				
			||||||
 | 
					func (fr *fakeResponder) Error(err error)                           {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// justQueueStream skips the usual stream validation before
 | 
				
			||||||
 | 
					// queueing the stream on the stream channel.
 | 
				
			||||||
 | 
					func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
 | 
				
			||||||
 | 
						return func(stream httpstream.Stream, replySent <-chan struct{}) error {
 | 
				
			||||||
 | 
							streams <- stream
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,57 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Copyright 2024 The Kubernetes Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package portforward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/httpstream"
 | 
				
			||||||
 | 
						"k8s.io/klog/v2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var _ httpstream.Dialer = &fallbackDialer{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// fallbackDialer encapsulates a primary and secondary dialer, including
 | 
				
			||||||
 | 
					// the boolean function to determine if the primary dialer failed. Implements
 | 
				
			||||||
 | 
					// the httpstream.Dialer interface.
 | 
				
			||||||
 | 
					type fallbackDialer struct {
 | 
				
			||||||
 | 
						primary        httpstream.Dialer
 | 
				
			||||||
 | 
						secondary      httpstream.Dialer
 | 
				
			||||||
 | 
						shouldFallback func(error) bool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewFallbackDialer creates the fallbackDialer with the primary and secondary dialers,
 | 
				
			||||||
 | 
					// as well as the boolean function to determine if the primary dialer failed.
 | 
				
			||||||
 | 
					func NewFallbackDialer(primary, secondary httpstream.Dialer, shouldFallback func(error) bool) httpstream.Dialer {
 | 
				
			||||||
 | 
						return &fallbackDialer{
 | 
				
			||||||
 | 
							primary:        primary,
 | 
				
			||||||
 | 
							secondary:      secondary,
 | 
				
			||||||
 | 
							shouldFallback: shouldFallback,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Dial is the single function necessary to implement the "httpstream.Dialer" interface.
 | 
				
			||||||
 | 
					// It takes the protocol version strings to request, returning an the upgraded
 | 
				
			||||||
 | 
					// httstream.Connection and the negotiated protocol version accepted. If the initial
 | 
				
			||||||
 | 
					// primary dialer fails, this function attempts the secondary dialer. Returns an error
 | 
				
			||||||
 | 
					// if one occurs.
 | 
				
			||||||
 | 
					func (f *fallbackDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
 | 
				
			||||||
 | 
						conn, version, err := f.primary.Dial(protocols...)
 | 
				
			||||||
 | 
						if err != nil && f.shouldFallback(err) {
 | 
				
			||||||
 | 
							klog.V(4).Infof("fallback to secondary dialer from primary dialer err: %v", err)
 | 
				
			||||||
 | 
							return f.secondary.Dial(protocols...)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return conn, version, err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,60 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Copyright 2024 The Kubernetes Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package portforward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/httpstream"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestFallbackDialer(t *testing.T) {
 | 
				
			||||||
 | 
						primaryProtocol := "primary.fake.protocol"
 | 
				
			||||||
 | 
						secondaryProtocol := "secondary.fake.protocol"
 | 
				
			||||||
 | 
						protocols := []string{primaryProtocol, secondaryProtocol}
 | 
				
			||||||
 | 
						// If primary dialer error is nil, then no fallback and primary negotiated protocol returned.
 | 
				
			||||||
 | 
						primary := &fakeDialer{dialed: false, negotiatedProtocol: primaryProtocol}
 | 
				
			||||||
 | 
						secondary := &fakeDialer{dialed: false, negotiatedProtocol: secondaryProtocol}
 | 
				
			||||||
 | 
						fallbackDialer := NewFallbackDialer(primary, secondary, notCalled)
 | 
				
			||||||
 | 
						_, negotiated, err := fallbackDialer.Dial(protocols...)
 | 
				
			||||||
 | 
						assert.True(t, primary.dialed, "no fallback; primary should have dialed")
 | 
				
			||||||
 | 
						assert.False(t, secondary.dialed, "no fallback; secondary should *not* have dialed")
 | 
				
			||||||
 | 
						assert.Equal(t, primaryProtocol, negotiated, "primary negotiated protocol returned")
 | 
				
			||||||
 | 
						assert.Nil(t, err, "error from primary dialer should be nil")
 | 
				
			||||||
 | 
						// If primary dialer error is upgrade error, then fallback returning secondary dial response.
 | 
				
			||||||
 | 
						primary = &fakeDialer{dialed: false, negotiatedProtocol: primaryProtocol, err: &httpstream.UpgradeFailureError{}}
 | 
				
			||||||
 | 
						secondary = &fakeDialer{dialed: false, negotiatedProtocol: secondaryProtocol}
 | 
				
			||||||
 | 
						fallbackDialer = NewFallbackDialer(primary, secondary, httpstream.IsUpgradeFailure)
 | 
				
			||||||
 | 
						_, negotiated, err = fallbackDialer.Dial(protocols...)
 | 
				
			||||||
 | 
						assert.True(t, primary.dialed, "fallback; primary should have dialed")
 | 
				
			||||||
 | 
						assert.True(t, secondary.dialed, "fallback; secondary should have dialed")
 | 
				
			||||||
 | 
						assert.Equal(t, secondaryProtocol, negotiated, "negotiated protocol is from secondary dialer")
 | 
				
			||||||
 | 
						assert.Nil(t, err, "error from secondary dialer should be nil")
 | 
				
			||||||
 | 
						// If primary dialer returns non-upgrade error, then primary error is returned.
 | 
				
			||||||
 | 
						nonUpgradeErr := fmt.Errorf("This is a non-upgrade error")
 | 
				
			||||||
 | 
						primary = &fakeDialer{dialed: false, err: nonUpgradeErr}
 | 
				
			||||||
 | 
						secondary = &fakeDialer{dialed: false}
 | 
				
			||||||
 | 
						fallbackDialer = NewFallbackDialer(primary, secondary, httpstream.IsUpgradeFailure)
 | 
				
			||||||
 | 
						_, _, err = fallbackDialer.Dial(protocols...)
 | 
				
			||||||
 | 
						assert.True(t, primary.dialed, "no fallback; primary should have dialed")
 | 
				
			||||||
 | 
						assert.False(t, secondary.dialed, "no fallback; secondary should *not* have dialed")
 | 
				
			||||||
 | 
						assert.Equal(t, nonUpgradeErr, err, "error is from primary dialer")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func notCalled(err error) bool { return false }
 | 
				
			||||||
@@ -191,11 +191,15 @@ func (pf *PortForwarder) ForwardPorts() error {
 | 
				
			|||||||
	defer pf.Close()
 | 
						defer pf.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var err error
 | 
						var err error
 | 
				
			||||||
	pf.streamConn, _, err = pf.dialer.Dial(PortForwardProtocolV1Name)
 | 
						var protocol string
 | 
				
			||||||
 | 
						pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return fmt.Errorf("error upgrading connection: %s", err)
 | 
							return fmt.Errorf("error upgrading connection: %s", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	defer pf.streamConn.Close()
 | 
						defer pf.streamConn.Close()
 | 
				
			||||||
 | 
						if protocol != PortForwardProtocolV1Name {
 | 
				
			||||||
 | 
							return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return pf.forward()
 | 
						return pf.forward()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -430,7 +430,8 @@ func TestGetListener(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
 | 
					func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
 | 
				
			||||||
	dialer := &fakeDialer{
 | 
						dialer := &fakeDialer{
 | 
				
			||||||
		conn: newFakeConnection(),
 | 
							conn:               newFakeConnection(),
 | 
				
			||||||
 | 
							negotiatedProtocol: PortForwardProtocolV1Name,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	stopChan := make(chan struct{})
 | 
						stopChan := make(chan struct{})
 | 
				
			||||||
@@ -570,7 +571,8 @@ func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
 | 
					func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
 | 
				
			||||||
	dialer := &fakeDialer{
 | 
						dialer := &fakeDialer{
 | 
				
			||||||
		conn: newFakeConnection(),
 | 
							conn:               newFakeConnection(),
 | 
				
			||||||
 | 
							negotiatedProtocol: PortForwardProtocolV1Name,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	stopChan := make(chan struct{})
 | 
						stopChan := make(chan struct{})
 | 
				
			||||||
@@ -601,7 +603,8 @@ func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) {
 | 
					func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) {
 | 
				
			||||||
	dialer := &fakeDialer{
 | 
						dialer := &fakeDialer{
 | 
				
			||||||
		conn: newFakeConnection(),
 | 
							conn:               newFakeConnection(),
 | 
				
			||||||
 | 
							negotiatedProtocol: PortForwardProtocolV1Name,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	stopChan := make(chan struct{})
 | 
						stopChan := make(chan struct{})
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,158 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Copyright 2023 The Kubernetes Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package portforward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						gwebsocket "github.com/gorilla/websocket"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"k8s.io/klog/v2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var _ net.Conn = &TunnelingConnection{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TunnelingConnection implements the "httpstream.Connection" interface, wrapping
 | 
				
			||||||
 | 
					// a websocket connection that tunnels SPDY.
 | 
				
			||||||
 | 
					type TunnelingConnection struct {
 | 
				
			||||||
 | 
						name              string
 | 
				
			||||||
 | 
						conn              *gwebsocket.Conn
 | 
				
			||||||
 | 
						inProgressMessage io.Reader
 | 
				
			||||||
 | 
						closeOnce         sync.Once
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewTunnelingConnection wraps the passed gorilla/websockets connection
 | 
				
			||||||
 | 
					// with the TunnelingConnection struct (implementing net.Conn).
 | 
				
			||||||
 | 
					func NewTunnelingConnection(name string, conn *gwebsocket.Conn) *TunnelingConnection {
 | 
				
			||||||
 | 
						return &TunnelingConnection{
 | 
				
			||||||
 | 
							name: name,
 | 
				
			||||||
 | 
							conn: conn,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Read implements "io.Reader" interface, reading from the stored connection
 | 
				
			||||||
 | 
					// into the passed buffer "p". Returns the number of bytes read and an error.
 | 
				
			||||||
 | 
					// Can keep track of the "inProgress" messsage from the tunneled connection.
 | 
				
			||||||
 | 
					func (c *TunnelingConnection) Read(p []byte) (int, error) {
 | 
				
			||||||
 | 
						klog.V(7).Infof("%s: tunneling connection read...", c.name)
 | 
				
			||||||
 | 
						defer klog.V(7).Infof("%s: tunneling connection read...complete", c.name)
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							if c.inProgressMessage == nil {
 | 
				
			||||||
 | 
								klog.V(8).Infof("%s: tunneling connection read before NextReader()...", c.name)
 | 
				
			||||||
 | 
								messageType, nextReader, err := c.conn.NextReader()
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									closeError := &gwebsocket.CloseError{}
 | 
				
			||||||
 | 
									if errors.As(err, &closeError) && closeError.Code == gwebsocket.CloseNormalClosure {
 | 
				
			||||||
 | 
										return 0, io.EOF
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									klog.V(4).Infof("%s:tunneling connection NextReader() error: %v", c.name, err)
 | 
				
			||||||
 | 
									return 0, err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if messageType != gwebsocket.BinaryMessage {
 | 
				
			||||||
 | 
									return 0, fmt.Errorf("invalid message type received")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								c.inProgressMessage = nextReader
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							klog.V(8).Infof("%s: tunneling connection read in progress message...", c.name)
 | 
				
			||||||
 | 
							i, err := c.inProgressMessage.Read(p)
 | 
				
			||||||
 | 
							if i == 0 && err == io.EOF {
 | 
				
			||||||
 | 
								c.inProgressMessage = nil
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								klog.V(8).Infof("%s: read %d bytes, error=%v, bytes=% X", c.name, i, err, p[:i])
 | 
				
			||||||
 | 
								return i, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Write implements "io.Writer" interface, copying the data in the passed
 | 
				
			||||||
 | 
					// byte array "p" into the stored tunneled connection. Returns the number
 | 
				
			||||||
 | 
					// of bytes written and an error.
 | 
				
			||||||
 | 
					func (c *TunnelingConnection) Write(p []byte) (n int, err error) {
 | 
				
			||||||
 | 
						klog.V(7).Infof("%s: write: %d bytes, bytes=% X", c.name, len(p), p)
 | 
				
			||||||
 | 
						defer klog.V(7).Infof("%s: tunneling connection write...complete", c.name)
 | 
				
			||||||
 | 
						w, err := c.conn.NextWriter(gwebsocket.BinaryMessage)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return 0, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer func() {
 | 
				
			||||||
 | 
							// close, which flushes the message
 | 
				
			||||||
 | 
							closeErr := w.Close()
 | 
				
			||||||
 | 
							if closeErr != nil && err == nil {
 | 
				
			||||||
 | 
								// if closing/flushing errored and we weren't already returning an error, return the close error
 | 
				
			||||||
 | 
								err = closeErr
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						n, err = w.Write(p)
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Close implements "io.Closer" interface, signaling the other tunneled connection
 | 
				
			||||||
 | 
					// endpoint, and closing the tunneled connection only once.
 | 
				
			||||||
 | 
					func (c *TunnelingConnection) Close() error {
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
						c.closeOnce.Do(func() {
 | 
				
			||||||
 | 
							klog.V(7).Infof("%s: tunneling connection Close()...", c.name)
 | 
				
			||||||
 | 
							// Signal other endpoint that websocket connection is closing; ignore error.
 | 
				
			||||||
 | 
							normalCloseMsg := gwebsocket.FormatCloseMessage(gwebsocket.CloseNormalClosure, "")
 | 
				
			||||||
 | 
							writeControlErr := c.conn.WriteControl(gwebsocket.CloseMessage, normalCloseMsg, time.Now().Add(time.Second))
 | 
				
			||||||
 | 
							closeErr := c.conn.Close()
 | 
				
			||||||
 | 
							if closeErr != nil {
 | 
				
			||||||
 | 
								err = closeErr
 | 
				
			||||||
 | 
							} else if writeControlErr != nil {
 | 
				
			||||||
 | 
								err = writeControlErr
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// LocalAddr implements part of the "net.Conn" interface, returning the local
 | 
				
			||||||
 | 
					// endpoint network address of the tunneled connection.
 | 
				
			||||||
 | 
					func (c *TunnelingConnection) LocalAddr() net.Addr {
 | 
				
			||||||
 | 
						return c.conn.LocalAddr()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// LocalAddr implements part of the "net.Conn" interface, returning the remote
 | 
				
			||||||
 | 
					// endpoint network address of the tunneled connection.
 | 
				
			||||||
 | 
					func (c *TunnelingConnection) RemoteAddr() net.Addr {
 | 
				
			||||||
 | 
						return c.conn.RemoteAddr()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SetDeadline sets the *absolute* time in the future for both
 | 
				
			||||||
 | 
					// read and write deadlines. Returns an error if one occurs.
 | 
				
			||||||
 | 
					func (c *TunnelingConnection) SetDeadline(t time.Time) error {
 | 
				
			||||||
 | 
						rerr := c.SetReadDeadline(t)
 | 
				
			||||||
 | 
						werr := c.SetWriteDeadline(t)
 | 
				
			||||||
 | 
						return errors.Join(rerr, werr)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SetDeadline sets the *absolute* time in the future for the
 | 
				
			||||||
 | 
					// read deadlines. Returns an error if one occurs.
 | 
				
			||||||
 | 
					func (c *TunnelingConnection) SetReadDeadline(t time.Time) error {
 | 
				
			||||||
 | 
						return c.conn.SetReadDeadline(t)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SetDeadline sets the *absolute* time in the future for the
 | 
				
			||||||
 | 
					// write deadlines. Returns an error if one occurs.
 | 
				
			||||||
 | 
					func (c *TunnelingConnection) SetWriteDeadline(t time.Time) error {
 | 
				
			||||||
 | 
						return c.conn.SetWriteDeadline(t)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,190 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Copyright 2024 The Kubernetes Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package portforward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/http/httptest"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						gwebsocket "github.com/gorilla/websocket"
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/require"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/httpstream"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/httpstream/spdy"
 | 
				
			||||||
 | 
						constants "k8s.io/apimachinery/pkg/util/portforward"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/wait"
 | 
				
			||||||
 | 
						"k8s.io/client-go/rest"
 | 
				
			||||||
 | 
						"k8s.io/client-go/transport/websocket"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestTunnelingConnection_ReadWriteClose(t *testing.T) {
 | 
				
			||||||
 | 
						// Stream channel that will receive streams created on upstream SPDY server.
 | 
				
			||||||
 | 
						streamChan := make(chan httpstream.Stream)
 | 
				
			||||||
 | 
						defer close(streamChan)
 | 
				
			||||||
 | 
						stopServerChan := make(chan struct{})
 | 
				
			||||||
 | 
						defer close(stopServerChan)
 | 
				
			||||||
 | 
						// Create tunneling connection server endpoint with fake upstream SPDY server.
 | 
				
			||||||
 | 
						tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
 | 
							var upgrader = gwebsocket.Upgrader{
 | 
				
			||||||
 | 
								CheckOrigin:  func(r *http.Request) bool { return true },
 | 
				
			||||||
 | 
								Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							conn, err := upgrader.Upgrade(w, req, nil)
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							defer conn.Close() //nolint:errcheck
 | 
				
			||||||
 | 
							require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
 | 
				
			||||||
 | 
							tunnelingConn := NewTunnelingConnection("server", conn)
 | 
				
			||||||
 | 
							spdyConn, err := spdy.NewServerConnection(tunnelingConn, justQueueStream(streamChan))
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							defer spdyConn.Close() //nolint:errcheck
 | 
				
			||||||
 | 
							<-stopServerChan
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
						defer tunnelingServer.Close()
 | 
				
			||||||
 | 
						// Dial the client tunneling connection to the tunneling server.
 | 
				
			||||||
 | 
						url, err := url.Parse(tunnelingServer.URL)
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						dialer, err := NewSPDYOverWebsocketDialer(url, &rest.Config{Host: url.Host})
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name)
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						assert.Equal(t, constants.PortForwardV1Name, protocol)
 | 
				
			||||||
 | 
						defer spdyClient.Close() //nolint:errcheck
 | 
				
			||||||
 | 
						// Create a SPDY client stream, which will queue a SPDY server stream
 | 
				
			||||||
 | 
						// on the stream creation channel. Send data on the client stream
 | 
				
			||||||
 | 
						// reading off the SPDY server stream, and validating it was tunneled.
 | 
				
			||||||
 | 
						expected := "This is a test tunneling SPDY data through websockets."
 | 
				
			||||||
 | 
						var actual []byte
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							clientStream, err := spdyClient.CreateStream(http.Header{})
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							_, err = io.Copy(clientStream, strings.NewReader(expected))
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							clientStream.Close() //nolint:errcheck
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case serverStream := <-streamChan:
 | 
				
			||||||
 | 
							actual, err = io.ReadAll(serverStream)
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							defer serverStream.Close() //nolint:errcheck
 | 
				
			||||||
 | 
						case <-time.After(wait.ForeverTestTimeout):
 | 
				
			||||||
 | 
							t.Fatalf("timeout waiting for spdy stream to arrive on channel.")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						assert.Equal(t, expected, string(actual), "error validating tunneled string")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestTunnelingConnection_LocalRemoteAddress(t *testing.T) {
 | 
				
			||||||
 | 
						stopServerChan := make(chan struct{})
 | 
				
			||||||
 | 
						defer close(stopServerChan)
 | 
				
			||||||
 | 
						tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
 | 
							var upgrader = gwebsocket.Upgrader{
 | 
				
			||||||
 | 
								CheckOrigin:  func(r *http.Request) bool { return true },
 | 
				
			||||||
 | 
								Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							conn, err := upgrader.Upgrade(w, req, nil)
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							defer conn.Close() //nolint:errcheck
 | 
				
			||||||
 | 
							require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
 | 
				
			||||||
 | 
							<-stopServerChan
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
						defer tunnelingServer.Close()
 | 
				
			||||||
 | 
						// Create the client side tunneling connection.
 | 
				
			||||||
 | 
						url, err := url.Parse(tunnelingServer.URL)
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						tConn, err := dialForTunnelingConnection(url)
 | 
				
			||||||
 | 
						require.NoError(t, err, "error creating client tunneling connection")
 | 
				
			||||||
 | 
						defer tConn.Close() //nolint:errcheck
 | 
				
			||||||
 | 
						// Validate "LocalAddr()" and "RemoteAddr()"
 | 
				
			||||||
 | 
						localAddr := tConn.LocalAddr()
 | 
				
			||||||
 | 
						remoteAddr := tConn.RemoteAddr()
 | 
				
			||||||
 | 
						assert.Equal(t, "tcp", localAddr.Network(), "tunneling connection must be TCP")
 | 
				
			||||||
 | 
						assert.Equal(t, "tcp", remoteAddr.Network(), "tunneling connection must be TCP")
 | 
				
			||||||
 | 
						_, err = net.ResolveTCPAddr("tcp", localAddr.String())
 | 
				
			||||||
 | 
						assert.NoError(t, err, "tunneling connection local addr should parse")
 | 
				
			||||||
 | 
						_, err = net.ResolveTCPAddr("tcp", remoteAddr.String())
 | 
				
			||||||
 | 
						assert.NoError(t, err, "tunneling connection remote addr should parse")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestTunnelingConnection_ReadWriteDeadlines(t *testing.T) {
 | 
				
			||||||
 | 
						stopServerChan := make(chan struct{})
 | 
				
			||||||
 | 
						defer close(stopServerChan)
 | 
				
			||||||
 | 
						tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
 | 
							var upgrader = gwebsocket.Upgrader{
 | 
				
			||||||
 | 
								CheckOrigin:  func(r *http.Request) bool { return true },
 | 
				
			||||||
 | 
								Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							conn, err := upgrader.Upgrade(w, req, nil)
 | 
				
			||||||
 | 
							require.NoError(t, err)
 | 
				
			||||||
 | 
							defer conn.Close() //nolint:errcheck
 | 
				
			||||||
 | 
							require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
 | 
				
			||||||
 | 
							<-stopServerChan
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
						defer tunnelingServer.Close()
 | 
				
			||||||
 | 
						// Create the client side tunneling connection.
 | 
				
			||||||
 | 
						url, err := url.Parse(tunnelingServer.URL)
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						tConn, err := dialForTunnelingConnection(url)
 | 
				
			||||||
 | 
						require.NoError(t, err, "error creating client tunneling connection")
 | 
				
			||||||
 | 
						defer tConn.Close() //nolint:errcheck
 | 
				
			||||||
 | 
						// Validate the read and write deadlines.
 | 
				
			||||||
 | 
						err = tConn.SetReadDeadline(time.Time{})
 | 
				
			||||||
 | 
						assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
 | 
				
			||||||
 | 
						err = tConn.SetWriteDeadline(time.Time{})
 | 
				
			||||||
 | 
						assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
 | 
				
			||||||
 | 
						err = tConn.SetDeadline(time.Time{})
 | 
				
			||||||
 | 
						assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
 | 
				
			||||||
 | 
						err = tConn.SetReadDeadline(time.Now().AddDate(10, 0, 0))
 | 
				
			||||||
 | 
						assert.NoError(t, err, "setting deadline 10 year from now succeeds")
 | 
				
			||||||
 | 
						err = tConn.SetWriteDeadline(time.Now().AddDate(10, 0, 0))
 | 
				
			||||||
 | 
						assert.NoError(t, err, "setting deadline 10 year from now succeeds")
 | 
				
			||||||
 | 
						err = tConn.SetDeadline(time.Now().AddDate(10, 0, 0))
 | 
				
			||||||
 | 
						assert.NoError(t, err, "setting deadline 10 year from now succeeds")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// dialForTunnelingConnection upgrades a request at the passed "url", creating
 | 
				
			||||||
 | 
					// a websocket connection. Returns the TunnelingConnection injected with the
 | 
				
			||||||
 | 
					// websocket connection or an error if one occurs.
 | 
				
			||||||
 | 
					func dialForTunnelingConnection(url *url.URL) (*TunnelingConnection, error) {
 | 
				
			||||||
 | 
						req, err := http.NewRequest("GET", url.String(), nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// Tunneling must initiate a websocket upgrade connection, using tunneling portforward protocol.
 | 
				
			||||||
 | 
						tunnelingProtocols := []string{constants.WebsocketsSPDYTunnelingPortForwardV1}
 | 
				
			||||||
 | 
						transport, holder, err := websocket.RoundTripperFor(&rest.Config{Host: url.Host})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						conn, err := websocket.Negotiate(transport, holder, req, tunnelingProtocols...)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return NewTunnelingConnection("client", conn), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
 | 
				
			||||||
 | 
						return func(stream httpstream.Stream, replySent <-chan struct{}) error {
 | 
				
			||||||
 | 
							streams <- stream
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,93 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Copyright 2023 The Kubernetes Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package portforward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/httpstream"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/httpstream/spdy"
 | 
				
			||||||
 | 
						constants "k8s.io/apimachinery/pkg/util/portforward"
 | 
				
			||||||
 | 
						restclient "k8s.io/client-go/rest"
 | 
				
			||||||
 | 
						"k8s.io/client-go/transport/websocket"
 | 
				
			||||||
 | 
						"k8s.io/klog/v2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const PingPeriod = 10 * time.Second
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// tunnelingDialer implements "httpstream.Dial" interface
 | 
				
			||||||
 | 
					type tunnelingDialer struct {
 | 
				
			||||||
 | 
						url       *url.URL
 | 
				
			||||||
 | 
						transport http.RoundTripper
 | 
				
			||||||
 | 
						holder    websocket.ConnectionHolder
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewTunnelingDialer creates and returns the tunnelingDialer structure which implemements the "httpstream.Dialer"
 | 
				
			||||||
 | 
					// interface. The dialer can upgrade a websocket request, creating a websocket connection. This function
 | 
				
			||||||
 | 
					// returns an error if one occurs.
 | 
				
			||||||
 | 
					func NewSPDYOverWebsocketDialer(url *url.URL, config *restclient.Config) (httpstream.Dialer, error) {
 | 
				
			||||||
 | 
						transport, holder, err := websocket.RoundTripperFor(config)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &tunnelingDialer{
 | 
				
			||||||
 | 
							url:       url,
 | 
				
			||||||
 | 
							transport: transport,
 | 
				
			||||||
 | 
							holder:    holder,
 | 
				
			||||||
 | 
						}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Dial upgrades to a tunneling streaming connection, returning a SPDY connection
 | 
				
			||||||
 | 
					// containing a WebSockets connection (which implements "net.Conn"). Also
 | 
				
			||||||
 | 
					// returns the protocol negotiated, or an error.
 | 
				
			||||||
 | 
					func (d *tunnelingDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
 | 
				
			||||||
 | 
						// There is no passed context, so skip the context when creating request for now.
 | 
				
			||||||
 | 
						// Websockets requires "GET" method: RFC 6455 Sec. 4.1 (page 17).
 | 
				
			||||||
 | 
						req, err := http.NewRequest("GET", d.url.String(), nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, "", err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// Add the spdy tunneling prefix to the requested protocols. The tunneling
 | 
				
			||||||
 | 
						// handler will know how to negotiate these protocols.
 | 
				
			||||||
 | 
						tunnelingProtocols := []string{}
 | 
				
			||||||
 | 
						for _, protocol := range protocols {
 | 
				
			||||||
 | 
							tunnelingProtocol := constants.WebsocketsSPDYTunnelingPrefix + protocol
 | 
				
			||||||
 | 
							tunnelingProtocols = append(tunnelingProtocols, tunnelingProtocol)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						klog.V(4).Infoln("Before WebSocket Upgrade Connection...")
 | 
				
			||||||
 | 
						conn, err := websocket.Negotiate(d.transport, d.holder, req, tunnelingProtocols...)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, "", err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if conn == nil {
 | 
				
			||||||
 | 
							return nil, "", fmt.Errorf("negotiated websocket connection is nil")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						protocol := conn.Subprotocol()
 | 
				
			||||||
 | 
						protocol = strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix)
 | 
				
			||||||
 | 
						klog.V(4).Infof("negotiated protocol: %s", protocol)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Wrap the websocket connection which implements "net.Conn".
 | 
				
			||||||
 | 
						tConn := NewTunnelingConnection("client", conn)
 | 
				
			||||||
 | 
						// Create SPDY connection injecting the previously created tunneling connection.
 | 
				
			||||||
 | 
						spdyConn, err := spdy.NewClientConnectionWithPings(tConn, PingPeriod)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return spdyConn, protocol, err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -31,6 +31,7 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	corev1 "k8s.io/api/core/v1"
 | 
						corev1 "k8s.io/api/core/v1"
 | 
				
			||||||
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
 | 
						metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/httpstream"
 | 
				
			||||||
	"k8s.io/apimachinery/pkg/util/sets"
 | 
						"k8s.io/apimachinery/pkg/util/sets"
 | 
				
			||||||
	"k8s.io/cli-runtime/pkg/genericiooptions"
 | 
						"k8s.io/cli-runtime/pkg/genericiooptions"
 | 
				
			||||||
	"k8s.io/client-go/kubernetes/scheme"
 | 
						"k8s.io/client-go/kubernetes/scheme"
 | 
				
			||||||
@@ -50,7 +51,7 @@ import (
 | 
				
			|||||||
type PortForwardOptions struct {
 | 
					type PortForwardOptions struct {
 | 
				
			||||||
	Namespace     string
 | 
						Namespace     string
 | 
				
			||||||
	PodName       string
 | 
						PodName       string
 | 
				
			||||||
	RESTClient    *restclient.RESTClient
 | 
						RESTClient    restclient.Interface
 | 
				
			||||||
	Config        *restclient.Config
 | 
						Config        *restclient.Config
 | 
				
			||||||
	PodClient     corev1client.PodsGetter
 | 
						PodClient     corev1client.PodsGetter
 | 
				
			||||||
	Address       []string
 | 
						Address       []string
 | 
				
			||||||
@@ -99,11 +100,7 @@ const (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewCmdPortForward(f cmdutil.Factory, streams genericiooptions.IOStreams) *cobra.Command {
 | 
					func NewCmdPortForward(f cmdutil.Factory, streams genericiooptions.IOStreams) *cobra.Command {
 | 
				
			||||||
	opts := &PortForwardOptions{
 | 
						opts := NewDefaultPortForwardOptions(streams)
 | 
				
			||||||
		PortForwarder: &defaultPortForwarder{
 | 
					 | 
				
			||||||
			IOStreams: streams,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	cmd := &cobra.Command{
 | 
						cmd := &cobra.Command{
 | 
				
			||||||
		Use:                   "port-forward TYPE/NAME [options] [LOCAL_PORT:]REMOTE_PORT [...[LOCAL_PORT_N:]REMOTE_PORT_N]",
 | 
							Use:                   "port-forward TYPE/NAME [options] [LOCAL_PORT:]REMOTE_PORT [...[LOCAL_PORT_N:]REMOTE_PORT_N]",
 | 
				
			||||||
		DisableFlagsInUseLine: true,
 | 
							DisableFlagsInUseLine: true,
 | 
				
			||||||
@@ -123,6 +120,14 @@ func NewCmdPortForward(f cmdutil.Factory, streams genericiooptions.IOStreams) *c
 | 
				
			|||||||
	return cmd
 | 
						return cmd
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewDefaultPortForwardOptions(streams genericiooptions.IOStreams) *PortForwardOptions {
 | 
				
			||||||
 | 
						return &PortForwardOptions{
 | 
				
			||||||
 | 
							PortForwarder: &defaultPortForwarder{
 | 
				
			||||||
 | 
								IOStreams: streams,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type portForwarder interface {
 | 
					type portForwarder interface {
 | 
				
			||||||
	ForwardPorts(method string, url *url.URL, opts PortForwardOptions) error
 | 
						ForwardPorts(method string, url *url.URL, opts PortForwardOptions) error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -137,6 +142,14 @@ func (f *defaultPortForwarder) ForwardPorts(method string, url *url.URL, opts Po
 | 
				
			|||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, method, url)
 | 
						dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, method, url)
 | 
				
			||||||
 | 
						if cmdutil.PortForwardWebsockets.IsEnabled() {
 | 
				
			||||||
 | 
							tunnelingDialer, err := portforward.NewSPDYOverWebsocketDialer(url, opts.Config)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							// First attempt tunneling (websocket) dialer, then fallback to spdy dialer.
 | 
				
			||||||
 | 
							dialer = portforward.NewFallbackDialer(tunnelingDialer, dialer, httpstream.IsUpgradeFailure)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	fw, err := portforward.NewOnAddresses(dialer, opts.Address, opts.Ports, opts.StopChannel, opts.ReadyChannel, f.Out, f.ErrOut)
 | 
						fw, err := portforward.NewOnAddresses(dialer, opts.Address, opts.Ports, opts.StopChannel, opts.ReadyChannel, f.Out, f.ErrOut)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
@@ -385,9 +398,17 @@ func (o PortForwardOptions) Validate() error {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Deprecated: Use RunPortForwardContext instead, which allows canceling.
 | 
				
			||||||
// RunPortForward implements all the necessary functionality for port-forward cmd.
 | 
					// RunPortForward implements all the necessary functionality for port-forward cmd.
 | 
				
			||||||
func (o PortForwardOptions) RunPortForward() error {
 | 
					func (o PortForwardOptions) RunPortForward() error {
 | 
				
			||||||
	pod, err := o.PodClient.Pods(o.Namespace).Get(context.TODO(), o.PodName, metav1.GetOptions{})
 | 
						return o.RunPortForwardContext(context.Background())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RunPortForwardContext implements all the necessary functionality for port-forward cmd.
 | 
				
			||||||
 | 
					// It ends portforwarding when an error is received from the backend, or an os.Interrupt
 | 
				
			||||||
 | 
					// signal is received, or the provided context is done.
 | 
				
			||||||
 | 
					func (o PortForwardOptions) RunPortForwardContext(ctx context.Context) error {
 | 
				
			||||||
 | 
						pod, err := o.PodClient.Pods(o.Namespace).Get(ctx, o.PodName, metav1.GetOptions{})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -400,8 +421,14 @@ func (o PortForwardOptions) RunPortForward() error {
 | 
				
			|||||||
	signal.Notify(signals, os.Interrupt)
 | 
						signal.Notify(signals, os.Interrupt)
 | 
				
			||||||
	defer signal.Stop(signals)
 | 
						defer signal.Stop(signals)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						returnCtx, returnCtxCancel := context.WithCancel(ctx)
 | 
				
			||||||
 | 
						defer returnCtxCancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
		<-signals
 | 
							select {
 | 
				
			||||||
 | 
							case <-signals:
 | 
				
			||||||
 | 
							case <-returnCtx.Done():
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		if o.StopChannel != nil {
 | 
							if o.StopChannel != nil {
 | 
				
			||||||
			close(o.StopChannel)
 | 
								close(o.StopChannel)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -17,6 +17,7 @@ limitations under the License.
 | 
				
			|||||||
package portforward
 | 
					package portforward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
@@ -101,6 +102,8 @@ func testPortForward(t *testing.T, flags map[string]string, args []string) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			opts := &PortForwardOptions{}
 | 
								opts := &PortForwardOptions{}
 | 
				
			||||||
 | 
								ctx, cancel := context.WithCancel(context.Background())
 | 
				
			||||||
 | 
								defer cancel()
 | 
				
			||||||
			cmd := NewCmdPortForward(tf, genericiooptions.NewTestIOStreamsDiscard())
 | 
								cmd := NewCmdPortForward(tf, genericiooptions.NewTestIOStreamsDiscard())
 | 
				
			||||||
			cmd.Run = func(cmd *cobra.Command, args []string) {
 | 
								cmd.Run = func(cmd *cobra.Command, args []string) {
 | 
				
			||||||
				if err = opts.Complete(tf, cmd, args); err != nil {
 | 
									if err = opts.Complete(tf, cmd, args); err != nil {
 | 
				
			||||||
@@ -110,7 +113,7 @@ func testPortForward(t *testing.T, flags map[string]string, args []string) {
 | 
				
			|||||||
				if err = opts.Validate(); err != nil {
 | 
									if err = opts.Validate(); err != nil {
 | 
				
			||||||
					return
 | 
										return
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				err = opts.RunPortForward()
 | 
									err = opts.RunPortForwardContext(ctx)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for name, value := range flags {
 | 
								for name, value := range flags {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -430,6 +430,7 @@ const (
 | 
				
			|||||||
	InteractiveDelete       FeatureGate = "KUBECTL_INTERACTIVE_DELETE"
 | 
						InteractiveDelete       FeatureGate = "KUBECTL_INTERACTIVE_DELETE"
 | 
				
			||||||
	OpenAPIV3Patch          FeatureGate = "KUBECTL_OPENAPIV3_PATCH"
 | 
						OpenAPIV3Patch          FeatureGate = "KUBECTL_OPENAPIV3_PATCH"
 | 
				
			||||||
	RemoteCommandWebsockets FeatureGate = "KUBECTL_REMOTE_COMMAND_WEBSOCKETS"
 | 
						RemoteCommandWebsockets FeatureGate = "KUBECTL_REMOTE_COMMAND_WEBSOCKETS"
 | 
				
			||||||
 | 
						PortForwardWebsockets   FeatureGate = "KUBECTL_PORT_FORWARD_WEBSOCKETS"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// IsEnabled returns true iff environment variable is set to true.
 | 
					// IsEnabled returns true iff environment variable is set to true.
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										27
									
								
								test/integration/apiserver/portforward/main_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								test/integration/apiserver/portforward/main_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Copyright 2024 The Kubernetes Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package portforward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"k8s.io/kubernetes/test/integration/framework"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMain(m *testing.M) {
 | 
				
			||||||
 | 
						framework.EtcdMain(m.Run)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										228
									
								
								test/integration/apiserver/portforward/portforward_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								test/integration/apiserver/portforward/portforward_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,228 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Copyright 2024 The Kubernetes Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package portforward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bufio"
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/http/httptest"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/require"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						corev1 "k8s.io/api/core/v1"
 | 
				
			||||||
 | 
						metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/types"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/remotecommand"
 | 
				
			||||||
 | 
						"k8s.io/apimachinery/pkg/util/wait"
 | 
				
			||||||
 | 
						utilfeature "k8s.io/apiserver/pkg/util/feature"
 | 
				
			||||||
 | 
						"k8s.io/cli-runtime/pkg/genericiooptions"
 | 
				
			||||||
 | 
						"k8s.io/client-go/kubernetes"
 | 
				
			||||||
 | 
						featuregatetesting "k8s.io/component-base/featuregate/testing"
 | 
				
			||||||
 | 
						"k8s.io/kubectl/pkg/cmd/portforward"
 | 
				
			||||||
 | 
						kubeletportforward "k8s.io/kubelet/pkg/cri/streaming/portforward"
 | 
				
			||||||
 | 
						kastesting "k8s.io/kubernetes/cmd/kube-apiserver/app/testing"
 | 
				
			||||||
 | 
						kubefeatures "k8s.io/kubernetes/pkg/features"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"k8s.io/kubernetes/test/integration/framework"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const remotePort = "8765"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestPortforward(t *testing.T) {
 | 
				
			||||||
 | 
						defer featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, kubefeatures.PortForwardWebsockets, true)()
 | 
				
			||||||
 | 
						t.Setenv("KUBECTL_PORT_FORWARD_WEBSOCKETS", "true")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var podName string
 | 
				
			||||||
 | 
						var podUID types.UID
 | 
				
			||||||
 | 
						backendServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
 | 
							t.Logf("backend saw request: %v", req.URL.String())
 | 
				
			||||||
 | 
							kubeletportforward.ServePortForward(
 | 
				
			||||||
 | 
								w,
 | 
				
			||||||
 | 
								req,
 | 
				
			||||||
 | 
								&dummyPortForwarder{t: t},
 | 
				
			||||||
 | 
								podName,
 | 
				
			||||||
 | 
								podUID,
 | 
				
			||||||
 | 
								&kubeletportforward.V4Options{},
 | 
				
			||||||
 | 
								wait.ForeverTestTimeout, // idle timeout
 | 
				
			||||||
 | 
								remotecommand.DefaultStreamCreationTimeout, // stream creation timeout
 | 
				
			||||||
 | 
								[]string{kubeletportforward.ProtocolV1Name},
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
						defer backendServer.Close()
 | 
				
			||||||
 | 
						backendURL, _ := url.Parse(backendServer.URL)
 | 
				
			||||||
 | 
						backendHost := backendURL.Hostname()
 | 
				
			||||||
 | 
						backendPort, _ := strconv.Atoi(backendURL.Port())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						etcd := framework.SharedEtcd()
 | 
				
			||||||
 | 
						server := kastesting.StartTestServerOrDie(t, nil, []string{"--disable-admission-plugins=ServiceAccount"}, etcd)
 | 
				
			||||||
 | 
						defer server.TearDownFn()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						adminClient, err := kubernetes.NewForConfig(server.ClientConfig)
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						node := &corev1.Node{
 | 
				
			||||||
 | 
							ObjectMeta: metav1.ObjectMeta{Name: "mynode"},
 | 
				
			||||||
 | 
							Status: corev1.NodeStatus{
 | 
				
			||||||
 | 
								DaemonEndpoints: corev1.NodeDaemonEndpoints{KubeletEndpoint: corev1.DaemonEndpoint{Port: int32(backendPort)}},
 | 
				
			||||||
 | 
								Addresses:       []corev1.NodeAddress{{Type: corev1.NodeInternalIP, Address: backendHost}},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if _, err := adminClient.CoreV1().Nodes().Create(context.Background(), node, metav1.CreateOptions{}); err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pod := &corev1.Pod{
 | 
				
			||||||
 | 
							ObjectMeta: metav1.ObjectMeta{Namespace: "default", Name: "mypod"},
 | 
				
			||||||
 | 
							Spec: corev1.PodSpec{
 | 
				
			||||||
 | 
								NodeName:   "mynode",
 | 
				
			||||||
 | 
								Containers: []corev1.Container{{Name: "test", Image: "test"}},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if _, err := adminClient.CoreV1().Pods("default").Create(context.Background(), pod, metav1.CreateOptions{}); err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if _, err := adminClient.CoreV1().Pods("default").Patch(context.Background(), "mypod", types.MergePatchType, []byte(`{"status":{"phase":"Running"}}`), metav1.PatchOptions{}, "status"); err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// local port missing asks os to find random open port.
 | 
				
			||||||
 | 
						// Example: ":8000" (local = random, remote = 8000)
 | 
				
			||||||
 | 
						localRemotePort := fmt.Sprintf(":%s", remotePort)
 | 
				
			||||||
 | 
						streams, _, out, errOut := genericiooptions.NewTestIOStreams()
 | 
				
			||||||
 | 
						portForwardOptions := portforward.NewDefaultPortForwardOptions(streams)
 | 
				
			||||||
 | 
						portForwardOptions.Namespace = "default"
 | 
				
			||||||
 | 
						portForwardOptions.PodName = "mypod"
 | 
				
			||||||
 | 
						portForwardOptions.RESTClient = adminClient.CoreV1().RESTClient()
 | 
				
			||||||
 | 
						portForwardOptions.Config = server.ClientConfig
 | 
				
			||||||
 | 
						portForwardOptions.PodClient = adminClient.CoreV1()
 | 
				
			||||||
 | 
						portForwardOptions.Address = []string{"127.0.0.1"}
 | 
				
			||||||
 | 
						portForwardOptions.Ports = []string{localRemotePort}
 | 
				
			||||||
 | 
						portForwardOptions.StopChannel = make(chan struct{}, 1)
 | 
				
			||||||
 | 
						portForwardOptions.ReadyChannel = make(chan struct{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := portForwardOptions.Validate(); err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ctx, cancel := context.WithCancel(context.Background())
 | 
				
			||||||
 | 
						defer cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						wg := sync.WaitGroup{}
 | 
				
			||||||
 | 
						wg.Add(1)
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							defer wg.Done()
 | 
				
			||||||
 | 
							if err := portForwardOptions.RunPortForwardContext(ctx); err != nil {
 | 
				
			||||||
 | 
								t.Error(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log("waiting for port forward to be ready")
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case <-portForwardOptions.ReadyChannel:
 | 
				
			||||||
 | 
							t.Log("port forward was ready")
 | 
				
			||||||
 | 
						case <-time.After(wait.ForeverTestTimeout):
 | 
				
			||||||
 | 
							t.Error("port forward was never ready")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Parse out the randomly selected local port from "out" stream.
 | 
				
			||||||
 | 
						localPort, err := parsePort(out.String())
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						t.Logf("Local Port: %s", localPort)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						timeoutContext, cleanupTimeoutContext := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
 | 
				
			||||||
 | 
						defer cleanupTimeoutContext()
 | 
				
			||||||
 | 
						testReq, _ := http.NewRequest("GET", fmt.Sprintf("http://127.0.0.1:%s/test", localPort), nil)
 | 
				
			||||||
 | 
						testReq = testReq.WithContext(timeoutContext)
 | 
				
			||||||
 | 
						testResp, err := http.DefaultClient.Do(testReq)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Error(err)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							t.Log(testResp.StatusCode)
 | 
				
			||||||
 | 
							data, err := io.ReadAll(testResp.Body)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Error(err)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								t.Log("client saw response:", string(data))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if string(data) != fmt.Sprintf("request to %s was ok", remotePort) {
 | 
				
			||||||
 | 
								t.Errorf("unexpected data")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if testResp.StatusCode != 200 {
 | 
				
			||||||
 | 
								t.Error("expected success")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						wg.Wait()
 | 
				
			||||||
 | 
						t.Logf("stdout: %s", out.String())
 | 
				
			||||||
 | 
						t.Logf("stderr: %s", errOut.String())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// parsePort parses out the local port from the port-forward output string.
 | 
				
			||||||
 | 
					// This should work for both IP4 and IP6 addresses.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//	Example: "Forwarding from 127.0.0.1:8000 -> 4000", returns "8000".
 | 
				
			||||||
 | 
					func parsePort(forwardAddr string) (string, error) {
 | 
				
			||||||
 | 
						parts := strings.Split(forwardAddr, " ")
 | 
				
			||||||
 | 
						if len(parts) != 5 {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("unable to parse local port from stdout: %s", forwardAddr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// parts[2] = "127.0.0.1:<LOCAL_PORT>"
 | 
				
			||||||
 | 
						_, localPort, err := net.SplitHostPort(parts[2])
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("unable to parse local port: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return localPort, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type dummyPortForwarder struct {
 | 
				
			||||||
 | 
						t *testing.T
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (d *dummyPortForwarder) PortForward(ctx context.Context, name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
 | 
				
			||||||
 | 
						d.t.Logf("handling port forward request for %d", port)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req, err := http.ReadRequest(bufio.NewReader(stream))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							d.t.Logf("error reading request: %v", err)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						d.t.Log(req.URL.String())
 | 
				
			||||||
 | 
						defer req.Body.Close() //nolint:errcheck
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp := &http.Response{
 | 
				
			||||||
 | 
							StatusCode: 200,
 | 
				
			||||||
 | 
							Proto:      "HTTP/1.1",
 | 
				
			||||||
 | 
							ProtoMajor: 1,
 | 
				
			||||||
 | 
							ProtoMinor: 1,
 | 
				
			||||||
 | 
							Body:       io.NopCloser(bytes.NewBufferString(fmt.Sprintf("request to %d was ok", port))),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resp.Write(stream) //nolint:errcheck
 | 
				
			||||||
 | 
						return stream.Close()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Reference in New Issue
	
	Block a user