普通文本  |  520行  |  18.56 KB

#!/usr/bin/python
"""
Client for file transfer services offered by RSS (Remote Shell Server).

@author: Michael Goldish (mgoldish@redhat.com)
@copyright: 2008-2010 Red Hat Inc.
"""

import socket, struct, time, sys, os, glob

# Globals
CHUNKSIZE = 65536

# Protocol message constants
RSS_MAGIC           = 0x525353
RSS_OK              = 1
RSS_ERROR           = 2
RSS_UPLOAD          = 3
RSS_DOWNLOAD        = 4
RSS_SET_PATH        = 5
RSS_CREATE_FILE     = 6
RSS_CREATE_DIR      = 7
RSS_LEAVE_DIR       = 8
RSS_DONE            = 9

# See rss.cpp for protocol details.


class FileTransferError(Exception):
    def __init__(self, msg, e=None, filename=None):
        Exception.__init__(self, msg, e, filename)
        self.msg = msg
        self.e = e
        self.filename = filename

    def __str__(self):
        s = self.msg
        if self.e and self.filename:
            s += "    (error: %s,    filename: %s)" % (self.e, self.filename)
        elif self.e:
            s += "    (%s)" % self.e
        elif self.filename:
            s += "    (filename: %s)" % self.filename
        return s


class FileTransferConnectError(FileTransferError):
    pass


class FileTransferTimeoutError(FileTransferError):
    pass


class FileTransferProtocolError(FileTransferError):
    pass


class FileTransferSocketError(FileTransferError):
    pass


class FileTransferServerError(FileTransferError):
    def __init__(self, errmsg):
        FileTransferError.__init__(self, None, errmsg)

    def __str__(self):
        s = "Server said: %r" % self.e
        if self.filename:
            s += "    (filename: %s)" % self.filename
        return s


class FileTransferNotFoundError(FileTransferError):
    pass


