普通文本  |  744行  |  27.05 KB

#!/usr/bin/env python
#
# Copyright 2012, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
#     * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#     * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
#     * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


"""WebSocket client utility for testing mux extension.

This code should be independent from mod_pywebsocket. See the comment of
client_for_testing.py.

NOTE: This code is far from robust like client_for_testing.py.
"""



import Queue
import base64
import collections
import email
import email.parser
import logging
import math
import os
import random
import socket
import struct
import threading

from mod_pywebsocket import util

from test import client_for_testing


_CONTROL_CHANNEL_ID = 0
_DEFAULT_CHANNEL_ID = 1

_MUX_OPCODE_ADD_CHANNEL_REQUEST = 0
_MUX_OPCODE_ADD_CHANNEL_RESPONSE = 1
_MUX_OPCODE_FLOW_CONTROL = 2
_MUX_OPCODE_DROP_CHANNEL = 3
_MUX_OPCODE_NEW_CHANNEL_SLOT = 4


class _ControlBlock:
    def __init__(self, opcode):
        self.opcode = opcode


def _parse_handshake_response(response):
    status_line, header_lines = response.split('\r\n', 1)

    words = status_line.split(' ')
    if len(words) < 3:
        raise ValueError('Bad Status-Line syntax %r' % status_line)
    [version, response_code] = words[:2]
    if version != 'HTTP/1.1':
        raise ValueError('Bad response version %r' % version)

    if response_code != '101':
        raise ValueError('Bad response code %r ' % response_code)
    headers = email.parser.Parser().parsestr(header_lines)
    return headers


def _parse_channel_id(data, offset=0):
    length = len(data)
    remaining = length - offset

    if remaining <= 0:
        raise Exception('No channel id found')

    channel_id = ord(data[offset])
    channel_id_length = 1
    if channel_id & 0xe0 == 0xe0:
        if remaining < 4:
            raise Exception('Invalid channel id format')
        channel_id = struct.unpack('!L',
                                   data[offset:offset+4])[0] & 0x1fffffff
        channel_id_length = 4
    elif channel_id & 0xc0 == 0xc0:
        if remaining < 3:
            raise Exception('Invalid channel id format')
        channel_id = (((channel_id & 0x1f) << 16) +
                      struct.unpack('!H', data[offset+1:offset+3])[0])
        channel_id_length = 3
    elif channel_id & 0x80 == 0x80:
        if remaining < 2:
            raise Exception('Invalid channel id format')
        channel_id = struct.unpack('!H', data[offset:offset+2])[0] & 0x3fff
        channel_id_length = 2

    return channel_id, channel_id_length


def _read_number(data, size_of_size, offset=0):
    if size_of_size == 1:
        return ord(data[offset])
    elif size_of_size == 2:
        return struct.unpack('!H', data[offset:offset+2])[0]
    elif size_of_size == 3:
        return ((ord(data[offset]) << 16)
                + struct.unpack('!H', data[offset+1:offset+3])[0])
    elif size_of_size == 4:
        return struct.unpack('!L', data[offset:offset+4])[0]
    else:
        raise Exception('Invalid "size of size" in control block')


def _parse_control_block_specific_data(data, size_of_size, offset=0):
    remaining = len(data) - offset
    if remaining < size_of_size:
        raise Exception('Invalid control block received')

    size = _read_number(data, size_of_size, offset)

    start_position = offset + size_of_size
    end_position = start_position + size
    if len(data) < end_position:
        raise Exception('Invalid size of control block (%d < %d)' % (
                len(data), end_position))
    return data[start_position:end_position], size_of_size + size


