#!/usr/bin/env python
#
# Copyright 2012, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
#     * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#     * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
#     * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


"""Tests for msgutil module."""


import array
import Queue
import struct
import unittest
import zlib

import set_sys_path  # Update sys.path to locate mod_pywebsocket module.

from mod_pywebsocket import common
from mod_pywebsocket.extensions import DeflateFrameExtensionProcessor
from mod_pywebsocket.extensions import PerFrameCompressionExtensionProcessor
from mod_pywebsocket.extensions import PerMessageCompressionExtensionProcessor
from mod_pywebsocket import msgutil
from mod_pywebsocket.stream import InvalidUTF8Exception
from mod_pywebsocket.stream import Stream
from mod_pywebsocket.stream import StreamHixie75
from mod_pywebsocket.stream import StreamOptions
from mod_pywebsocket import util
from test import mock


# We use one fixed nonce for testing instead of cryptographically secure PRNG.
_MASKING_NONCE = 'ABCD'


def _mask_hybi(frame):
    frame_key = map(ord, _MASKING_NONCE)
    frame_key_len = len(frame_key)
    result = array.array('B')
    result.fromstring(frame)
    count = 0
    for i in xrange(len(result)):
        result[i] ^= frame_key[count]
        count = (count + 1) % frame_key_len
    return _MASKING_NONCE + result.tostring()


def _install_extension_processor(processor, request, stream_options):
    response = processor.get_extension_response()
    if response is not None:
        processor.setup_stream_options(stream_options)
        request.ws_extension_processors.append(processor)


def _create_request_from_rawdata(
    read_data, deflate_stream=False, deflate_frame_request=None,
    perframe_compression_request=None, permessage_compression_request=None):
    req = mock.MockRequest(connection=mock.MockConn(''.join(read_data)))
    req.ws_version = common.VERSION_HYBI_LATEST
    stream_options = StreamOptions()
    stream_options.deflate_stream = deflate_stream
    req.ws_extension_processors = []
    if deflate_frame_request is not None:
        processor = DeflateFrameExtensionProcessor(deflate_frame_request)
        _install_extension_processor(processor, req, stream_options)
    elif perframe_compression_request is not None:
        processor = PerFrameCompressionExtensionProcessor(
            perframe_compression_request)
        _install_extension_processor(processor, req, stream_options)
    elif permessage_compression_request is not None:
        processor = PerMessageCompressionExtensionProcessor(
            permessage_compression_request)
        _install_extension_processor(processor, req, stream_options)

    req.ws_stream = Stream(req, stream_options)
    return req


def _create_request(*frames):
    """Creates MockRequest using data given as frames.

    frames will be returned on calling request.connection.read() where request
    is MockRequest returned by this function.
    """

    read_data = []
    for (header, body) in frames:
        read_data.append(header + _mask_hybi(body))

    return _create_request_from_rawdata(read_data)


def _create_blocking_request():
    """Creates MockRequest.

    Data written to a MockRequest can be read out by calling
    request.connection.written_data().
    """

    req = mock.MockRequest(connection=mock.MockBlockingConn())
    req.ws_version = common.VERSION_HYBI_LATEST
    stream_options = StreamOptions()
    req.ws_stream = Stream(req, stream_options)
    return req


def _create_request_hixie75(read_data=''):
    req = mock.MockRequest(connection=mock.MockConn(read_data))
    req.ws_stream = StreamHixie75(req)
    return req


def _create_blocking_request_hixie75():
    req = mock.MockRequest(connection=mock.MockBlockingConn())
    req.ws_stream = StreamHixie75(req)
    return req