class FileTransferClient(object):
    """
    Connect to a RSS (remote shell server) and transfer files.
    """

    def __init__(self, address, port, log_func=None, timeout=20):
        """
        Connect to a server.

        @param address: The server's address
        @param port: The server's port
        @param log_func: If provided, transfer stats will be passed to this
                function during the transfer
        @param timeout: Time duration to wait for connection to succeed
        @raise FileTransferConnectError: Raised if the connection fails
        """
        self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._socket.settimeout(timeout)
        try:
            self._socket.connect((address, port))
        except socket.error, e:
            raise FileTransferConnectError("Cannot connect to server at "
                                           "%s:%s" % (address, port), e)
        try:
            if self._receive_msg(timeout) != RSS_MAGIC:
                raise FileTransferConnectError("Received wrong magic number")
        except FileTransferTimeoutError:
            raise FileTransferConnectError("Timeout expired while waiting to "
                                           "receive magic number")
        self._send(struct.pack("=i", CHUNKSIZE))
        self._log_func = log_func
        self._last_time = time.time()
        self._last_transferred = 0
        self.transferred = 0


    def __del__(self):
        self.close()


    def close(self):
        """
        Close the connection.
        """
        self._socket.close()


    def _send(self, str, timeout=60):
        try:
            if timeout <= 0:
                raise socket.timeout
            self._socket.settimeout(timeout)
            self._socket.sendall(str)
        except socket.timeout:
            raise FileTransferTimeoutError("Timeout expired while sending "
                                           "data to server")
        except socket.error, e:
            raise FileTransferSocketError("Could not send data to server", e)


    def _receive(self, size, timeout=60):
        strs = []
        end_time = time.time() + timeout
        try:
            while size > 0:
                timeout = end_time - time.time()
                if timeout <= 0:
                    raise socket.timeout
                self._socket.settimeout(timeout)
                data = self._socket.recv(size)
                if not data:
                    raise FileTransferProtocolError("Connection closed "
                                                    "unexpectedly while "
                                                    "receiving data from "
                                                    "server")
                strs.append(data)
                size -= len(data)
        except socket.timeout:
            raise FileTransferTimeoutError("Timeout expired while receiving "
                                           "data from server")
        except socket.error, e:
            raise FileTransferSocketError("Error receiving data from server",
                                          e)
        return "".join(strs)


    def _report_stats(self, str):
        if self._log_func:
            dt = time.time() - self._last_time
            if dt >= 1:
                transferred = self.transferred / 1048576.
                speed = (self.transferred - self._last_transferred) / dt
                speed /= 1048576.
                self._log_func("%s %.3f MB (%.3f MB/sec)" %
                               (str, transferred, speed))
                self._last_time = time.time()
                self._last_transferred = self.transferred


    def _send_packet(self, str, timeout=60):
        self._send(struct.pack("=I", len(str)))
        self._send(str, timeout)
        self.transferred += len(str) + 4
        self._report_stats("Sent")


    def _receive_packet(self, timeout=60):
        size = struct.unpack("=I", self._receive(4))[0]
        str = self._receive(size, timeout)
        self.transferred += len(str) + 4
        self._report_stats("Received")
        return str


    def _send_file_chunks(self, filename, timeout=60):
        if self._log_func:
            self._log_func("Sending file %s" % filename)
        f = open(filename, "rb")
        try:
            try:
                end_time = time.time() + timeout
                while True:
                    data = f.read(CHUNKSIZE)
                    self._send_packet(data, end_time - time.time())
                    if len(data) < CHUNKSIZE:
                        break
            except FileTransferError, e:
                e.filename = filename
                raise
        finally:
            f.close()


    def _receive_file_chunks(self, filename, timeout=60):
        if self._log_func:
            self._log_func("Receiving file %s" % filename)
        f = open(filename, "wb")
        try:
            try:
                end_time = time.time() + timeout
                while True:
                    data = self._receive_packet(end_time - time.time())
                    f.write(data)
                    if len(data) < CHUNKSIZE:
                        break
            except FileTransferError, e:
                e.filename = filename
                raise
        finally:
            f.close()


    def _send_msg(self, msg, timeout=60):
        self._send(struct.pack("=I", msg))


    def _receive_msg(self, timeout=60):
        s = self._receive(4, timeout)
        return struct.unpack("=I", s)[0]


    def _handle_transfer_error(self):
        # Save original exception
        e = sys.exc_info()
        try:
            # See if we can get an error message
            msg = self._receive_msg()
        except FileTransferError:
            # No error message -- re-raise original exception
            raise e[0], e[1], e[2]
        if msg == RSS_ERROR:
            errmsg = self._receive_packet()
            raise FileTransferServerError(errmsg)
        raise e[0], e[1], e[2]


class FileUploadClient(FileTransferClient):
    """
    Connect to a RSS (remote shell server) and upload files or directory trees.
    """

    def __init__(self, address, port, log_func=None, timeout=20):
        """
        Connect to a server.

        @param address: The server's address
        @param port: The server's port
        @param log_func: If provided, transfer stats will be passed to this
                function during the transfer
        @param timeout: Time duration to wait for connection to succeed
        @raise FileTransferConnectError: Raised if the connection fails
        @raise FileTransferProtocolError: Raised if an incorrect magic number
                is received
        @raise FileTransferSocketError: Raised if the RSS_UPLOAD message cannot
                be sent to the server
        """
        super(FileUploadClient, self).__init__(address, port, log_func, timeout)
        self._send_msg(RSS_UPLOAD)


    def _upload_file(self, path, end_time):
        if os.path.isfile(path):
            self._send_msg(RSS_CREATE_FILE)
            self._send_packet(os.path.basename(path))
            self._send_file_chunks(path, end_time - time.time())
        elif os.path.isdir(path):
            self._send_msg(RSS_CREATE_DIR)
            self._send_packet(os.path.basename(path))
            for filename in os.listdir(path):
                self._upload_file(os.path.join(path, filename), end_time)
            self._send_msg(RSS_LEAVE_DIR)


    def upload(self, src_pattern, dst_path, timeout=600):
        """
        Send files or directory trees to the server.
        The semantics of src_pattern and dst_path are similar to those of scp.
        For example, the following are OK:
            src_pattern='/tmp/foo.txt', dst_path='C:\\'
                (uploads a single file)
            src_pattern='/usr/', dst_path='C:\\Windows\\'
                (uploads a directory tree recursively)
            src_pattern='/usr/*', dst_path='C:\\Windows\\'
                (uploads all files and directory trees under /usr/)
        The following is not OK:
            src_pattern='/tmp/foo.txt', dst_path='C:\\Windows\\*'
                (wildcards are only allowed in src_pattern)

        @param src_pattern: A path or wildcard pattern specifying the files or
                directories to send to the server
        @param dst_path: A path in the server's filesystem where the files will
                be saved
        @param timeout: Time duration in seconds to wait for the transfer to
                complete
        @raise FileTransferTimeoutError: Raised if timeout expires
        @raise FileTransferServerError: Raised if something goes wrong and the
                server sends an informative error message to the client
        @note: Other exceptions can be raised.
        """
        end_time = time.time() + timeout
        try:
            try:
                self._send_msg(RSS_SET_PATH)
                self._send_packet(dst_path)
                matches = glob.glob(src_pattern)
                for filename in matches:
                    self._upload_file(os.path.abspath(filename), end_time)
                self._send_msg(RSS_DONE)
            except FileTransferTimeoutError:
                raise
            except FileTransferError:
                self._handle_transfer_error()
            else:
                # If nothing was transferred, raise an exception
                if not matches:
                    raise FileTransferNotFoundError("Pattern %s does not "
                                                    "match any files or "
                                                    "directories" %
                                                    src_pattern)
                # Look for RSS_OK or RSS_ERROR
                msg = self._receive_msg(end_time - time.time())
                if msg == RSS_OK:
                    return
                elif msg == RSS_ERROR:
                    errmsg = self._receive_packet()
                    raise FileTransferServerError(errmsg)
                else:
                    # Neither RSS_OK nor RSS_ERROR found
                    raise FileTransferProtocolError("Received unexpected msg")
        except:
            # In any case, if the transfer failed, close the connection
            self.close()
            raise


class FileDownloadClient(FileTransferClient):
    """
    Connect to a RSS (remote shell server) and download files or directory trees.
    """

    def __init__(self, address, port, log_func=None, timeout=20):
        """
        Connect to a server.

        @param address: The server's address
        @param port: The server's port
        @param log_func: If provided, transfer stats will be passed to this
                function during the transfer
        @param timeout: Time duration to wait for connection to succeed
        @raise FileTransferConnectError: Raised if the connection fails
        @raise FileTransferProtocolError: Raised if an incorrect magic number
                is received
        @raise FileTransferSendError: Raised if the RSS_UPLOAD message cannot
                be sent to the server
        """
        super(FileDownloadClient, self).__init__(address, port, log_func, timeout)
        self._send_msg(RSS_DOWNLOAD)


    def download(self, src_pattern, dst_path, timeout=600):
        """
        Receive files or directory trees from the server.
        The semantics of src_pattern and dst_path are similar to those of scp.
        For example, the following are OK:
            src_pattern='C:\\foo.txt', dst_path='/tmp'
                (downloads a single file)
            src_pattern='C:\\Windows', dst_path='/tmp'
                (downloads a directory tree recursively)
            src_pattern='C:\\Windows\\*', dst_path='/tmp'
                (downloads all files and directory trees under C:\\Windows)
        The following is not OK:
            src_pattern='C:\\Windows', dst_path='/tmp/*'
                (wildcards are only allowed in src_pattern)

        @param src_pattern: A path or wildcard pattern specifying the files or
                directories, in the server's filesystem, that will be sent to
                the client
        @param dst_path: A path in the local filesystem where the files will
                be saved
        @param timeout: Time duration in seconds to wait for the transfer to
                complete
        @raise FileTransferTimeoutError: Raised if timeout expires
        @raise FileTransferServerError: Raised if something goes wrong and the
                server sends an informative error message to the client
        @note: Other exceptions can be raised.
        """
        dst_path = os.path.abspath(dst_path)
        end_time = time.time() + timeout
        file_count = 0
        dir_count = 0
        try:
            try:
                self._send_msg(RSS_SET_PATH)
                self._send_packet(src_pattern)
            except FileTransferError:
                self._handle_transfer_error()
            while True:
                msg = self._receive_msg()
                if msg == RSS_CREATE_FILE:
                    # Receive filename and file contents
                    filename = self._receive_packet()
                    if os.path.isdir(dst_path):
                        dst_path = os.path.join(dst_path, filename)
                    self._receive_file_chunks(dst_path, end_time - time.time())
                    dst_path = os.path.dirname(dst_path)
                    file_count += 1
                elif msg == RSS_CREATE_DIR:
                    # Receive dirname and create the directory
                    dirname = self._receive_packet()
                    if os.path.isdir(dst_path):
                        dst_path = os.path.join(dst_path, dirname)
                    if not os.path.isdir(dst_path):
                        os.mkdir(dst_path)
                    dir_count += 1
                elif msg == RSS_LEAVE_DIR:
                    # Return to parent dir
                    dst_path = os.path.dirname(dst_path)
                elif msg == RSS_DONE:
                    # Transfer complete
                    if not file_count and not dir_count:
                        raise FileTransferNotFoundError("Pattern %s does not "
                                                        "match any files or "
                                                        "directories that "
                                                        "could be downloaded" %
                                                        src_pattern)
                    break
                elif msg == RSS_ERROR:
                    # Receive error message and abort
                    errmsg = self._receive_packet()
                    raise FileTransferServerError(errmsg)
                else:
                    # Unexpected msg
                    raise FileTransferProtocolError("Received unexpected msg")
        except:
            # In any case, if the transfer failed, close the connection
            self.close()
            raise


def upload(address, port, src_pattern, dst_path, log_func=None, timeout=60,
           connect_timeout=20):
    """
    Connect to server and upload files.

    @see: FileUploadClient
    """
    client = FileUploadClient(address, port, log_func, connect_timeout)
    client.upload(src_pattern, dst_path, timeout)
    client.close()


def download(address, port, src_pattern, dst_path, log_func=None, timeout=60,
             connect_timeout=20):
    """
    Connect to server and upload files.

    @see: FileDownloadClient
    """
    client = FileDownloadClient(address, port, log_func, connect_timeout)
    client.download(src_pattern, dst_path, timeout)
    client.close()


def main():
    import optparse

    usage = "usage: %prog [options] address port src_pattern dst_path"
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("-d", "--download",
                      action="store_true", dest="download",
                      help="download files from server")
    parser.add_option("-u", "--upload",
                      action="store_true", dest="upload",
                      help="upload files to server")
    parser.add_option("-v", "--verbose",
                      action="store_true", dest="verbose",
                      help="be verbose")
    parser.add_option("-t", "--timeout",
                      type="int", dest="timeout", default=3600,
                      help="transfer timeout")
    options, args = parser.parse_args()
    if options.download == options.upload:
        parser.error("you must specify either -d or -u")
    if len(args) != 4:
        parser.error("incorrect number of arguments")
    address, port, src_pattern, dst_path = args
    port = int(port)

    logger = None
    if options.verbose:
        def p(s):
            print s
        logger = p

    if options.download:
        download(address, port, src_pattern, dst_path, logger, options.timeout)
    elif options.upload:
        upload(address, port, src_pattern, dst_path, logger, options.timeout)


if __name__ == "__main__":
    main()