def _parse_control_blocks(data):
    blocks = []
    length = len(data)
    pos = 0

    while pos < length:
        first_byte = ord(data[pos])
        pos += 1
        opcode = (first_byte >> 5) & 0x7
        block = _ControlBlock(opcode)

        # TODO(bashi): Support more opcode
        if opcode == _MUX_OPCODE_ADD_CHANNEL_RESPONSE:
            block.encode = (first_byte >> 2) & 3
            block.rejected = (first_byte >> 4) & 1

            channel_id, advance = _parse_channel_id(data, pos)
            block.channel_id = channel_id
            pos += advance

            size_of_size = (first_byte & 3) + 1
            encoded_handshake, advance = _parse_control_block_specific_data(
                data, size_of_size, pos)
            block.encoded_handshake = encoded_handshake
            pos += advance
            blocks.append(block)
        elif opcode == _MUX_OPCODE_DROP_CHANNEL:
            block.mux_error = (first_byte >> 4) & 1

            channel_id, channel_id_length = _parse_channel_id(data, pos)
            block.channel_id = channel_id
            pos += channel_id_length

            size_of_size = first_byte & 3
            reason, size = _parse_control_block_specific_data(
                data, size_of_size, pos)
            block.reason = reason
            pos += size
            blocks.append(block)
        elif opcode == _MUX_OPCODE_FLOW_CONTROL:
            channel_id, advance = _parse_channel_id(data, pos)
            block.channel_id = channel_id
            pos += advance
            size_of_quota = (first_byte & 3) + 1
            block.send_quota = _read_number(data, size_of_quota, pos)
            pos += size_of_quota
            blocks.append(block)
        elif opcode == _MUX_OPCODE_NEW_CHANNEL_SLOT:
            size_of_slots = ((first_byte >> 2) & 3) + 1
            size_of_quota = (first_byte & 3) + 1
            block.slots = _read_number(data, size_of_slots, pos)
            pos += size_of_slots
            block.send_quota = _read_number(data, size_of_quota, pos)
            pos += size_of_quota
            blocks.append(block)
        else:
            raise Exception(
                'Unsupported mux opcode %d received' % opcode)

    return blocks


def _encode_channel_id(channel_id):
    if channel_id < 0:
        raise ValueError('Channel id %d must not be negative' % channel_id)

    if channel_id < 2 ** 7:
        return chr(channel_id)
    if channel_id < 2 ** 14:
        return struct.pack('!H', 0x8000 + channel_id)
    if channel_id < 2 ** 21:
        first = chr(0xc0 + (channel_id >> 16))
        return first + struct.pack('!H', channel_id & 0xffff)
    if channel_id < 2 ** 29:
        return struct.pack('!L', 0xe0000000 + channel_id)

    raise ValueError('Channel id %d is too large' % channel_id)


def _size_of_number_in_bytes_minus_1(number):
    # Calculate the minimum number of bytes minus 1 that are required to store
    # the data.
    if number < 0:
        raise ValueError('Invalid number: %d' % number)
    elif number < 2 ** 8:
        return 0
    elif number < 2 ** 16:
        return 1
    elif number < 2 ** 24:
        return 2
    elif number < 2 ** 32:
        return 3
    else:
        raise ValueError('Invalid number %d' % number)


def _encode_number(number):
    if number < 2 ** 8:
        return chr(number)
    elif number < 2 ** 16:
        return struct.pack('!H', number)
    elif number < 2 ** 24:
        return chr(number >> 16) + struct.pack('!H', number & 0xffff)
    else:
        return struct.pack('!L', number)


def _create_add_channel_request(channel_id, encoded_handshake,
                                encoding=0):
    length = len(encoded_handshake)
    size_of_length = _size_of_number_in_bytes_minus_1(length)

    first_byte = ((_MUX_OPCODE_ADD_CHANNEL_REQUEST << 5) | (encoding << 2) |
                  size_of_length)
    encoded_length = _encode_number(length)

    return (chr(first_byte) + _encode_channel_id(channel_id) +
            encoded_length + encoded_handshake)


def _create_flow_control(channel_id, replenished_quota):
    size_of_quota = _size_of_number_in_bytes_minus_1(replenished_quota)
    first_byte = ((_MUX_OPCODE_FLOW_CONTROL << 5) | size_of_quota)
    return (chr(first_byte) + _encode_channel_id(channel_id) +
            _encode_number(replenished_quota))


