Adds metrics to PortForward Websockets
This commit is contained in:
		| @@ -33,7 +33,7 @@ var registerMetricsOnce sync.Once | |||||||
|  |  | ||||||
| var ( | var ( | ||||||
| 	// streamTranslatorRequestsTotal counts the number of requests that were handled by | 	// streamTranslatorRequestsTotal counts the number of requests that were handled by | ||||||
| 	// the StreamTranslatorProxy. | 	// the StreamTranslatorProxy (RemoteCommand subprotocol). | ||||||
| 	streamTranslatorRequestsTotal = metrics.NewCounterVec( | 	streamTranslatorRequestsTotal = metrics.NewCounterVec( | ||||||
| 		&metrics.CounterOpts{ | 		&metrics.CounterOpts{ | ||||||
| 			Subsystem:      subsystem, | 			Subsystem:      subsystem, | ||||||
| @@ -43,19 +43,37 @@ var ( | |||||||
| 		}, | 		}, | ||||||
| 		[]string{statuscode}, | 		[]string{statuscode}, | ||||||
| 	) | 	) | ||||||
|  | 	// streamTunnelRequestsTotal counts the number of requests that were handled by | ||||||
|  | 	// the StreamTunnelProxy (PortForward subprotocol). | ||||||
|  | 	streamTunnelRequestsTotal = metrics.NewCounterVec( | ||||||
|  | 		&metrics.CounterOpts{ | ||||||
|  | 			Subsystem:      subsystem, | ||||||
|  | 			Name:           "stream_tunnel_requests_total", | ||||||
|  | 			Help:           "Total number of requests that were handled by the StreamTunnelProxy, which processes streaming PortForward/V2", | ||||||
|  | 			StabilityLevel: metrics.ALPHA, | ||||||
|  | 		}, | ||||||
|  | 		[]string{statuscode}, | ||||||
|  | 	) | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func Register() { | func Register() { | ||||||
| 	registerMetricsOnce.Do(func() { | 	registerMetricsOnce.Do(func() { | ||||||
| 		legacyregistry.MustRegister(streamTranslatorRequestsTotal) | 		legacyregistry.MustRegister(streamTranslatorRequestsTotal) | ||||||
|  | 		legacyregistry.MustRegister(streamTunnelRequestsTotal) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func ResetForTest() { | func ResetForTest() { | ||||||
| 	streamTranslatorRequestsTotal.Reset() | 	streamTranslatorRequestsTotal.Reset() | ||||||
|  | 	streamTunnelRequestsTotal.Reset() | ||||||
| } | } | ||||||
|  |  | ||||||
| // IncStreamTranslatorRequest increments the # of requests handled by the StreamTranslatorProxy. | // IncStreamTranslatorRequest increments the # of requests handled by the StreamTranslatorProxy. | ||||||
| func IncStreamTranslatorRequest(ctx context.Context, status string) { | func IncStreamTranslatorRequest(ctx context.Context, status string) { | ||||||
| 	streamTranslatorRequestsTotal.WithContext(ctx).WithLabelValues(status).Add(1) | 	streamTranslatorRequestsTotal.WithContext(ctx).WithLabelValues(status).Add(1) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // IncStreamTunnelRequest increments the # of requests handled by the StreamTunnelProxy. | ||||||
|  | func IncStreamTunnelRequest(ctx context.Context, status string) { | ||||||
|  | 	streamTunnelRequestsTotal.WithContext(ctx).WithLabelValues(status).Add(1) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -19,10 +19,12 @@ package proxy | |||||||
| import ( | import ( | ||||||
| 	"bufio" | 	"bufio" | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -34,6 +36,7 @@ import ( | |||||||
| 	"k8s.io/apimachinery/pkg/util/httpstream/wsstream" | 	"k8s.io/apimachinery/pkg/util/httpstream/wsstream" | ||||||
| 	utilnet "k8s.io/apimachinery/pkg/util/net" | 	utilnet "k8s.io/apimachinery/pkg/util/net" | ||||||
| 	constants "k8s.io/apimachinery/pkg/util/portforward" | 	constants "k8s.io/apimachinery/pkg/util/portforward" | ||||||
|  | 	"k8s.io/apiserver/pkg/util/proxy/metrics" | ||||||
| 	"k8s.io/client-go/tools/portforward" | 	"k8s.io/client-go/tools/portforward" | ||||||
| 	"k8s.io/klog/v2" | 	"k8s.io/klog/v2" | ||||||
| ) | ) | ||||||
| @@ -61,6 +64,7 @@ func (h *TunnelingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { | |||||||
|  |  | ||||||
| 	spdyProtocols := spdyProtocolsFromWebsocketProtocols(req) | 	spdyProtocols := spdyProtocolsFromWebsocketProtocols(req) | ||||||
| 	if len(spdyProtocols) == 0 { | 	if len(spdyProtocols) == 0 { | ||||||
|  | 		metrics.IncStreamTunnelRequest(req.Context(), strconv.Itoa(http.StatusBadRequest)) | ||||||
| 		http.Error(w, "unable to upgrade: no tunneling spdy protocols provided", http.StatusBadRequest) | 		http.Error(w, "unable to upgrade: no tunneling spdy protocols provided", http.StatusBadRequest) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -326,6 +330,7 @@ func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.R | |||||||
| 		if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(spdy.HeaderSpdy31)) { | 		if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(spdy.HeaderSpdy31)) { | ||||||
| 			klog.Errorf("unable to upgrade: missing upgrade headers in response: %#v", backendResponse.Header) | 			klog.Errorf("unable to upgrade: missing upgrade headers in response: %#v", backendResponse.Header) | ||||||
| 			u.err = fmt.Errorf("unable to upgrade: missing upgrade headers in response") | 			u.err = fmt.Errorf("unable to upgrade: missing upgrade headers in response") | ||||||
|  | 			metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(http.StatusInternalServerError)) | ||||||
| 			http.Error(u.w, u.err.Error(), http.StatusInternalServerError) | 			http.Error(u.w, u.err.Error(), http.StatusInternalServerError) | ||||||
| 			return u.err | 			return u.err | ||||||
| 		} | 		} | ||||||
| @@ -347,16 +352,20 @@ func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.R | |||||||
| 		conn, err := upgrader.Upgrade(u.w, u.req, nil) | 		conn, err := upgrader.Upgrade(u.w, u.req, nil) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			klog.Errorf("error upgrading websocket connection: %v", err) | 			klog.Errorf("error upgrading websocket connection: %v", err) | ||||||
|  | 			metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(http.StatusInternalServerError)) | ||||||
| 			u.err = err | 			u.err = err | ||||||
| 			return u.err | 			return u.err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		klog.V(4).Infof("websocket connection created: %s", conn.Subprotocol()) | 		klog.V(4).Infof("websocket connection created: %s", conn.Subprotocol()) | ||||||
|  | 		metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(http.StatusSwitchingProtocols)) | ||||||
| 		u.conn = portforward.NewTunnelingConnection("server", conn) | 		u.conn = portforward.NewTunnelingConnection("server", conn) | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// anything other than an upgrade should pass through the backend response | 	// anything other than an upgrade should pass through the backend response | ||||||
|  | 	klog.Errorf("SPDY upgrade failed: %s", backendResponse.Status) | ||||||
|  | 	metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(backendResponse.StatusCode)) | ||||||
|  |  | ||||||
| 	// try to hijack | 	// try to hijack | ||||||
| 	conn, _, err = u.w.(http.Hijacker).Hijack() | 	conn, _, err = u.w.(http.Hijacker).Hijack() | ||||||
|   | |||||||
| @@ -40,11 +40,17 @@ import ( | |||||||
| 	"k8s.io/apimachinery/pkg/util/proxy" | 	"k8s.io/apimachinery/pkg/util/proxy" | ||||||
| 	"k8s.io/apimachinery/pkg/util/wait" | 	"k8s.io/apimachinery/pkg/util/wait" | ||||||
| 	"k8s.io/apiserver/pkg/registry/rest" | 	"k8s.io/apiserver/pkg/registry/rest" | ||||||
|  | 	"k8s.io/apiserver/pkg/util/proxy/metrics" | ||||||
| 	restconfig "k8s.io/client-go/rest" | 	restconfig "k8s.io/client-go/rest" | ||||||
| 	"k8s.io/client-go/tools/portforward" | 	"k8s.io/client-go/tools/portforward" | ||||||
|  | 	"k8s.io/component-base/metrics/legacyregistry" | ||||||
|  | 	"k8s.io/component-base/metrics/testutil" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) { | func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) { | ||||||
|  | 	metrics.Register() | ||||||
|  | 	metrics.ResetForTest() | ||||||
|  | 	t.Cleanup(metrics.ResetForTest) | ||||||
| 	// Create fake upstream SPDY server, with channel receiving SPDY streams. | 	// Create fake upstream SPDY server, with channel receiving SPDY streams. | ||||||
| 	streamChan := make(chan httpstream.Stream) | 	streamChan := make(chan httpstream.Stream) | ||||||
| 	defer close(streamChan) | 	defer close(streamChan) | ||||||
| @@ -106,6 +112,157 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) { | |||||||
| 		t.Fatalf("timeout waiting for spdy stream to arrive on channel.") | 		t.Fatalf("timeout waiting for spdy stream to arrive on channel.") | ||||||
| 	} | 	} | ||||||
| 	assert.Equal(t, randomData, actual, "error validating tunneled random data") | 	assert.Equal(t, randomData, actual, "error validating tunneled random data") | ||||||
|  |  | ||||||
|  | 	// Validate the streamtunnel metrics; should be one 101 Switching Protocols. | ||||||
|  | 	metricNames := []string{"apiserver_stream_tunnel_requests_total"} | ||||||
|  | 	expected := ` | ||||||
|  | # HELP apiserver_stream_tunnel_requests_total [ALPHA] Total number of requests that were handled by the StreamTunnelProxy, which processes streaming PortForward/V2 | ||||||
|  | # TYPE apiserver_stream_tunnel_requests_total counter | ||||||
|  | apiserver_stream_tunnel_requests_total{code="101"} 1 | ||||||
|  | ` | ||||||
|  | 	if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(expected), metricNames...); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestTunnelingHandler_BadRequestWithoutProtcols(t *testing.T) { | ||||||
|  | 	metrics.Register() | ||||||
|  | 	metrics.ResetForTest() | ||||||
|  | 	t.Cleanup(metrics.ResetForTest) | ||||||
|  | 	// Create TunnelingHandler with empty upstream URL and fake transport. An error should | ||||||
|  | 	// be returned before the upstream proxying to SPDY occurs, so a test SPDY server is not needed. | ||||||
|  | 	transport, err := fakeTransport() | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	upgradeHandler := proxy.NewUpgradeAwareHandler(&url.URL{}, transport, false, true, proxy.NewErrorResponder(&fakeResponder{})) | ||||||
|  | 	tunnelingHandler := NewTunnelingHandler(upgradeHandler) | ||||||
|  | 	tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { | ||||||
|  | 		tunnelingHandler.ServeHTTP(w, req) | ||||||
|  | 	})) | ||||||
|  | 	defer tunnelingServer.Close() | ||||||
|  | 	// Create SPDY client connection containing a TunnelingConnection by upgrading | ||||||
|  | 	// a request to TunnelingHandler using new portforward version 2. | ||||||
|  | 	tunnelingURL, err := url.Parse(tunnelingServer.URL) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	dialer, err := portforward.NewSPDYOverWebsocketDialer(tunnelingURL, &restconfig.Config{Host: tunnelingURL.Host}) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	// Request without subprotocols--causing a bad request to be returned. | ||||||
|  | 	_, protocol, err := dialer.Dial("") | ||||||
|  | 	require.Error(t, err) | ||||||
|  | 	assert.Equal(t, "", protocol) | ||||||
|  |  | ||||||
|  | 	// Validate the streamtunnel metrics; should be one 400 failure. | ||||||
|  | 	metricNames := []string{"apiserver_stream_tunnel_requests_total"} | ||||||
|  | 	expected := ` | ||||||
|  | # HELP apiserver_stream_tunnel_requests_total [ALPHA] Total number of requests that were handled by the StreamTunnelProxy, which processes streaming PortForward/V2 | ||||||
|  | # TYPE apiserver_stream_tunnel_requests_total counter | ||||||
|  | apiserver_stream_tunnel_requests_total{code="400"} 1 | ||||||
|  | ` | ||||||
|  | 	if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(expected), metricNames...); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestTunnelingHandler_BadHandshakeError(t *testing.T) { | ||||||
|  | 	metrics.Register() | ||||||
|  | 	metrics.ResetForTest() | ||||||
|  | 	t.Cleanup(metrics.ResetForTest) | ||||||
|  | 	// Create fake upstream SPDY server, returning forbidden for bad handshake. | ||||||
|  | 	spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { | ||||||
|  | 		// Handshake fails. | ||||||
|  | 		_, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name}) | ||||||
|  | 		require.Error(t, err, "handshake should have returned an error") | ||||||
|  | 		assert.True(t, strings.Contains(err.Error(), "unable to negotiate protocol")) | ||||||
|  | 		w.WriteHeader(http.StatusForbidden) | ||||||
|  | 	})) | ||||||
|  | 	defer spdyServer.Close() | ||||||
|  | 	// Create UpgradeAwareProxy handler, with url/transport pointing to upstream SPDY. Then | ||||||
|  | 	// create TunnelingHandler by injecting upgrade handler. Create TunnelingServer. | ||||||
|  | 	url, err := url.Parse(spdyServer.URL) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	transport, err := fakeTransport() | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	upgradeHandler := proxy.NewUpgradeAwareHandler(url, transport, false, true, proxy.NewErrorResponder(&fakeResponder{})) | ||||||
|  | 	tunnelingHandler := NewTunnelingHandler(upgradeHandler) | ||||||
|  | 	tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { | ||||||
|  | 		tunnelingHandler.ServeHTTP(w, req) | ||||||
|  | 	})) | ||||||
|  | 	defer tunnelingServer.Close() | ||||||
|  | 	// Create SPDY client connection containing a TunnelingConnection by upgrading | ||||||
|  | 	// a request to TunnelingHandler using new portforward version 2. | ||||||
|  | 	tunnelingURL, err := url.Parse(tunnelingServer.URL) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	dialer, err := portforward.NewSPDYOverWebsocketDialer(tunnelingURL, &restconfig.Config{Host: tunnelingURL.Host}) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	// Handshake will fail, returning a 400-level response. | ||||||
|  | 	_, protocol, err := dialer.Dial("UNKNOWN_SUBPROTOCOL") | ||||||
|  | 	require.Error(t, err) | ||||||
|  | 	assert.Equal(t, "", protocol) | ||||||
|  |  | ||||||
|  | 	// Validate the streamtunnel metrics; should be one 400 failure. | ||||||
|  | 	metricNames := []string{"apiserver_stream_tunnel_requests_total"} | ||||||
|  | 	expected := ` | ||||||
|  | # HELP apiserver_stream_tunnel_requests_total [ALPHA] Total number of requests that were handled by the StreamTunnelProxy, which processes streaming PortForward/V2 | ||||||
|  | # TYPE apiserver_stream_tunnel_requests_total counter | ||||||
|  | apiserver_stream_tunnel_requests_total{code="400"} 1 | ||||||
|  | ` | ||||||
|  | 	if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(expected), metricNames...); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestTunnelingHandler_UpstreamSPDYServerErrorPropagated(t *testing.T) { | ||||||
|  | 	metrics.Register() | ||||||
|  | 	metrics.ResetForTest() | ||||||
|  | 	t.Cleanup(metrics.ResetForTest) | ||||||
|  |  | ||||||
|  | 	// Validate that various 500-level errors are propagated and incremented in metrics. | ||||||
|  | 	for statusCode, codeStr := range map[int]string{ | ||||||
|  | 		http.StatusInternalServerError: "500", | ||||||
|  | 		http.StatusBadGateway:          "502", | ||||||
|  | 		http.StatusServiceUnavailable:  "503", | ||||||
|  | 	} { | ||||||
|  | 		// Create fake upstream SPDY server, which returns a 500-level error. | ||||||
|  | 		spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { | ||||||
|  | 			_, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name}) | ||||||
|  | 			require.NoError(t, err, "handshake should have succeeded") | ||||||
|  | 			// Returned status code should be incremented in metrics. | ||||||
|  | 			w.WriteHeader(statusCode) | ||||||
|  | 		})) | ||||||
|  | 		defer spdyServer.Close() | ||||||
|  | 		// Create UpgradeAwareProxy handler, with url/transport pointing to upstream SPDY. Then | ||||||
|  | 		// create TunnelingHandler by injecting upgrade handler. Create TunnelingServer. | ||||||
|  | 		url, err := url.Parse(spdyServer.URL) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		transport, err := fakeTransport() | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		upgradeHandler := proxy.NewUpgradeAwareHandler(url, transport, false, true, proxy.NewErrorResponder(&fakeResponder{})) | ||||||
|  | 		tunnelingHandler := NewTunnelingHandler(upgradeHandler) | ||||||
|  | 		tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { | ||||||
|  | 			tunnelingHandler.ServeHTTP(w, req) | ||||||
|  | 		})) | ||||||
|  | 		defer tunnelingServer.Close() | ||||||
|  | 		// Create SPDY client connection containing a TunnelingConnection by upgrading | ||||||
|  | 		// a request to TunnelingHandler using new portforward version 2. | ||||||
|  | 		tunnelingURL, err := url.Parse(tunnelingServer.URL) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		dialer, err := portforward.NewSPDYOverWebsocketDialer(tunnelingURL, &restconfig.Config{Host: tunnelingURL.Host}) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		_, protocol, err := dialer.Dial(constants.PortForwardV1Name) | ||||||
|  | 		require.Error(t, err) | ||||||
|  | 		assert.Equal(t, "", protocol) | ||||||
|  |  | ||||||
|  | 		// Validate the streamtunnel metrics are incrementing 500-level status codes. | ||||||
|  | 		metricNames := []string{"apiserver_stream_tunnel_requests_total"} | ||||||
|  | 		expected := ` | ||||||
|  | # HELP apiserver_stream_tunnel_requests_total [ALPHA] Total number of requests that were handled by the StreamTunnelProxy, which processes streaming PortForward/V2 | ||||||
|  | # TYPE apiserver_stream_tunnel_requests_total counter | ||||||
|  | apiserver_stream_tunnel_requests_total{code="` + codeStr + `"} 1 | ||||||
|  | ` | ||||||
|  | 		if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(expected), metricNames...); err != nil { | ||||||
|  | 			t.Fatal(err) | ||||||
|  | 		} | ||||||
|  | 		metrics.ResetForTest() | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestTunnelingResponseWriter_Hijack(t *testing.T) { | func TestTunnelingResponseWriter_Hijack(t *testing.T) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Sean Sullivan
					Sean Sullivan