diff --git a/connection/ssh_executor.py b/connection/ssh_executor.py index 237420f..9917df7 100644 --- a/connection/ssh_executor.py +++ b/connection/ssh_executor.py @@ -6,19 +6,21 @@ import socket import subprocess import paramiko +import os from datetime import timedelta, datetime from connection.base_executor import BaseExecutor -from core.test_run import TestRun +from core.test_run import TestRun, Blocked from test_utils.output import Output class SshExecutor(BaseExecutor): - def __init__(self, ip, username, port=22): - self.ip = ip + def __init__(self, host, username, port=22): + self.host = host self.user = username self.port = port self.ssh = paramiko.SSHClient() + self.ssh_config = None self._check_config_for_reboot_timeout() def __del__(self): @@ -26,26 +28,61 @@ class SshExecutor(BaseExecutor): def connect(self, user=None, port=None, timeout: timedelta = timedelta(seconds=30)): + hostname = self.host user = user or self.user port = port or self.port self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + config, sock, key_filename = None, None, None + # search for 'host' in SSH config try: - self.ssh.connect(self.ip, username=user, + path = os.path.expanduser('~/.ssh/config') + config = paramiko.SSHConfig.from_path(path) + except FileNotFoundError: + pass + + if config is not None: + target = config.lookup(self.host) + hostname = target['hostname'] + key_filename = target.get('identityfile', None) + user = target.get('user', user) + port = target.get('port', port) + if target.get('proxyjump', None) is not None: + proxy = config.lookup(target['proxyjump']) + jump = paramiko.SSHClient() + jump.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + try: + jump.connect(proxy['hostname'], username=proxy['user'], + port=int(proxy.get('port', 22)), key_filename=proxy.get('identityfile', None)) + transport = jump.get_transport() + local_addr = (proxy['hostname'], int(proxy.get('port', 22))) + dest_addr = (hostname, port) + sock = transport.open_channel("direct-tcpip", dest_addr, local_addr) + except Exception as e: + raise ConnectionError(f"An exception of type '{type(e)}' occurred while trying to " + f"connect to proxy '{proxy['hostname']}'.\n {e}") + + if user is None: + TestRun.block("There is no user given in config.") + + try: + self.ssh.connect(hostname, username=user, port=port, timeout=timeout.total_seconds(), - banner_timeout=timeout.total_seconds()) + banner_timeout=timeout.total_seconds(), + sock=sock, key_filename=key_filename) + self.ssh_config = config except paramiko.AuthenticationException as e: raise paramiko.AuthenticationException( f"Authentication exception occurred while trying to connect to DUT. " f"Please check your SSH key-based authentication.\n{e}") except (paramiko.SSHException, socket.timeout) as e: raise ConnectionError(f"An exception of type '{type(e)}' occurred while trying to " - f"connect to {self.ip}.\n {e}") + f"connect to {hostname}.\n {e}") def disconnect(self): try: self.ssh.close() except Exception: - raise Exception(f"An exception occurred while trying to disconnect from {self.ip}") + raise Exception(f"An exception occurred while trying to disconnect from {self.host}") def _execute(self, command, timeout): try: @@ -53,7 +90,7 @@ class SshExecutor(BaseExecutor): timeout=timeout.total_seconds()) except paramiko.SSHException as e: raise ConnectionError(f"An exception occurred while executing command '{command}' on" - f" {self.ip}\n{e}") + f" {self.host}\n{e}") return Output(stdout.read(), stderr.read(), stdout.channel.recv_exit_status()) @@ -71,8 +108,8 @@ class SshExecutor(BaseExecutor): for exclude in exclude_list: options.append(f"--exclude {exclude}") - src_to_dst = f"{self.user}@{self.ip}:{src} {dst} " if dut_to_controller else\ - f"{src} {self.user}@{self.ip}:{dst} " + src_to_dst = f"{self.user}@{self.host}:{src} {dst} " if dut_to_controller else\ + f"{src} {self.user}@{self.host}:{dst} " try: completed_process = subprocess.run( @@ -124,7 +161,7 @@ class SshExecutor(BaseExecutor): try: self.connect() return - except paramiko.AuthenticationException: + except (paramiko.AuthenticationException, Blocked): raise except Exception: continue diff --git a/core/test_run_utils.py b/core/test_run_utils.py index a27d018..ab0c26f 100644 --- a/core/test_run_utils.py +++ b/core/test_run_utils.py @@ -133,19 +133,20 @@ def __presetup(cls): if cls.config['type'] == 'ssh': try: IP(cls.config['ip']) + cls.config['host'] = cls.config['ip'] except ValueError: TestRun.block("IP address from config is in invalid format.") + except KeyError: + if 'host' not in cls.config: + TestRun.block("No IP address or host defined in config") port = cls.config.get('port', 22) - if 'user' in cls.config: - cls.executor = SshExecutor( - cls.config['ip'], - cls.config['user'], - port - ) - else: - TestRun.block("There is no user given in config.") + cls.executor = SshExecutor( + cls.config['host'], + cls.config.get('user', None), + port + ) elif cls.config['type'] == 'local': cls.executor = LocalExecutor() else: