Merge pull request #119186 from seans3/stream-translator-proxy

Stream Translator Proxy and FallbackExecutor for WebSockets
This commit is contained in:
Kubernetes Prow Robot
2023-10-24 17:10:34 +02:00
committed by GitHub
42 changed files with 2379 additions and 118 deletions

View File

@@ -179,6 +179,14 @@ const (
// Enables kubelet to detect CSI volume condition and send the event of the abnormal volume to the corresponding pod that is using it.
CSIVolumeHealth featuregate.Feature = "CSIVolumeHealth"
// owner: @seans3
// kep: http://kep.k8s.io/4006
// alpha: v1.29
//
// Enables StreamTranslator proxy to handle WebSockets upgrade requests for the
// version of the RemoteCommand subprotocol that supports the "close" signal.
TranslateStreamCloseWebsocketRequests featuregate.Feature = "TranslateStreamCloseWebsocketRequests"
// owner: @nckturner
// kep: http://kep.k8s.io/2699
// alpha: v1.27
@@ -925,6 +933,8 @@ var defaultKubernetesFeatureGates = map[featuregate.Feature]featuregate.FeatureS
SkipReadOnlyValidationGCE: {Default: true, PreRelease: featuregate.Deprecated}, // remove in 1.31
TranslateStreamCloseWebsocketRequests: {Default: false, PreRelease: featuregate.Alpha},
CloudControllerManagerWebhook: {Default: false, PreRelease: featuregate.Alpha},
ContainerCheckpoint: {Default: false, PreRelease: featuregate.Alpha},

View File

@@ -18,6 +18,7 @@ package server
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
@@ -959,7 +960,10 @@ func TestServeExecInContainerIdleTimeout(t *testing.T) {
url := fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?c=ls&c=-a&" + api.ExecStdinParam + "=1"
upgradeRoundTripper := spdy.NewRoundTripper(nil)
upgradeRoundTripper, err := spdy.NewRoundTripper(&tls.Config{})
if err != nil {
t.Fatalf("Error creating SpdyRoundTripper: %v", err)
}
c := &http.Client{Transport: upgradeRoundTripper}
resp, err := c.Do(makeReq(t, "POST", url, "v4.channel.k8s.io"))
@@ -1115,7 +1119,10 @@ func testExecAttach(t *testing.T, verb string) {
upgradeRoundTripper httpstream.UpgradeRoundTripper
c *http.Client
)
upgradeRoundTripper = spdy.NewRoundTripper(nil)
upgradeRoundTripper, err = spdy.NewRoundTripper(&tls.Config{})
if err != nil {
t.Fatalf("Error creating SpdyRoundTripper: %v", err)
}
c = &http.Client{Transport: upgradeRoundTripper}
resp, err = c.Do(makeReq(t, "POST", url, "v4.channel.k8s.io"))
@@ -1211,7 +1218,10 @@ func TestServePortForwardIdleTimeout(t *testing.T) {
url := fw.testHTTPServer.URL + "/portForward/" + podNamespace + "/" + podName
upgradeRoundTripper := spdy.NewRoundTripper(nil)
upgradeRoundTripper, err := spdy.NewRoundTripper(&tls.Config{})
if err != nil {
t.Fatalf("Error creating SpdyRoundTripper: %v", err)
}
c := &http.Client{Transport: upgradeRoundTripper}
req := makeReq(t, "POST", url, "portforward.k8s.io")
@@ -1310,7 +1320,10 @@ func TestServePortForward(t *testing.T) {
c *http.Client
)
upgradeRoundTripper = spdy.NewRoundTripper(nil)
upgradeRoundTripper, err = spdy.NewRoundTripper(&tls.Config{})
if err != nil {
t.Fatalf("Error creating SpdyRoundTripper: %v", err)
}
c = &http.Client{Transport: upgradeRoundTripper}
req := makeReq(t, "POST", url, "portforward.k8s.io")

View File

@@ -23,12 +23,16 @@ import (
"net/url"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
"k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/pkg/util/proxy"
genericregistry "k8s.io/apiserver/pkg/registry/generic/registry"
"k8s.io/apiserver/pkg/registry/rest"
utilfeature "k8s.io/apiserver/pkg/util/feature"
translator "k8s.io/apiserver/pkg/util/proxy"
api "k8s.io/kubernetes/pkg/apis/core"
"k8s.io/kubernetes/pkg/capabilities"
"k8s.io/kubernetes/pkg/features"
"k8s.io/kubernetes/pkg/kubelet/client"
"k8s.io/kubernetes/pkg/registry/core/pod"
)
@@ -113,7 +117,21 @@ func (r *AttachREST) Connect(ctx context.Context, name string, opts runtime.Obje
if err != nil {
return nil, err
}
return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil
handler := newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder)
if utilfeature.DefaultFeatureGate.Enabled(features.TranslateStreamCloseWebsocketRequests) {
// Wrap the upgrade aware handler to implement stream translation
// for WebSocket/V5 upgrade requests.
streamOptions := translator.Options{
Stdin: attachOpts.Stdin,
Stdout: attachOpts.Stdout,
Stderr: attachOpts.Stderr,
Tty: attachOpts.TTY,
}
maxBytesPerSec := capabilities.Get().PerConnectionBandwidthLimitBytesPerSec
streamtranslator := translator.NewStreamTranslatorHandler(location, transport, maxBytesPerSec, streamOptions)
handler = translator.NewTranslatingHandler(handler, streamtranslator, wsstream.IsWebSocketRequestWithStreamCloseProtocol)
}
return handler, nil
}
// NewConnectOptions returns the versioned object that represents exec parameters
@@ -156,7 +174,21 @@ func (r *ExecREST) Connect(ctx context.Context, name string, opts runtime.Object
if err != nil {
return nil, err
}
return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil
handler := newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder)
if utilfeature.DefaultFeatureGate.Enabled(features.TranslateStreamCloseWebsocketRequests) {
// Wrap the upgrade aware handler to implement stream translation
// for WebSocket/V5 upgrade requests.
streamOptions := translator.Options{
Stdin: execOpts.Stdin,
Stdout: execOpts.Stdout,
Stderr: execOpts.Stderr,
Tty: execOpts.TTY,
}
maxBytesPerSec := capabilities.Get().PerConnectionBandwidthLimitBytesPerSec
streamtranslator := translator.NewStreamTranslatorHandler(location, transport, maxBytesPerSec, streamOptions)
handler = translator.NewTranslatingHandler(handler, streamtranslator, wsstream.IsWebSocketRequestWithStreamCloseProtocol)
}
return handler, nil
}
// NewConnectOptions returns the versioned object that represents exec parameters
@@ -213,7 +245,7 @@ func (r *PortForwardREST) Connect(ctx context.Context, name string, opts runtime
return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil
}
func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder rest.Responder) *proxy.UpgradeAwareHandler {
func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder rest.Responder) http.Handler {
handler := proxy.NewUpgradeAwareHandler(location, transport, wrapTransport, upgradeRequired, proxy.NewErrorResponder(responder))
handler.MaxBytesPerSec = capabilities.Get().PerConnectionBandwidthLimitBytesPerSec
return handler

View File

@@ -73,9 +73,11 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/moby/spdystream v0.2.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.16.0 // indirect

View File

@@ -163,6 +163,7 @@ github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8V
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df h1:7RFfzj4SSt6nnvCPbCqijJi1nWCd+TqAT3bYCStRC18=
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df/go.mod h1:pSwJ0fSY5KhvocuWSx4fz3BA8OrA1bQn+K1Eli3BRwM=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4B6AGu/h5Sxe66HYVdqdGu2l9Iebqhi/AEoA=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
@@ -378,6 +379,7 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/moby/spdystream v0.2.0 h1:cjW1zVyyoiM0T7b6UoySUFqzXMoqRckQtXwGPiBhOM8=
github.com/moby/spdystream v0.2.0/go.mod h1:f7i0iNDQJ059oMTcWxx8MA/zKFIuD/lY+0GqbN2Wy8c=
github.com/moby/term v0.0.0-20221205130635-1aeaba878587/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -388,6 +390,7 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=

View File

@@ -17,6 +17,7 @@ limitations under the License.
package httpstream
import (
"errors"
"fmt"
"io"
"net/http"
@@ -95,6 +96,26 @@ type Stream interface {
Identifier() uint32
}
// UpgradeFailureError encapsulates the cause for why the streaming
// upgrade request failed. Implements error interface.
type UpgradeFailureError struct {
Cause error
}
func (u *UpgradeFailureError) Error() string {
return fmt.Sprintf("unable to upgrade streaming request: %s", u.Cause)
}
// IsUpgradeFailure returns true if the passed error is (or wrapped error contains)
// the UpgradeFailureError.
func IsUpgradeFailure(err error) bool {
if err == nil {
return false
}
var upgradeErr *UpgradeFailureError
return errors.As(err, &upgradeErr)
}
// IsUpgradeRequest returns true if the given request is a connection upgrade request
func IsUpgradeRequest(req *http.Request) bool {
for _, h := range req.Header[http.CanonicalHeaderKey(HeaderConnection)] {

View File

@@ -17,6 +17,8 @@ limitations under the License.
package httpstream
import (
"errors"
"fmt"
"net/http"
"reflect"
"testing"
@@ -129,3 +131,40 @@ func TestHandshake(t *testing.T) {
}
}
}
func TestIsUpgradeFailureError(t *testing.T) {
testCases := map[string]struct {
err error
expected bool
}{
"nil error should return false": {
err: nil,
expected: false,
},
"Non-upgrade error should return false": {
err: fmt.Errorf("this is not an upgrade error"),
expected: false,
},
"UpgradeFailure error should return true": {
err: &UpgradeFailureError{},
expected: true,
},
"Wrapped Non-UpgradeFailure error should return false": {
err: fmt.Errorf("%s: %w", "first error", errors.New("Non-upgrade error")),
expected: false,
},
"Wrapped UpgradeFailure error should return true": {
err: fmt.Errorf("%s: %w", "first error", &UpgradeFailureError{}),
expected: true,
},
}
for name, test := range testCases {
t.Run(name, func(t *testing.T) {
actual := IsUpgradeFailure(test.err)
if test.expected != actual {
t.Errorf("expected upgrade failure %t, got %t", test.expected, actual)
}
})
}
}

View File

@@ -38,6 +38,7 @@ import (
"k8s.io/apimachinery/pkg/runtime/serializer"
"k8s.io/apimachinery/pkg/util/httpstream"
utilnet "k8s.io/apimachinery/pkg/util/net"
apiproxy "k8s.io/apimachinery/pkg/util/proxy"
"k8s.io/apimachinery/third_party/forked/golang/netutil"
)
@@ -68,6 +69,10 @@ type SpdyRoundTripper struct {
// pingPeriod is a period for sending Ping frames over established
// connections.
pingPeriod time.Duration
// upgradeTransport is an optional substitute for dialing if present. This field is
// mutually exclusive with the "tlsConfig", "Dialer", and "proxier".
upgradeTransport http.RoundTripper
}
var _ utilnet.TLSClientConfigHolder = &SpdyRoundTripper{}
@@ -76,43 +81,61 @@ var _ utilnet.Dialer = &SpdyRoundTripper{}
// NewRoundTripper creates a new SpdyRoundTripper that will use the specified
// tlsConfig.
func NewRoundTripper(tlsConfig *tls.Config) *SpdyRoundTripper {
func NewRoundTripper(tlsConfig *tls.Config) (*SpdyRoundTripper, error) {
return NewRoundTripperWithConfig(RoundTripperConfig{
TLS: tlsConfig,
TLS: tlsConfig,
UpgradeTransport: nil,
})
}
// NewRoundTripperWithProxy creates a new SpdyRoundTripper that will use the
// specified tlsConfig and proxy func.
func NewRoundTripperWithProxy(tlsConfig *tls.Config, proxier func(*http.Request) (*url.URL, error)) *SpdyRoundTripper {
func NewRoundTripperWithProxy(tlsConfig *tls.Config, proxier func(*http.Request) (*url.URL, error)) (*SpdyRoundTripper, error) {
return NewRoundTripperWithConfig(RoundTripperConfig{
TLS: tlsConfig,
Proxier: proxier,
TLS: tlsConfig,
Proxier: proxier,
UpgradeTransport: nil,
})
}
// NewRoundTripperWithConfig creates a new SpdyRoundTripper with the specified
// configuration.
func NewRoundTripperWithConfig(cfg RoundTripperConfig) *SpdyRoundTripper {
// configuration. Returns an error if the SpdyRoundTripper is misconfigured.
func NewRoundTripperWithConfig(cfg RoundTripperConfig) (*SpdyRoundTripper, error) {
// Process UpgradeTransport, which is mutually exclusive to TLSConfig and Proxier.
if cfg.UpgradeTransport != nil {
if cfg.TLS != nil || cfg.Proxier != nil {
return nil, fmt.Errorf("SpdyRoundTripper: UpgradeTransport is mutually exclusive to TLSConfig or Proxier")
}
tlsConfig, err := utilnet.TLSClientConfig(cfg.UpgradeTransport)
if err != nil {
return nil, fmt.Errorf("SpdyRoundTripper: Unable to retrieve TLSConfig from UpgradeTransport: %v", err)
}
cfg.TLS = tlsConfig
}
if cfg.Proxier == nil {
cfg.Proxier = utilnet.NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
}
return &SpdyRoundTripper{
tlsConfig: cfg.TLS,
proxier: cfg.Proxier,
pingPeriod: cfg.PingPeriod,
}
tlsConfig: cfg.TLS,
proxier: cfg.Proxier,
pingPeriod: cfg.PingPeriod,
upgradeTransport: cfg.UpgradeTransport,
}, nil
}
// RoundTripperConfig is a set of options for an SpdyRoundTripper.
type RoundTripperConfig struct {
// TLS configuration used by the round tripper.
// TLS configuration used by the round tripper if UpgradeTransport not present.
TLS *tls.Config
// Proxier is a proxy function invoked on each request. Optional.
Proxier func(*http.Request) (*url.URL, error)
// PingPeriod is a period for sending SPDY Pings on the connection.
// Optional.
PingPeriod time.Duration
// UpgradeTransport is a subtitute transport used for dialing. If set,
// this field will be used instead of "TLS" and "Proxier" for connection creation.
// Optional.
UpgradeTransport http.RoundTripper
}
// TLSClientConfig implements pkg/util/net.TLSClientConfigHolder for proper TLS checking during
@@ -123,7 +146,13 @@ func (s *SpdyRoundTripper) TLSClientConfig() *tls.Config {
// Dial implements k8s.io/apimachinery/pkg/util/net.Dialer.
func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) {
conn, err := s.dial(req)
var conn net.Conn
var err error
if s.upgradeTransport != nil {
conn, err = apiproxy.DialURL(req.Context(), req.URL, s.upgradeTransport)
} else {
conn, err = s.dial(req)
}
if err != nil {
return nil, err
}

View File

@@ -25,7 +25,9 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strconv"
"strings"
"testing"
"github.com/armon/go-socks5"
@@ -324,7 +326,10 @@ func TestRoundTripAndNewConnection(t *testing.T) {
t.Fatalf("error creating request: %s", err)
}
spdyTransport := NewRoundTripper(testCase.clientTLS)
spdyTransport, err := NewRoundTripper(testCase.clientTLS)
if err != nil {
t.Fatalf("error creating SpdyRoundTripper: %v", err)
}
var proxierCalled bool
var proxyCalledWithHost string
@@ -428,6 +433,74 @@ func TestRoundTripAndNewConnection(t *testing.T) {
}
}
// Tests SpdyRoundTripper constructors
func TestRoundTripConstuctor(t *testing.T) {
testCases := map[string]struct {
tlsConfig *tls.Config
proxier func(req *http.Request) (*url.URL, error)
upgradeTransport http.RoundTripper
expectedTLSConfig *tls.Config
errMsg string
}{
"Basic TLSConfig; no error": {
tlsConfig: &tls.Config{InsecureSkipVerify: true},
expectedTLSConfig: &tls.Config{InsecureSkipVerify: true},
upgradeTransport: nil,
},
"Basic TLSConfig and Proxier: no error": {
tlsConfig: &tls.Config{InsecureSkipVerify: true},
proxier: func(req *http.Request) (*url.URL, error) { return nil, nil },
expectedTLSConfig: &tls.Config{InsecureSkipVerify: true},
upgradeTransport: nil,
},
"TLSConfig with UpgradeTransport: error": {
tlsConfig: &tls.Config{InsecureSkipVerify: true},
upgradeTransport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
expectedTLSConfig: &tls.Config{InsecureSkipVerify: true},
errMsg: "SpdyRoundTripper: UpgradeTransport is mutually exclusive to TLSConfig or Proxier",
},
"Proxier with UpgradeTransport: error": {
proxier: func(req *http.Request) (*url.URL, error) { return nil, nil },
upgradeTransport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
expectedTLSConfig: &tls.Config{InsecureSkipVerify: true},
errMsg: "SpdyRoundTripper: UpgradeTransport is mutually exclusive to TLSConfig or Proxier",
},
"Only UpgradeTransport: no error": {
upgradeTransport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
expectedTLSConfig: &tls.Config{InsecureSkipVerify: true},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
spdyRoundTripper, err := NewRoundTripperWithConfig(
RoundTripperConfig{
TLS: testCase.tlsConfig,
Proxier: testCase.proxier,
UpgradeTransport: testCase.upgradeTransport,
},
)
if testCase.errMsg != "" {
if err == nil {
t.Fatalf("expected error but received none")
}
if !strings.Contains(err.Error(), testCase.errMsg) {
t.Fatalf("expected error message (%s), got (%s)", err.Error(), testCase.errMsg)
}
}
if testCase.errMsg == "" {
if err != nil {
t.Fatalf("unexpected error received: %v", err)
}
actualTLSConfig := spdyRoundTripper.TLSClientConfig()
if !reflect.DeepEqual(testCase.expectedTLSConfig, actualTLSConfig) {
t.Errorf("expected TLSConfig (%v), got (%v)",
testCase.expectedTLSConfig, actualTLSConfig)
}
}
})
}
}
type Interceptor struct {
Authorization socks5.AuthContext
proxyCalledWithHost *string
@@ -544,7 +617,10 @@ func TestRoundTripSocks5AndNewConnection(t *testing.T) {
t.Fatalf("error creating request: %s", err)
}
spdyTransport := NewRoundTripper(testCase.clientTLS)
spdyTransport, err := NewRoundTripper(testCase.clientTLS)
if err != nil {
t.Fatalf("error creating SpdyRoundTripper: %v", err)
}
var proxierCalled bool
var proxyCalledWithHost string
@@ -704,7 +780,10 @@ func TestRoundTripPassesContextToDialer(t *testing.T) {
cancel()
req, err := http.NewRequestWithContext(ctx, "GET", u, nil)
require.NoError(t, err)
spdyTransport := NewRoundTripper(&tls.Config{})
spdyTransport, err := NewRoundTripper(&tls.Config{})
if err != nil {
t.Fatalf("error creating SpdyRoundTripper: %v", err)
}
_, err = spdyTransport.Dial(req)
assert.EqualError(t, err, "dial tcp 127.0.0.1:1233: operation was canceled")
})

View File

@@ -32,6 +32,8 @@ import (
"k8s.io/klog/v2"
)
const WebSocketProtocolHeader = "Sec-Websocket-Protocol"
// The Websocket subprotocol "channel.k8s.io" prepends each binary message with a byte indicating
// the channel number (zero indexed) the message was sent on. Messages in both directions should
// prefix their messages with this channel byte. When used for remote execution, the channel numbers
@@ -87,6 +89,23 @@ func IsWebSocketRequest(req *http.Request) bool {
return httpstream.IsUpgradeRequest(req)
}
// IsWebSocketRequestWithStreamCloseProtocol returns true if the request contains headers
// identifying that it is requesting a websocket upgrade with a remotecommand protocol
// version that supports the "CLOSE" signal; false otherwise.
func IsWebSocketRequestWithStreamCloseProtocol(req *http.Request) bool {
if !IsWebSocketRequest(req) {
return false
}
requestedProtocols := strings.TrimSpace(req.Header.Get(WebSocketProtocolHeader))
for _, requestedProtocol := range strings.Split(requestedProtocols, ",") {
if protocolSupportsStreamClose(strings.TrimSpace(requestedProtocol)) {
return true
}
}
return false
}
// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the
// read and write deadlines are pushed every time a new message is received.
func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
@@ -168,15 +187,46 @@ func (conn *Conn) SetIdleTimeout(duration time.Duration) {
conn.timeout = duration
}
// SetWriteDeadline sets a timeout on writing to the websocket connection. The
// passed "duration" identifies how far into the future the write must complete
// by before the timeout fires.
func (conn *Conn) SetWriteDeadline(duration time.Duration) {
conn.ws.SetWriteDeadline(time.Now().Add(duration)) //nolint:errcheck
}
// Open the connection and create channels for reading and writing. It returns
// the selected subprotocol, a slice of channels and an error.
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) {
// serveHTTPComplete is channel that is closed/selected when "websocket#ServeHTTP" finishes.
serveHTTPComplete := make(chan struct{})
// Ensure panic in spawned goroutine is propagated into the parent goroutine.
panicChan := make(chan any, 1)
go func() {
defer runtime.HandleCrash()
defer conn.Close()
// If websocket server returns, propagate panic if necessary. Otherwise,
// signal HTTPServe finished by closing "serveHTTPComplete".
defer func() {
if p := recover(); p != nil {
panicChan <- p
} else {
close(serveHTTPComplete)
}
}()
websocket.Server{Handshake: conn.handshake, Handler: conn.handle}.ServeHTTP(w, req)
}()
<-conn.ready
// In normal circumstances, "websocket.Server#ServeHTTP" calls "initialize" which closes
// "conn.ready" and then blocks until serving is complete.
select {
case <-conn.ready:
klog.V(8).Infof("websocket server initialized--serving")
case <-serveHTTPComplete:
// websocket server returned before completing initialization; cleanup and return error.
conn.closeNonThreadSafe() //nolint:errcheck
return "", nil, fmt.Errorf("websocket server finished before becoming ready")
case p := <-panicChan:
panic(p)
}
rwc := make([]io.ReadWriteCloser, len(conn.channels))
for i := range conn.channels {
rwc[i] = conn.channels[i]
@@ -225,14 +275,23 @@ func (conn *Conn) resetTimeout() {
}
}
// Close is only valid after Open has been called
func (conn *Conn) Close() error {
<-conn.ready
// closeNonThreadSafe cleans up by closing streams and the websocket
// connection *without* waiting for the "ready" channel.
func (conn *Conn) closeNonThreadSafe() error {
for _, s := range conn.channels {
s.Close()
}
conn.ws.Close()
return nil
var err error
if conn.ws != nil {
err = conn.ws.Close()
}
return err
}
// Close is only valid after Open has been called
func (conn *Conn) Close() error {
<-conn.ready
return conn.closeNonThreadSafe()
}
// protocolSupportsStreamClose returns true if the passed protocol
@@ -244,8 +303,8 @@ func protocolSupportsStreamClose(protocol string) bool {
// handle implements a websocket handler.
func (conn *Conn) handle(ws *websocket.Conn) {
defer conn.Close()
conn.initialize(ws)
defer conn.Close()
supportsStreamClose := protocolSupportsStreamClose(conn.selectedProtocol)
for {

View File

@@ -25,6 +25,8 @@ import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
)
@@ -271,3 +273,146 @@ func TestVersionedConn(t *testing.T) {
}()
}
}
func TestIsWebSocketRequestWithStreamCloseProtocol(t *testing.T) {
tests := map[string]struct {
headers map[string]string
expected bool
}{
"No headers returns false": {
headers: map[string]string{},
expected: false,
},
"Only connection upgrade header is false": {
headers: map[string]string{
"Connection": "upgrade",
},
expected: false,
},
"Only websocket upgrade header is false": {
headers: map[string]string{
"Upgrade": "websocket",
},
expected: false,
},
"Only websocket and connection upgrade headers is false": {
headers: map[string]string{
"Connection": "upgrade",
"Upgrade": "websocket",
},
expected: false,
},
"Missing connection/upgrade header is false": {
headers: map[string]string{
"Upgrade": "websocket",
WebSocketProtocolHeader: "v5.channel.k8s.io",
},
expected: false,
},
"Websocket connection upgrade headers with v5 protocol is true": {
headers: map[string]string{
"Connection": "upgrade",
"Upgrade": "websocket",
WebSocketProtocolHeader: "v5.channel.k8s.io",
},
expected: true,
},
"Websocket connection upgrade headers with wrong case v5 protocol is false": {
headers: map[string]string{
"Connection": "upgrade",
"Upgrade": "websocket",
WebSocketProtocolHeader: "v5.CHANNEL.k8s.io", // header value is case-sensitive
},
expected: false,
},
"Websocket connection upgrade headers with v4 protocol is false": {
headers: map[string]string{
"Connection": "upgrade",
"Upgrade": "websocket",
WebSocketProtocolHeader: "v4.channel.k8s.io",
},
expected: false,
},
"Websocket connection upgrade headers with multiple protocols but missing v5 is false": {
headers: map[string]string{
"Connection": "upgrade",
"Upgrade": "websocket",
WebSocketProtocolHeader: "v4.channel.k8s.io,v3.channel.k8s.io,v2.channel.k8s.io",
},
expected: false,
},
"Websocket connection upgrade headers with multiple protocols including v5 and spaces is true": {
headers: map[string]string{
"Connection": "upgrade",
"Upgrade": "websocket",
WebSocketProtocolHeader: "v5.channel.k8s.io, v4.channel.k8s.io",
},
expected: true,
},
"Websocket connection upgrade headers with multiple protocols out of order including v5 and spaces is true": {
headers: map[string]string{
"Connection": "upgrade",
"Upgrade": "websocket",
WebSocketProtocolHeader: "v4.channel.k8s.io, v5.channel.k8s.io, v3.channel.k8s.io",
},
expected: true,
},
"Websocket connection upgrade headers key is case-insensitive": {
headers: map[string]string{
"Connection": "upgrade",
"Upgrade": "websocket",
"sec-websocket-protocol": "v4.channel.k8s.io, v5.channel.k8s.io, v3.channel.k8s.io",
},
expected: true,
},
}
for name, test := range tests {
req, err := http.NewRequest("GET", "http://www.example.com/", nil)
require.NoError(t, err)
for key, value := range test.headers {
req.Header.Add(key, value)
}
actual := IsWebSocketRequestWithStreamCloseProtocol(req)
assert.Equal(t, test.expected, actual, "%s: expected (%t), got (%t)", name, test.expected, actual)
}
}
func TestProtocolSupportsStreamClose(t *testing.T) {
tests := map[string]struct {
protocol string
expected bool
}{
"empty protocol returns false": {
protocol: "",
expected: false,
},
"not binary protocol returns false": {
protocol: "base64.channel.k8s.io",
expected: false,
},
"V1 protocol returns false": {
protocol: "channel.k8s.io",
expected: false,
},
"V4 protocol returns false": {
protocol: "v4.channel.k8s.io",
expected: false,
},
"V5 protocol returns true": {
protocol: "v5.channel.k8s.io",
expected: true,
},
"V5 protocol wrong case returns false": {
protocol: "V5.channel.K8S.io",
expected: false,
},
}
for name, test := range tests {
actual := protocolSupportsStreamClose(test.protocol)
assert.Equal(t, test.expected, actual,
"%s: expected (%t), got (%t)", name, test.expected, actual)
}
}

View File

@@ -29,12 +29,12 @@ import (
"k8s.io/klog/v2"
)
// dialURL will dial the specified URL using the underlying dialer held by the passed
// DialURL will dial the specified URL using the underlying dialer held by the passed
// RoundTripper. The primary use of this method is to support proxying upgradable connections.
// For this reason this method will prefer to negotiate http/1.1 if the URL scheme is https.
// If you wish to ensure ALPN negotiates http2 then set NextProto=[]string{"http2"} in the
// TLSConfig of the http.Transport
func dialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (net.Conn, error) {
func DialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)
dialer, err := utilnet.DialerFor(transport)

View File

@@ -143,7 +143,7 @@ func TestDialURL(t *testing.T) {
u, _ := url.Parse(ts.URL)
_, p, _ := net.SplitHostPort(u.Host)
u.Host = net.JoinHostPort("127.0.0.1", p)
conn, err := dialURL(context.Background(), u, transport)
conn, err := DialURL(context.Background(), u, transport)
// Make sure dialing doesn't mutate the transport's TLSConfig
if !reflect.DeepEqual(tc.TLSConfig, tlsConfigCopy) {

View File

@@ -492,7 +492,7 @@ func getResponse(r io.Reader) (*http.Response, []byte, error) {
// dial dials the backend at req.URL and writes req to it.
func dial(req *http.Request, transport http.RoundTripper) (net.Conn, error) {
conn, err := dialURL(req.Context(), req.URL, transport)
conn, err := DialURL(req.Context(), req.URL, transport)
if err != nil {
return nil, fmt.Errorf("error dialing backend: %v", err)
}

View File

@@ -18,6 +18,7 @@ require (
github.com/google/uuid v1.3.0
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.8.3
go.etcd.io/etcd/api/v3 v3.5.9
@@ -87,9 +88,9 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/moby/spdystream v0.2.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pquerna/cachecontrol v0.1.0 // indirect

View File

@@ -163,6 +163,7 @@ github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8V
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df h1:7RFfzj4SSt6nnvCPbCqijJi1nWCd+TqAT3bYCStRC18=
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df/go.mod h1:pSwJ0fSY5KhvocuWSx4fz3BA8OrA1bQn+K1Eli3BRwM=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4B6AGu/h5Sxe66HYVdqdGu2l9Iebqhi/AEoA=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
@@ -376,6 +377,7 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/moby/spdystream v0.2.0 h1:cjW1zVyyoiM0T7b6UoySUFqzXMoqRckQtXwGPiBhOM8=
github.com/moby/spdystream v0.2.0/go.mod h1:f7i0iNDQJ059oMTcWxx8MA/zKFIuD/lY+0GqbN2Wy8c=
github.com/moby/term v0.0.0-20221205130635-1aeaba878587/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

View File

@@ -0,0 +1,167 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"fmt"
"net/http"
"net/url"
"github.com/mxk/go-flowrate/flowrate"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
constants "k8s.io/apimachinery/pkg/util/remotecommand"
"k8s.io/client-go/tools/remotecommand"
"k8s.io/client-go/util/exec"
)
// StreamTranslatorHandler is a handler which translates WebSocket stream data
// to SPDY to proxy to kubelet (and ContainerRuntime).
type StreamTranslatorHandler struct {
// Location is the location of the upstream proxy. It is used as the location to Dial on the upstream server
// for upgrade requests.
Location *url.URL
// Transport provides an optional round tripper to use to proxy. If nil, the default proxy transport is used
Transport http.RoundTripper
// MaxBytesPerSec throttles stream Reader/Writer if necessary
MaxBytesPerSec int64
// Options define the requested streams (e.g. stdin, stdout).
Options Options
}
// NewStreamTranslatorHandler creates a new proxy handler. Responder is required for returning
// errors to the caller.
func NewStreamTranslatorHandler(location *url.URL, transport http.RoundTripper, maxBytesPerSec int64, opts Options) *StreamTranslatorHandler {
return &StreamTranslatorHandler{
Location: location,
Transport: transport,
MaxBytesPerSec: maxBytesPerSec,
Options: opts,
}
}
func (h *StreamTranslatorHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Create WebSocket server, including particular streams requested. If this websocket
// endpoint is not able to be upgraded, the websocket library will return errors
// to the client.
websocketStreams, err := webSocketServerStreams(req, w, h.Options)
if err != nil {
return
}
defer websocketStreams.conn.Close()
// Creating SPDY executor, ensuring redirects are not followed.
spdyRoundTripper, err := spdy.NewRoundTripperWithConfig(spdy.RoundTripperConfig{UpgradeTransport: h.Transport})
if err != nil {
websocketStreams.writeStatus(apierrors.NewInternalError(err)) //nolint:errcheck
return
}
spdyExecutor, err := remotecommand.NewSPDYExecutorRejectRedirects(spdyRoundTripper, spdyRoundTripper, "POST", h.Location)
if err != nil {
websocketStreams.writeStatus(apierrors.NewInternalError(err)) //nolint:errcheck
return
}
// Wire the WebSocket server streams output to the SPDY client input. The stdin/stdout/stderr streams
// can be throttled if the transfer rate exceeds the "MaxBytesPerSec" (zero means unset). Throttling
// the streams instead of the underlying connection *may* not perform the same if two streams
// traveling the same direction (e.g. stdout, stderr) are being maxed out.
opts := remotecommand.StreamOptions{}
if h.Options.Stdin {
stdin := websocketStreams.stdinStream
if h.MaxBytesPerSec > 0 {
stdin = flowrate.NewReader(stdin, h.MaxBytesPerSec)
}
opts.Stdin = stdin
}
if h.Options.Stdout {
stdout := websocketStreams.stdoutStream
if h.MaxBytesPerSec > 0 {
stdout = flowrate.NewWriter(stdout, h.MaxBytesPerSec)
}
opts.Stdout = stdout
}
if h.Options.Stderr {
stderr := websocketStreams.stderrStream
if h.MaxBytesPerSec > 0 {
stderr = flowrate.NewWriter(stderr, h.MaxBytesPerSec)
}
opts.Stderr = stderr
}
if h.Options.Tty {
opts.Tty = true
opts.TerminalSizeQueue = &translatorSizeQueue{resizeChan: websocketStreams.resizeChan}
}
// Start the SPDY client with connected streams. Output from the WebSocket server
// streams will be forwarded into the SPDY client. Report SPDY execution errors
// through the websocket error stream.
err = spdyExecutor.StreamWithContext(req.Context(), opts)
if err != nil {
//nolint:errcheck // Ignore writeStatus returned error
if statusErr, ok := err.(*apierrors.StatusError); ok {
websocketStreams.writeStatus(statusErr)
} else if exitErr, ok := err.(exec.CodeExitError); ok && exitErr.Exited() {
websocketStreams.writeStatus(codeExitToStatusError(exitErr))
} else {
websocketStreams.writeStatus(apierrors.NewInternalError(err))
}
return
}
// Write the success status back to the WebSocket client.
//nolint:errcheck
websocketStreams.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{
Status: metav1.StatusSuccess,
}})
}
// translatorSizeQueue feeds the size events from the WebSocket
// resizeChan into the SPDY client input. Implements TerminalSizeQueue
// interface.
type translatorSizeQueue struct {
resizeChan chan remotecommand.TerminalSize
}
func (t *translatorSizeQueue) Next() *remotecommand.TerminalSize {
size, ok := <-t.resizeChan
if !ok {
return nil
}
return &size
}
// codeExitToStatusError converts a passed CodeExitError to the type necessary
// to send through an error stream using "writeStatus".
func codeExitToStatusError(exitErr exec.CodeExitError) *apierrors.StatusError {
rc := exitErr.ExitStatus()
return &apierrors.StatusError{
ErrStatus: metav1.Status{
Status: metav1.StatusFailure,
Reason: constants.NonZeroExitCodeReason,
Details: &metav1.StatusDetails{
Causes: []metav1.StatusCause{
{
Type: constants.ExitCodeCauseType,
Message: fmt.Sprintf("%d", rc),
},
},
},
Message: fmt.Sprintf("command terminated with non-zero exit code: %v", exitErr),
},
}
}

View File

@@ -0,0 +1,872 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"bytes"
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"io"
"math"
mrand "math/rand"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"time"
v1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
rcconstants "k8s.io/apimachinery/pkg/util/remotecommand"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/remotecommand"
"k8s.io/client-go/transport"
)
// TestStreamTranslator_LoopbackStdinToStdout returns random data sent on the client's
// STDIN channel back onto the client's STDOUT channel. There are two servers in this test: the
// upstream fake SPDY server, and the StreamTranslator server. The StreamTranslator proxys the
// data received from the websocket client upstream to the SPDY server (by translating the
// websocket data into spdy). The returned data read on the websocket client STDOUT is then
// compared the random data sent on STDIN to ensure they are the same.
func TestStreamTranslator_LoopbackStdinToStdout(t *testing.T) {
// Create upstream fake SPDY server which copies STDIN back onto STDOUT stream.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, err := createSPDYServerStreams(w, req, Options{
Stdin: true,
Stdout: true,
})
if err != nil {
t.Errorf("error on createHTTPStreams: %v", err)
return
}
defer ctx.conn.Close()
// Loopback STDIN data onto STDOUT stream.
_, err = io.Copy(ctx.stdoutStream, ctx.stdinStream)
if err != nil {
t.Fatalf("error copying STDIN to STDOUT: %v", err)
}
}))
defer spdyServer.Close()
// Create StreamTranslatorHandler, which points upstream to fake SPDY server with
// streams STDIN and STDOUT. Create test server from StreamTranslatorHandler.
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Stdin: true, Stdout: true}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
// Generate random data, and set it up to stream on STDIN. The data will be
// returned on the STDOUT buffer.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
if _, err := rand.Read(randomData); err != nil {
t.Errorf("unexpected error reading random data: %v", err)
}
var stdout bytes.Buffer
options := &remotecommand.StreamOptions{
Stdin: bytes.NewReader(randomData),
Stdout: &stdout,
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDOUT.
if !bytes.Equal(randomData, data) {
t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
}
}
// TestStreamTranslator_LoopbackStdinToStderr returns random data sent on the client's
// STDIN channel back onto the client's STDERR channel. There are two servers in this test: the
// upstream fake SPDY server, and the StreamTranslator server. The StreamTranslator proxys the
// data received from the websocket client upstream to the SPDY server (by translating the
// websocket data into spdy). The returned data read on the websocket client STDERR is then
// compared the random data sent on STDIN to ensure they are the same.
func TestStreamTranslator_LoopbackStdinToStderr(t *testing.T) {
// Create upstream fake SPDY server which copies STDIN back onto STDERR stream.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, err := createSPDYServerStreams(w, req, Options{
Stdin: true,
Stderr: true,
})
if err != nil {
t.Errorf("error on createHTTPStreams: %v", err)
return
}
defer ctx.conn.Close()
// Loopback STDIN data onto STDERR stream.
_, err = io.Copy(ctx.stderrStream, ctx.stdinStream)
if err != nil {
t.Fatalf("error copying STDIN to STDERR: %v", err)
}
}))
defer spdyServer.Close()
// Create StreamTranslatorHandler, which points upstream to fake SPDY server with
// streams STDIN and STDERR. Create test server from StreamTranslatorHandler.
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Stdin: true, Stderr: true}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
// Generate random data, and set it up to stream on STDIN. The data will be
// returned on the STDERR buffer.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
if _, err := rand.Read(randomData); err != nil {
t.Errorf("unexpected error reading random data: %v", err)
}
var stderr bytes.Buffer
options := &remotecommand.StreamOptions{
Stdin: bytes.NewReader(randomData),
Stderr: &stderr,
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
data, err := io.ReadAll(bytes.NewReader(stderr.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDERR.
if !bytes.Equal(randomData, data) {
t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
}
}
// Returns a random exit code in the range(1-127).
func randomExitCode() int {
errorCode := mrand.Intn(127) // Range: (0 - 126)
errorCode += 1 // Range: (1 - 127)
return errorCode
}
// TestStreamTranslator_ErrorStream tests the error stream by sending an error with a random
// exit code, then validating the error arrives on the error stream.
func TestStreamTranslator_ErrorStream(t *testing.T) {
expectedExitCode := randomExitCode()
// Create upstream fake SPDY server, returning a non-zero exit code
// on error stream within the structured error.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, err := createSPDYServerStreams(w, req, Options{
Stdout: true,
})
if err != nil {
t.Errorf("error on createHTTPStreams: %v", err)
return
}
defer ctx.conn.Close()
// Read/discard STDIN data before returning error on error stream.
_, err = io.Copy(io.Discard, ctx.stdinStream)
if err != nil {
t.Fatalf("error copying STDIN to DISCARD: %v", err)
}
// Force an non-zero exit code error returned on the error stream.
err = ctx.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{
Status: metav1.StatusFailure,
Reason: rcconstants.NonZeroExitCodeReason,
Details: &metav1.StatusDetails{
Causes: []metav1.StatusCause{
{
Type: rcconstants.ExitCodeCauseType,
Message: fmt.Sprintf("%d", expectedExitCode),
},
},
},
}})
if err != nil {
t.Fatalf("error writing status: %v", err)
}
}))
defer spdyServer.Close()
// Create StreamTranslatorHandler, which points upstream to fake SPDY server, and
// create a test server using the StreamTranslatorHandler.
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Stdin: true}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
// Generate random data, and set it up to stream on STDIN. The data will be discarded at
// upstream SDPY server.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
if _, err := rand.Read(randomData); err != nil {
t.Errorf("unexpected error reading random data: %v", err)
}
options := &remotecommand.StreamOptions{
Stdin: bytes.NewReader(randomData),
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
// Expect exit code error on error stream.
if err == nil {
t.Errorf("expected error, but received none")
}
expectedError := fmt.Sprintf("command terminated with exit code %d", expectedExitCode)
// Compare expected error with exit code to actual error.
if expectedError != err.Error() {
t.Errorf("expected error (%s), got (%s)", expectedError, err)
}
}
}
// TestStreamTranslator_MultipleReadChannels tests two streams (STDOUT, STDERR) reading from
// the connections at the same time.
func TestStreamTranslator_MultipleReadChannels(t *testing.T) {
// Create upstream fake SPDY server which copies STDIN back onto STDOUT and STDERR stream.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, err := createSPDYServerStreams(w, req, Options{
Stdin: true,
Stdout: true,
Stderr: true,
})
if err != nil {
t.Errorf("error on createHTTPStreams: %v", err)
return
}
defer ctx.conn.Close()
// TeeReader copies data read on STDIN onto STDERR.
stdinReader := io.TeeReader(ctx.stdinStream, ctx.stderrStream)
// Also copy STDIN to STDOUT.
_, err = io.Copy(ctx.stdoutStream, stdinReader)
if err != nil {
t.Errorf("error copying STDIN to STDOUT: %v", err)
}
}))
defer spdyServer.Close()
// Create StreamTranslatorHandler, which points upstream to fake SPDY server with
// streams STDIN, STDOUT, and STDERR. Create test server from StreamTranslatorHandler.
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Stdin: true, Stdout: true, Stderr: true}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
// Generate random data, and set it up to stream on STDIN. The data will be
// returned on the STDOUT and STDERR buffer.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
if _, err := rand.Read(randomData); err != nil {
t.Errorf("unexpected error reading random data: %v", err)
}
var stdout, stderr bytes.Buffer
options := &remotecommand.StreamOptions{
Stdin: bytes.NewReader(randomData),
Stdout: &stdout,
Stderr: &stderr,
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDOUT.
if !bytes.Equal(stdoutBytes, randomData) {
t.Errorf("unexpected data received: %d sent: %d", len(stdoutBytes), len(randomData))
}
stderrBytes, err := io.ReadAll(bytes.NewReader(stderr.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDERR.
if !bytes.Equal(stderrBytes, randomData) {
t.Errorf("unexpected data received: %d sent: %d", len(stderrBytes), len(randomData))
}
}
// TestStreamTranslator_ThrottleReadChannels tests two streams (STDOUT, STDERR) using rate limited streams.
func TestStreamTranslator_ThrottleReadChannels(t *testing.T) {
// Create upstream fake SPDY server which copies STDIN back onto STDOUT and STDERR stream.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, err := createSPDYServerStreams(w, req, Options{
Stdin: true,
Stdout: true,
Stderr: true,
})
if err != nil {
t.Errorf("error on createHTTPStreams: %v", err)
return
}
defer ctx.conn.Close()
// TeeReader copies data read on STDIN onto STDERR.
stdinReader := io.TeeReader(ctx.stdinStream, ctx.stderrStream)
// Also copy STDIN to STDOUT.
_, err = io.Copy(ctx.stdoutStream, stdinReader)
if err != nil {
t.Errorf("error copying STDIN to STDOUT: %v", err)
}
}))
defer spdyServer.Close()
// Create StreamTranslatorHandler, which points upstream to fake SPDY server with
// streams STDIN, STDOUT, and STDERR. Create test server from StreamTranslatorHandler.
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Stdin: true, Stdout: true, Stderr: true}
maxBytesPerSec := 900 * 1024 // slightly less than the 1MB that is being transferred to exercise throttling.
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, int64(maxBytesPerSec), streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
// Generate random data, and set it up to stream on STDIN. The data will be
// returned on the STDOUT and STDERR buffer.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
if _, err := rand.Read(randomData); err != nil {
t.Errorf("unexpected error reading random data: %v", err)
}
var stdout, stderr bytes.Buffer
options := &remotecommand.StreamOptions{
Stdin: bytes.NewReader(randomData),
Stdout: &stdout,
Stderr: &stderr,
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDOUT.
if !bytes.Equal(stdoutBytes, randomData) {
t.Errorf("unexpected data received: %d sent: %d", len(stdoutBytes), len(randomData))
}
stderrBytes, err := io.ReadAll(bytes.NewReader(stderr.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDERR.
if !bytes.Equal(stderrBytes, randomData) {
t.Errorf("unexpected data received: %d sent: %d", len(stderrBytes), len(randomData))
}
}
// fakeTerminalSizeQueue implements TerminalSizeQueue, returning a random set of
// "maxSizes" number of TerminalSizes, storing the TerminalSizes in "sizes" slice.
type fakeTerminalSizeQueue struct {
maxSizes int
terminalSizes []remotecommand.TerminalSize
}
// newTerminalSizeQueue returns a pointer to a fakeTerminalSizeQueue passing
// "max" number of random TerminalSizes created.
func newTerminalSizeQueue(max int) *fakeTerminalSizeQueue {
return &fakeTerminalSizeQueue{
maxSizes: max,
terminalSizes: make([]remotecommand.TerminalSize, 0, max),
}
}
// Next returns a pointer to the next random TerminalSize, or nil if we have
// already returned "maxSizes" TerminalSizes already. Stores the randomly
// created TerminalSize in "terminalSizes" field for later validation.
func (f *fakeTerminalSizeQueue) Next() *remotecommand.TerminalSize {
if len(f.terminalSizes) >= f.maxSizes {
return nil
}
size := randomTerminalSize()
f.terminalSizes = append(f.terminalSizes, size)
return &size
}
// randomTerminalSize returns a TerminalSize with random values in the
// range (0-65535) for the fields Width and Height.
func randomTerminalSize() remotecommand.TerminalSize {
randWidth := uint16(mrand.Intn(int(math.Pow(2, 16))))
randHeight := uint16(mrand.Intn(int(math.Pow(2, 16))))
return remotecommand.TerminalSize{
Width: randWidth,
Height: randHeight,
}
}
// TestStreamTranslator_MultipleWriteChannels
func TestStreamTranslator_TTYResizeChannel(t *testing.T) {
// Create the fake terminal size queue and the actualTerminalSizes which
// will be received at the opposite websocket endpoint.
numSizeQueue := 10000
sizeQueue := newTerminalSizeQueue(numSizeQueue)
actualTerminalSizes := make([]remotecommand.TerminalSize, 0, numSizeQueue)
// Create upstream fake SPDY server which copies STDIN back onto STDERR stream.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, err := createSPDYServerStreams(w, req, Options{
Tty: true,
})
if err != nil {
t.Errorf("error on createHTTPStreams: %v", err)
return
}
defer ctx.conn.Close()
// Read the terminal resize requests, storing them in actualTerminalSizes
for i := 0; i < numSizeQueue; i++ {
actualTerminalSize := <-ctx.resizeChan
actualTerminalSizes = append(actualTerminalSizes, actualTerminalSize)
}
}))
defer spdyServer.Close()
// Create StreamTranslatorHandler, which points upstream to fake SPDY server with
// resize (TTY resize) stream. Create test server from StreamTranslatorHandler.
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Tty: true}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
options := &remotecommand.StreamOptions{
Tty: true,
TerminalSizeQueue: sizeQueue,
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
// Validate the random TerminalSizes sent on the resize stream are the same
// as the actual TerminalSizes received at the websocket server.
if len(actualTerminalSizes) != numSizeQueue {
t.Fatalf("expected to receive num terminal resizes (%d), got (%d)",
numSizeQueue, len(actualTerminalSizes))
}
for i, actual := range actualTerminalSizes {
expected := sizeQueue.terminalSizes[i]
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected terminal resize window %v, got %v", expected, actual)
}
}
}
// TestStreamTranslator_WebSocketServerErrors validates that when there is a problem creating
// the websocket server as the first step of the StreamTranslator an error is properly returned.
func TestStreamTranslator_WebSocketServerErrors(t *testing.T) {
spdyLocation, err := url.Parse("http://127.0.0.1")
if err != nil {
t.Fatalf("Unable to parse spdy server URL")
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, Options{})
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutorForProtocols(
&rest.Config{Host: streamTranslatorLocation.Host},
"GET",
streamTranslatorServer.URL,
rcconstants.StreamProtocolV4Name, // RemoteCommand V4 protocol is unsupported
)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client. The WebSocket server within the
// StreamTranslator propagates an error here because the V4 protocol is not supported.
errorChan <- exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{})
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
// Must return "websocket unable to upgrade" (bad handshake) error.
if err == nil {
t.Fatalf("expected error, but received none")
}
if !strings.Contains(err.Error(), "unable to upgrade streaming request") {
t.Errorf("expected websocket bad handshake error, got (%s)", err)
}
}
}
// TestStreamTranslator_BlockRedirects verifies that the StreamTranslator will *not* follow
// redirects; it will thrown an error instead.
func TestStreamTranslator_BlockRedirects(t *testing.T) {
for _, statusCode := range []int{
http.StatusMovedPermanently, // 301
http.StatusFound, // 302
http.StatusSeeOther, // 303
http.StatusTemporaryRedirect, // 307
http.StatusPermanentRedirect, // 308
} {
// Create upstream fake SPDY server which returns a redirect.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Location", "/")
w.WriteHeader(statusCode)
}))
defer spdyServer.Close()
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Stdout: true}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
// Should return "redirect not allowed" error.
errorChan <- exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{})
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
// Must return "redirect now allowed" error.
if err == nil {
t.Fatalf("expected error, but received none")
}
if !strings.Contains(err.Error(), "redirect not allowed") {
t.Errorf("expected redirect not allowed error, got (%s)", err)
}
}
}
}
// streamContext encapsulates the structures necessary to communicate through
// a SPDY connection, including the Reader/Writer streams.
type streamContext struct {
conn io.Closer
stdinStream io.ReadCloser
stdoutStream io.WriteCloser
stderrStream io.WriteCloser
resizeStream io.ReadCloser
resizeChan chan remotecommand.TerminalSize
writeStatus func(status *apierrors.StatusError) error
}
type streamAndReply struct {
httpstream.Stream
replySent <-chan struct{}
}
// CreateSPDYServerStreams upgrades the passed HTTP request to a SPDY bi-directional streaming
// connection with remote command streams defined in passed options. Returns a streamContext
// structure containing the Reader/Writer streams to communicate through the SDPY connection.
// Returns an error if unable to upgrade the HTTP connection to a SPDY connection.
func createSPDYServerStreams(w http.ResponseWriter, req *http.Request, opts Options) (*streamContext, error) {
_, err := httpstream.Handshake(req, w, []string{rcconstants.StreamProtocolV4Name})
if err != nil {
return nil, err
}
upgrader := spdy.NewResponseUpgrader()
streamCh := make(chan streamAndReply)
conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error {
streamCh <- streamAndReply{Stream: stream, replySent: replySent}
return nil
})
ctx := &streamContext{
conn: conn,
}
// wait for stream
replyChan := make(chan struct{}, 5)
defer close(replyChan)
receivedStreams := 0
expectedStreams := 1 // expect at least the error stream
if opts.Stdout {
expectedStreams++
}
if opts.Stdin {
expectedStreams++
}
if opts.Stderr {
expectedStreams++
}
if opts.Tty {
expectedStreams++
}
WaitForStreams:
for {
select {
case stream := <-streamCh:
streamType := stream.Headers().Get(v1.StreamType)
switch streamType {
case v1.StreamTypeError:
replyChan <- struct{}{}
ctx.writeStatus = v4WriteStatusFunc(stream)
case v1.StreamTypeStdout:
replyChan <- struct{}{}
ctx.stdoutStream = stream
case v1.StreamTypeStdin:
replyChan <- struct{}{}
ctx.stdinStream = stream
case v1.StreamTypeStderr:
replyChan <- struct{}{}
ctx.stderrStream = stream
case v1.StreamTypeResize:
replyChan <- struct{}{}
ctx.resizeStream = stream
default:
// add other stream ...
return nil, errors.New("unimplemented stream type")
}
case <-replyChan:
receivedStreams++
if receivedStreams == expectedStreams {
break WaitForStreams
}
}
}
if ctx.resizeStream != nil {
ctx.resizeChan = make(chan remotecommand.TerminalSize)
go handleResizeEvents(req.Context(), ctx.resizeStream, ctx.resizeChan)
}
return ctx, nil
}
func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error {
return func(status *apierrors.StatusError) error {
bs, err := json.Marshal(status.Status())
if err != nil {
return err
}
_, err = stream.Write(bs)
return err
}
}
func fakeTransport() (*http.Transport, error) {
cfg := &transport.Config{
TLS: transport.TLSConfig{
Insecure: true,
CAFile: "",
},
}
rt, err := transport.New(cfg)
if err != nil {
return nil, err
}
t, ok := rt.(*http.Transport)
if !ok {
return nil, fmt.Errorf("unknown transport type: %T", rt)
}
return t, nil
}

