blob: 6c87138230b303bf38dc4cf84c1290a88f7bd8d2 [file] [log] [blame]
#!/usr/bin/env python
#
# Copyright 2011, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests for handshake module."""
import unittest
import set_sys_path # Update sys.path to locate mod_pywebsocket module.
from mod_pywebsocket import common
from mod_pywebsocket.handshake._base import AbortedByUserException
from mod_pywebsocket.handshake._base import HandshakeException
from mod_pywebsocket.handshake._base import VersionException
from mod_pywebsocket.handshake.hybi import Handshaker
import mock
class RequestDefinition(object):
"""A class for holding data for constructing opening handshake strings for
testing the opening handshake processor.
"""
def __init__(self, method, uri, headers):
self.method = method
self.uri = uri
self.headers = headers
def _create_good_request_def():
return RequestDefinition(
'GET', '/demo',
{'Host': 'server.example.com',
'Upgrade': 'websocket',
'Connection': 'Upgrade',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
'Sec-WebSocket-Version': '13',
'Origin': 'http://example.com'})
def _create_request(request_def):
conn = mock.MockConn('')
return mock.MockRequest(
method=request_def.method,
uri=request_def.uri,
headers_in=request_def.headers,
connection=conn)
def _create_handshaker(request):
handshaker = Handshaker(request, mock.MockDispatcher())
return handshaker
class SubprotocolChoosingDispatcher(object):
"""A dispatcher for testing. This dispatcher sets the i-th subprotocol
of requested ones to ws_protocol where i is given on construction as index
argument. If index is negative, default_value will be set to ws_protocol.
"""
def __init__(self, index, default_value=None):
self.index = index
self.default_value = default_value
def do_extra_handshake(self, conn_context):
if self.index >= 0:
conn_context.ws_protocol = conn_context.ws_requested_protocols[
self.index]
else:
conn_context.ws_protocol = self.default_value
def transfer_data(self, conn_context):
pass
class HandshakeAbortedException(Exception):
pass
class AbortingDispatcher(object):
"""A dispatcher for testing. This dispatcher raises an exception in
do_extra_handshake to reject the request.
"""
def do_extra_handshake(self, conn_context):
raise HandshakeAbortedException('An exception to reject the request')
def transfer_data(self, conn_context):
pass
class AbortedByUserDispatcher(object):
"""A dispatcher for testing. This dispatcher raises an
AbortedByUserException in do_extra_handshake to reject the request.
"""
def do_extra_handshake(self, conn_context):
raise AbortedByUserException('An AbortedByUserException to reject the '
'request')
def transfer_data(self, conn_context):
pass
_EXPECTED_RESPONSE = (
'HTTP/1.1 101 Switching Protocols\r\n'
'Upgrade: websocket\r\n'
'Connection: Upgrade\r\n'
'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n')
class HandshakerTest(unittest.TestCase):
"""A unittest for draft-ietf-hybi-thewebsocketprotocol-06 and later
handshake processor.
"""
def test_do_handshake(self):
request = _create_request(_create_good_request_def())
dispatcher = mock.MockDispatcher()
handshaker = Handshaker(request, dispatcher)
handshaker.do_handshake()
self.assertTrue(dispatcher.do_extra_handshake_called)
self.assertEqual(
_EXPECTED_RESPONSE, request.connection.written_data())
self.assertEqual('/demo', request.ws_resource)
self.assertEqual('http://example.com', request.ws_origin)
self.assertEqual(None, request.ws_protocol)
self.assertEqual(None, request.ws_extensions)
self.assertEqual(common.VERSION_HYBI_LATEST, request.ws_version)
def test_do_handshake_with_extra_headers(self):
request_def = _create_good_request_def()
# Add headers not related to WebSocket opening handshake.
request_def.headers['FooKey'] = 'BarValue'
request_def.headers['EmptyKey'] = ''
request = _create_request(request_def)
handshaker = _create_handshaker(request)
handshaker.do_handshake()
self.assertEqual(
_EXPECTED_RESPONSE, request.connection.written_data())
def test_do_handshake_with_capitalized_value(self):
request_def = _create_good_request_def()
request_def.headers['upgrade'] = 'WEBSOCKET'
request = _create_request(request_def)
handshaker = _create_handshaker(request)
handshaker.do_handshake()
self.assertEqual(
_EXPECTED_RESPONSE, request.connection.written_data())
request_def = _create_good_request_def()
request_def.headers['Connection'] = 'UPGRADE'
request = _create_request(request_def)
handshaker = _create_handshaker(request)
handshaker.do_handshake()
self.assertEqual(
_EXPECTED_RESPONSE, request.connection.written_data())
def test_do_handshake_with_multiple_connection_values(self):
request_def = _create_good_request_def()
request_def.headers['Connection'] = 'Upgrade, keep-alive, , '
request = _create_request(request_def)
handshaker = _create_handshaker(request)
handshaker.do_handshake()
self.assertEqual(
_EXPECTED_RESPONSE, request.connection.written_data())
def test_aborting_handshake(self):
handshaker = Handshaker(
_create_request(_create_good_request_def()),
AbortingDispatcher())
# do_extra_handshake raises an exception. Check that it's not caught by
# do_handshake.
self.assertRaises(HandshakeAbortedException, handshaker.do_handshake)
def test_do_handshake_with_protocol(self):
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Protocol'] = 'chat, superchat'
request = _create_request(request_def)
handshaker = Handshaker(request, SubprotocolChoosingDispatcher(0))
handshaker.do_handshake()
EXPECTED_RESPONSE = (
'HTTP/1.1 101 Switching Protocols\r\n'
'Upgrade: websocket\r\n'
'Connection: Upgrade\r\n'
'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n'
'Sec-WebSocket-Protocol: chat\r\n\r\n')
self.assertEqual(EXPECTED_RESPONSE, request.connection.written_data())
self.assertEqual('chat', request.ws_protocol)
def test_do_handshake_protocol_not_in_request_but_in_response(self):
request_def = _create_good_request_def()
request = _create_request(request_def)
handshaker = Handshaker(
request, SubprotocolChoosingDispatcher(-1, 'foobar'))
# No request has been made but ws_protocol is set. HandshakeException
# must be raised.
self.assertRaises(HandshakeException, handshaker.do_handshake)
def test_do_handshake_with_protocol_no_protocol_selection(self):
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Protocol'] = 'chat, superchat'
request = _create_request(request_def)
handshaker = _create_handshaker(request)
# ws_protocol is not set. HandshakeException must be raised.
self.assertRaises(HandshakeException, handshaker.do_handshake)
def test_do_handshake_with_extensions(self):
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Extensions'] = (
'permessage-compress; method=deflate, unknown')
EXPECTED_RESPONSE = (
'HTTP/1.1 101 Switching Protocols\r\n'
'Upgrade: websocket\r\n'
'Connection: Upgrade\r\n'
'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n'
'Sec-WebSocket-Extensions: permessage-compress; method=deflate\r\n'
'\r\n')
request = _create_request(request_def)
handshaker = _create_handshaker(request)
handshaker.do_handshake()
self.assertEqual(EXPECTED_RESPONSE, request.connection.written_data())
self.assertEqual(1, len(request.ws_extensions))
extension = request.ws_extensions[0]
self.assertEqual(common.PERMESSAGE_COMPRESSION_EXTENSION,
extension.name())
self.assertEqual(['method'], extension.get_parameter_names())
self.assertEqual('deflate', extension.get_parameter_value('method'))
self.assertEqual(1, len(request.ws_extension_processors))
self.assertEqual(common.PERMESSAGE_COMPRESSION_EXTENSION,
request.ws_extension_processors[0].name())
def test_do_handshake_with_permessage_compress(self):
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Extensions'] = (
'permessage-compress; method=deflate')
request = _create_request(request_def)
handshaker = _create_handshaker(request)
handshaker.do_handshake()
self.assertEqual(1, len(request.ws_extensions))
self.assertEqual(common.PERMESSAGE_COMPRESSION_EXTENSION,
request.ws_extensions[0].name())
self.assertEqual(1, len(request.ws_extension_processors))
self.assertEqual(common.PERMESSAGE_COMPRESSION_EXTENSION,
request.ws_extension_processors[0].name())
def test_do_handshake_with_quoted_extensions(self):
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Extensions'] = (
'permessage-compress; method=deflate, , '
'unknown; e = "mc^2"; ma="\r\n \\\rf "; pv=nrt')
request = _create_request(request_def)
handshaker = _create_handshaker(request)
handshaker.do_handshake()
self.assertEqual(2, len(request.ws_requested_extensions))
first_extension = request.ws_requested_extensions[0]
self.assertEqual('permessage-compress', first_extension.name())
self.assertEqual(['method'], first_extension.get_parameter_names())
self.assertEqual('deflate',
first_extension.get_parameter_value('method'))
second_extension = request.ws_requested_extensions[1]
self.assertEqual('unknown', second_extension.name())
self.assertEqual(
['e', 'ma', 'pv'], second_extension.get_parameter_names())
self.assertEqual('mc^2', second_extension.get_parameter_value('e'))
self.assertEqual(' \rf ', second_extension.get_parameter_value('ma'))
self.assertEqual('nrt', second_extension.get_parameter_value('pv'))
def test_do_handshake_with_optional_headers(self):
request_def = _create_good_request_def()
request_def.headers['EmptyValue'] = ''
request_def.headers['AKey'] = 'AValue'
request = _create_request(request_def)
handshaker = _create_handshaker(request)
handshaker.do_handshake()
self.assertEqual(
'AValue', request.headers_in['AKey'])
self.assertEqual(
'', request.headers_in['EmptyValue'])
def test_abort_extra_handshake(self):
handshaker = Handshaker(
_create_request(_create_good_request_def()),
AbortedByUserDispatcher())
# do_extra_handshake raises an AbortedByUserException. Check that it's
# not caught by do_handshake.
self.assertRaises(AbortedByUserException, handshaker.do_handshake)
def test_do_handshake_with_mux_and_deflate_frame(self):
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Extensions'] = ('%s, %s' % (
common.MUX_EXTENSION,
common.DEFLATE_FRAME_EXTENSION))
request = _create_request(request_def)
handshaker = _create_handshaker(request)
handshaker.do_handshake()
# mux should be rejected.
self.assertEqual(1, len(request.ws_extensions))
self.assertEqual(common.DEFLATE_FRAME_EXTENSION,
request.ws_extensions[0].name())
self.assertEqual(2, len(request.ws_extension_processors))
self.assertEqual(common.MUX_EXTENSION,
request.ws_extension_processors[0].name())
self.assertEqual(common.DEFLATE_FRAME_EXTENSION,
request.ws_extension_processors[1].name())
self.assertFalse(hasattr(request, 'mux_processor'))
def test_do_handshake_with_deflate_frame_and_mux(self):
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Extensions'] = ('%s, %s' % (
common.DEFLATE_FRAME_EXTENSION,
common.MUX_EXTENSION))
request = _create_request(request_def)
handshaker = _create_handshaker(request)
handshaker.do_handshake()
# mux should be rejected.
self.assertEqual(1, len(request.ws_extensions))
first_extension = request.ws_extensions[0]
self.assertEqual(common.DEFLATE_FRAME_EXTENSION,
first_extension.name())
self.assertEqual(2, len(request.ws_extension_processors))
self.assertEqual(common.DEFLATE_FRAME_EXTENSION,
request.ws_extension_processors[0].name())
self.assertEqual(common.MUX_EXTENSION,
request.ws_extension_processors[1].name())
self.assertFalse(hasattr(request, 'mux'))
def test_do_handshake_with_permessage_compress_and_mux(self):
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Extensions'] = (
'%s; method=deflate, %s' % (
common.PERMESSAGE_COMPRESSION_EXTENSION,
common.MUX_EXTENSION))
request = _create_request(request_def)
handshaker = _create_handshaker(request)
handshaker.do_handshake()
self.assertEqual(1, len(request.ws_extensions))
self.assertEqual(common.MUX_EXTENSION,
request.ws_extensions[0].name())
self.assertEqual(2, len(request.ws_extension_processors))
self.assertEqual(common.PERMESSAGE_COMPRESSION_EXTENSION,
request.ws_extension_processors[0].name())
self.assertEqual(common.MUX_EXTENSION,
request.ws_extension_processors[1].name())
self.assertTrue(hasattr(request, 'mux_processor'))
self.assertTrue(request.mux_processor.is_active())
mux_extensions = request.mux_processor.extensions()
self.assertEqual(1, len(mux_extensions))
self.assertEqual(common.PERMESSAGE_COMPRESSION_EXTENSION,
mux_extensions[0].name())
def test_do_handshake_with_mux_and_permessage_compress(self):
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Extensions'] = (
'%s, %s; method=deflate' % (
common.MUX_EXTENSION,
common.PERMESSAGE_COMPRESSION_EXTENSION))
request = _create_request(request_def)
handshaker = _create_handshaker(request)
handshaker.do_handshake()
# mux should be rejected.
self.assertEqual(1, len(request.ws_extensions))
first_extension = request.ws_extensions[0]
self.assertEqual(common.PERMESSAGE_COMPRESSION_EXTENSION,
first_extension.name())
self.assertEqual(2, len(request.ws_extension_processors))
self.assertEqual(common.MUX_EXTENSION,
request.ws_extension_processors[0].name())
self.assertEqual(common.PERMESSAGE_COMPRESSION_EXTENSION,
request.ws_extension_processors[1].name())
self.assertFalse(hasattr(request, 'mux_processor'))
def test_bad_requests(self):
bad_cases = [
('HTTP request',
RequestDefinition(
'GET', '/demo',
{'Host': 'www.google.com',
'User-Agent':
'Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.5;'
' en-US; rv:1.9.1.3) Gecko/20090824 Firefox/3.5.3'
' GTB6 GTBA',
'Accept':
'text/html,application/xhtml+xml,application/xml;q=0.9,'
'*/*;q=0.8',
'Accept-Language': 'en-us,en;q=0.5',
'Accept-Encoding': 'gzip,deflate',
'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
'Keep-Alive': '300',
'Connection': 'keep-alive'}), None, True)]
request_def = _create_good_request_def()
request_def.method = 'POST'
bad_cases.append(('Wrong method', request_def, None, True))
request_def = _create_good_request_def()
del request_def.headers['Host']
bad_cases.append(('Missing Host', request_def, None, True))
request_def = _create_good_request_def()
del request_def.headers['Upgrade']
bad_cases.append(('Missing Upgrade', request_def, None, True))
request_def = _create_good_request_def()
request_def.headers['Upgrade'] = 'nonwebsocket'
bad_cases.append(('Wrong Upgrade', request_def, None, True))
request_def = _create_good_request_def()
del request_def.headers['Connection']
bad_cases.append(('Missing Connection', request_def, None, True))
request_def = _create_good_request_def()
request_def.headers['Connection'] = 'Downgrade'
bad_cases.append(('Wrong Connection', request_def, None, True))
request_def = _create_good_request_def()
del request_def.headers['Sec-WebSocket-Key']
bad_cases.append(('Missing Sec-WebSocket-Key', request_def, 400, True))
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Key'] = (
'dGhlIHNhbXBsZSBub25jZQ==garbage')
bad_cases.append(('Wrong Sec-WebSocket-Key (with garbage on the tail)',
request_def, 400, True))
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Key'] = 'YQ==' # BASE64 of 'a'
bad_cases.append(
('Wrong Sec-WebSocket-Key (decoded value is not 16 octets long)',
request_def, 400, True))
request_def = _create_good_request_def()
# The last character right before == must be any of A, Q, w and g.
request_def.headers['Sec-WebSocket-Key'] = (
'AQIDBAUGBwgJCgsMDQ4PEC==')
bad_cases.append(
('Wrong Sec-WebSocket-Key (padding bits are not zero)',
request_def, 400, True))
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Key'] = (
'dGhlIHNhbXBsZSBub25jZQ==,dGhlIHNhbXBsZSBub25jZQ==')
bad_cases.append(
('Wrong Sec-WebSocket-Key (multiple values)',
request_def, 400, True))
request_def = _create_good_request_def()
del request_def.headers['Sec-WebSocket-Version']
bad_cases.append(('Missing Sec-WebSocket-Version', request_def, None,
True))
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Version'] = '3'
bad_cases.append(('Wrong Sec-WebSocket-Version', request_def, None,
False))
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Version'] = '13, 13'
bad_cases.append(('Wrong Sec-WebSocket-Version (multiple values)',
request_def, 400, True))
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Protocol'] = 'illegal\x09protocol'
bad_cases.append(('Illegal Sec-WebSocket-Protocol',
request_def, 400, True))
request_def = _create_good_request_def()
request_def.headers['Sec-WebSocket-Protocol'] = ''
bad_cases.append(('Empty Sec-WebSocket-Protocol',
request_def, 400, True))
for (case_name, request_def, expected_status,
expect_handshake_exception) in bad_cases:
request = _create_request(request_def)
handshaker = Handshaker(request, mock.MockDispatcher())
try:
handshaker.do_handshake()
self.fail('No exception thrown for \'%s\' case' % case_name)
except HandshakeException, e:
self.assertTrue(expect_handshake_exception)
self.assertEqual(expected_status, e.status)
except VersionException, e:
self.assertFalse(expect_handshake_exception)
if __name__ == '__main__':
unittest.main()
# vi:sts=4 sw=4 et