class MessageTest(unittest.TestCase):
    # Tests for Stream

    def test_send_message(self):
        request = _create_request()
        msgutil.send_message(request, 'Hello')
        self.assertEqual('\x81\x05Hello', request.connection.written_data())

        payload = 'a' * 125
        request = _create_request()
        msgutil.send_message(request, payload)
        self.assertEqual('\x81\x7d' + payload,
                         request.connection.written_data())

    def test_send_medium_message(self):
        payload = 'a' * 126
        request = _create_request()
        msgutil.send_message(request, payload)
        self.assertEqual('\x81\x7e\x00\x7e' + payload,
                         request.connection.written_data())

        payload = 'a' * ((1 << 16) - 1)
        request = _create_request()
        msgutil.send_message(request, payload)
        self.assertEqual('\x81\x7e\xff\xff' + payload,
                         request.connection.written_data())

    def test_send_large_message(self):
        payload = 'a' * (1 << 16)
        request = _create_request()
        msgutil.send_message(request, payload)
        self.assertEqual('\x81\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + payload,
                         request.connection.written_data())

    def test_send_message_unicode(self):
        request = _create_request()
        msgutil.send_message(request, u'\u65e5')
        # U+65e5 is encoded as e6,97,a5 in UTF-8
        self.assertEqual('\x81\x03\xe6\x97\xa5',
                         request.connection.written_data())

    def test_send_message_fragments(self):
        request = _create_request()
        msgutil.send_message(request, 'Hello', False)
        msgutil.send_message(request, ' ', False)
        msgutil.send_message(request, 'World', False)
        msgutil.send_message(request, '!', True)
        self.assertEqual('\x01\x05Hello\x00\x01 \x00\x05World\x80\x01!',
                         request.connection.written_data())

    def test_send_fragments_immediate_zero_termination(self):
        request = _create_request()
        msgutil.send_message(request, 'Hello World!', False)
        msgutil.send_message(request, '', True)
        self.assertEqual('\x01\x0cHello World!\x80\x00',
                         request.connection.written_data())

    def test_send_message_deflate_stream(self):
        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        request = _create_request_from_rawdata('', deflate_stream=True)
        msgutil.send_message(request, 'Hello')
        expected = compress.compress('\x81\x05Hello')
        expected += compress.flush(zlib.Z_SYNC_FLUSH)
        self.assertEqual(expected, request.connection.written_data())

    def test_send_message_deflate_frame(self):
        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
        request = _create_request_from_rawdata(
            '', deflate_frame_request=extension)
        msgutil.send_message(request, 'Hello')
        msgutil.send_message(request, 'World')

        expected = ''

        compressed_hello = compress.compress('Hello')
        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_hello = compressed_hello[:-4]
        expected += '\xc1%c' % len(compressed_hello)
        expected += compressed_hello

        compressed_world = compress.compress('World')
        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_world = compressed_world[:-4]
        expected += '\xc1%c' % len(compressed_world)
        expected += compressed_world

        self.assertEqual(expected, request.connection.written_data())

    def test_send_message_deflate_frame_comp_bit(self):
        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
        request = _create_request_from_rawdata(
            '', deflate_frame_request=extension)
        self.assertEquals(1, len(request.ws_extension_processors))
        deflate_frame_processor = request.ws_extension_processors[0]
        msgutil.send_message(request, 'Hello')
        deflate_frame_processor.disable_outgoing_compression()
        msgutil.send_message(request, 'Hello')
        deflate_frame_processor.enable_outgoing_compression()
        msgutil.send_message(request, 'Hello')

        expected = ''

        compressed_hello = compress.compress('Hello')
        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_hello = compressed_hello[:-4]
        expected += '\xc1%c' % len(compressed_hello)
        expected += compressed_hello

        expected += '\x81\x05Hello'

        compressed_2nd_hello = compress.compress('Hello')
        compressed_2nd_hello += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_2nd_hello = compressed_2nd_hello[:-4]
        expected += '\xc1%c' % len(compressed_2nd_hello)
        expected += compressed_2nd_hello

        self.assertEqual(expected, request.connection.written_data())

    def test_send_message_deflate_frame_no_context_takeover_parameter(self):
        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
        extension.add_parameter('no_context_takeover', None)
        request = _create_request_from_rawdata(
            '', deflate_frame_request=extension)
        for i in xrange(3):
            msgutil.send_message(request, 'Hello')

        compressed_message = compress.compress('Hello')
        compressed_message += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_message = compressed_message[:-4]
        expected = '\xc1%c' % len(compressed_message)
        expected += compressed_message

        self.assertEqual(
            expected + expected + expected, request.connection.written_data())

    def test_deflate_frame_bad_request_parameters(self):
        """Tests that if there's anything wrong with deflate-frame extension
        request, deflate-frame is rejected.
        """

        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
        # max_window_bits less than 8 is illegal.
        extension.add_parameter('max_window_bits', '7')
        processor = DeflateFrameExtensionProcessor(extension)
        self.assertEqual(None, processor.get_extension_response())

        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
        # max_window_bits greater than 15 is illegal.
        extension.add_parameter('max_window_bits', '16')
        processor = DeflateFrameExtensionProcessor(extension)
        self.assertEqual(None, processor.get_extension_response())

        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
        # Non integer max_window_bits is illegal.
        extension.add_parameter('max_window_bits', 'foobar')
        processor = DeflateFrameExtensionProcessor(extension)
        self.assertEqual(None, processor.get_extension_response())

        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
        # no_context_takeover must not have any value.
        extension.add_parameter('no_context_takeover', 'foobar')
        processor = DeflateFrameExtensionProcessor(extension)
        self.assertEqual(None, processor.get_extension_response())

    def test_deflate_frame_response_parameters(self):
        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
        processor = DeflateFrameExtensionProcessor(extension)
        processor.set_response_window_bits(8)
        response = processor.get_extension_response()
        self.assertTrue(response.has_parameter('max_window_bits'))
        self.assertEqual('8', response.get_parameter_value('max_window_bits'))

        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
        processor = DeflateFrameExtensionProcessor(extension)
        processor.set_response_no_context_takeover(True)
        response = processor.get_extension_response()
        self.assertTrue(response.has_parameter('no_context_takeover'))
        self.assertTrue(
            response.get_parameter_value('no_context_takeover') is None)

    def test_send_message_perframe_compress_deflate(self):
        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
        extension = common.ExtensionParameter(
            common.PERFRAME_COMPRESSION_EXTENSION)
        extension.add_parameter('method', 'deflate')
        request = _create_request_from_rawdata(
                      '', perframe_compression_request=extension)
        msgutil.send_message(request, 'Hello')
        msgutil.send_message(request, 'World')

        expected = ''

        compressed_hello = compress.compress('Hello')
        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_hello = compressed_hello[:-4]
        expected += '\xc1%c' % len(compressed_hello)
        expected += compressed_hello

        compressed_world = compress.compress('World')
        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_world = compressed_world[:-4]
        expected += '\xc1%c' % len(compressed_world)
        expected += compressed_world

        self.assertEqual(expected, request.connection.written_data())

    def test_send_message_permessage_compress_deflate(self):
        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
        extension = common.ExtensionParameter(
            common.PERMESSAGE_COMPRESSION_EXTENSION)
        extension.add_parameter('method', 'deflate')
        request = _create_request_from_rawdata(
                      '', permessage_compression_request=extension)
        msgutil.send_message(request, 'Hello')
        msgutil.send_message(request, 'World')

        expected = ''

        compressed_hello = compress.compress('Hello')
        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_hello = compressed_hello[:-4]
        expected += '\xc1%c' % len(compressed_hello)
        expected += compressed_hello

        compressed_world = compress.compress('World')
        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_world = compressed_world[:-4]
        expected += '\xc1%c' % len(compressed_world)
        expected += compressed_world

        self.assertEqual(expected, request.connection.written_data())

    def test_receive_message(self):
        request = _create_request(
            ('\x81\x85', 'Hello'), ('\x81\x86', 'World!'))
        self.assertEqual('Hello', msgutil.receive_message(request))
        self.assertEqual('World!', msgutil.receive_message(request))

        payload = 'a' * 125
        request = _create_request(('\x81\xfd', payload))
        self.assertEqual(payload, msgutil.receive_message(request))

    def test_receive_medium_message(self):
        payload = 'a' * 126
        request = _create_request(('\x81\xfe\x00\x7e', payload))
        self.assertEqual(payload, msgutil.receive_message(request))

        payload = 'a' * ((1 << 16) - 1)
        request = _create_request(('\x81\xfe\xff\xff', payload))
        self.assertEqual(payload, msgutil.receive_message(request))

    def test_receive_large_message(self):
        payload = 'a' * (1 << 16)
        request = _create_request(
            ('\x81\xff\x00\x00\x00\x00\x00\x01\x00\x00', payload))
        self.assertEqual(payload, msgutil.receive_message(request))

    def test_receive_length_not_encoded_using_minimal_number_of_bytes(self):
        # Log warning on receiving bad payload length field that doesn't use
        # minimal number of bytes but continue processing.

        payload = 'a'
        # 1 byte can be represented without extended payload length field.
        request = _create_request(
            ('\x81\xff\x00\x00\x00\x00\x00\x00\x00\x01', payload))
        self.assertEqual(payload, msgutil.receive_message(request))

    def test_receive_message_unicode(self):
        request = _create_request(('\x81\x83', '\xe6\x9c\xac'))
        # U+672c is encoded as e6,9c,ac in UTF-8
        self.assertEqual(u'\u672c', msgutil.receive_message(request))

    def test_receive_message_erroneous_unicode(self):
        # \x80 and \x81 are invalid as UTF-8.
        request = _create_request(('\x81\x82', '\x80\x81'))
        # Invalid characters should raise InvalidUTF8Exception
        self.assertRaises(InvalidUTF8Exception,
                          msgutil.receive_message,
                          request)

    def test_receive_fragments(self):
        request = _create_request(
            ('\x01\x85', 'Hello'),
            ('\x00\x81', ' '),
            ('\x00\x85', 'World'),
            ('\x80\x81', '!'))
        self.assertEqual('Hello World!', msgutil.receive_message(request))

    def test_receive_fragments_unicode(self):
        # UTF-8 encodes U+6f22 into e6bca2 and U+5b57 into e5ad97.
        request = _create_request(
            ('\x01\x82', '\xe6\xbc'),
            ('\x00\x82', '\xa2\xe5'),
            ('\x80\x82', '\xad\x97'))
        self.assertEqual(u'\u6f22\u5b57', msgutil.receive_message(request))

    def test_receive_fragments_immediate_zero_termination(self):
        request = _create_request(
            ('\x01\x8c', 'Hello World!'), ('\x80\x80', ''))
        self.assertEqual('Hello World!', msgutil.receive_message(request))

    def test_receive_fragments_duplicate_start(self):
        request = _create_request(
            ('\x01\x85', 'Hello'), ('\x01\x85', 'World'))
        self.assertRaises(msgutil.InvalidFrameException,
                          msgutil.receive_message,
                          request)

    def test_receive_fragments_intermediate_but_not_started(self):
        request = _create_request(('\x00\x85', 'Hello'))
        self.assertRaises(msgutil.InvalidFrameException,
                          msgutil.receive_message,
                          request)

    def test_receive_fragments_end_but_not_started(self):
        request = _create_request(('\x80\x85', 'Hello'))
        self.assertRaises(msgutil.InvalidFrameException,
                          msgutil.receive_message,
                          request)

    def test_receive_message_discard(self):
        request = _create_request(
            ('\x8f\x86', 'IGNORE'), ('\x81\x85', 'Hello'),
            ('\x8f\x89', 'DISREGARD'), ('\x81\x86', 'World!'))
        self.assertRaises(msgutil.UnsupportedFrameException,
                          msgutil.receive_message, request)
        self.assertEqual('Hello', msgutil.receive_message(request))
        self.assertRaises(msgutil.UnsupportedFrameException,
                          msgutil.receive_message, request)
        self.assertEqual('World!', msgutil.receive_message(request))

    def test_receive_close(self):
        request = _create_request(
            ('\x88\x8a', struct.pack('!H', 1000) + 'Good bye'))
        self.assertEqual(None, msgutil.receive_message(request))
        self.assertEqual(1000, request.ws_close_code)
        self.assertEqual('Good bye', request.ws_close_reason)

    def test_receive_message_deflate_stream(self):
        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        data = compress.compress('\x81\x85' + _mask_hybi('Hello'))
        data += compress.flush(zlib.Z_SYNC_FLUSH)
        data += compress.compress('\x81\x89' + _mask_hybi('WebSocket'))
        data += compress.flush(zlib.Z_FINISH)

        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        data += compress.compress('\x81\x85' + _mask_hybi('World'))
        data += compress.flush(zlib.Z_SYNC_FLUSH)
        # Close frame
        data += compress.compress(
            '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye'))
        data += compress.flush(zlib.Z_SYNC_FLUSH)

        request = _create_request_from_rawdata(data, deflate_stream=True)
        self.assertEqual('Hello', msgutil.receive_message(request))
        self.assertEqual('WebSocket', msgutil.receive_message(request))
        self.assertEqual('World', msgutil.receive_message(request))

        self.assertFalse(request.drain_received_data_called)

        self.assertEqual(None, msgutil.receive_message(request))

        self.assertTrue(request.drain_received_data_called)

    def test_receive_message_deflate_frame(self):
        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        data = ''

        compressed_hello = compress.compress('Hello')
        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_hello = compressed_hello[:-4]
        data += '\xc1%c' % (len(compressed_hello) | 0x80)
        data += _mask_hybi(compressed_hello)

        compressed_websocket = compress.compress('WebSocket')
        compressed_websocket += compress.flush(zlib.Z_FINISH)
        compressed_websocket += '\x00'
        data += '\xc1%c' % (len(compressed_websocket) | 0x80)
        data += _mask_hybi(compressed_websocket)

        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        compressed_world = compress.compress('World')
        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_world = compressed_world[:-4]
        data += '\xc1%c' % (len(compressed_world) | 0x80)
        data += _mask_hybi(compressed_world)

        # Close frame
        data += '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye')

        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
        request = _create_request_from_rawdata(
            data, deflate_frame_request=extension)
        self.assertEqual('Hello', msgutil.receive_message(request))
        self.assertEqual('WebSocket', msgutil.receive_message(request))
        self.assertEqual('World', msgutil.receive_message(request))

        self.assertEqual(None, msgutil.receive_message(request))

    def test_receive_message_deflate_frame_client_using_smaller_window(self):
        """Test that frames coming from a client which is using smaller window
        size that the server are correctly received.
        """

        # Using the smallest window bits of 8 for generating input frames.
        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -8)

        data = ''

        # Use a frame whose content is bigger than the clients' DEFLATE window
        # size before compression. The content mainly consists of 'a' but
        # repetition of 'b' is put at the head and tail so that if the window
        # size is big, the head is back-referenced but if small, not.
        payload = 'b' * 64 + 'a' * 1024 + 'b' * 64
        compressed_hello = compress.compress(payload)
        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_hello = compressed_hello[:-4]
        data += '\xc1%c' % (len(compressed_hello) | 0x80)
        data += _mask_hybi(compressed_hello)

        # Close frame
        data += '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye')

        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
        request = _create_request_from_rawdata(
            data, deflate_frame_request=extension)
        self.assertEqual(payload, msgutil.receive_message(request))

        self.assertEqual(None, msgutil.receive_message(request))

    def test_receive_message_deflate_frame_comp_bit(self):
        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        data = ''

        compressed_hello = compress.compress('Hello')
        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_hello = compressed_hello[:-4]
        data += '\xc1%c' % (len(compressed_hello) | 0x80)
        data += _mask_hybi(compressed_hello)

        data += '\x81\x85' + _mask_hybi('Hello')

        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        compressed_2nd_hello = compress.compress('Hello')
        compressed_2nd_hello += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_2nd_hello = compressed_2nd_hello[:-4]
        data += '\xc1%c' % (len(compressed_2nd_hello) | 0x80)
        data += _mask_hybi(compressed_2nd_hello)

        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
        request = _create_request_from_rawdata(
            data, deflate_frame_request=extension)
        for i in xrange(3):
            self.assertEqual('Hello', msgutil.receive_message(request))

    def test_receive_message_perframe_compression_frame(self):
        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        data = ''

        compressed_hello = compress.compress('Hello')
        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_hello = compressed_hello[:-4]
        data += '\xc1%c' % (len(compressed_hello) | 0x80)
        data += _mask_hybi(compressed_hello)

        compressed_websocket = compress.compress('WebSocket')
        compressed_websocket += compress.flush(zlib.Z_FINISH)
        compressed_websocket += '\x00'
        data += '\xc1%c' % (len(compressed_websocket) | 0x80)
        data += _mask_hybi(compressed_websocket)

        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        compressed_world = compress.compress('World')
        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_world = compressed_world[:-4]
        data += '\xc1%c' % (len(compressed_world) | 0x80)
        data += _mask_hybi(compressed_world)

        # Close frame
        data += '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye')

        extension = common.ExtensionParameter(
            common.PERFRAME_COMPRESSION_EXTENSION)
        extension.add_parameter('method', 'deflate')
        request = _create_request_from_rawdata(
            data, perframe_compression_request=extension)
        self.assertEqual('Hello', msgutil.receive_message(request))
        self.assertEqual('WebSocket', msgutil.receive_message(request))
        self.assertEqual('World', msgutil.receive_message(request))

        self.assertEqual(None, msgutil.receive_message(request))

    def test_receive_message_permessage_deflate_compression(self):
        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        data = ''

        compressed_hello = compress.compress('HelloWebSocket')
        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_hello = compressed_hello[:-4]
        split_position = len(compressed_hello) / 2
        data += '\x41%c' % (split_position | 0x80)
        data += _mask_hybi(compressed_hello[:split_position])

        data += '\x80%c' % ((len(compressed_hello) - split_position) | 0x80)
        data += _mask_hybi(compressed_hello[split_position:])

        compress = zlib.compressobj(
            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)

        compressed_world = compress.compress('World')
        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
        compressed_world = compressed_world[:-4]
        data += '\xc1%c' % (len(compressed_world) | 0x80)
        data += _mask_hybi(compressed_world)

        # Close frame
        data += '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye')

        extension = common.ExtensionParameter(
            common.PERMESSAGE_COMPRESSION_EXTENSION)
        extension.add_parameter('method', 'deflate')
        request = _create_request_from_rawdata(
            data, permessage_compression_request=extension)
        self.assertEqual('HelloWebSocket', msgutil.receive_message(request))
        self.assertEqual('World', msgutil.receive_message(request))

        self.assertEqual(None, msgutil.receive_message(request))

    def test_send_longest_close(self):
        reason = 'a' * 123
        request = _create_request(
            ('\x88\xfd',
             struct.pack('!H', common.STATUS_NORMAL_CLOSURE) + reason))
        request.ws_stream.close_connection(common.STATUS_NORMAL_CLOSURE,
                                           reason)
        self.assertEqual(request.ws_close_code, common.STATUS_NORMAL_CLOSURE)
        self.assertEqual(request.ws_close_reason, reason)

    def test_send_close_too_long(self):
        request = _create_request()
        self.assertRaises(msgutil.BadOperationException,
                          Stream.close_connection,
                          request.ws_stream,
                          common.STATUS_NORMAL_CLOSURE,
                          'a' * 124)

    def test_send_close_inconsistent_code_and_reason(self):
        request = _create_request()
        # reason parameter must not be specified when code is None.
        self.assertRaises(msgutil.BadOperationException,
                          Stream.close_connection,
                          request.ws_stream,
                          None,
                          'a')

    def test_send_ping(self):
        request = _create_request()
        msgutil.send_ping(request, 'Hello World!')
        self.assertEqual('\x89\x0cHello World!',
                         request.connection.written_data())

    def test_send_longest_ping(self):
        request = _create_request()
        msgutil.send_ping(request, 'a' * 125)
        self.assertEqual('\x89\x7d' + 'a' * 125,
                         request.connection.written_data())

    def test_send_ping_too_long(self):
        request = _create_request()
        self.assertRaises(msgutil.BadOperationException,
                          msgutil.send_ping,
                          request,
                          'a' * 126)

    def test_receive_ping(self):
        """Tests receiving a ping control frame."""

        def handler(request, message):
            request.called = True

        # Stream automatically respond to ping with pong without any action
        # by application layer.
        request = _create_request(
            ('\x89\x85', 'Hello'), ('\x81\x85', 'World'))
        self.assertEqual('World', msgutil.receive_message(request))
        self.assertEqual('\x8a\x05Hello',
                         request.connection.written_data())

        request = _create_request(
            ('\x89\x85', 'Hello'), ('\x81\x85', 'World'))
        request.on_ping_handler = handler
        self.assertEqual('World', msgutil.receive_message(request))
        self.assertTrue(request.called)

    def test_receive_longest_ping(self):
        request = _create_request(
            ('\x89\xfd', 'a' * 125), ('\x81\x85', 'World'))
        self.assertEqual('World', msgutil.receive_message(request))
        self.assertEqual('\x8a\x7d' + 'a' * 125,
                         request.connection.written_data())

    def test_receive_ping_too_long(self):
        request = _create_request(('\x89\xfe\x00\x7e', 'a' * 126))
        self.assertRaises(msgutil.InvalidFrameException,
                          msgutil.receive_message,
                          request)

    def test_receive_pong(self):
        """Tests receiving a pong control frame."""

        def handler(request, message):
            request.called = True

        request = _create_request(
            ('\x8a\x85', 'Hello'), ('\x81\x85', 'World'))
        request.on_pong_handler = handler
        msgutil.send_ping(request, 'Hello')
        self.assertEqual('\x89\x05Hello',
                         request.connection.written_data())
        # Valid pong is received, but receive_message won't return for it.
        self.assertEqual('World', msgutil.receive_message(request))
        # Check that nothing was written after receive_message call.
        self.assertEqual('\x89\x05Hello',
                         request.connection.written_data())

        self.assertTrue(request.called)

    def test_receive_unsolicited_pong(self):
        # Unsolicited pong is allowed from HyBi 07.
        request = _create_request(
            ('\x8a\x85', 'Hello'), ('\x81\x85', 'World'))
        msgutil.receive_message(request)

        request = _create_request(
            ('\x8a\x85', 'Hello'), ('\x81\x85', 'World'))
        msgutil.send_ping(request, 'Jumbo')
        # Body mismatch.
        msgutil.receive_message(request)

    def test_ping_cannot_be_fragmented(self):
        request = _create_request(('\x09\x85', 'Hello'))
        self.assertRaises(msgutil.InvalidFrameException,
                          msgutil.receive_message,
                          request)

    def test_ping_with_too_long_payload(self):
        request = _create_request(('\x89\xfe\x01\x00', 'a' * 256))
        self.assertRaises(msgutil.InvalidFrameException,
                          msgutil.receive_message,
                          request)