View File

@@ -0,0 +1,51 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"net/http"
"k8s.io/klog/v2"
)
// translatingHandler wraps the delegate handler, implementing the
// http.Handler interface. The delegate handles all requests unless
// the request satisfies the passed "shouldTranslate" function
// (currently only for WebSocket/V5 request), in which case the translator
// handles the request.
type translatingHandler struct {
delegate http.Handler
translator http.Handler
shouldTranslate func(*http.Request) bool
}
func NewTranslatingHandler(delegate http.Handler, translator http.Handler, shouldTranslate func(*http.Request) bool) http.Handler {
return &translatingHandler{
delegate: delegate,
translator: translator,
shouldTranslate: shouldTranslate,
}
}
func (t *translatingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if t.shouldTranslate(req) {
klog.V(4).Infof("request handled by translator proxy")
t.translator.ServeHTTP(w, req)
return
}
t.delegate.ServeHTTP(w, req)
}

View File

@@ -0,0 +1,121 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
)
// fakeHandler implements http.Handler interface
type fakeHandler struct {
served bool
}
// ServeHTTP stores the fact that this fake handler was called.
func (fh *fakeHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
fh.served = true
}
func TestTranslatingHandler(t *testing.T) {
tests := map[string]struct {
upgrade string
version string
expectTranslator bool
}{
"websocket/v5 upgrade, serves translator": {
upgrade: "websocket",
version: "v5.channel.k8s.io",
expectTranslator: true,
},
"websocket/v5 upgrade with multiple other versions, serves translator": {
upgrade: "websocket",
version: "v5.channel.k8s.io, v4.channel.k8s.io, v3.channel.k8s.io",
expectTranslator: true,
},
"websocket/v5 upgrade with multiple other versions out of order, serves translator": {
upgrade: "websocket",
version: "v4.channel.k8s.io, v3.channel.k8s.io, v5.channel.k8s.io",
expectTranslator: true,
},
"no upgrade, serves delegate": {
upgrade: "",
version: "",
expectTranslator: false,
},
"no upgrade with v5, serves delegate": {
upgrade: "",
version: "v5.channel.k8s.io",
expectTranslator: false,
},
"websocket/v5 wrong case upgrade, serves delegage": {
upgrade: "websocket",
version: "v5.CHANNEL.k8s.io",
expectTranslator: false,
},
"spdy/v5 upgrade, serves delegate": {
upgrade: "spdy",
version: "v5.channel.k8s.io",
expectTranslator: false,
},
"spdy/v4 upgrade, serves delegate": {
upgrade: "spdy",
version: "v4.channel.k8s.io",
expectTranslator: false,
},
"websocket/v4 upgrade, serves delegate": {
upgrade: "websocket",
version: "v4.channel.k8s.io",
expectTranslator: false,
},
"websocket without version upgrade, serves delegate": {
upgrade: "websocket",
version: "",
expectTranslator: false,
},
}
for name, test := range tests {
req, err := http.NewRequest("GET", "http://www.example.com/", nil)
require.NoError(t, err)
if test.upgrade != "" {
req.Header.Add("Connection", "Upgrade")
req.Header.Add("Upgrade", test.upgrade)
}
if len(test.version) > 0 {
req.Header.Add(wsstream.WebSocketProtocolHeader, test.version)
}
delegate := fakeHandler{}
translator := fakeHandler{}
translatingHandler := NewTranslatingHandler(&delegate, &translator,
wsstream.IsWebSocketRequestWithStreamCloseProtocol)
translatingHandler.ServeHTTP(nil, req)
if !delegate.served && !translator.served {
t.Errorf("unexpected neither translator nor delegate served")
continue
}
if test.expectTranslator {
if !translator.served {
t.Errorf("%s: expected translator served, got delegate served", name)
}
} else if !delegate.served {
t.Errorf("%s: expected delegate served, got translator served", name)
}
}
}

