Merge pull request #28942 from kubernetes/revert-28805-ssh-dial-timeout
Revert "Add a customized ssh dialer that will timeout"
This commit is contained in:
		| @@ -111,7 +111,7 @@ func makeSSHTunnel(user string, signer ssh.Signer, host string) (*SSHTunnel, err | ||||
|  | ||||
| func (s *SSHTunnel) Open() error { | ||||
| 	var err error | ||||
| 	s.client, err = defaultTimeoutDialer.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config) | ||||
| 	s.client, err = realTimeoutDialer.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config) | ||||
| 	tunnelOpenCounter.Inc() | ||||
| 	if err != nil { | ||||
| 		tunnelOpenFailCounter.Inc() | ||||
| @@ -154,9 +154,21 @@ type sshDialer interface { | ||||
| 	Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) | ||||
| } | ||||
|  | ||||
| // timeoutDialer implements a Dial() method that will timeout. The golang | ||||
| // Real implementation of sshDialer | ||||
| type realSSHDialer struct{} | ||||
|  | ||||
| var _ sshDialer = &realSSHDialer{} | ||||
|  | ||||
| func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { | ||||
| 	return ssh.Dial(network, addr, config) | ||||
| } | ||||
|  | ||||
| // timeoutDialer wraps an sshDialer with a timeout around Dial(). The golang | ||||
| // ssh library can hang indefinitely inside the Dial() call (see issue #23835). | ||||
| // Wrapping all Dial() calls with a conservative timeout provides safety against | ||||
| // getting stuck on that. | ||||
| type timeoutDialer struct { | ||||
| 	dialer  sshDialer | ||||
| 	timeout time.Duration | ||||
| } | ||||
|  | ||||
| @@ -164,32 +176,30 @@ type timeoutDialer struct { | ||||
| // seconds). This timeout is only intended to catch otherwise uncaught hangs. | ||||
| const sshDialTimeout = 150 * time.Second | ||||
|  | ||||
| var defaultTimeoutDialer sshDialer = &timeoutDialer{sshDialTimeout} | ||||
| var realTimeoutDialer sshDialer = &timeoutDialer{&realSSHDialer{}, sshDialTimeout} | ||||
|  | ||||
| func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { | ||||
| 	conn, err := net.Dial(network, addr) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	conn.SetDeadline(time.Now().Add(d.timeout)) | ||||
| 	// set to 0 so that conn will not time out after Dial. | ||||
| 	defer func() { | ||||
| 		conn.SetDeadline(time.Time{}) | ||||
| 	var client *ssh.Client | ||||
| 	errCh := make(chan error, 1) | ||||
| 	go func() { | ||||
| 		defer runtime.HandleCrash() | ||||
| 		var err error | ||||
| 		client, err = d.dialer.Dial(network, addr, config) | ||||
| 		errCh <- err | ||||
| 	}() | ||||
| 	// if conn times out, the NewClientConn will close it, so we will not end up | ||||
| 	// with hanging goroutines or open file descriptors. | ||||
| 	c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	select { | ||||
| 	case err := <-errCh: | ||||
| 		return client, err | ||||
| 	case <-time.After(d.timeout): | ||||
| 		return nil, fmt.Errorf("timed out dialing %s:%s", network, addr) | ||||
| 	} | ||||
| 	return ssh.NewClient(c, chans, reqs), nil | ||||
| } | ||||
|  | ||||
| // RunSSHCommand returns the stdout, stderr, and exit code from running cmd on | ||||
| // host as specific user, along with any SSH-level error. | ||||
| // If user=="", it will default (like SSH) to os.Getenv("USER") | ||||
| func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) { | ||||
| 	return runSSHCommand(defaultTimeoutDialer, cmd, user, host, signer, true) | ||||
| 	return runSSHCommand(realTimeoutDialer, cmd, user, host, signer, true) | ||||
| } | ||||
|  | ||||
| // Internal implementation of runSSHCommand, for testing | ||||
|   | ||||
| @@ -329,49 +329,38 @@ func TestSSHUser(t *testing.T) { | ||||
|  | ||||
| } | ||||
|  | ||||
| type slowDialer struct { | ||||
| 	delay time.Duration | ||||
| 	err   error | ||||
| } | ||||
|  | ||||
| func (s *slowDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { | ||||
| 	time.Sleep(s.delay) | ||||
| 	if s.err != nil { | ||||
| 		return nil, s.err | ||||
| 	} | ||||
| 	return &ssh.Client{}, nil | ||||
| } | ||||
|  | ||||
| func TestTimeoutDialer(t *testing.T) { | ||||
| 	testCases := []struct { | ||||
| 		delay             time.Duration | ||||
| 		timeout           time.Duration | ||||
| 		err               error | ||||
| 		expectedErrString string | ||||
| 	}{ | ||||
| 		// should cause ssh.Dial to timeout. | ||||
| 		{0, "i/o timeout"}, | ||||
| 		// should succeed | ||||
| 		{1 * time.Second, ""}, | ||||
| 		// delay > timeout should cause ssh.Dial to timeout. | ||||
| 		{1 * time.Second, 0, nil, "timed out dialing"}, | ||||
| 		// delay < timeout should return the result of the call to the dialer. | ||||
| 		{0, 1 * time.Second, nil, ""}, | ||||
| 		{0, 1 * time.Second, fmt.Errorf("test dial error"), "test dial error"}, | ||||
| 	} | ||||
| 	for _, tc := range testCases { | ||||
| 		// setup | ||||
| 		private, _, err := GenerateKey(2048) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("unexpected error: %v", err) | ||||
| 			t.FailNow() | ||||
| 		} | ||||
| 		server, err := runTestSSHServer("foo", "bar") | ||||
| 		if err != nil { | ||||
| 			t.Errorf("unexpected error: %v", err) | ||||
| 			t.FailNow() | ||||
| 		} | ||||
| 		privateData := EncodePrivateKey(private) | ||||
| 		tunnel, err := NewSSHTunnelFromBytes("foo", privateData, server.Host) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("unexpected error: %v", err) | ||||
| 			t.FailNow() | ||||
| 		} | ||||
| 		tunnel.SSHPort = server.Port | ||||
|  | ||||
| 		// test the dialer | ||||
| 		dialer := &timeoutDialer{tc.timeout} | ||||
| 		client, err := dialer.Dial("tcp", net.JoinHostPort(tunnel.Host, tunnel.SSHPort), tunnel.Config) | ||||
| 		dialer := &timeoutDialer{&slowDialer{tc.delay, tc.err}, tc.timeout} | ||||
| 		_, err := dialer.Dial("tcp", "addr:port", &ssh.ClientConfig{}) | ||||
| 		if len(tc.expectedErrString) == 0 && err != nil || | ||||
| 			!strings.Contains(fmt.Sprint(err), tc.expectedErrString) { | ||||
| 			t.Errorf("Expected error to contain %q; got %v", tc.expectedErrString, err) | ||||
| 		} | ||||
| 		if len(tc.expectedErrString) == 0 { | ||||
| 			// verify the connection doesn't timeout after the handshake is done. | ||||
| 			time.Sleep(tc.timeout + 1*time.Second) | ||||
| 			if _, _, err := client.OpenChannel("direct-tcpip", nil); err != nil { | ||||
| 				t.Errorf("unexpected error %v", err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Marek Grabowski
					Marek Grabowski