class _MuxReaderThread(threading.Thread):
    """Mux reader thread.

    Reads frames and passes them to the mux client. This thread accesses
    private functions/variables of the mux client.
    """

    def __init__(self, mux):
        threading.Thread.__init__(self)
        self.setDaemon(True)
        self._mux = mux
        self._stop_requested = False

    def _receive_message(self):
        first_opcode = None
        pending_payload = []
        while not self._stop_requested:
            fin, rsv1, rsv2, rsv3, opcode, payload_length = (
                client_for_testing.read_frame_header(self._mux._socket))

            if not first_opcode:
                if opcode == client_for_testing.OPCODE_TEXT:
                    raise Exception('Received a text message on physical '
                                    'connection')
                if opcode == client_for_testing.OPCODE_CONTINUATION:
                    raise Exception('Received an intermediate frame but '
                                    'fragmentation was not started')
                if (opcode == client_for_testing.OPCODE_BINARY or
                    opcode == client_for_testing.OPCODE_PONG or
                    opcode == client_for_testing.OPCODE_PONG or
                    opcode == client_for_testing.OPCODE_CLOSE):
                    first_opcode = opcode
                else:
                    raise Exception('Received an undefined opcode frame: %d' %
                                    opcode)

            elif opcode != client_for_testing.OPCODE_CONTINUATION:
                raise Exception('Received a new opcode before '
                                'terminating fragmentation')

            payload = client_for_testing.receive_bytes(
                self._mux._socket, payload_length)

            if self._mux._incoming_frame_filter is not None:
                payload = self._mux._incoming_frame_filter.filter(payload)

            pending_payload.append(payload)

            if fin:
                break

        if self._stop_requested:
            return None, None

        message = ''.join(pending_payload)
        return first_opcode, message

    def request_stop(self):
        self._stop_requested = True

    def run(self):
        try:
            while not self._stop_requested:
                # opcode is OPCODE_BINARY or control opcodes when a message
                # is succesfully received.
                opcode, message = self._receive_message()
                if not opcode:
                    return
                if opcode == client_for_testing.OPCODE_BINARY:
                    channel_id, advance = _parse_channel_id(message)
                    self._mux._dispatch_frame(channel_id, message[advance:])
                else:
                    self._mux._process_control_message(opcode, message)
        finally:
            self._mux._notify_reader_done()


class _InnerFrame(object):
    def __init__(self, fin, rsv1, rsv2, rsv3, opcode, payload):
        self.fin = fin
        self.rsv1 = rsv1
        self.rsv2 = rsv2
        self.rsv3 = rsv3
        self.opcode = opcode
        self.payload = payload


class _LogicalChannelData(object):
    def __init__(self):
        self.queue = Queue.Queue()
        self.send_quota = 0
        self.receive_quota = 0