View File

@@ -0,0 +1,200 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
constants "k8s.io/apimachinery/pkg/util/remotecommand"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/client-go/tools/remotecommand"
)
const (
// idleTimeout is the read/write deadline set for websocket server connection. Reading
// or writing the connection will return an i/o timeout if this deadline is exceeded.
// Currently, we use the same value as the kubelet websocket server.
defaultIdleConnectionTimeout = 4 * time.Hour
// Deadline for writing errors to the websocket connection before io/timeout.
writeErrorDeadline = 10 * time.Second
)
// Options contains details about which streams are required for
// remote command execution.
type Options struct {
Stdin bool
Stdout bool
Stderr bool
Tty bool
}
// conns contains the connection and streams used when
// forwarding an attach or execute session into a container.
type conns struct {
conn io.Closer
stdinStream io.ReadCloser
stdoutStream io.WriteCloser
stderrStream io.WriteCloser
writeStatus func(status *apierrors.StatusError) error
resizeStream io.ReadCloser
resizeChan chan remotecommand.TerminalSize
tty bool
}
// Create WebSocket server streams to respond to a WebSocket client. Creates the streams passed
// in the stream options.
func webSocketServerStreams(req *http.Request, w http.ResponseWriter, opts Options) (*conns, error) {
ctx, err := createWebSocketStreams(req, w, opts)
if err != nil {
return nil, err
}
if ctx.resizeStream != nil {
ctx.resizeChan = make(chan remotecommand.TerminalSize)
go func() {
// Resize channel closes in panic case, and panic does not take down caller.
defer func() {
if p := recover(); p != nil {
// Standard panic logging.
for _, fn := range runtime.PanicHandlers {
fn(p)
}
}
}()
handleResizeEvents(req.Context(), ctx.resizeStream, ctx.resizeChan)
}()
}
return ctx, nil
}
// Read terminal resize events off of passed stream and queue into passed channel.
func handleResizeEvents(ctx context.Context, stream io.Reader, channel chan<- remotecommand.TerminalSize) {
defer close(channel)
decoder := json.NewDecoder(stream)
for {
size := remotecommand.TerminalSize{}
if err := decoder.Decode(&size); err != nil {
break
}
select {
case channel <- size:
case <-ctx.Done():
// To avoid leaking this routine, exit if the http request finishes. This path
// would generally be hit if starting the process fails and nothing is started to
// ingest these resize events.
return
}
}
}
// createChannels returns the standard channel types for a shell connection (STDIN 0, STDOUT 1, STDERR 2)
// along with the approximate duplex value. It also creates the error (3) and resize (4) channels.
func createChannels(opts Options) []wsstream.ChannelType {
// open the requested channels, and always open the error channel
channels := make([]wsstream.ChannelType, 5)
channels[constants.StreamStdIn] = readChannel(opts.Stdin)
channels[constants.StreamStdOut] = writeChannel(opts.Stdout)
channels[constants.StreamStdErr] = writeChannel(opts.Stderr)
channels[constants.StreamErr] = wsstream.WriteChannel
channels[constants.StreamResize] = wsstream.ReadChannel
return channels
}
// readChannel returns wsstream.ReadChannel if real is true, or wsstream.IgnoreChannel.
func readChannel(real bool) wsstream.ChannelType {
if real {
return wsstream.ReadChannel
}
return wsstream.IgnoreChannel
}
// writeChannel returns wsstream.WriteChannel if real is true, or wsstream.IgnoreChannel.
func writeChannel(real bool) wsstream.ChannelType {
if real {
return wsstream.WriteChannel
}
return wsstream.IgnoreChannel
}
// createWebSocketStreams returns a "conns" struct containing the websocket connection and
// streams needed to perform an exec or an attach.
func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts Options) (*conns, error) {
channels := createChannels(opts)
conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{
// WebSocket server only supports remote command version 5.
constants.StreamProtocolV5Name: {
Binary: true,
Channels: channels,
},
})
conn.SetIdleTimeout(defaultIdleConnectionTimeout)
// Opening the connection responds to WebSocket client, negotiating
// the WebSocket upgrade connection and the subprotocol.
_, streams, err := conn.Open(w, req)
if err != nil {
return nil, err
}
// Send an empty message to the lowest writable channel to notify the client the connection is established
switch {
case opts.Stdout:
_, err = streams[constants.StreamStdOut].Write([]byte{})
case opts.Stderr:
_, err = streams[constants.StreamStdErr].Write([]byte{})
default:
_, err = streams[constants.StreamErr].Write([]byte{})
}
if err != nil {
conn.Close()
return nil, fmt.Errorf("write error during websocket server creation: %v", err)
}
ctx := &conns{
conn: conn,
stdinStream: streams[constants.StreamStdIn],
stdoutStream: streams[constants.StreamStdOut],
stderrStream: streams[constants.StreamStdErr],
tty: opts.Tty,
resizeStream: streams[constants.StreamResize],
}
// writeStatus returns a WriteStatusFunc that marshals a given api Status
// as json in the error channel.
ctx.writeStatus = func(status *apierrors.StatusError) error {
bs, err := json.Marshal(status.Status())
if err != nil {
return err
}
// Write status error to error stream with deadline.
conn.SetWriteDeadline(writeErrorDeadline)
_, err = streams[constants.StreamErr].Write(bs)
return err
}
return ctx, nil
}

