# pylint: disable-msg=C0111 import os, time, signal, socket, re, fnmatch, logging, threading import paramiko from autotest_lib.client.common_lib import utils, error, global_config from autotest_lib.server import subcommand from autotest_lib.server.hosts import abstract_ssh class ParamikoHost(abstract_ssh.AbstractSSHHost): KEEPALIVE_TIMEOUT_SECONDS = 30 CONNECT_TIMEOUT_SECONDS = 30 CONNECT_TIMEOUT_RETRIES = 3 BUFFSIZE = 2**16 def _initialize(self, hostname, *args, **dargs): super(ParamikoHost, self)._initialize(hostname=hostname, *args, **dargs) # paramiko is very noisy, tone down the logging paramiko.util.log_to_file("/dev/null", paramiko.util.ERROR) self.keys = self.get_user_keys(hostname) self.pid = None @staticmethod def _load_key(path): """Given a path to a private key file, load the appropriate keyfile. Tries to load the file as both an RSAKey and a DSAKey. If the file cannot be loaded as either type, returns None.""" try: return paramiko.DSSKey.from_private_key_file(path) except paramiko.SSHException: try: return paramiko.RSAKey.from_private_key_file(path) except paramiko.SSHException: return None @staticmethod def _parse_config_line(line): """Given an ssh config line, return a (key, value) tuple for the config value listed in the line, or (None, None)""" match = re.match(r"\s*(\w+)\s*=?(.*)\n", line) if match: return match.groups() else: return None, None @staticmethod def get_user_keys(hostname): """Returns a mapping of path -> paramiko.PKey entries available for this user. Keys are found in the default locations (~/.ssh/id_[d|r]sa) as well as any IdentityFile entries in the standard ssh config files. """ raw_identity_files = ["~/.ssh/id_dsa", "~/.ssh/id_rsa"] for config_path in ("/etc/ssh/ssh_config", "~/.ssh/config"): config_path = os.path.expanduser(config_path) if not os.path.exists(config_path): continue host_pattern = "*" config_lines = open(config_path).readlines() for line in config_lines: key, value = ParamikoHost._parse_config_line(line) if key == "Host": host_pattern = value elif (key == "IdentityFile" and fnmatch.fnmatch(hostname, host_pattern)): raw_identity_files.append(value) # drop any files that use percent-escapes; we don't support them identity_files = [] UNSUPPORTED_ESCAPES = ["%d", "%u", "%l", "%h", "%r"] for path in raw_identity_files: # skip this path if it uses % escapes if sum((escape in path) for escape in UNSUPPORTED_ESCAPES): continue path = os.path.expanduser(path) if os.path.exists(path): identity_files.append(path) # load up all the keys that we can and return them user_keys = {} for path in identity_files: key = ParamikoHost._load_key(path) if key: user_keys[path] = key # load up all the ssh agent keys use_sshagent = global_config.global_config.get_config_value( 'AUTOSERV', 'use_sshagent_with_paramiko', type=bool) if use_sshagent: ssh_agent = paramiko.Agent() for i, key in enumerate(ssh_agent.get_keys()): user_keys['agent-key-%d' % i] = key return user_keys def _check_transport_error(self, transport): error = transport.get_exception() if error: transport.close() raise error def _connect_socket(self): """Return a socket for use in instantiating a paramiko transport. Does not have to be a literal socket, it can be anything that the paramiko.Transport constructor accepts.""" return self.hostname, self.port def _connect_transport(self, pkey): for _ in xrange(self.CONNECT_TIMEOUT_RETRIES): transport = paramiko.Transport(self._connect_socket()) completed = threading.Event() transport.start_client(completed) completed.wait(self.CONNECT_TIMEOUT_SECONDS) if completed.isSet(): self._check_transport_error(transport) completed.clear() transport.auth_publickey(self.user, pkey, completed) completed.wait(self.CONNECT_TIMEOUT_SECONDS) if completed.isSet(): self._check_transport_error(transport) if not transport.is_authenticated(): transport.close() raise paramiko.AuthenticationException() return transport logging.warning("SSH negotiation (%s:%d) timed out, retrying", self.hostname, self.port) # HACK: we can't count on transport.join not hanging now, either transport.join = lambda: None transport.close() logging.error("SSH negotation (%s:%d) has timed out %s times, " "giving up", self.hostname, self.port, self.CONNECT_TIMEOUT_RETRIES) raise error.AutoservSSHTimeout("SSH negotiation timed out") def _init_transport(self): for path, key in self.keys.iteritems(): try: logging.debug("Connecting with %s", path) transport = self._connect_transport(key) transport.set_keepalive(self.KEEPALIVE_TIMEOUT_SECONDS) self.transport = transport self.pid = os.getpid() return except paramiko.AuthenticationException: logging.debug("Authentication failure") else: raise error.AutoservSshPermissionDeniedError( "Permission denied using all keys available to ParamikoHost", utils.CmdResult()) def _open_channel(self, timeout): start_time = time.time() if os.getpid() != self.pid: if self.pid is not None: # HACK: paramiko tries to join() on its worker thread # and this just hangs on linux after a fork() self.transport.join = lambda: None self.transport.atfork() join_hook = lambda cmd: self._close_transport() subcommand.subcommand.register_join_hook(join_hook) logging.debug("Reopening SSH connection after a process fork") self._init_transport() channel = None try: channel = self.transport.open_session() except (socket.error, paramiko.SSHException, EOFError), e: logging.warning("Exception occured while opening session: %s", e) if time.time() - start_time >= timeout: raise error.AutoservSSHTimeout("ssh failed: %s" % e) if not channel: # we couldn't get a channel; re-initing transport should fix that try: self.transport.close() except Exception, e: logging.debug("paramiko.Transport.close failed with %s", e) self._init_transport() return self.transport.open_session() else: return channel def _close_transport(self): if os.getpid() == self.pid: self.transport.close() def close(self): super(ParamikoHost, self).close() self._close_transport() @classmethod def _exhaust_stream(cls, tee, output_list, recvfunc): while True: try: output_list.append(recvfunc(cls.BUFFSIZE)) except socket.timeout: return tee.write(output_list[-1]) if not output_list[-1]: return @classmethod def __send_stdin(cls, channel, stdin): if not stdin or not channel.send_ready(): # nothing more to send or just no space to send now return sent = channel.send(stdin[:cls.BUFFSIZE]) if not sent: logging.warning('Could not send a single stdin byte.') else: stdin = stdin[sent:] if not stdin: # no more stdin input, close output direction channel.shutdown_write() return stdin def run(self, command, timeout=3600, ignore_status=False, stdout_tee=utils.TEE_TO_LOGS, stderr_tee=utils.TEE_TO_LOGS, connect_timeout=30, stdin=None, verbose=True, args=(), ignore_timeout=False): """ Run a command on the remote host. @see common_lib.hosts.host.run() @param connect_timeout: connection timeout (in seconds) @param options: string with additional ssh command options @param verbose: log the commands @param ignore_timeout: bool True command timeouts should be ignored. Will return None on command timeout. @raises AutoservRunError: if the command failed @raises AutoservSSHTimeout: ssh connection has timed out """ stdout = utils.get_stream_tee_file( stdout_tee, utils.DEFAULT_STDOUT_LEVEL, prefix=utils.STDOUT_PREFIX) stderr = utils.get_stream_tee_file( stderr_tee, utils.get_stderr_level(ignore_status), prefix=utils.STDERR_PREFIX) for arg in args: command += ' "%s"' % utils.sh_escape(arg) if verbose: logging.debug("Running (ssh-paramiko) '%s'", command) # start up the command start_time = time.time() try: channel = self._open_channel(timeout) channel.exec_command(command) except (socket.error, paramiko.SSHException, EOFError), e: # This has to match the string from paramiko *exactly*. if str(e) != 'Channel closed.': raise error.AutoservSSHTimeout("ssh failed: %s" % e) # pull in all the stdout, stderr until the command terminates raw_stdout, raw_stderr = [], [] timed_out = False while not channel.exit_status_ready(): if channel.recv_ready(): raw_stdout.append(channel.recv(self.BUFFSIZE)) stdout.write(raw_stdout[-1]) if channel.recv_stderr_ready(): raw_stderr.append(channel.recv_stderr(self.BUFFSIZE)) stderr.write(raw_stderr[-1]) if timeout and time.time() - start_time > timeout: timed_out = True break stdin = self.__send_stdin(channel, stdin) time.sleep(1) if timed_out: exit_status = -signal.SIGTERM else: exit_status = channel.recv_exit_status() channel.settimeout(10) self._exhaust_stream(stdout, raw_stdout, channel.recv) self._exhaust_stream(stderr, raw_stderr, channel.recv_stderr) channel.close() duration = time.time() - start_time # create the appropriate results stdout = "".join(raw_stdout) stderr = "".join(raw_stderr) result = utils.CmdResult(command, stdout, stderr, exit_status, duration) if exit_status == -signal.SIGHUP: msg = "ssh connection unexpectedly terminated" raise error.AutoservRunError(msg, result) if timed_out: logging.warning('Paramiko command timed out after %s sec: %s', timeout, command) if not ignore_timeout: raise error.AutoservRunError("command timed out", result) if not ignore_status and exit_status: raise error.AutoservRunError(command, result) return result