Merge pull request #123785 from seans3/streamtunnel-unit-tests
Adds unit tests to `PortForward` streamtunnel
This commit is contained in:
		| @@ -19,6 +19,8 @@ package proxy | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"crypto/rand" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| @@ -48,7 +50,6 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) { | ||||
| 	defer close(streamChan) | ||||
| 	stopServerChan := make(chan struct{}) | ||||
| 	defer close(stopServerChan) | ||||
| 	// Create fake upstream SPDY server. | ||||
| 	spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { | ||||
| 		_, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name}) | ||||
| 		require.NoError(t, err) | ||||
| @@ -107,6 +108,120 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) { | ||||
| 	assert.Equal(t, randomData, actual, "error validating tunneled random data") | ||||
| } | ||||
|  | ||||
| func TestTunnelingResponseWriter_Hijack(t *testing.T) { | ||||
| 	// Regular hijack returns connection, nil bufio, and no error. | ||||
| 	trw := &tunnelingResponseWriter{conn: &mockConn{}} | ||||
| 	assert.False(t, trw.hijacked, "hijacked field starts false before Hijack()") | ||||
| 	assert.False(t, trw.written, "written field startes false before Hijack()") | ||||
| 	actual, bufio, err := trw.Hijack() | ||||
| 	assert.NoError(t, err, "Hijack() does not return error") | ||||
| 	assert.NotNil(t, actual, "conn returned from Hijack() is not nil") | ||||
| 	assert.Nil(t, bufio, "bufio returned from Hijack() is always nil") | ||||
| 	assert.True(t, trw.hijacked, "hijacked field becomes true after Hijack()") | ||||
| 	assert.False(t, trw.written, "written field stays false after Hijack()") | ||||
| 	// Hijacking after writing to response writer is an error. | ||||
| 	trw = &tunnelingResponseWriter{written: true} | ||||
| 	_, _, err = trw.Hijack() | ||||
| 	assert.Error(t, err, "Hijack after writing to response writer is error") | ||||
| 	assert.True(t, strings.Contains(err.Error(), "connection has already been written to")) | ||||
| 	// Hijacking after already hijacked is an error. | ||||
| 	trw = &tunnelingResponseWriter{hijacked: true} | ||||
| 	_, _, err = trw.Hijack() | ||||
| 	assert.Error(t, err, "Hijack after writing to response writer is error") | ||||
| 	assert.True(t, strings.Contains(err.Error(), "connection has already been hijacked")) | ||||
| } | ||||
|  | ||||
| func TestTunnelingResponseWriter_DelegateResponseWriter(t *testing.T) { | ||||
| 	// Validate Header() for delegate response writer. | ||||
| 	expectedHeader := http.Header{} | ||||
| 	expectedHeader.Set("foo", "bar") | ||||
| 	trw := &tunnelingResponseWriter{w: &mockResponseWriter{header: expectedHeader}} | ||||
| 	assert.Equal(t, expectedHeader, trw.Header(), "") | ||||
| 	// Validate Write() for delegate response writer. | ||||
| 	expectedWrite := []byte("this is a test write string") | ||||
| 	assert.False(t, trw.written, "written field is before Write()") | ||||
| 	_, err := trw.Write(expectedWrite) | ||||
| 	assert.NoError(t, err, "No error expected after Write() on tunneling response writer") | ||||
| 	assert.True(t, trw.written, "written field is set after writing to tunneling response writer") | ||||
| 	// Writing to response writer after hijacked is an error. | ||||
| 	trw.hijacked = true | ||||
| 	_, err = trw.Write(expectedWrite) | ||||
| 	assert.Error(t, err, "Writing to ResponseWriter after Hijack() is an error") | ||||
| 	assert.True(t, errors.Is(err, http.ErrHijacked), "Hijacked error returned if writing after hijacked") | ||||
| 	// Validate WriteHeader(). | ||||
| 	trw = &tunnelingResponseWriter{w: &mockResponseWriter{}} | ||||
| 	expectedStatusCode := 201 | ||||
| 	assert.False(t, trw.written, "Written field originally false in delegate response writer") | ||||
| 	trw.WriteHeader(expectedStatusCode) | ||||
| 	assert.Equal(t, expectedStatusCode, trw.w.(*mockResponseWriter).statusCode, "Expected written status code is correct") | ||||
| 	assert.True(t, trw.written, "Written field set to true after writing delegate response writer") | ||||
| 	// Response writer already written to does not write status. | ||||
| 	trw = &tunnelingResponseWriter{w: &mockResponseWriter{}} | ||||
| 	trw.written = true | ||||
| 	trw.WriteHeader(expectedStatusCode) | ||||
| 	assert.Equal(t, 0, trw.w.(*mockResponseWriter).statusCode, "No status code for previously written response writer") | ||||
| 	// Hijacked response writer does not write status. | ||||
| 	trw = &tunnelingResponseWriter{w: &mockResponseWriter{}} | ||||
| 	trw.hijacked = true | ||||
| 	trw.WriteHeader(expectedStatusCode) | ||||
| 	assert.Equal(t, 0, trw.w.(*mockResponseWriter).statusCode, "No status code written to hijacked response writer") | ||||
| 	assert.False(t, trw.written, "Hijacked response writer does not write status") | ||||
| 	// Writing "101 Switching Protocols" status is an error, since it should happen via hijacked connection. | ||||
| 	trw = &tunnelingResponseWriter{w: &mockResponseWriter{header: http.Header{}}} | ||||
| 	trw.WriteHeader(http.StatusSwitchingProtocols) | ||||
| 	assert.Equal(t, http.StatusInternalServerError, trw.w.(*mockResponseWriter).statusCode, "Internal server error written") | ||||
| } | ||||
|  | ||||
| func TestTunnelingWebsocketUpgraderConn_LocalRemoteAddress(t *testing.T) { | ||||
| 	expectedLocalAddr := &net.TCPAddr{ | ||||
| 		IP:   net.IPv4(127, 0, 0, 1), | ||||
| 		Port: 80, | ||||
| 	} | ||||
| 	expectedRemoteAddr := &net.TCPAddr{ | ||||
| 		IP:   net.IPv4(127, 0, 0, 2), | ||||
| 		Port: 443, | ||||
| 	} | ||||
| 	tc := &tunnelingWebsocketUpgraderConn{ | ||||
| 		conn: &mockConn{ | ||||
| 			localAddr:  expectedLocalAddr, | ||||
| 			remoteAddr: expectedRemoteAddr, | ||||
| 		}, | ||||
| 	} | ||||
| 	assert.Equal(t, expectedLocalAddr, tc.LocalAddr(), "LocalAddr() returns expected TCPAddr") | ||||
| 	assert.Equal(t, expectedRemoteAddr, tc.RemoteAddr(), "RemoteAddr() returns expected TCPAddr") | ||||
| 	// Connection nil, returns empty address | ||||
| 	tc.conn = nil | ||||
| 	assert.Equal(t, noopAddr{}, tc.LocalAddr(), "nil connection, LocalAddr() returns noopAddr") | ||||
| 	assert.Equal(t, noopAddr{}, tc.RemoteAddr(), "nil connection, RemoteAddr() returns noopAddr") | ||||
| 	// Validate the empty strings from noopAddr | ||||
| 	assert.Equal(t, "", noopAddr{}.Network(), "noopAddr Network() returns empty string") | ||||
| 	assert.Equal(t, "", noopAddr{}.String(), "noopAddr String() returns empty string") | ||||
| } | ||||
|  | ||||
| func TestTunnelingWebsocketUpgraderConn_SetDeadline(t *testing.T) { | ||||
| 	tc := &tunnelingWebsocketUpgraderConn{conn: &mockConn{}} | ||||
| 	expected := time.Now() | ||||
| 	assert.Nil(t, tc.SetDeadline(expected), "SetDeadline does not return error") | ||||
| 	assert.Equal(t, expected, tc.conn.(*mockConn).readDeadline, "SetDeadline() sets read deadline") | ||||
| 	assert.Equal(t, expected, tc.conn.(*mockConn).writeDeadline, "SetDeadline() sets write deadline") | ||||
| 	expected = time.Now() | ||||
| 	assert.Nil(t, tc.SetWriteDeadline(expected), "SetWriteDeadline does not return error") | ||||
| 	assert.Equal(t, expected, tc.conn.(*mockConn).writeDeadline, "Expected write deadline set") | ||||
| 	expected = time.Now() | ||||
| 	assert.Nil(t, tc.SetReadDeadline(expected), "SetReadDeadline does not return error") | ||||
| 	assert.Equal(t, expected, tc.conn.(*mockConn).readDeadline, "Expected read deadline set") | ||||
| 	expectedErr := fmt.Errorf("deadline error") | ||||
| 	tc = &tunnelingWebsocketUpgraderConn{conn: &mockConn{deadlineErr: expectedErr}} | ||||
| 	expected = time.Now() | ||||
| 	actualErr := tc.SetDeadline(expected) | ||||
| 	assert.Equal(t, expectedErr, actualErr, "SetDeadline() expected error returned") | ||||
| 	// Connection nil, returns nil error. | ||||
| 	tc.conn = nil | ||||
| 	assert.Nil(t, tc.SetDeadline(expected), "SetDeadline() with nil connection always returns nil error") | ||||
| 	assert.Nil(t, tc.SetWriteDeadline(expected), "SetWriteDeadline() with nil connection always returns nil error") | ||||
| 	assert.Nil(t, tc.SetReadDeadline(expected), "SetReadDeadline() with nil connection always returns nil error") | ||||
| } | ||||
|  | ||||
| var expectedContentLengthHeaders = http.Header{ | ||||
| 	"Content-Length": []string{"25"}, | ||||
| 	"Date":           []string{"Sun, 25 Feb 2024 08:09:25 GMT"}, | ||||
| @@ -330,7 +445,12 @@ func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response, ba | ||||
| var _ net.Conn = &mockConn{} | ||||
|  | ||||
| type mockConn struct { | ||||
| 	written []byte | ||||
| 	written       []byte | ||||
| 	localAddr     *net.TCPAddr | ||||
| 	remoteAddr    *net.TCPAddr | ||||
| 	readDeadline  time.Time | ||||
| 	writeDeadline time.Time | ||||
| 	deadlineErr   error | ||||
| } | ||||
|  | ||||
| func (mc *mockConn) Write(p []byte) (int, error) { | ||||
| @@ -338,13 +458,31 @@ func (mc *mockConn) Write(p []byte) (int, error) { | ||||
| 	return len(p), nil | ||||
| } | ||||
|  | ||||
| func (mc *mockConn) Read(p []byte) (int, error)         { return 0, nil } | ||||
| func (mc *mockConn) Close() error                       { return nil } | ||||
| func (mc *mockConn) LocalAddr() net.Addr                { return &net.TCPAddr{} } | ||||
| func (mc *mockConn) RemoteAddr() net.Addr               { return &net.TCPAddr{} } | ||||
| func (mc *mockConn) SetDeadline(t time.Time) error      { return nil } | ||||
| func (mc *mockConn) SetReadDeadline(t time.Time) error  { return nil } | ||||
| func (mc *mockConn) SetWriteDeadline(t time.Time) error { return nil } | ||||
| func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil } | ||||
| func (mc *mockConn) Close() error               { return nil } | ||||
| func (mc *mockConn) LocalAddr() net.Addr        { return mc.localAddr } | ||||
| func (mc *mockConn) RemoteAddr() net.Addr       { return mc.remoteAddr } | ||||
| func (mc *mockConn) SetDeadline(t time.Time) error { | ||||
| 	mc.SetReadDeadline(t)  //nolint:errcheck | ||||
| 	mc.SetWriteDeadline(t) // nolint:errcheck | ||||
| 	return mc.deadlineErr | ||||
| } | ||||
| func (mc *mockConn) SetReadDeadline(t time.Time) error  { mc.readDeadline = t; return mc.deadlineErr } | ||||
| func (mc *mockConn) SetWriteDeadline(t time.Time) error { mc.writeDeadline = t; return mc.deadlineErr } | ||||
|  | ||||
| // mockResponseWriter implements "http.ResponseWriter" interface | ||||
| type mockResponseWriter struct { | ||||
| 	header     http.Header | ||||
| 	written    []byte | ||||
| 	statusCode int | ||||
| } | ||||
|  | ||||
| func (mrw *mockResponseWriter) Header() http.Header { return mrw.header } | ||||
| func (mrw *mockResponseWriter) Write(p []byte) (int, error) { | ||||
| 	mrw.written = append(mrw.written, p...) | ||||
| 	return len(p), nil | ||||
| } | ||||
| func (mrw *mockResponseWriter) WriteHeader(statusCode int) { mrw.statusCode = statusCode } | ||||
|  | ||||
| // fakeResponder implements "rest.Responder" interface. | ||||
| var _ rest.Responder = &fakeResponder{} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Kubernetes Prow Robot
					Kubernetes Prow Robot