View File

@@ -49,6 +49,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/sys v0.13.0 // indirect

View File

@@ -75,6 +75,7 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=

View File

@@ -0,0 +1,57 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import (
"context"
)
var _ Executor = &fallbackExecutor{}
type fallbackExecutor struct {
primary Executor
secondary Executor
shouldFallback func(error) bool
}
// NewFallbackExecutor creates an Executor that first attempts to use the
// WebSocketExecutor, falling back to the legacy SPDYExecutor if the initial
// websocket "StreamWithContext" call fails.
// func NewFallbackExecutor(config *restclient.Config, method string, url *url.URL) (Executor, error) {
func NewFallbackExecutor(primary, secondary Executor, shouldFallback func(error) bool) (Executor, error) {
return &fallbackExecutor{
primary: primary,
secondary: secondary,
shouldFallback: shouldFallback,
}, nil
}
// Stream is deprecated. Please use "StreamWithContext".
func (f *fallbackExecutor) Stream(options StreamOptions) error {
return f.StreamWithContext(context.Background(), options)
}
// StreamWithContext initially attempts to call "StreamWithContext" using the
// primary executor, falling back to calling the secondary executor if the
// initial primary call to upgrade to a websocket connection fails.
func (f *fallbackExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error {
err := f.primary.StreamWithContext(ctx, options)
if f.shouldFallback(err) {
return f.secondary.StreamWithContext(ctx, options)
}
return err
}

View File

@@ -0,0 +1,227 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import (
"bytes"
"context"
"crypto/rand"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/remotecommand"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/rest"
)
func TestFallbackClient_WebSocketPrimarySucceeds(t *testing.T) {
// Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream.
websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
if err != nil {
w.WriteHeader(http.StatusForbidden)
return
}
defer conns.conn.Close()
// Loopback the STDIN stream onto the STDOUT stream.
_, err = io.Copy(conns.stdoutStream, conns.stdinStream)
require.NoError(t, err)
}))
defer websocketServer.Close()
// Now create the fallback client (executor), and point it to the "websocketServer".
// Must add STDIN and STDOUT query params for the client request.
websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
websocketLocation, err := url.Parse(websocketServer.URL)
require.NoError(t, err)
websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
require.NoError(t, err)
spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation)
require.NoError(t, err)
// Never fallback, so always use the websocketExecutor, which succeeds against websocket server.
exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return false })
require.NoError(t, err)
// Generate random data, and set it up to stream on STDIN. The data will be
// returned on the STDOUT buffer.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
if _, err := rand.Read(randomData); err != nil {
t.Errorf("unexpected error reading random data: %v", err)
}
var stdout bytes.Buffer
options := &StreamOptions{
Stdin: bytes.NewReader(randomData),
Stdout: &stdout,
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
if err != nil {
t.Errorf("unexpected error")
}
}
data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDOUT.
if !bytes.Equal(randomData, data) {
t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
}
}
func TestFallbackClient_SPDYSecondarySucceeds(t *testing.T) {
// Create fake SPDY server. Copy received STDIN data back onto STDOUT stream.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var stdin, stdout bytes.Buffer
ctx, err := createHTTPStreams(w, req, &StreamOptions{
Stdin: &stdin,
Stdout: &stdout,
})
if err != nil {
w.WriteHeader(http.StatusForbidden)
return
}
defer ctx.conn.Close()
_, err = io.Copy(ctx.stdoutStream, ctx.stdinStream)
if err != nil {
t.Fatalf("error copying STDIN to STDOUT: %v", err)
}
}))
defer spdyServer.Close()
spdyLocation, err := url.Parse(spdyServer.URL)
require.NoError(t, err)
websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: spdyLocation.Host}, "GET", spdyServer.URL)
require.NoError(t, err)
spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: spdyLocation.Host}, "POST", spdyLocation)
require.NoError(t, err)
// Always fallback to spdyExecutor, and spdyExecutor succeeds against fake spdy server.
exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true })
require.NoError(t, err)
// Generate random data, and set it up to stream on STDIN. The data will be
// returned on the STDOUT buffer.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
if _, err := rand.Read(randomData); err != nil {
t.Errorf("unexpected error reading random data: %v", err)
}
var stdout bytes.Buffer
options := &StreamOptions{
Stdin: bytes.NewReader(randomData),
Stdout: &stdout,
}
errorChan := make(chan error)
go func() {
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
if err != nil {
t.Errorf("unexpected error")
}
}
data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDOUT.
if !bytes.Equal(randomData, data) {
t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
}
}
func TestFallbackClient_PrimaryAndSecondaryFail(t *testing.T) {
// Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream.
websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
if err != nil {
w.WriteHeader(http.StatusForbidden)
return
}
defer conns.conn.Close()
// Loopback the STDIN stream onto the STDOUT stream.
_, err = io.Copy(conns.stdoutStream, conns.stdinStream)
require.NoError(t, err)
}))
defer websocketServer.Close()
// Now create the fallback client (executor), and point it to the "websocketServer".
// Must add STDIN and STDOUT query params for the client request.
websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
websocketLocation, err := url.Parse(websocketServer.URL)
require.NoError(t, err)
websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
require.NoError(t, err)
spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation)
require.NoError(t, err)
// Always fallback to spdyExecutor, but spdyExecutor fails against websocket server.
exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true })
require.NoError(t, err)
// Update the websocket executor to request remote command v4, which is unsupported.
fallbackExec, ok := exec.(*fallbackExecutor)
assert.True(t, ok, "error casting executor as fallbackExecutor")
websocketExec, ok := fallbackExec.primary.(*wsStreamExecutor)
assert.True(t, ok, "error casting executor as websocket executor")
// Set the attempted subprotocol version to V4; websocket server only accepts V5.
websocketExec.protocols = []string{remotecommand.StreamProtocolV4Name}
// Generate random data, and set it up to stream on STDIN. The data will be
// returned on the STDOUT buffer.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
if _, err := rand.Read(randomData); err != nil {
t.Errorf("unexpected error reading random data: %v", err)
}
var stdout bytes.Buffer
options := &StreamOptions{
Stdin: bytes.NewReader(randomData),
Stdout: &stdout,
}
errorChan := make(chan error)
go func() {
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
// Ensure secondary executor returned an error.
require.Error(t, err)
}
}

View File

@@ -34,9 +34,10 @@ type spdyStreamExecutor struct {
upgrader spdy.Upgrader
transport http.RoundTripper
method string
url *url.URL
protocols []string
method string
url *url.URL
protocols []string
rejectRedirects bool // if true, receiving redirect from upstream is an error
}
// NewSPDYExecutor connects to the provided server and upgrades the connection to
@@ -49,6 +50,20 @@ func NewSPDYExecutor(config *restclient.Config, method string, url *url.URL) (Ex
return NewSPDYExecutorForTransports(wrapper, upgradeRoundTripper, method, url)
}
// NewSPDYExecutorRejectRedirects returns an Executor that will upgrade the future
// connection to a SPDY bi-directional streaming connection when calling "Stream" (deprecated)
// or "StreamWithContext" (preferred). Additionally, if the upstream server returns a redirect
// during the attempted upgrade in these "Stream" calls, an error is returned.
func NewSPDYExecutorRejectRedirects(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL) (Executor, error) {
executor, err := NewSPDYExecutorForTransports(transport, upgrader, method, url)
if err != nil {
return nil, err
}
spdyExecutor := executor.(*spdyStreamExecutor)
spdyExecutor.rejectRedirects = true
return spdyExecutor, nil
}
// NewSPDYExecutorForTransports connects to the provided server using the given transport,
// upgrades the response using the given upgrader to multiplexed bidirectional streams.
func NewSPDYExecutorForTransports(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL) (Executor, error) {
@@ -88,9 +103,15 @@ func (e *spdyStreamExecutor) newConnectionAndStream(ctx context.Context, options
return nil, nil, fmt.Errorf("error creating request: %v", err)
}
client := http.Client{Transport: e.transport}
if e.rejectRedirects {
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return fmt.Errorf("redirect not allowed")
}
}
conn, protocol, err := spdy.Negotiate(
e.upgrader,
&http.Client{Transport: e.transport},
&client,
req,
e.protocols...,
)