class MessageTestHixie75(unittest.TestCase):
    """Tests for draft-hixie-thewebsocketprotocol-76 stream class."""

    def test_send_message(self):
        request = _create_request_hixie75()
        msgutil.send_message(request, 'Hello')
        self.assertEqual('\x00Hello\xff', request.connection.written_data())

    def test_send_message_unicode(self):
        request = _create_request_hixie75()
        msgutil.send_message(request, u'\u65e5')
        # U+65e5 is encoded as e6,97,a5 in UTF-8
        self.assertEqual('\x00\xe6\x97\xa5\xff',
                         request.connection.written_data())

    def test_receive_message(self):
        request = _create_request_hixie75('\x00Hello\xff\x00World!\xff')
        self.assertEqual('Hello', msgutil.receive_message(request))
        self.assertEqual('World!', msgutil.receive_message(request))

    def test_receive_message_unicode(self):
        request = _create_request_hixie75('\x00\xe6\x9c\xac\xff')
        # U+672c is encoded as e6,9c,ac in UTF-8
        self.assertEqual(u'\u672c', msgutil.receive_message(request))

    def test_receive_message_erroneous_unicode(self):
        # \x80 and \x81 are invalid as UTF-8.
        request = _create_request_hixie75('\x00\x80\x81\xff')
        # Invalid characters should be replaced with
        # U+fffd REPLACEMENT CHARACTER
        self.assertEqual(u'\ufffd\ufffd', msgutil.receive_message(request))

    def test_receive_message_discard(self):
        request = _create_request_hixie75('\x80\x06IGNORE\x00Hello\xff'
                                          '\x01DISREGARD\xff\x00World!\xff')
        self.assertEqual('Hello', msgutil.receive_message(request))
        self.assertEqual('World!', msgutil.receive_message(request))


