blob: d4598944e087cbd2c2b2ce2a4b089aafd7a334c2 [file] [log] [blame]
#!/usr/bin/env python
#
# Copyright 2012, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests for mux module."""
import Queue
import copy
import logging
import optparse
import struct
import sys
import unittest
import time
import zlib
import set_sys_path # Update sys.path to locate mod_pywebsocket module.
from mod_pywebsocket import common
from mod_pywebsocket import mux
from mod_pywebsocket._stream_base import ConnectionTerminatedException
from mod_pywebsocket._stream_base import UnsupportedFrameException
from mod_pywebsocket._stream_hybi import Frame
from mod_pywebsocket._stream_hybi import Stream
from mod_pywebsocket._stream_hybi import StreamOptions
from mod_pywebsocket._stream_hybi import create_binary_frame
from mod_pywebsocket._stream_hybi import create_close_frame
from mod_pywebsocket._stream_hybi import create_closing_handshake_body
from mod_pywebsocket._stream_hybi import parse_frame
from mod_pywebsocket.extensions import MuxExtensionProcessor
import mock
_TEST_HEADERS = {'Host': 'server.example.com',
'Upgrade': 'websocket',
'Connection': 'Upgrade',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
'Sec-WebSocket-Version': '13',
'Origin': 'http://example.com'}
class _OutgoingChannelData(object):
def __init__(self):
self.messages = []
self.control_messages = []
self.builder = mux._InnerMessageBuilder()
class _MockMuxConnection(mock.MockBlockingConn):
"""Mock class of mod_python connection for mux."""
def __init__(self):
mock.MockBlockingConn.__init__(self)
self._control_blocks = []
self._channel_data = {}
self._current_opcode = None
self._pending_fragments = []
self.server_close_code = None
def write(self, data):
"""Override MockBlockingConn.write."""
self._current_data = data
self._position = 0
def _receive_bytes(length):
if self._position + length > len(self._current_data):
raise ConnectionTerminatedException(
'Failed to receive %d bytes from encapsulated '
'frame' % length)
data = self._current_data[self._position:self._position+length]
self._position += length
return data
# Parse physical frames and assemble a message if the message is
# fragmented.
opcode, payload, fin, rsv1, rsv2, rsv3 = (
parse_frame(_receive_bytes, unmask_receive=False))
self._pending_fragments.append(payload)
if self._current_opcode is None:
if opcode == common.OPCODE_CONTINUATION:
raise Exception('Sending invalid continuation opcode')
self._current_opcode = opcode
else:
if opcode != common.OPCODE_CONTINUATION:
raise Exception('Sending invalid opcode %d' % opcode)
if not fin:
return
inner_frame_data = ''.join(self._pending_fragments)
self._pending_fragments = []
self._current_opcode = None
# Handle a control message on the physical channel.
# TODO(bashi): Support other opcodes if needed.
if opcode == common.OPCODE_CLOSE:
if len(payload) >= 2:
self.server_close_code = struct.unpack('!H', payload[:2])[0]
close_body = create_closing_handshake_body(
common.STATUS_NORMAL_CLOSURE, '')
close_frame = create_close_frame(close_body, mask=True)
self.put_bytes(close_frame)
return
# Parse the payload of the message on physical channel.
parser = mux._MuxFramePayloadParser(inner_frame_data)
channel_id = parser.read_channel_id()
if channel_id == mux._CONTROL_CHANNEL_ID:
self._control_blocks.extend(list(parser.read_control_blocks()))
return
if not channel_id in self._channel_data:
self._channel_data[channel_id] = _OutgoingChannelData()
channel_data = self._channel_data[channel_id]
# Parse logical frames and assemble an inner (logical) message.
(inner_fin, inner_rsv1, inner_rsv2, inner_rsv3, inner_opcode,
inner_payload) = parser.read_inner_frame()
inner_frame = Frame(inner_fin, inner_rsv1, inner_rsv2, inner_rsv3,
inner_opcode, inner_payload)
message = channel_data.builder.build(inner_frame)
if message is None:
return
if (message.opcode == common.OPCODE_TEXT or
message.opcode == common.OPCODE_BINARY):
channel_data.messages.append(message.payload)
self.on_data_message(message.payload)
else:
channel_data.control_messages.append(
{'opcode': message.opcode,
'message': message.payload})
def on_data_message(self, message):
pass
def get_written_control_blocks(self):
return self._control_blocks
def get_written_messages(self, channel_id):
return self._channel_data[channel_id].messages
def get_written_control_messages(self, channel_id):
return self._channel_data[channel_id].control_messages
class _FailOnWriteConnection(_MockMuxConnection):
"""Specicialized version of _MockMuxConnection. Its write() method raises
an exception for testing when a data message is written.
"""
def on_data_message(self, message):
"""Override to raise an exception."""
raise Exception('Intentional failure')
class _ChannelEvent(object):
"""A structure that records channel events."""
def __init__(self):
self.request = None
self.messages = []
self.exception = None
self.client_initiated_closing = False
class _MuxMockDispatcher(object):
"""Mock class of dispatch.Dispatcher for mux."""
def __init__(self):
self.channel_events = {}
def do_extra_handshake(self, request):
if request.ws_requested_protocols is not None:
request.ws_protocol = request.ws_requested_protocols[0]
def _do_echo(self, request, channel_events):
while True:
message = request.ws_stream.receive_message()
if message == None:
channel_events.client_initiated_closing = True
return
if message == 'Goodbye':
return
channel_events.messages.append(message)
# echo back
request.ws_stream.send_message(message)
def _do_ping(self, request, channel_events):
request.ws_stream.send_ping('Ping!')
def _do_ping_while_hello_world(self, request, channel_events):
request.ws_stream.send_message('Hello ', end=False)
request.ws_stream.send_ping('Ping!')
request.ws_stream.send_message('World!', end=True)
def _do_two_ping_while_hello_world(self, request, channel_events):
request.ws_stream.send_message('Hello ', end=False)
request.ws_stream.send_ping('Ping!')
request.ws_stream.send_ping('Pong!')
request.ws_stream.send_message('World!', end=True)
def transfer_data(self, request):
self.channel_events[request.channel_id] = _ChannelEvent()
self.channel_events[request.channel_id].request = request
try:
# Note: more handler will be added.
if request.uri.endswith('echo'):
self._do_echo(request,
self.channel_events[request.channel_id])
elif request.uri.endswith('ping'):
self._do_ping(request,
self.channel_events[request.channel_id])
elif request.uri.endswith('two_ping_while_hello_world'):
self._do_two_ping_while_hello_world(
request, self.channel_events[request.channel_id])
elif request.uri.endswith('ping_while_hello_world'):
self._do_ping_while_hello_world(
request, self.channel_events[request.channel_id])
else:
raise ValueError('Cannot handle path %r' % request.path)
if not request.server_terminated:
request.ws_stream.close_connection()
except ConnectionTerminatedException, e:
self.channel_events[request.channel_id].exception = e
except Exception, e:
self.channel_events[request.channel_id].exception = e
raise
def _create_mock_request(connection=None, logical_channel_extensions=None):
if connection is None:
connection = _MockMuxConnection()
request = mock.MockRequest(uri='/echo',
headers_in=_TEST_HEADERS,
connection=connection)
request.ws_stream = Stream(request, options=StreamOptions())
request.mux_processor = MuxExtensionProcessor(
common.ExtensionParameter(common.MUX_EXTENSION))
if logical_channel_extensions is not None:
request.mux_processor.set_extensions(logical_channel_extensions)
request.mux_processor.set_quota(8 * 1024)
return request
def _create_add_channel_request_frame(channel_id, encoding, encoded_handshake):
# Allow invalid encoding for testing.
first_byte = ((mux._MUX_OPCODE_ADD_CHANNEL_REQUEST << 5) | encoding)
payload = (chr(first_byte) +
mux._encode_channel_id(channel_id) +
mux._encode_number(len(encoded_handshake)) +
encoded_handshake)
return create_binary_frame(
(mux._encode_channel_id(mux._CONTROL_CHANNEL_ID) + payload), mask=True)
def _create_drop_channel_frame(channel_id, code=None, message=''):
payload = mux._create_drop_channel(channel_id, code, message)
return create_binary_frame(
(mux._encode_channel_id(mux._CONTROL_CHANNEL_ID) + payload), mask=True)
def _create_flow_control_frame(channel_id, replenished_quota):
payload = mux._create_flow_control(channel_id, replenished_quota)
return create_binary_frame(
(mux._encode_channel_id(mux._CONTROL_CHANNEL_ID) + payload), mask=True)
def _create_logical_frame(channel_id, message, opcode=common.OPCODE_BINARY,
fin=True, rsv1=False, rsv2=False, rsv3=False,
mask=True):
bits = chr((fin << 7) | (rsv1 << 6) | (rsv2 << 5) | (rsv3 << 4) | opcode)
payload = mux._encode_channel_id(channel_id) + bits + message
return create_binary_frame(payload, mask=True)
def _create_request_header(path='/echo', extensions=None):
headers = (
'GET %s HTTP/1.1\r\n'
'Host: server.example.com\r\n'
'Connection: Upgrade\r\n'
'Origin: http://example.com\r\n') % path
if extensions:
headers += '%s: %s' % (
common.SEC_WEBSOCKET_EXTENSIONS_HEADER, extensions)
return headers
class MuxTest(unittest.TestCase):
"""A unittest for mux module."""
def test_channel_id_decode(self):
data = '\x00\x01\xbf\xff\xdf\xff\xff\xff\xff\xff\xff'
parser = mux._MuxFramePayloadParser(data)
channel_id = parser.read_channel_id()
self.assertEqual(0, channel_id)
channel_id = parser.read_channel_id()
self.assertEqual(1, channel_id)
channel_id = parser.read_channel_id()
self.assertEqual(2 ** 14 - 1, channel_id)
channel_id = parser.read_channel_id()
self.assertEqual(2 ** 21 - 1, channel_id)
channel_id = parser.read_channel_id()
self.assertEqual(2 ** 29 - 1, channel_id)
self.assertEqual(len(data), parser._read_position)
def test_channel_id_encode(self):
encoded = mux._encode_channel_id(0)
self.assertEqual('\x00', encoded)
encoded = mux._encode_channel_id(2 ** 14 - 1)
self.assertEqual('\xbf\xff', encoded)
encoded = mux._encode_channel_id(2 ** 14)
self.assertEqual('\xc0@\x00', encoded)
encoded = mux._encode_channel_id(2 ** 21 - 1)
self.assertEqual('\xdf\xff\xff', encoded)
encoded = mux._encode_channel_id(2 ** 21)
self.assertEqual('\xe0 \x00\x00', encoded)
encoded = mux._encode_channel_id(2 ** 29 - 1)
self.assertEqual('\xff\xff\xff\xff', encoded)
# channel_id is too large
self.assertRaises(ValueError,
mux._encode_channel_id,
2 ** 29)
def test_read_multiple_control_blocks(self):
# Use AddChannelRequest because it can contain arbitrary length of data
data = ('\x00\x01\x01a'
'\x00\x02\x7d%s'
'\x00\x03\x7e\xff\xff%s'
'\x00\x04\x7f\x00\x00\x00\x00\x00\x01\x00\x00%s') % (
'a' * 0x7d, 'b' * 0xffff, 'c' * 0x10000)
parser = mux._MuxFramePayloadParser(data)
blocks = list(parser.read_control_blocks())
self.assertEqual(4, len(blocks))
self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[0].opcode)
self.assertEqual(1, blocks[0].channel_id)
self.assertEqual(1, len(blocks[0].encoded_handshake))
self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[1].opcode)
self.assertEqual(2, blocks[1].channel_id)
self.assertEqual(0x7d, len(blocks[1].encoded_handshake))
self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[2].opcode)
self.assertEqual(3, blocks[2].channel_id)
self.assertEqual(0xffff, len(blocks[2].encoded_handshake))
self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[3].opcode)
self.assertEqual(4, blocks[3].channel_id)
self.assertEqual(0x10000, len(blocks[3].encoded_handshake))
self.assertEqual(len(data), parser._read_position)
def test_read_add_channel_request(self):
data = '\x00\x01\x01a'
parser = mux._MuxFramePayloadParser(data)
blocks = list(parser.read_control_blocks())
self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[0].opcode)
self.assertEqual(1, blocks[0].channel_id)
self.assertEqual(1, len(blocks[0].encoded_handshake))
def test_read_drop_channel(self):
data = '\x60\x01\x00'
parser = mux._MuxFramePayloadParser(data)
blocks = list(parser.read_control_blocks())
self.assertEqual(1, len(blocks))
self.assertEqual(1, blocks[0].channel_id)
self.assertEqual(mux._MUX_OPCODE_DROP_CHANNEL, blocks[0].opcode)
self.assertEqual(None, blocks[0].drop_code)
self.assertEqual(0, len(blocks[0].drop_message))
data = '\x60\x02\x09\x03\xe8Success'
parser = mux._MuxFramePayloadParser(data)
blocks = list(parser.read_control_blocks())
self.assertEqual(1, len(blocks))
self.assertEqual(2, blocks[0].channel_id)
self.assertEqual(mux._MUX_OPCODE_DROP_CHANNEL, blocks[0].opcode)
self.assertEqual(1000, blocks[0].drop_code)
self.assertEqual('Success', blocks[0].drop_message)
# Reason is too short.
data = '\x60\x01\x01\x00'
parser = mux._MuxFramePayloadParser(data)
self.assertRaises(mux.PhysicalConnectionError,
lambda: list(parser.read_control_blocks()))
def test_read_flow_control(self):
data = '\x40\x01\x02'
parser = mux._MuxFramePayloadParser(data)
blocks = list(parser.read_control_blocks())
self.assertEqual(1, len(blocks))
self.assertEqual(1, blocks[0].channel_id)
self.assertEqual(mux._MUX_OPCODE_FLOW_CONTROL, blocks[0].opcode)
self.assertEqual(2, blocks[0].send_quota)
def test_read_new_channel_slot(self):
data = '\x80\x01\x02\x02\x03'
parser = mux._MuxFramePayloadParser(data)
# TODO(bashi): Implement
self.assertRaises(mux.PhysicalConnectionError,
lambda: list(parser.read_control_blocks()))
def test_read_invalid_number_field_in_control_block(self):
# No number field.
data = ''
parser = mux._MuxFramePayloadParser(data)
self.assertRaises(ValueError, parser._read_number)
# The last two bytes are missing.
data = '\x7e'
parser = mux._MuxFramePayloadParser(data)
self.assertRaises(ValueError, parser._read_number)
# Missing the last one byte.
data = '\x7f\x00\x00\x00\x00\x00\x01\x00'
parser = mux._MuxFramePayloadParser(data)
self.assertRaises(ValueError, parser._read_number)
# The length of number field is too large.
data = '\x7f\xff\xff\xff\xff\xff\xff\xff\xff'
parser = mux._MuxFramePayloadParser(data)
self.assertRaises(ValueError, parser._read_number)
# The msb of the first byte is set.
data = '\x80'
parser = mux._MuxFramePayloadParser(data)
self.assertRaises(ValueError, parser._read_number)
# Using 3 bytes encoding for 125.
data = '\x7e\x00\x7d'
parser = mux._MuxFramePayloadParser(data)
self.assertRaises(ValueError, parser._read_number)
# Using 9 bytes encoding for 0xffff
data = '\x7f\x00\x00\x00\x00\x00\x00\xff\xff'
parser = mux._MuxFramePayloadParser(data)
self.assertRaises(ValueError, parser._read_number)
def test_read_invalid_size_and_contents(self):
# Only contain number field.
data = '\x01'
parser = mux._MuxFramePayloadParser(data)
self.assertRaises(mux.PhysicalConnectionError,
parser._read_size_and_contents)
def test_create_add_channel_response(self):
data = mux._create_add_channel_response(channel_id=1,
encoded_handshake='FooBar',
encoding=0,
rejected=False)
self.assertEqual('\x20\x01\x06FooBar', data)
data = mux._create_add_channel_response(channel_id=2,
encoded_handshake='Hello',
encoding=1,
rejected=True)
self.assertEqual('\x31\x02\x05Hello', data)
def test_create_drop_channel(self):
data = mux._create_drop_channel(channel_id=1)
self.assertEqual('\x60\x01\x00', data)
data = mux._create_drop_channel(channel_id=1,
code=2000,
message='error')
self.assertEqual('\x60\x01\x07\x07\xd0error', data)
# reason must be empty if code is None
self.assertRaises(ValueError,
mux._create_drop_channel,
1, None, 'FooBar')
def test_parse_request_text(self):
request_text = _create_request_header()
command, path, version, headers = mux._parse_request_text(request_text)
self.assertEqual('GET', command)
self.assertEqual('/echo', path)
self.assertEqual('HTTP/1.1', version)
self.assertEqual(3, len(headers))
self.assertEqual('server.example.com', headers['Host'])
self.assertEqual('http://example.com', headers['Origin'])
class MuxHandlerTest(unittest.TestCase):
def test_add_channel(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=6)
request.connection.put_bytes(flow_control)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=3, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=3,
replenished_quota=6)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Hello'))
request.connection.put_bytes(
_create_logical_frame(channel_id=3, message='World'))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=3, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
self.assertEqual([], dispatcher.channel_events[1].messages)
self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
self.assertEqual(['World'], dispatcher.channel_events[3].messages)
# Channel 2
messages = request.connection.get_written_messages(2)
self.assertEqual(1, len(messages))
self.assertEqual('Hello', messages[0])
# Channel 3
messages = request.connection.get_written_messages(3)
self.assertEqual(1, len(messages))
self.assertEqual('World', messages[0])
control_blocks = request.connection.get_written_control_blocks()
# There should be 8 control blocks:
# - 1 NewChannelSlot
# - 2 AddChannelResponses for channel id 2 and 3
# - 6 FlowControls for channel id 1 (initialize), 'Hello', 'World',
# and 3 'Goodbye's
self.assertEqual(9, len(control_blocks))
def test_physical_connection_write_failure(self):
# Use _FailOnWriteConnection.
request = _create_mock_request(connection=_FailOnWriteConnection())
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
# Let the worker echo back 'Hello'. It causes _FailOnWriteConnection
# raising an exception.
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Hello'))
# Let the worker exit. This will be unnecessary when
# _LogicalConnection.write() is changed to throw an exception if
# woke up by on_writer_done.
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
# All threads should be done.
self.assertTrue(mux_handler.wait_until_done(timeout=2))
def test_send_blocked(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
# On receiving this 'Hello', the server tries to echo back 'Hello',
# but it will be blocked since there's no send quota available for the
# channel 2.
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Hello'))
# Wait until the worker is blocked due to send quota shortage.
time.sleep(1)
# Close the channel 2. The worker should be notified of the end of
# writer thread and stop waiting for send quota to be replenished.
drop_channel = _create_drop_channel_frame(channel_id=2)
request.connection.put_bytes(drop_channel)
# Make sure the channel 1 is also closed.
drop_channel = _create_drop_channel_frame(channel_id=1)
request.connection.put_bytes(drop_channel)
# All threads should be done.
self.assertTrue(mux_handler.wait_until_done(timeout=2))
def test_add_channel_delta_encoding(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
delta = 'GET /echo HTTP/1.1\r\n\r\n'
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=1, encoded_handshake=delta)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=6)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Hello'))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
messages = request.connection.get_written_messages(2)
self.assertEqual(1, len(messages))
self.assertEqual('Hello', messages[0])
def test_add_channel_delta_encoding_override(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
# Override Sec-WebSocket-Protocol.
delta = ('GET /echo HTTP/1.1\r\n'
'Sec-WebSocket-Protocol: x-foo\r\n'
'\r\n')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=1, encoded_handshake=delta)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=6)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Hello'))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
messages = request.connection.get_written_messages(2)
self.assertEqual(1, len(messages))
self.assertEqual('Hello', messages[0])
self.assertEqual('x-foo',
dispatcher.channel_events[2].request.ws_protocol)
def test_add_channel_delta_after_identity(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
# Sec-WebSocket-Protocol is different from client's opening handshake
# of the physical connection.
# TODO(bashi): Remove Upgrade, Connection, Sec-WebSocket-Key and
# Sec-WebSocket-Version.
encoded_handshake = (
'GET /echo HTTP/1.1\r\n'
'Host: server.example.com\r\n'
'Sec-WebSocket-Protocol: x-foo\r\n'
'Connection: Upgrade\r\n'
'Origin: http://example.com\r\n'
'\r\n')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=6)
request.connection.put_bytes(flow_control)
delta = 'GET /echo HTTP/1.1\r\n\r\n'
add_channel_request = _create_add_channel_request_frame(
channel_id=3, encoding=1, encoded_handshake=delta)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=3,
replenished_quota=6)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Hello'))
request.connection.put_bytes(
_create_logical_frame(channel_id=3, message='World'))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=3, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
self.assertEqual([], dispatcher.channel_events[1].messages)
self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
self.assertEqual(['World'], dispatcher.channel_events[3].messages)
# Channel 2
messages = request.connection.get_written_messages(2)
self.assertEqual(1, len(messages))
self.assertEqual('Hello', messages[0])
# Channel 3
messages = request.connection.get_written_messages(3)
self.assertEqual(1, len(messages))
self.assertEqual('World', messages[0])
# Handshake base should be updated.
self.assertEqual(
'x-foo',
mux_handler._handshake_base._headers['Sec-WebSocket-Protocol'])
def test_add_channel_delta_remove_header(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
# Override handshake delta base.
encoded_handshake = (
'GET /echo HTTP/1.1\r\n'
'Host: server.example.com\r\n'
'Sec-WebSocket-Protocol: x-foo\r\n'
'Connection: Upgrade\r\n'
'Origin: http://example.com\r\n'
'\r\n')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=6)
request.connection.put_bytes(flow_control)
# Remove Sec-WebSocket-Protocol header.
delta = ('GET /echo HTTP/1.1\r\n'
'Sec-WebSocket-Protocol:'
'\r\n')
add_channel_request = _create_add_channel_request_frame(
channel_id=3, encoding=1, encoded_handshake=delta)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=3,
replenished_quota=6)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Hello'))
request.connection.put_bytes(
_create_logical_frame(channel_id=3, message='World'))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=3, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
self.assertEqual([], dispatcher.channel_events[1].messages)
self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
self.assertEqual(['World'], dispatcher.channel_events[3].messages)
# Channel 2
messages = request.connection.get_written_messages(2)
self.assertEqual(1, len(messages))
self.assertEqual('Hello', messages[0])
# Channel 3
messages = request.connection.get_written_messages(3)
self.assertEqual(1, len(messages))
self.assertEqual('World', messages[0])
self.assertEqual(
'x-foo',
dispatcher.channel_events[2].request.ws_protocol)
self.assertEqual(
None,
dispatcher.channel_events[3].request.ws_protocol)
def test_add_channel_delta_encoding_permessage_compress(self):
# Enable permessage compress extension on the implicitly opened channel.
extensions = common.parse_extensions(
'%s; method=deflate' % common.PERMESSAGE_COMPRESSION_EXTENSION)
request = _create_mock_request(
logical_channel_extensions=extensions)
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
delta = 'GET /echo HTTP/1.1\r\n\r\n'
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=1, encoded_handshake=delta)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=20)
request.connection.put_bytes(flow_control)
# Send compressed 'Hello' on logical channel 1 and 2.
compress = zlib.compressobj(
zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
compressed_hello = compress.compress('Hello')
compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
compressed_hello = compressed_hello[:-4]
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message=compressed_hello,
rsv1=True))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message=compressed_hello,
rsv1=True))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
self.assertEqual(['Hello'], dispatcher.channel_events[1].messages)
self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
# Written 'Hello's should be compressed.
messages = request.connection.get_written_messages(1)
self.assertEqual(1, len(messages))
self.assertEqual(compressed_hello, messages[0])
messages = request.connection.get_written_messages(2)
self.assertEqual(1, len(messages))
self.assertEqual(compressed_hello, messages[0])
def test_add_channel_delta_encoding_remove_extensions(self):
# Enable permessage compress extension on the implicitly opened channel.
extensions = common.parse_extensions(
'%s; method=deflate' % common.PERMESSAGE_COMPRESSION_EXTENSION)
request = _create_mock_request(
logical_channel_extensions=extensions)
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
# Remove permessage compress extension.
delta = ('GET /echo HTTP/1.1\r\n'
'Sec-WebSocket-Extensions:\r\n'
'\r\n')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=1, encoded_handshake=delta)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=20)
request.connection.put_bytes(flow_control)
# Send compressed message on logical channel 2. The message should
# be rejected (since rsv1 is set).
compress = zlib.compressobj(
zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
compressed_hello = compress.compress('Hello')
compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
compressed_hello = compressed_hello[:-4]
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message=compressed_hello,
rsv1=True))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
drop_channel = next(
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(mux._DROP_CODE_NORMAL_CLOSURE, drop_channel.drop_code)
self.assertEqual(2, drop_channel.channel_id)
# UnsupportedFrameException should be raised on logical channel 2.
self.assertTrue(isinstance(dispatcher.channel_events[2].exception,
UnsupportedFrameException))
def test_add_channel_invalid_encoding(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=3,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
self.assertTrue(mux_handler.wait_until_done(timeout=2))
drop_channel = next(
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(mux._DROP_CODE_UNKNOWN_REQUEST_ENCODING,
drop_channel.drop_code)
self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
request.connection.server_close_code)
def test_add_channel_incomplete_handshake(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
incomplete_encoded_handshake = 'GET /echo HTTP/1.1'
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=incomplete_encoded_handshake)
request.connection.put_bytes(add_channel_request)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
self.assertTrue(1 in dispatcher.channel_events)
self.assertTrue(not 2 in dispatcher.channel_events)
def test_add_channel_duplicate_channel_id(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
self.assertTrue(mux_handler.wait_until_done(timeout=2))
drop_channel = next(
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(mux._DROP_CODE_CHANNEL_ALREADY_EXISTS,
drop_channel.drop_code)
self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
request.connection.server_close_code)
def test_receive_drop_channel(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
drop_channel = _create_drop_channel_frame(channel_id=2)
request.connection.put_bytes(drop_channel)
# Terminate implicitly opened channel.
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
exception = dispatcher.channel_events[2].exception
self.assertTrue(exception.__class__ == ConnectionTerminatedException)
def test_receive_ping_frame(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=13)
request.connection.put_bytes(flow_control)
ping_frame = _create_logical_frame(channel_id=2,
message='Hello World!',
opcode=common.OPCODE_PING)
request.connection.put_bytes(ping_frame)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
messages = request.connection.get_written_control_messages(2)
self.assertEqual(common.OPCODE_PONG, messages[0]['opcode'])
self.assertEqual('Hello World!', messages[0]['message'])
def test_receive_fragmented_ping(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=13)
request.connection.put_bytes(flow_control)
# Send a ping with message 'Hello world!' in two fragmented frames.
ping_frame1 = _create_logical_frame(channel_id=2,
message='Hello ',
fin=False,
opcode=common.OPCODE_PING)
request.connection.put_bytes(ping_frame1)
ping_frame2 = _create_logical_frame(channel_id=2,
message='World!',
fin=True,
opcode=common.OPCODE_CONTINUATION)
request.connection.put_bytes(ping_frame2)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
messages = request.connection.get_written_control_messages(2)
self.assertEqual(common.OPCODE_PONG, messages[0]['opcode'])
self.assertEqual('Hello World!', messages[0]['message'])
def test_receive_fragmented_ping_while_receiving_fragmented_message(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=19)
request.connection.put_bytes(flow_control)
# Send a fragmented frame of message 'Hello '.
hello = _create_logical_frame(channel_id=2,
message='Hello ',
fin=False)
request.connection.put_bytes(hello)
# Before sending the last fragmented frame of the message, send a
# fragmented ping.
ping1 = _create_logical_frame(channel_id=2,
message='Pi',
fin=False,
opcode=common.OPCODE_PING)
request.connection.put_bytes(ping1)
ping2 = _create_logical_frame(channel_id=2,
message='ng!',
fin=True,
opcode=common.OPCODE_CONTINUATION)
request.connection.put_bytes(ping2)
# Send the last fragmented frame of the message.
world = _create_logical_frame(channel_id=2,
message='World!',
fin=True,
opcode=common.OPCODE_CONTINUATION)
request.connection.put_bytes(world)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
messages = request.connection.get_written_messages(2)
self.assertEqual(['Hello World!'], messages)
control_messages = request.connection.get_written_control_messages(2)
self.assertEqual(common.OPCODE_PONG, control_messages[0]['opcode'])
self.assertEqual('Ping!', control_messages[0]['message'])
def test_receive_two_ping_while_receiving_fragmented_message(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=25)
request.connection.put_bytes(flow_control)
# Send a fragmented frame of message 'Hello '.
hello = _create_logical_frame(channel_id=2,
message='Hello ',
fin=False)
request.connection.put_bytes(hello)
# Before sending the last fragmented frame of the message, send a
# fragmented ping and a non-fragmented ping.
ping1 = _create_logical_frame(channel_id=2,
message='Pi',
fin=False,
opcode=common.OPCODE_PING)
request.connection.put_bytes(ping1)
ping2 = _create_logical_frame(channel_id=2,
message='ng!',
fin=True,
opcode=common.OPCODE_CONTINUATION)
request.connection.put_bytes(ping2)
ping3 = _create_logical_frame(channel_id=2,
message='Pong!',
fin=True,
opcode=common.OPCODE_PING)
request.connection.put_bytes(ping3)
# Send the last fragmented frame of the message.
world = _create_logical_frame(channel_id=2,
message='World!',
fin=True,
opcode=common.OPCODE_CONTINUATION)
request.connection.put_bytes(world)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
messages = request.connection.get_written_messages(2)
self.assertEqual(['Hello World!'], messages)
control_messages = request.connection.get_written_control_messages(2)
self.assertEqual(common.OPCODE_PONG, control_messages[0]['opcode'])
self.assertEqual('Ping!', control_messages[0]['message'])
self.assertEqual(common.OPCODE_PONG, control_messages[1]['opcode'])
self.assertEqual('Pong!', control_messages[1]['message'])
def test_receive_message_while_receiving_fragmented_ping(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=19)
request.connection.put_bytes(flow_control)
# Send a fragmented ping.
ping1 = _create_logical_frame(channel_id=2,
message='Pi',
fin=False,
opcode=common.OPCODE_PING)
request.connection.put_bytes(ping1)
# Before sending the last fragmented ping, send a message.
# The logical channel (2) should be dropped.
message = _create_logical_frame(channel_id=2,
message='Hello world!',
fin=True)
request.connection.put_bytes(message)
# Send the last fragmented frame of the message.
ping2 = _create_logical_frame(channel_id=2,
message='ng!',
fin=True,
opcode=common.OPCODE_CONTINUATION)
request.connection.put_bytes(ping2)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
drop_channel = next(
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(2, drop_channel.channel_id)
# No message should be sent on channel 2.
self.assertRaises(KeyError,
request.connection.get_written_messages,
2)
self.assertRaises(KeyError,
request.connection.get_written_control_messages,
2)
def test_send_ping(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/ping')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=6)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
messages = request.connection.get_written_control_messages(2)
self.assertEqual(common.OPCODE_PING, messages[0]['opcode'])
self.assertEqual('Ping!', messages[0]['message'])
def test_send_fragmented_ping(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/ping')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
# Replenish 3 bytes. This isn't enough to send the whole ping frame
# because the frame will have 5 bytes message('Ping!'). The frame
# should be fragmented.
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=3)
request.connection.put_bytes(flow_control)
# Wait until the worker is blocked due to send quota shortage.
time.sleep(1)
# Replenish remaining 2 + 1 bytes (including extra cost).
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=3)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
messages = request.connection.get_written_control_messages(2)
self.assertEqual(common.OPCODE_PING, messages[0]['opcode'])
self.assertEqual('Ping!', messages[0]['message'])
def test_send_fragmented_ping_while_sending_fragmented_message(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(
path='/ping_while_hello_world')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
# Application will send:
# - text message 'Hello ' with fin=0
# - ping with 'Ping!' message
# - text message 'World!' with fin=1
# Replenish (6 + 1) + (2 + 1) bytes so that the ping will be
# fragmented on the logical channel.
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=10)
request.connection.put_bytes(flow_control)
time.sleep(1)
# Replenish remaining 3 + 6 bytes.
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=9)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
messages = request.connection.get_written_messages(2)
self.assertEqual(['Hello World!'], messages)
control_messages = request.connection.get_written_control_messages(2)
self.assertEqual(common.OPCODE_PING, control_messages[0]['opcode'])
self.assertEqual('Ping!', control_messages[0]['message'])
def test_send_fragmented_two_ping_while_sending_fragmented_message(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(
path='/two_ping_while_hello_world')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
# Application will send:
# - text message 'Hello ' with fin=0
# - ping with 'Ping!' message
# - ping with 'Pong!' message
# - text message 'World!' with fin=1
# Replenish (6 + 1) + (2 + 1) bytes so that the first ping will be
# fragmented on the logical channel.
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=10)
request.connection.put_bytes(flow_control)
time.sleep(1)
# Replenish remaining 3 + (5 + 1) + 6 bytes. The second ping won't
# be fragmented on the logical channel.
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=15)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
messages = request.connection.get_written_messages(2)
self.assertEqual(['Hello World!'], messages)
control_messages = request.connection.get_written_control_messages(2)
self.assertEqual(common.OPCODE_PING, control_messages[0]['opcode'])
self.assertEqual('Ping!', control_messages[0]['message'])
self.assertEqual(common.OPCODE_PING, control_messages[1]['opcode'])
self.assertEqual('Pong!', control_messages[1]['message'])
def test_send_drop_channel(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
# DropChannel for channel id 1 which doesn't have reason.
frame = create_binary_frame('\x00\x60\x01\x00', mask=True)
request.connection.put_bytes(frame)
self.assertTrue(mux_handler.wait_until_done(timeout=2))
drop_channel = next(
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(mux._DROP_CODE_ACKNOWLEDGED,
drop_channel.drop_code)
self.assertEqual(1, drop_channel.channel_id)
def test_two_flow_control(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
# Replenish 5 bytes.
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=5)
request.connection.put_bytes(flow_control)
# Send 10 bytes. The server will try echo back 10 bytes.
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='HelloWorld'))
# Replenish 5 + 1 (per-message extra cost) bytes.
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=6)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
messages = request.connection.get_written_messages(2)
self.assertEqual(['HelloWorld'], messages)
received_flow_controls = [
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_FLOW_CONTROL and b.channel_id == 2]
# Replenishment for 'HelloWorld' + 1
self.assertEqual(11, received_flow_controls[0].send_quota)
# Replenishment for 'Goodbye' + 1
self.assertEqual(8, received_flow_controls[1].send_quota)
def test_no_send_quota_on_server(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='HelloWorld'))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
# Just wait for 1 sec so that the server attempts to echo back
# 'HelloWorld'.
self.assertFalse(mux_handler.wait_until_done(timeout=1))
# No message should be sent on channel 2.
self.assertRaises(KeyError,
request.connection.get_written_messages,
2)
def test_no_send_quota_on_server_for_permessage_extra_cost(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=6)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Hello'))
# Replenish only len('World') bytes.
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=5)
request.connection.put_bytes(flow_control)
# Server should not callback for this message.
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='World'))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
# Just wait for 1 sec so that the server attempts to echo back
# 'World'.
self.assertFalse(mux_handler.wait_until_done(timeout=1))
# Only one message should be sent on channel 2.
messages = request.connection.get_written_messages(2)
self.assertEqual(['Hello'], messages)
def test_quota_violation_by_client(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, 0)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='HelloWorld'))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
control_blocks = request.connection.get_written_control_blocks()
self.assertEqual(5, len(control_blocks))
drop_channel = next(
b for b in control_blocks
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(mux._DROP_CODE_SEND_QUOTA_VIOLATION,
drop_channel.drop_code)
def test_consume_quota_empty_message(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
# Client has 1 byte quota.
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, 1)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=2)
request.connection.put_bytes(flow_control)
# Send an empty message. Pywebsocket always replenishes 1 byte quota
# for empty message
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message=''))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
# This message violates quota on channel id 2.
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
self.assertEqual(1, len(dispatcher.channel_events[2].messages))
self.assertEqual('', dispatcher.channel_events[2].messages[0])
received_flow_controls = [
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_FLOW_CONTROL and b.channel_id == 2]
self.assertEqual(1, len(received_flow_controls))
self.assertEqual(1, received_flow_controls[0].send_quota)
drop_channel = next(
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(2, drop_channel.channel_id)
self.assertEqual(mux._DROP_CODE_SEND_QUOTA_VIOLATION,
drop_channel.drop_code)
def test_consume_quota_fragmented_message(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
# Client has len('Hello') + len('Goodbye') + 2 bytes quota.
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, 14)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=6)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='He', fin=False,
opcode=common.OPCODE_TEXT))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='llo', fin=True,
opcode=common.OPCODE_CONTINUATION))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
messages = request.connection.get_written_messages(2)
self.assertEqual(['Hello'], messages)
def test_fragmented_control_message(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/ping')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
# Replenish total 6 bytes in 3 FlowControls.
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=1)
request.connection.put_bytes(flow_control)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=2)
request.connection.put_bytes(flow_control)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=3)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
messages = request.connection.get_written_control_messages(2)
self.assertEqual(common.OPCODE_PING, messages[0]['opcode'])
self.assertEqual('Ping!', messages[0]['message'])
def test_channel_slot_violation_by_client(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(slots=1,
send_quota=mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=6)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Hello'))
# This request should be rejected.
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=3, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=3,
replenished_quota=6)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=3, message='Hello'))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
self.assertEqual([], dispatcher.channel_events[1].messages)
self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
self.assertFalse(dispatcher.channel_events.has_key(3))
drop_channel = next(
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(3, drop_channel.channel_id)
self.assertEqual(mux._DROP_CODE_NEW_CHANNEL_SLOT_VIOLATION,
drop_channel.drop_code)
def test_quota_overflow_by_client(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(slots=1,
send_quota=mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
# Replenish 0x7FFFFFFFFFFFFFFF bytes twice.
flow_control = _create_flow_control_frame(
channel_id=2,
replenished_quota=0x7FFFFFFFFFFFFFFF)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(flow_control)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
drop_channel = next(
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(2, drop_channel.channel_id)
self.assertEqual(mux._DROP_CODE_SEND_QUOTA_OVERFLOW,
drop_channel.drop_code)
def test_invalid_encapsulated_message(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
first_byte = (mux._MUX_OPCODE_ADD_CHANNEL_REQUEST << 5)
block = (chr(first_byte) +
mux._encode_channel_id(1) +
mux._encode_number(0))
payload = mux._encode_channel_id(mux._CONTROL_CHANNEL_ID) + block
text_frame = create_binary_frame(payload, opcode=common.OPCODE_TEXT,
mask=True)
request.connection.put_bytes(text_frame)
self.assertTrue(mux_handler.wait_until_done(timeout=2))
drop_channel = next(
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(mux._DROP_CODE_INVALID_ENCAPSULATING_MESSAGE,
drop_channel.drop_code)
self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
request.connection.server_close_code)
def test_channel_id_truncated(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
# The last byte of the channel id is missing.
frame = create_binary_frame('\x80', mask=True)
request.connection.put_bytes(frame)
self.assertTrue(mux_handler.wait_until_done(timeout=2))
drop_channel = next(
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(mux._DROP_CODE_CHANNEL_ID_TRUNCATED,
drop_channel.drop_code)
self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
request.connection.server_close_code)
def test_inner_frame_truncated(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
# Just contain channel id 1.
frame = create_binary_frame('\x01', mask=True)
request.connection.put_bytes(frame)
self.assertTrue(mux_handler.wait_until_done(timeout=2))
drop_channel = next(
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(mux._DROP_CODE_ENCAPSULATED_FRAME_IS_TRUNCATED,
drop_channel.drop_code)
self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
request.connection.server_close_code)
def test_unknown_mux_opcode(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
# Undefined opcode 5
frame = create_binary_frame('\x00\xa0', mask=True)
request.connection.put_bytes(frame)
self.assertTrue(mux_handler.wait_until_done(timeout=2))
drop_channel = next(
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(mux._DROP_CODE_UNKNOWN_MUX_OPCODE,
drop_channel.drop_code)
self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
request.connection.server_close_code)
def test_invalid_mux_control_block(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
# DropChannel contains 1 byte reason
frame = create_binary_frame('\x00\x60\x00\x01\x00', mask=True)
request.connection.put_bytes(frame)
self.assertTrue(mux_handler.wait_until_done(timeout=2))
drop_channel = next(
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
self.assertEqual(mux._DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
drop_channel.drop_code)
self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
request.connection.server_close_code)
def test_permessage_compress(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
# Enable permessage compress extension on logical channel 2.
extensions = '%s; method=deflate' % (
common.PERMESSAGE_COMPRESSION_EXTENSION)
encoded_handshake = _create_request_header(path='/echo',
extensions=extensions)
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
flow_control = _create_flow_control_frame(channel_id=2,
replenished_quota=20)
request.connection.put_bytes(flow_control)
# Send compressed 'Hello' twice.
compress = zlib.compressobj(
zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
compressed_hello1 = compress.compress('Hello')
compressed_hello1 += compress.flush(zlib.Z_SYNC_FLUSH)
compressed_hello1 = compressed_hello1[:-4]
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message=compressed_hello1,
rsv1=True))
compressed_hello2 = compress.compress('Hello')
compressed_hello2 += compress.flush(zlib.Z_SYNC_FLUSH)
compressed_hello2 = compressed_hello2[:-4]
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message=compressed_hello2,
rsv1=True))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
request.connection.put_bytes(
_create_logical_frame(channel_id=2, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
self.assertEqual(['Hello', 'Hello'],
dispatcher.channel_events[2].messages)
# Written 'Hello's should be compressed.
messages = request.connection.get_written_messages(2)
self.assertEqual(2, len(messages))
self.assertEqual(compressed_hello1, messages[0])
self.assertEqual(compressed_hello2, messages[1])
def test_permessage_compress_fragmented_message(self):
extensions = common.parse_extensions(
'%s; method=deflate' % common.PERMESSAGE_COMPRESSION_EXTENSION)
request = _create_mock_request(
logical_channel_extensions=extensions)
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
# Send compressed 'HelloHelloHello' as fragmented message.
compress = zlib.compressobj(
zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
compressed_hello = compress.compress('HelloHelloHello')
compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
compressed_hello = compressed_hello[:-4]
m = len(compressed_hello) / 2
request.connection.put_bytes(
_create_logical_frame(channel_id=1,
message=compressed_hello[:m],
fin=False, rsv1=True,
opcode=common.OPCODE_TEXT))
request.connection.put_bytes(
_create_logical_frame(channel_id=1,
message=compressed_hello[m:],
fin=True, rsv1=False,
opcode=common.OPCODE_CONTINUATION))
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
self.assertEqual(['HelloHelloHello'],
dispatcher.channel_events[1].messages)
messages = request.connection.get_written_messages(1)
self.assertEqual(1, len(messages))
self.assertEqual(compressed_hello, messages[0])
def test_receive_bad_fragmented_message(self):
request = _create_mock_request()
dispatcher = _MuxMockDispatcher()
mux_handler = mux._MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
mux._INITIAL_QUOTA_FOR_CLIENT)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=2, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
# Send a frame with fin=False, and then send a frame with
# opcode=TEXT (not CONTINUATION). Logical channel 2 should be dropped.
frame1 = _create_logical_frame(channel_id=2,
message='Hello ',
fin=False,
opcode=common.OPCODE_TEXT)
request.connection.put_bytes(frame1)
frame2 = _create_logical_frame(channel_id=2,
message='World!',
fin=True,
opcode=common.OPCODE_TEXT)
request.connection.put_bytes(frame2)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=3, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
# Send a frame with opcode=CONTINUATION without a preceding frame
# the fin of which is not set. Logical channel 3 should be dropped.
frame3 = _create_logical_frame(channel_id=3,
message='Hello',
fin=True,
opcode=common.OPCODE_CONTINUATION)
request.connection.put_bytes(frame3)
encoded_handshake = _create_request_header(path='/echo')
add_channel_request = _create_add_channel_request_frame(
channel_id=4, encoding=0,
encoded_handshake=encoded_handshake)
request.connection.put_bytes(add_channel_request)
# Send a frame with opcode=PING and fin=False, and then send a frame
# with opcode=TEXT (not CONTINUATION). Logical channel 4 should be
# dropped.
frame4 = _create_logical_frame(channel_id=4,
message='Ping',
fin=False,
opcode=common.OPCODE_PING)
request.connection.put_bytes(frame4)
frame5 = _create_logical_frame(channel_id=4,
message='Hello',
fin=True,
opcode=common.OPCODE_TEXT)
request.connection.put_bytes(frame5)
request.connection.put_bytes(
_create_logical_frame(channel_id=1, message='Goodbye'))
self.assertTrue(mux_handler.wait_until_done(timeout=2))
drop_channels = [
b for b in request.connection.get_written_control_blocks()
if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL]
self.assertEqual(3, len(drop_channels))
for d in drop_channels:
self.assertEqual(mux._DROP_CODE_BAD_FRAGMENTATION,
d.drop_code)
if __name__ == '__main__':
unittest.main()
# vi:sts=4 sw=4 et