View File

@@ -183,6 +183,7 @@ func TestSPDYExecutorStream(t *testing.T) {
}
func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server {
//nolint:errcheck
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
ctx, err := createHTTPStreams(writer, request, options)
if err != nil {
@@ -381,7 +382,7 @@ func TestStreamRandomData(t *testing.T) {
}
defer ctx.conn.Close()
io.Copy(ctx.stdoutStream, ctx.stdinStream)
io.Copy(ctx.stdoutStream, ctx.stdinStream) //nolint:errcheck
}))
defer server.Close()

View File

@@ -85,22 +85,26 @@ type wsStreamExecutor struct {
heartbeatDeadline time.Duration
}
// NewWebSocketExecutor allows to execute commands via a WebSocket connection.
func NewWebSocketExecutor(config *restclient.Config, method, url string) (Executor, error) {
// Only supports V5 protocol for correct version skew functionality.
// Previous api servers will proxy upgrade requests to legacy websocket
// servers on container runtimes which support V1-V4. These legacy
// websocket servers will not handle the new CLOSE signal.
return NewWebSocketExecutorForProtocols(config, method, url, remotecommand.StreamProtocolV5Name)
}
// NewWebSocketExecutorForProtocols allows to execute commands via a WebSocket connection.
func NewWebSocketExecutorForProtocols(config *restclient.Config, method, url string, protocols ...string) (Executor, error) {
transport, upgrader, err := websocket.RoundTripperFor(config)
if err != nil {
return nil, fmt.Errorf("error creating websocket transports: %v", err)
}
return &wsStreamExecutor{
transport: transport,
upgrader: upgrader,
method: method,
url: url,
// Only supports V5 protocol for correct version skew functionality.
// Previous api servers will proxy upgrade requests to legacy websocket
// servers on container runtimes which support V1-V4. These legacy
// websocket servers will not handle the new CLOSE signal.
protocols: []string{remotecommand.StreamProtocolV5Name},
transport: transport,
upgrader: upgrader,
method: method,
url: url,
protocols: protocols,
heartbeatPeriod: pingPeriod,
heartbeatDeadline: pingReadDeadline,
}, nil
@@ -177,10 +181,12 @@ func (e *wsStreamExecutor) StreamWithContext(ctx context.Context, options Stream
}
type wsStreamCreator struct {
conn *gwebsocket.Conn
conn *gwebsocket.Conn
// Protects writing to websocket connection; reading is lock-free
connWriteLock sync.Mutex
streams map[byte]*stream
streamsMu sync.Mutex
// map of stream id to stream; multiple streams read/write the connection
streams map[byte]*stream
streamsMu sync.Mutex
}
func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator {
@@ -226,7 +232,7 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream,
return s, nil
}
// readDemuxLoop is the reading processor for this endpoint of the websocket
// readDemuxLoop is the lock-free reading processor for this endpoint of the websocket
// connection. This loop reads the connection, and demultiplexes the data
// into one of the individual stream pipes (by checking the stream id). This
// loop can *not* be run concurrently, because there can only be one websocket

View File

@@ -74,7 +74,7 @@ func TestWebSocketClient_LoopbackStdinToStdout(t *testing.T) {
if err != nil {
t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
}
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL)
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
@@ -149,7 +149,7 @@ func TestWebSocketClient_DifferentBufferSizes(t *testing.T) {
if err != nil {
t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
}
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL)
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
@@ -223,7 +223,7 @@ func TestWebSocketClient_LoopbackStdinAsPipe(t *testing.T) {
if err != nil {
t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
}
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL)
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
@@ -304,7 +304,7 @@ func TestWebSocketClient_LoopbackStdinToStderr(t *testing.T) {
if err != nil {
t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
}
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL)
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
@@ -377,7 +377,7 @@ func TestWebSocketClient_MultipleReadChannels(t *testing.T) {
if err != nil {
t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
}
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL)
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
@@ -479,7 +479,7 @@ func TestWebSocketClient_ErrorStream(t *testing.T) {
if err != nil {
t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
}
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL)
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
@@ -637,7 +637,7 @@ func TestWebSocketClient_MultipleWriteChannels(t *testing.T) {
if err != nil {
t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
}
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL)
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
@@ -723,7 +723,7 @@ func TestWebSocketClient_ProtocolVersions(t *testing.T) {
if err != nil {
t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
}
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL)
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
@@ -766,11 +766,14 @@ func TestWebSocketClient_ProtocolVersions(t *testing.T) {
func TestWebSocketClient_BadHandshake(t *testing.T) {
// Create fake WebSocket server (supports V5 subprotocol).
websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
if err != nil {
t.Fatalf("error on webSocketServerStreams: %v", err)
// Bad handshake means websocket server will not completely initialize.
_, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
if err == nil {
t.Fatalf("expected error, but received none.")
}
if !strings.Contains(err.Error(), "websocket server finished before becoming ready") {
t.Errorf("expected websocket server error, but got: %v", err)
}
defer conns.conn.Close()
}))
defer websocketServer.Close()
@@ -779,7 +782,7 @@ func TestWebSocketClient_BadHandshake(t *testing.T) {
if err != nil {
t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
}
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL)
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
@@ -831,7 +834,7 @@ func TestWebSocketClient_HeartbeatTimeout(t *testing.T) {
if err != nil {
t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
}
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL)
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
@@ -909,7 +912,7 @@ func TestWebSocketClient_TextMessageTypeError(t *testing.T) {
if err != nil {
t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
}
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL)
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
@@ -970,7 +973,7 @@ func TestWebSocketClient_EmptyMessageHandled(t *testing.T) {
if err != nil {
t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
}
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL)
exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
@@ -1009,14 +1012,14 @@ func TestWebSocketClient_ExecutorErrors(t *testing.T) {
ExecProvider: &clientcmdapi.ExecConfig{},
AuthProvider: &clientcmdapi.AuthProviderConfig{},
}
_, err := NewWebSocketExecutor(&config, "POST", "http://localhost")
_, err := NewWebSocketExecutor(&config, "GET", "http://localhost")
if err == nil {
t.Errorf("expecting executor constructor error, but received none.")
} else if !strings.Contains(err.Error(), "error creating websocket transports") {
t.Errorf("expecting error creating transports, got (%s)", err.Error())
}
// Verify that a nil context will cause an error in StreamWithContext
exec, err := NewWebSocketExecutor(&rest.Config{}, "POST", "http://localhost")
exec, err := NewWebSocketExecutor(&rest.Config{}, "GET", "http://localhost")
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
@@ -1316,7 +1319,16 @@ func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts *opti
resizeStream: streams[remotecommand.StreamResize],
}
wsStreams.writeStatus = v4WriteStatusFunc(streams[remotecommand.StreamErr])
wsStreams.writeStatus = func(stream io.Writer) func(status *apierrors.StatusError) error {
return func(status *apierrors.StatusError) error {
bs, err := json.Marshal(status.Status())
if err != nil {
return err
}
_, err = stream.Write(bs)
return err
}
}(streams[remotecommand.StreamErr])
return wsStreams, nil
}