class MessageReceiverTest(unittest.TestCase):
    """Tests the Stream class using MessageReceiver."""

    def test_queue(self):
        request = _create_blocking_request()
        receiver = msgutil.MessageReceiver(request)

        self.assertEqual(None, receiver.receive_nowait())

        request.connection.put_bytes('\x81\x86' + _mask_hybi('Hello!'))
        self.assertEqual('Hello!', receiver.receive())

    def test_onmessage(self):
        onmessage_queue = Queue.Queue()

        def onmessage_handler(message):
            onmessage_queue.put(message)

        request = _create_blocking_request()
        receiver = msgutil.MessageReceiver(request, onmessage_handler)

        request.connection.put_bytes('\x81\x86' + _mask_hybi('Hello!'))
        self.assertEqual('Hello!', onmessage_queue.get())


class MessageReceiverHixie75Test(unittest.TestCase):
    """Tests the StreamHixie75 class using MessageReceiver."""

    def test_queue(self):
        request = _create_blocking_request_hixie75()
        receiver = msgutil.MessageReceiver(request)

        self.assertEqual(None, receiver.receive_nowait())

        request.connection.put_bytes('\x00Hello!\xff')
        self.assertEqual('Hello!', receiver.receive())

    def test_onmessage(self):
        onmessage_queue = Queue.Queue()

        def onmessage_handler(message):
            onmessage_queue.put(message)

        request = _create_blocking_request_hixie75()
        receiver = msgutil.MessageReceiver(request, onmessage_handler)

        request.connection.put_bytes('\x00Hello!\xff')
        self.assertEqual('Hello!', onmessage_queue.get())