class MuxClient(object):
    """WebSocket mux client.

    Note that this class is NOT thread-safe. Do not access an instance of this
    class from multiple threads at a same time.
    """

    def __init__(self, options):
        self._logger = util.get_class_logger(self)

        self._options = options
        self._options.enable_mux()
        self._stream = None
        self._socket = None
        self._handshake = client_for_testing.WebSocketHandshake(self._options)
        self._incoming_frame_filter = None
        self._outgoing_frame_filter = None

        self._is_active = False
        self._read_thread = None
        self._control_blocks_condition = threading.Condition()
        self._control_blocks = []
        self._channel_slots = collections.deque()
        self._logical_channels_condition = threading.Condition();
        self._logical_channels = {}
        self._timeout = 2
        self._physical_connection_close_event = None
        self._physical_connection_close_message = None

    def _parse_inner_frame(self, data):
        if len(data) == 0:
            raise Exception('Invalid encapsulated frame received')

        first_byte = ord(data[0])
        fin = (first_byte << 7) & 1
        rsv1 = (first_byte << 6) & 1
        rsv2 = (first_byte << 5) & 1
        rsv3 = (first_byte << 4) & 1
        opcode = first_byte & 0xf

        if self._outgoing_frame_filter:
            payload = self._outgoing_frame_filter.filter(
                data[1:])
        else:
            payload = data[1:]

        return _InnerFrame(fin, rsv1, rsv2, rsv3, opcode, payload)

    def _process_mux_control_blocks(self):
        for block in self._control_blocks:
            if block.opcode == _MUX_OPCODE_ADD_CHANNEL_RESPONSE:
                # AddChannelResponse will be handled in add_channel().
                continue
            elif block.opcode == _MUX_OPCODE_FLOW_CONTROL:
                try:
                    self._logical_channels_condition.acquire()
                    if not block.channel_id in self._logical_channels:
                        raise Exception('Invalid flow control received for '
                                        'channel id %d' % block.channel_id)
                    self._logical_channels[block.channel_id].send_quota += (
                        block.send_quota)
                    self._logical_channels_condition.notify()
                finally:
                    self._logical_channels_condition.release()
            elif block.opcode == _MUX_OPCODE_NEW_CHANNEL_SLOT:
                self._channel_slots.extend([block.send_quota] * block.slots)

    def _dispatch_frame(self, channel_id, payload):
        if channel_id == _CONTROL_CHANNEL_ID:
            try:
                self._control_blocks_condition.acquire()
                self._control_blocks += _parse_control_blocks(payload)
                self._process_mux_control_blocks()
                self._control_blocks_condition.notify()
            finally:
                self._control_blocks_condition.release()
        else:
            try:
                self._logical_channels_condition.acquire()
                if not channel_id in self._logical_channels:
                    raise Exception('Received logical frame on channel id '
                                    '%d, which is not established' %
                                    channel_id)

                inner_frame = self._parse_inner_frame(payload)
                self._logical_channels[channel_id].receive_quota -= (
                        len(inner_frame.payload))
                if self._logical_channels[channel_id].receive_quota < 0:
                    raise Exception('The server violates quota on '
                                    'channel id %d' % channel_id)
            finally:
                self._logical_channels_condition.release()
            self._logical_channels[channel_id].queue.put(inner_frame)

    def _process_control_message(self, opcode, message):
        # Ping/Pong are not supported.
        if opcode == client_for_testing.OPCODE_CLOSE:
            self._physical_connection_close_message = message
            if self._is_active:
                self._stream.send_close(
                    code=client_for_testing.STATUS_NORMAL_CLOSURE, reason='')
                self._read_thread.request_stop()

            if self._physical_connection_close_event:
                self._physical_connection_close_event.set()

    def _notify_reader_done(self):
        self._logger.debug('Read thread terminated.')
        self.close_socket()

    def _assert_channel_slot_available(self):
        try:
            self._control_blocks_condition.acquire()
            if len(self._channel_slots) == 0:
                # Wait once
                self._control_blocks_condition.wait(timeout=self._timeout)
        finally:
            self._control_blocks_condition.release()

        if len(self._channel_slots) == 0:
            raise Exception('Failed to receive NewChannelSlot')

    def _assert_send_quota_available(self, channel_id):
        try:
            self._logical_channels_condition.acquire()
            if self._logical_channels[channel_id].send_quota == 0:
                # Wait once
                self._logical_channels_condition.wait(timeout=self._timeout)
        finally:
            self._logical_channels_condition.release()

        if self._logical_channels[channel_id].send_quota == 0:
            raise Exception('Failed to receive FlowControl for channel id %d' %
                            channel_id)

    def connect(self):
        self._socket = socket.socket()
        self._socket.settimeout(self._options.socket_timeout)

        self._socket.connect((self._options.server_host,
                              self._options.server_port))
        if self._options.use_tls:
            self._socket = _TLSSocket(self._socket)

        self._handshake.handshake(self._socket)
        self._stream = client_for_testing.WebSocketStream(
            self._socket, self._handshake)

        self._logical_channels[_DEFAULT_CHANNEL_ID] = _LogicalChannelData()

        self._read_thread = _MuxReaderThread(self)
        self._read_thread.start()

        self._assert_channel_slot_available()
        self._assert_send_quota_available(_DEFAULT_CHANNEL_ID)

        self._is_active = True
        self._logger.info('Connection established')

    def add_channel(self, channel_id, options):
        if not self._is_active:
            raise Exception('Mux client is not active')

        if channel_id in self._logical_channels:
            raise Exception('Channel id %d already exists' % channel_id)

        try:
            send_quota = self._channel_slots.popleft()
        except IndexError, e:
            raise Exception('No channel slots: %r' % e)

        # Create AddChannel request
        request_line = 'GET %s HTTP/1.1\r\n' % options.resource
        fields = []
        fields.append('Upgrade: websocket\r\n')
        fields.append('Connection: Upgrade\r\n')
        if options.server_port == client_for_testing.DEFAULT_PORT:
            fields.append('Host: %s\r\n' % options.server_host.lower())
        else:
            fields.append('Host: %s:%d\r\n' % (options.server_host.lower(),
                                               options.server_port))
        fields.append('Origin: %s\r\n' % options.origin.lower())

        original_key = os.urandom(16)
        key = base64.b64encode(original_key)
        fields.append('Sec-WebSocket-Key: %s\r\n' % key)

        fields.append('Sec-WebSocket-Version: 13\r\n')

        if len(options.extensions) > 0:
            fields.append('Sec-WebSocket-Extensions: %s\r\n' %
                          ', '.join(options.extensions))

        handshake = request_line + ''.join(fields) + '\r\n'
        add_channel_request = _create_add_channel_request(
            channel_id, handshake)
        payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + add_channel_request
        self._stream.send_binary(payload)

        # Wait AddChannelResponse
        self._logger.debug('Waiting AddChannelResponse for the request...')
        response = None
        try:
            self._control_blocks_condition.acquire()
            while True:
                for block in self._control_blocks:
                    if block.opcode != _MUX_OPCODE_ADD_CHANNEL_RESPONSE:
                        continue
                    if block.channel_id == channel_id:
                        response = block
                        self._control_blocks.remove(response)
                        break
                if response:
                    break
                self._control_blocks_condition.wait(self._timeout)
                if not self._is_active:
                    raise Exception('AddChannelRequest timed out')
        finally:
            self._control_blocks_condition.release()

        # Validate AddChannelResponse
        if response.rejected:
            raise Exception('The server rejected AddChannelRequest')

        fields = _parse_handshake_response(response.encoded_handshake)

        if not 'upgrade' in fields:
            raise Exception('No Upgrade header')
        if fields['upgrade'] != 'websocket':
            raise Exception('Wrong Upgrade header')
        if not 'connection' in fields:
            raise Exception('No Connection header')
        if fields['connection'] != 'Upgrade':
            raise Exception('Wrong Connection header')
        if not 'sec-websocket-accept' in fields:
            raise Exception('No Sec-WebSocket-Accept header')

        accept = fields['sec-websocket-accept']
        try:
            decoded_accept = base64.b64decode(accept)
        except TypeError, e:
            raise Exception(
                'Illegal value for header Sec-WebSocket-Accept: ' + accept)

        if len(decoded_accept) != 20:
            raise Exception(
                'Decoded value of Sec-WebSocket-Accept is not 20-byte long')

        original_expected_accept = util.sha1_hash(
            key + client_for_testing.WEBSOCKET_ACCEPT_UUID).digest()
        expected_accept = base64.b64encode(original_expected_accept)

        if accept != expected_accept:
            raise Exception(
                'Invalid Sec-WebSocket-Accept header: %r (expected) != %r '
                '(actual)' % (accept, expected_accept))

        self._logical_channels_condition.acquire()
        self._logical_channels[channel_id] = _LogicalChannelData()
        self._logical_channels[channel_id].send_quota = send_quota
        self._logical_channels_condition.release()

        self._logger.debug('Logical channel %d established' % channel_id)

    def _check_logical_channel_is_opened(self, channel_id):
        if not self._is_active:
            raise Exception('Mux client is not active')

        if not channel_id in self._logical_channels:
            raise Exception('Logical channel %d is not established.')

    def drop_channel(self, channel_id):
        # TODO(bashi): Implement
        pass

    def send_flow_control(self, channel_id, replenished_quota):
        self._check_logical_channel_is_opened(channel_id)
        flow_control = _create_flow_control(channel_id, replenished_quota)
        payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + flow_control
        # Replenish receive quota
        try:
            self._logical_channels_condition.acquire()
            self._logical_channels[channel_id].receive_quota += (
                replenished_quota)
        finally:
            self._logical_channels_condition.release()
        self._stream.send_binary(payload)

    def send_message(self, channel_id, message, end=True, binary=False):
        self._check_logical_channel_is_opened(channel_id)

        if binary:
            first_byte = (end << 7) | client_for_testing.OPCODE_BINARY
        else:
            first_byte = (end << 7) | client_for_testing.OPCODE_TEXT
            message = message.encode('utf-8')

        try:
            self._logical_channels_condition.acquire()
            if self._logical_channels[channel_id].send_quota < len(message):
                raise Exception('Send quota violation: %d < %d' % (
                        self._logical_channels[channel_id].send_quota,
                        len(message)))

            self._logical_channels[channel_id].send_quota -= len(message)
        finally:
            self._logical_channels_condition.release()
        payload = _encode_channel_id(channel_id) + chr(first_byte) + message
        self._stream.send_binary(payload)

    def assert_receive(self, channel_id, payload, binary=False):
        self._check_logical_channel_is_opened(channel_id)

        try:
            inner_frame = self._logical_channels[channel_id].queue.get(
                timeout=self._timeout)
        except Queue.Empty, e:
            raise Exception('Cannot receive message from channel id %d' %
                            channel_id)

        if binary:
            opcode = client_for_testing.OPCODE_BINARY
        else:
            opcode = client_for_testing.OPCODE_TEXT

        if inner_frame.opcode != opcode:
            raise Exception('Unexpected opcode received (%r != %r)' %
                            (expected_opcode, inner_frame.opcode))

        if inner_frame.payload != payload:
            raise Exception('Unexpected payload received')

    def send_close(self, channel_id, code=None, reason=''):
        self._check_logical_channel_is_opened(channel_id)

        if code is not None:
            body = struct.pack('!H', code) + reason.encode('utf-8')
        else:
            body = ''

        first_byte = (1 << 7) | client_for_testing.OPCODE_CLOSE
        payload = _encode_channel_id(channel_id) + chr(first_byte) + body
        self._stream.send_binary(payload)

    def assert_receive_close(self, channel_id):
        self._check_logical_channel_is_opened(channel_id)

        try:
            inner_frame = self._logical_channels[channel_id].queue.get(
                timeout=self._timeout)
        except Queue.Empty, e:
            raise Exception('Cannot receive message from channel id %d' %
                            channel_id)
        if inner_frame.opcode != client_for_testing.OPCODE_CLOSE:
            raise Exception('Didn\'t receive close frame')

    def send_physical_connection_close(self, code=None, reason=''):
        self._physical_connection_close_event = threading.Event()
        self._stream.send_close(code, reason)

    # This method can be used only after calling
    # send_physical_connection_close().
    def assert_physical_connection_receive_close(
        self, code=client_for_testing.STATUS_NORMAL_CLOSURE, reason=''):
        self._physical_connection_close_event.wait(timeout=self._timeout)
        if (not self._physical_connection_close_event.isSet() or
            not self._physical_connection_close_message):
            raise Exception('Didn\'t receive closing handshake')

    def close_socket(self):
        self._is_active = False
        self._socket.close()