View File

@@ -43,11 +43,15 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, Upgrader, er
if config.Proxy != nil {
proxy = config.Proxy
}
upgradeRoundTripper := spdy.NewRoundTripperWithConfig(spdy.RoundTripperConfig{
TLS: tlsConfig,
Proxier: proxy,
PingPeriod: time.Second * 5,
upgradeRoundTripper, err := spdy.NewRoundTripperWithConfig(spdy.RoundTripperConfig{
TLS: tlsConfig,
Proxier: proxy,
PingPeriod: time.Second * 5,
UpgradeTransport: nil,
})
if err != nil {
return nil, nil, err
}
wrapper, err := restclient.HTTPWrappersForConfig(config, upgradeRoundTripper)
if err != nil {
return nil, nil, err

View File

@@ -108,10 +108,7 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response
}
wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header)
if err != nil {
if err != gwebsocket.ErrBadHandshake {
return nil, err
}
return nil, fmt.Errorf("unable to upgrade connection: %v", err)
return nil, &httpstream.UpgradeFailureError{Cause: err}
}
rt.Conn = wsConn
@@ -155,7 +152,7 @@ func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http.
req.Header[httpstream.HeaderProtocolVersion] = protocols
resp, err := rt.RoundTrip(req)
if err != nil {
return nil, fmt.Errorf("error sending request: %v", err)
return nil, err
}
err = resp.Body.Close()
if err != nil {

View File

@@ -49,7 +49,7 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) {
// Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()".
websocketLocation, err := url.Parse(websocketServer.URL)
require.NoError(t, err)
req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil)
req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil)
require.NoError(t, err)
rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
require.NoError(t, err)
@@ -67,18 +67,17 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) {
func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) {
// Create fake WebSocket server.
websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
conns, err := webSocketServerStreams(req, w)
if err != nil {
t.Fatalf("error on webSocketServerStreams: %v", err)
}
defer conns.conn.Close()
// Bad handshake means websocket server will not completely initialize.
_, err := webSocketServerStreams(req, w)
require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "websocket server finished before becoming ready"))
}))
defer websocketServer.Close()
// Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()".
websocketLocation, err := url.Parse(websocketServer.URL)
require.NoError(t, err)
req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil)
req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil)
require.NoError(t, err)
rt, _, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
require.NoError(t, err)
@@ -105,7 +104,7 @@ func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) {
// Create the websocket roundtripper and call "Negotiate" to create websocket connection.
websocketLocation, err := url.Parse(websocketServer.URL)
require.NoError(t, err)
req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil)
req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil)
require.NoError(t, err)
rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
require.NoError(t, err)

View File