class MessageSenderTest(unittest.TestCase):
    """Tests the Stream class using MessageSender."""

    def test_send(self):
        request = _create_blocking_request()
        sender = msgutil.MessageSender(request)

        sender.send('World')
        self.assertEqual('\x81\x05World', request.connection.written_data())

    def test_send_nowait(self):
        # Use a queue to check the bytes written by MessageSender.
        # request.connection.written_data() cannot be used here because
        # MessageSender runs in a separate thread.
        send_queue = Queue.Queue()

        def write(bytes):
            send_queue.put(bytes)

        request = _create_blocking_request()
        request.connection.write = write

        sender = msgutil.MessageSender(request)

        sender.send_nowait('Hello')
        sender.send_nowait('World')
        self.assertEqual('\x81\x05Hello', send_queue.get())
        self.assertEqual('\x81\x05World', send_queue.get())


class MessageSenderHixie75Test(unittest.TestCase):
    """Tests the StreamHixie75 class using MessageSender."""

    def test_send(self):
        request = _create_blocking_request_hixie75()
        sender = msgutil.MessageSender(request)

        sender.send('World')
        self.assertEqual('\x00World\xff', request.connection.written_data())

    def test_send_nowait(self):
        # Use a queue to check the bytes written by MessageSender.
        # request.connection.written_data() cannot be used here because
        # MessageSender runs in a separate thread.
        send_queue = Queue.Queue()

        def write(bytes):
            send_queue.put(bytes)

        request = _create_blocking_request_hixie75()
        request.connection.write = write

        sender = msgutil.MessageSender(request)

        sender.send_nowait('Hello')
        sender.send_nowait('World')
        self.assertEqual('\x00Hello\xff', send_queue.get())
        self.assertEqual('\x00World\xff', send_queue.get())


if __name__ == '__main__':
    unittest.main()


# vi:sts=4 sw=4 et