# pylint: disable-msg=C0111
import os, time, signal, socket, re, fnmatch, logging, threading
import paramiko

from autotest_lib.client.common_lib import utils, error, global_config
from autotest_lib.server import subcommand
from autotest_lib.server.hosts import abstract_ssh


class ParamikoHost(abstract_ssh.AbstractSSHHost):
    KEEPALIVE_TIMEOUT_SECONDS = 30
    CONNECT_TIMEOUT_SECONDS = 30
    CONNECT_TIMEOUT_RETRIES = 3
    BUFFSIZE = 2**16

    def _initialize(self, hostname, *args, **dargs):
        super(ParamikoHost, self)._initialize(hostname=hostname, *args, **dargs)

        # paramiko is very noisy, tone down the logging
        paramiko.util.log_to_file("/dev/null", paramiko.util.ERROR)

        self.keys = self.get_user_keys(hostname)
        self.pid = None


    @staticmethod
    def _load_key(path):
        """Given a path to a private key file, load the appropriate keyfile.

        Tries to load the file as both an RSAKey and a DSAKey. If the file
        cannot be loaded as either type, returns None."""
        try:
            return paramiko.DSSKey.from_private_key_file(path)
        except paramiko.SSHException:
            try:
                return paramiko.RSAKey.from_private_key_file(path)
            except paramiko.SSHException:
                return None


    @staticmethod
    def _parse_config_line(line):
        """Given an ssh config line, return a (key, value) tuple for the
        config value listed in the line, or (None, None)"""
        match = re.match(r"\s*(\w+)\s*=?(.*)\n", line)
        if match:
            return match.groups()
        else:
            return None, None


    @staticmethod
    def get_user_keys(hostname):
        """Returns a mapping of path -> paramiko.PKey entries available for
        this user. Keys are found in the default locations (~/.ssh/id_[d|r]sa)
        as well as any IdentityFile entries in the standard ssh config files.
        """
        raw_identity_files = ["~/.ssh/id_dsa", "~/.ssh/id_rsa"]
        for config_path in ("/etc/ssh/ssh_config", "~/.ssh/config"):
            config_path = os.path.expanduser(config_path)
            if not os.path.exists(config_path):
                continue
            host_pattern = "*"
            config_lines = open(config_path).readlines()
            for line in config_lines:
                key, value = ParamikoHost._parse_config_line(line)
                if key == "Host":
                    host_pattern = value
                elif (key == "IdentityFile"
                      and fnmatch.fnmatch(hostname, host_pattern)):
                    raw_identity_files.append(value)

        # drop any files that use percent-escapes; we don't support them
        identity_files = []
        UNSUPPORTED_ESCAPES = ["%d", "%u", "%l", "%h", "%r"]
        for path in raw_identity_files:
            # skip this path if it uses % escapes
            if sum((escape in path) for escape in UNSUPPORTED_ESCAPES):
                continue
            path = os.path.expanduser(path)
            if os.path.exists(path):
                identity_files.append(path)

        # load up all the keys that we can and return them
        user_keys = {}
        for path in identity_files:
            key = ParamikoHost._load_key(path)
            if key:
                user_keys[path] = key

        # load up all the ssh agent keys
        use_sshagent = global_config.global_config.get_config_value(
            'AUTOSERV', 'use_sshagent_with_paramiko', type=bool)
        if use_sshagent:
            ssh_agent = paramiko.Agent()
            for i, key in enumerate(ssh_agent.get_keys()):
                user_keys['agent-key-%d' % i] = key

        return user_keys


    def _check_transport_error(self, transport):
        error = transport.get_exception()
        if error:
            transport.close()
            raise error


    def _connect_socket(self):
        """Return a socket for use in instantiating a paramiko transport. Does
        not have to be a literal socket, it can be anything that the
        paramiko.Transport constructor accepts."""
        return self.hostname, self.port


    def _connect_transport(self, pkey):
        for _ in xrange(self.CONNECT_TIMEOUT_RETRIES):
            transport = paramiko.Transport(self._connect_socket())
            completed = threading.Event()
            transport.start_client(completed)
            completed.wait(self.CONNECT_TIMEOUT_SECONDS)
            if completed.isSet():
                self._check_transport_error(transport)
                completed.clear()
                transport.auth_publickey(self.user, pkey, completed)
                completed.wait(self.CONNECT_TIMEOUT_SECONDS)
                if completed.isSet():
                    self._check_transport_error(transport)
                    if not transport.is_authenticated():
                        transport.close()
                        raise paramiko.AuthenticationException()
                    return transport
            logging.warning("SSH negotiation (%s:%d) timed out, retrying",
                         self.hostname, self.port)
            # HACK: we can't count on transport.join not hanging now, either
            transport.join = lambda: None
            transport.close()
        logging.error("SSH negotation (%s:%d) has timed out %s times, "
                      "giving up", self.hostname, self.port,
                      self.CONNECT_TIMEOUT_RETRIES)
        raise error.AutoservSSHTimeout("SSH negotiation timed out")


    def _init_transport(self):
        for path, key in self.keys.iteritems():
            try:
                logging.debug("Connecting with %s", path)
                transport = self._connect_transport(key)
                transport.set_keepalive(self.KEEPALIVE_TIMEOUT_SECONDS)
                self.transport = transport
                self.pid = os.getpid()
                return
            except paramiko.AuthenticationException:
                logging.debug("Authentication failure")
        else:
            raise error.AutoservSshPermissionDeniedError(
                "Permission denied using all keys available to ParamikoHost",
                utils.CmdResult())


    def _open_channel(self, timeout):
        start_time = time.time()
        if os.getpid() != self.pid:
            if self.pid is not None:
                # HACK: paramiko tries to join() on its worker thread
                # and this just hangs on linux after a fork()
                self.transport.join = lambda: None
                self.transport.atfork()
                join_hook = lambda cmd: self._close_transport()
                subcommand.subcommand.register_join_hook(join_hook)
                logging.debug("Reopening SSH connection after a process fork")
            self._init_transport()

        channel = None
        try:
            channel = self.transport.open_session()
        except (socket.error, paramiko.SSHException, EOFError), e:
            logging.warning("Exception occured while opening session: %s", e)
            if time.time() - start_time >= timeout:
                raise error.AutoservSSHTimeout("ssh failed: %s" % e)

        if not channel:
            # we couldn't get a channel; re-initing transport should fix that
            try:
                self.transport.close()
            except Exception, e:
                logging.debug("paramiko.Transport.close failed with %s", e)
            self._init_transport()
            return self.transport.open_session()
        else:
            return channel


    def _close_transport(self):
        if os.getpid() == self.pid:
            self.transport.close()


    def close(self):
        super(ParamikoHost, self).close()
        self._close_transport()


    @classmethod
    def _exhaust_stream(cls, tee, output_list, recvfunc):
        while True:
            try:
                output_list.append(recvfunc(cls.BUFFSIZE))
            except socket.timeout:
                return
            tee.write(output_list[-1])
            if not output_list[-1]:
                return


    @classmethod
    def __send_stdin(cls, channel, stdin):
        if not stdin or not channel.send_ready():
            # nothing more to send or just no space to send now
            return

        sent = channel.send(stdin[:cls.BUFFSIZE])
        if not sent:
            logging.warning('Could not send a single stdin byte.')
        else:
            stdin = stdin[sent:]
            if not stdin:
                # no more stdin input, close output direction
                channel.shutdown_write()
        return stdin


    def run(self, command, timeout=3600, ignore_status=False,
            stdout_tee=utils.TEE_TO_LOGS, stderr_tee=utils.TEE_TO_LOGS,
            connect_timeout=30, stdin=None, verbose=True, args=(),
            ignore_timeout=False):
        """
        Run a command on the remote host.
        @see common_lib.hosts.host.run()

        @param connect_timeout: connection timeout (in seconds)
        @param options: string with additional ssh command options
        @param verbose: log the commands
        @param ignore_timeout: bool True command timeouts should be
                               ignored.  Will return None on command timeout.

        @raises AutoservRunError: if the command failed
        @raises AutoservSSHTimeout: ssh connection has timed out
        """

        stdout = utils.get_stream_tee_file(
                stdout_tee, utils.DEFAULT_STDOUT_LEVEL,
                prefix=utils.STDOUT_PREFIX)
        stderr = utils.get_stream_tee_file(
                stderr_tee, utils.get_stderr_level(ignore_status),
                prefix=utils.STDERR_PREFIX)

        for arg in args:
            command += ' "%s"' % utils.sh_escape(arg)

        if verbose:
            logging.debug("Running (ssh-paramiko) '%s'", command)

        # start up the command
        start_time = time.time()
        try:
            channel = self._open_channel(timeout)
            channel.exec_command(command)
        except (socket.error, paramiko.SSHException, EOFError), e:
            # This has to match the string from paramiko *exactly*.
            if str(e) != 'Channel closed.':
                raise error.AutoservSSHTimeout("ssh failed: %s" % e)

        # pull in all the stdout, stderr until the command terminates
        raw_stdout, raw_stderr = [], []
        timed_out = False
        while not channel.exit_status_ready():
            if channel.recv_ready():
                raw_stdout.append(channel.recv(self.BUFFSIZE))
                stdout.write(raw_stdout[-1])
            if channel.recv_stderr_ready():
                raw_stderr.append(channel.recv_stderr(self.BUFFSIZE))
                stderr.write(raw_stderr[-1])
            if timeout and time.time() - start_time > timeout:
                timed_out = True
                break
            stdin = self.__send_stdin(channel, stdin)
            time.sleep(1)

        if timed_out:
            exit_status = -signal.SIGTERM
        else:
            exit_status = channel.recv_exit_status()
        channel.settimeout(10)
        self._exhaust_stream(stdout, raw_stdout, channel.recv)
        self._exhaust_stream(stderr, raw_stderr, channel.recv_stderr)
        channel.close()
        duration = time.time() - start_time

        # create the appropriate results
        stdout = "".join(raw_stdout)
        stderr = "".join(raw_stderr)
        result = utils.CmdResult(command, stdout, stderr, exit_status,
                                 duration)
        if exit_status == -signal.SIGHUP:
            msg = "ssh connection unexpectedly terminated"
            raise error.AutoservRunError(msg, result)
        if timed_out:
            logging.warning('Paramiko command timed out after %s sec: %s', timeout,
                         command)
            if not ignore_timeout:
                raise error.AutoservRunError("command timed out", result)
        if not ignore_status and exit_status:
            raise error.AutoservRunError(command, result)
        return result