@@ -49,6 +49,7 @@ require (
github.com/google/cel-go v0.17.6 // indirect
github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 // indirect
github.com/imdario/mergo v0.3.6 // indirect
@@ -57,6 +58,7 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/moby/spdystream v0.2.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect

View File

@@ -163,6 +163,7 @@ github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8V
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df h1:7RFfzj4SSt6nnvCPbCqijJi1nWCd+TqAT3bYCStRC18=
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df/go.mod h1:pSwJ0fSY5KhvocuWSx4fz3BA8OrA1bQn+K1Eli3BRwM=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4B6AGu/h5Sxe66HYVdqdGu2l9Iebqhi/AEoA=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
@@ -326,6 +327,7 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA=
@@ -369,6 +371,7 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/moby/spdystream v0.2.0 h1:cjW1zVyyoiM0T7b6UoySUFqzXMoqRckQtXwGPiBhOM8=
github.com/moby/spdystream v0.2.0/go.mod h1:f7i0iNDQJ059oMTcWxx8MA/zKFIuD/lY+0GqbN2Wy8c=
github.com/moby/term v0.0.0-20221205130635-1aeaba878587/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

View File

@@ -28,6 +28,7 @@ import (
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/cli-runtime/pkg/genericclioptions"
"k8s.io/cli-runtime/pkg/genericiooptions"
"k8s.io/cli-runtime/pkg/resource"
@@ -125,7 +126,7 @@ func NewCmdAttach(f cmdutil.Factory, streams genericiooptions.IOStreams) *cobra.
// RemoteAttach defines the interface accepted by the Attach command - provided for test stubbing
type RemoteAttach interface {
Attach(method string, url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error
Attach(url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error
}
// DefaultAttachFunc is the default AttachFunc used
@@ -148,7 +149,7 @@ func DefaultAttachFunc(o *AttachOptions, containerToAttach *corev1.Container, ra
TTY: raw,
}, scheme.ParameterCodec)
return o.Attach.Attach("POST", req.URL(), o.Config, o.In, o.Out, o.ErrOut, raw, sizeQueue)
return o.Attach.Attach(req.URL(), o.Config, o.In, o.Out, o.ErrOut, raw, sizeQueue)
}
}
@@ -156,11 +157,24 @@ func DefaultAttachFunc(o *AttachOptions, containerToAttach *corev1.Container, ra
type DefaultRemoteAttach struct{}
// Attach executes attach to a running container
func (*DefaultRemoteAttach) Attach(method string, url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error {
exec, err := remotecommand.NewSPDYExecutor(config, method, url)
func (*DefaultRemoteAttach) Attach(url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error {
// Legacy SPDY executor is default. If feature gate enabled, fallback
// executor attempts websockets first--then SPDY.
exec, err := remotecommand.NewSPDYExecutor(config, "POST", url)
if err != nil {
return err
}
if cmdutil.RemoteCommandWebsockets.IsEnabled() {
// WebSocketExecutor must be "GET" method as described in RFC 6455 Sec. 4.1 (page 17).
websocketExec, err := remotecommand.NewWebSocketExecutor(config, "GET", url.String())
if err != nil {
return err
}
exec, err = remotecommand.NewFallbackExecutor(websocketExec, exec, httpstream.IsUpgradeFailure)
if err != nil {
return err
}
}
return exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{
Stdin: stdin,
Stdout: stdout,

View File

@@ -43,13 +43,11 @@ import (
)
type fakeRemoteAttach struct {
method string
url *url.URL
err error
url *url.URL
err error
}
func (f *fakeRemoteAttach) Attach(method string, url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error {
f.method = method
func (f *fakeRemoteAttach) Attach(url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error {
f.url = url
return f.err
}
@@ -327,7 +325,7 @@ func TestAttach(t *testing.T) {
return err
}
return options.Attach.Attach("POST", u, nil, nil, nil, nil, raw, sizeQueue)
return options.Attach.Attach(u, nil, nil, nil, nil, raw, sizeQueue)
}
}
@@ -347,9 +345,6 @@ func TestAttach(t *testing.T) {
t.Errorf("%s: Did not get expected path for exec request: %q %q", test.name, test.attachPath, remoteAttach.url.Path)
return
}
if remoteAttach.method != "POST" {
t.Errorf("%s: Did not get method for attach request: %s", test.name, remoteAttach.method)
}
if remoteAttach.url.Query().Get("container") != "bar" {
t.Errorf("%s: Did not have query parameters: %s", test.name, remoteAttach.url.Query())
}
@@ -428,7 +423,7 @@ func TestAttachWarnings(t *testing.T) {
return err
}
return options.Attach.Attach("POST", u, nil, nil, nil, nil, raw, sizeQueue)
return options.Attach.Attach(u, nil, nil, nil, nil, raw, sizeQueue)
}
}

View File

@@ -27,6 +27,7 @@ import (
"github.com/spf13/cobra"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/cli-runtime/pkg/genericclioptions"
"k8s.io/cli-runtime/pkg/genericiooptions"
"k8s.io/cli-runtime/pkg/resource"
@@ -113,17 +114,30 @@ func NewCmdExec(f cmdutil.Factory, streams genericiooptions.IOStreams) *cobra.Co
// RemoteExecutor defines the interface accepted by the Exec command - provided for test stubbing
type RemoteExecutor interface {
Execute(method string, url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error
Execute(url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error
}
// DefaultRemoteExecutor is the standard implementation of remote command execution
type DefaultRemoteExecutor struct{}
func (*DefaultRemoteExecutor) Execute(method string, url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error {
exec, err := remotecommand.NewSPDYExecutor(config, method, url)
func (*DefaultRemoteExecutor) Execute(url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error {
// Legacy SPDY executor is default. If feature gate enabled, fallback
// executor attempts websockets first--then SPDY.
exec, err := remotecommand.NewSPDYExecutor(config, "POST", url)
if err != nil {
return err
}
if cmdutil.RemoteCommandWebsockets.IsEnabled() {
// WebSocketExecutor must be "GET" method as described in RFC 6455 Sec. 4.1 (page 17).
websocketExec, err := remotecommand.NewWebSocketExecutor(config, "GET", url.String())
if err != nil {
return err
}
exec, err = remotecommand.NewFallbackExecutor(websocketExec, exec, httpstream.IsUpgradeFailure)
if err != nil {
return err
}
}
return exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{
Stdin: stdin,
Stdout: stdout,
@@ -371,7 +385,7 @@ func (p *ExecOptions) Run() error {
TTY: t.Raw,
}, scheme.ParameterCodec)
return p.Executor.Execute("POST", req.URL(), p.Config, p.In, p.Out, p.ErrOut, t.Raw, sizeQueue)
return p.Executor.Execute(req.URL(), p.Config, p.In, p.Out, p.ErrOut, t.Raw, sizeQueue)
}
if err := t.Safe(fn); err != nil {

View File

@@ -40,13 +40,11 @@ import (
)
type fakeRemoteExecutor struct {
method string
url *url.URL
execErr error
}
func (f *fakeRemoteExecutor) Execute(method string, url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error {
f.method = method
func (f *fakeRemoteExecutor) Execute(url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error {
f.url = url
return f.execErr
}
@@ -264,9 +262,6 @@ func TestExec(t *testing.T) {
t.Errorf("%s: Did not get expected container query param for exec request", test.name)
return
}
if ex.method != "POST" {
t.Errorf("%s: Did not get method for exec request: %s", test.name, ex.method)
}
})
}
}

View File

@@ -425,8 +425,10 @@ func GetPodRunningTimeoutFlag(cmd *cobra.Command) (time.Duration, error) {
type FeatureGate string
const (
ApplySet FeatureGate = "KUBECTL_APPLYSET"
CmdPluginAsSubcommand FeatureGate = "KUBECTL_ENABLE_CMD_SHADOW"
ApplySet FeatureGate = "KUBECTL_APPLYSET"
CmdPluginAsSubcommand FeatureGate = "KUBECTL_ENABLE_CMD_SHADOW"
InteractiveDelete FeatureGate = "KUBECTL_INTERACTIVE_DELETE"
RemoteCommandWebsockets FeatureGate = "KUBECTL_REMOTE_COMMAND_WEBSOCKETS"
)
// IsEnabled returns true iff environment variable is set to true.

View File

@@ -35,6 +35,7 @@ require (
github.com/moby/spdystream v0.2.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.16.0 // indirect
github.com/prometheus/client_model v0.4.0 // indirect

View File

@@ -111,6 +111,7 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=

View File

@@ -42,6 +42,8 @@ import (
"sigs.k8s.io/yaml"
utilkubectl "k8s.io/kubectl/pkg/cmd/util"
v1 "k8s.io/api/core/v1"
rbacv1 "k8s.io/api/rbac/v1"
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
@@ -801,6 +803,66 @@ metadata:
framework.ExpectNoError(c.CoreV1().Pods(ns).Delete(ctx, "run-test-3", metav1.DeleteOptions{}))
})
ginkgo.It("should support inline execution and attach with websockets or fallback to spdy", func(ctx context.Context) {
waitForStdinContent := func(pod, content string) string {
var logOutput string
err := wait.PollUntilContextTimeout(ctx, 10*time.Second, 5*time.Minute, false, func(ctx context.Context) (bool, error) {
logOutput = e2ekubectl.RunKubectlOrDie(ns, "logs", pod)
return strings.Contains(logOutput, content), nil
})
framework.ExpectNoError(err, "waiting for '%v' output", content)
return logOutput
}
ginkgo.By("executing a command with run and attach with stdin")
// We wait for a non-empty line so we know kubectl has attached
e2ekubectl.NewKubectlCommand(ns, "run", "run-test", "--image="+busyboxImage, "--restart=OnFailure", podRunningTimeoutArg, "--attach=true", "--stdin", "--", "sh", "-c", "echo -n read: && cat && echo 'stdin closed'").
WithStdinData("value\nabcd1234").
AppendEnv([]string{string(utilkubectl.RemoteCommandWebsockets), "true"}).
ExecOrDie(ns)
runOutput := waitForStdinContent("run-test", "stdin closed")
gomega.Expect(runOutput).To(gomega.ContainSubstring("read:value"))
gomega.Expect(runOutput).To(gomega.ContainSubstring("abcd1234"))
gomega.Expect(runOutput).To(gomega.ContainSubstring("stdin closed"))
framework.ExpectNoError(c.CoreV1().Pods(ns).Delete(ctx, "run-test", metav1.DeleteOptions{}))
ginkgo.By("executing a command with run and attach without stdin")
// There is a race on this scenario described in #73099
// It fails if we are not able to attach before the container prints
// "stdin closed", but hasn't exited yet.
// We wait 10 seconds before printing to give time to kubectl to attach
// to the container, this does not solve the race though.
e2ekubectl.NewKubectlCommand(ns, "run", "run-test-2", "--image="+busyboxImage, "--restart=OnFailure", podRunningTimeoutArg, "--attach=true", "--leave-stdin-open=true", "--", "sh", "-c", "cat && echo 'stdin closed'").
WithStdinData("abcd1234").
AppendEnv([]string{string(utilkubectl.RemoteCommandWebsockets), "true"}).
ExecOrDie(ns)
runOutput = waitForStdinContent("run-test-2", "stdin closed")
gomega.Expect(runOutput).ToNot(gomega.ContainSubstring("abcd1234"))
gomega.Expect(runOutput).To(gomega.ContainSubstring("stdin closed"))
framework.ExpectNoError(c.CoreV1().Pods(ns).Delete(ctx, "run-test-2", metav1.DeleteOptions{}))
ginkgo.By("executing a command with run and attach with stdin with open stdin should remain running")
e2ekubectl.NewKubectlCommand(ns, "run", "run-test-3", "--image="+busyboxImage, "--restart=OnFailure", podRunningTimeoutArg, "--attach=true", "--leave-stdin-open=true", "--stdin", "--", "sh", "-c", "cat && echo 'stdin closed'").
WithStdinData("abcd1234\n").
AppendEnv([]string{string(utilkubectl.RemoteCommandWebsockets), "true"}).
ExecOrDie(ns)
runOutput = waitForStdinContent("run-test-3", "abcd1234")
gomega.Expect(runOutput).To(gomega.ContainSubstring("abcd1234"))
gomega.Expect(runOutput).ToNot(gomega.ContainSubstring("stdin closed"))
g := func(pods []*v1.Pod) sort.Interface { return sort.Reverse(controller.ActivePods(pods)) }
runTestPod, _, err := polymorphichelpers.GetFirstPod(f.ClientSet.CoreV1(), ns, "run=run-test-3", 1*time.Minute, g)
framework.ExpectNoError(err)
framework.ExpectNoError(e2epod.WaitTimeoutForPodReadyInNamespace(ctx, c, runTestPod.Name, ns, time.Minute))
framework.ExpectNoError(c.CoreV1().Pods(ns).Delete(ctx, "run-test-3", metav1.DeleteOptions{}))
})
ginkgo.It("should contain last line of the log", func(ctx context.Context) {
podName := "run-log-test"