blob: 76334685b311225e3873b9b828ed031a7e907802 [file] [log] [blame]
# Copyright 2012, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""This file provides classes and helper functions for multiplexing extension.
Specification:
http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-06
"""
import collections
import copy
import email
import email.parser
import logging
import math
import struct
import threading
import traceback
from mod_pywebsocket import common
from mod_pywebsocket import handshake
from mod_pywebsocket import util
from mod_pywebsocket._stream_base import BadOperationException
from mod_pywebsocket._stream_base import ConnectionTerminatedException
from mod_pywebsocket._stream_base import InvalidFrameException
from mod_pywebsocket._stream_hybi import Frame
from mod_pywebsocket._stream_hybi import Stream
from mod_pywebsocket._stream_hybi import StreamOptions
from mod_pywebsocket._stream_hybi import create_binary_frame
from mod_pywebsocket._stream_hybi import create_closing_handshake_body
from mod_pywebsocket._stream_hybi import create_header
from mod_pywebsocket._stream_hybi import create_length_header
from mod_pywebsocket._stream_hybi import parse_frame
from mod_pywebsocket.handshake import hybi
_CONTROL_CHANNEL_ID = 0
_DEFAULT_CHANNEL_ID = 1
_MUX_OPCODE_ADD_CHANNEL_REQUEST = 0
_MUX_OPCODE_ADD_CHANNEL_RESPONSE = 1
_MUX_OPCODE_FLOW_CONTROL = 2
_MUX_OPCODE_DROP_CHANNEL = 3
_MUX_OPCODE_NEW_CHANNEL_SLOT = 4
_MAX_CHANNEL_ID = 2 ** 29 - 1
_INITIAL_NUMBER_OF_CHANNEL_SLOTS = 64
_INITIAL_QUOTA_FOR_CLIENT = 8 * 1024
_HANDSHAKE_ENCODING_IDENTITY = 0
_HANDSHAKE_ENCODING_DELTA = 1
# We need only these status code for now.
_HTTP_BAD_RESPONSE_MESSAGES = {
common.HTTP_STATUS_BAD_REQUEST: 'Bad Request',
}
# DropChannel reason code
# TODO(bashi): Define all reason code defined in -05 draft.
_DROP_CODE_NORMAL_CLOSURE = 1000
_DROP_CODE_INVALID_ENCAPSULATING_MESSAGE = 2001
_DROP_CODE_CHANNEL_ID_TRUNCATED = 2002
_DROP_CODE_ENCAPSULATED_FRAME_IS_TRUNCATED = 2003
_DROP_CODE_UNKNOWN_MUX_OPCODE = 2004
_DROP_CODE_INVALID_MUX_CONTROL_BLOCK = 2005
_DROP_CODE_CHANNEL_ALREADY_EXISTS = 2006
_DROP_CODE_NEW_CHANNEL_SLOT_VIOLATION = 2007
_DROP_CODE_UNKNOWN_REQUEST_ENCODING = 2010
_DROP_CODE_SEND_QUOTA_VIOLATION = 3005
_DROP_CODE_SEND_QUOTA_OVERFLOW = 3006
_DROP_CODE_ACKNOWLEDGED = 3008
_DROP_CODE_BAD_FRAGMENTATION = 3009
class MuxUnexpectedException(Exception):
"""Exception in handling multiplexing extension."""
pass
# Temporary
class MuxNotImplementedException(Exception):
"""Raised when a flow enters unimplemented code path."""
pass
class LogicalConnectionClosedException(Exception):
"""Raised when logical connection is gracefully closed."""
pass
class PhysicalConnectionError(Exception):
"""Raised when there is a physical connection error."""
def __init__(self, drop_code, message=''):
super(PhysicalConnectionError, self).__init__(
'code=%d, message=%r' % (drop_code, message))
self.drop_code = drop_code
self.message = message
class LogicalChannelError(Exception):
"""Raised when there is a logical channel error."""
def __init__(self, channel_id, drop_code, message=''):
super(LogicalChannelError, self).__init__(
'channel_id=%d, code=%d, message=%r' % (
channel_id, drop_code, message))
self.channel_id = channel_id
self.drop_code = drop_code
self.message = message
def _encode_channel_id(channel_id):
if channel_id < 0:
raise ValueError('Channel id %d must not be negative' % channel_id)
if channel_id < 2 ** 7:
return chr(channel_id)
if channel_id < 2 ** 14:
return struct.pack('!H', 0x8000 + channel_id)
if channel_id < 2 ** 21:
first = chr(0xc0 + (channel_id >> 16))
return first + struct.pack('!H', channel_id & 0xffff)
if channel_id < 2 ** 29:
return struct.pack('!L', 0xe0000000 + channel_id)
raise ValueError('Channel id %d is too large' % channel_id)
def _encode_number(number):
return create_length_header(number, False)
def _create_add_channel_response(channel_id, encoded_handshake,
encoding=0, rejected=False):
if encoding != 0 and encoding != 1:
raise ValueError('Invalid encoding %d' % encoding)
first_byte = ((_MUX_OPCODE_ADD_CHANNEL_RESPONSE << 5) |
(rejected << 4) | encoding)
block = (chr(first_byte) +
_encode_channel_id(channel_id) +
_encode_number(len(encoded_handshake)) +
encoded_handshake)
return block
def _create_drop_channel(channel_id, code=None, message=''):
if len(message) > 0 and code is None:
raise ValueError('Code must be specified if message is specified')
first_byte = _MUX_OPCODE_DROP_CHANNEL << 5
block = chr(first_byte) + _encode_channel_id(channel_id)
if code is None:
block += _encode_number(0) # Reason size
else:
reason = struct.pack('!H', code) + message
reason_size = _encode_number(len(reason))
block += reason_size + reason
return block
def _create_flow_control(channel_id, replenished_quota):
first_byte = _MUX_OPCODE_FLOW_CONTROL << 5
block = (chr(first_byte) +
_encode_channel_id(channel_id) +
_encode_number(replenished_quota))
return block
def _create_new_channel_slot(slots, send_quota):
if slots < 0 or send_quota < 0:
raise ValueError('slots and send_quota must be non-negative.')
first_byte = _MUX_OPCODE_NEW_CHANNEL_SLOT << 5
block = (chr(first_byte) +
_encode_number(slots) +
_encode_number(send_quota))
return block
def _create_fallback_new_channel_slot():
first_byte = (_MUX_OPCODE_NEW_CHANNEL_SLOT << 5) | 1 # Set the F flag
block = (chr(first_byte) + _encode_number(0) + _encode_number(0))
return block
def _parse_request_text(request_text):
request_line, header_lines = request_text.split('\r\n', 1)
words = request_line.split(' ')
if len(words) != 3:
raise ValueError('Bad Request-Line syntax %r' % request_line)
[command, path, version] = words
if version != 'HTTP/1.1':
raise ValueError('Bad request version %r' % version)
# email.parser.Parser() parses RFC 2822 (RFC 822) style headers.
# RFC 6455 refers RFC 2616 for handshake parsing, and RFC 2616 refers
# RFC 822.
headers = email.parser.Parser().parsestr(header_lines)
return command, path, version, headers
class _ControlBlock(object):
"""A structure that holds parsing result of multiplexing control block.
Control block specific attributes will be added by _MuxFramePayloadParser.
(e.g. encoded_handshake will be added for AddChannelRequest and
AddChannelResponse)
"""
def __init__(self, opcode):
self.opcode = opcode
class _MuxFramePayloadParser(object):
"""A class that parses multiplexed frame payload."""
def __init__(self, payload):
self._data = payload
self._read_position = 0
self._logger = util.get_class_logger(self)
def read_channel_id(self):
"""Reads channel id.
Raises:
ValueError: when the payload doesn't contain
valid channel id.
"""
remaining_length = len(self._data) - self._read_position
pos = self._read_position
if remaining_length == 0:
raise ValueError('Invalid channel id format')
channel_id = ord(self._data[pos])
channel_id_length = 1
if channel_id & 0xe0 == 0xe0:
if remaining_length < 4:
raise ValueError('Invalid channel id format')
channel_id = struct.unpack('!L',
self._data[pos:pos+4])[0] & 0x1fffffff
channel_id_length = 4
elif channel_id & 0xc0 == 0xc0:
if remaining_length < 3:
raise ValueError('Invalid channel id format')
channel_id = (((channel_id & 0x1f) << 16) +
struct.unpack('!H', self._data[pos+1:pos+3])[0])
channel_id_length = 3
elif channel_id & 0x80 == 0x80:
if remaining_length < 2:
raise ValueError('Invalid channel id format')
channel_id = struct.unpack('!H',
self._data[pos:pos+2])[0] & 0x3fff
channel_id_length = 2
self._read_position += channel_id_length
return channel_id
def read_inner_frame(self):
"""Reads an inner frame.
Raises:
PhysicalConnectionError: when the inner frame is invalid.
"""
if len(self._data) == self._read_position:
raise PhysicalConnectionError(
_DROP_CODE_ENCAPSULATED_FRAME_IS_TRUNCATED)
bits = ord(self._data[self._read_position])
self._read_position += 1
fin = (bits & 0x80) == 0x80
rsv1 = (bits & 0x40) == 0x40
rsv2 = (bits & 0x20) == 0x20
rsv3 = (bits & 0x10) == 0x10
opcode = bits & 0xf
payload = self.remaining_data()
# Consume rest of the message which is payload data of the original
# frame.
self._read_position = len(self._data)
return fin, rsv1, rsv2, rsv3, opcode, payload
def _read_number(self):
if self._read_position + 1 > len(self._data):
raise ValueError(
'Cannot read the first byte of number field')
number = ord(self._data[self._read_position])
if number & 0x80 == 0x80:
raise ValueError(
'The most significant bit of the first byte of number should '
'be unset')
self._read_position += 1
pos = self._read_position
if number == 127:
if pos + 8 > len(self._data):
raise ValueError('Invalid number field')
self._read_position += 8
number = struct.unpack('!Q', self._data[pos:pos+8])[0]
if number > 0x7FFFFFFFFFFFFFFF:
raise ValueError('Encoded number(%d) >= 2^63' % number)
if number <= 0xFFFF:
raise ValueError(
'%d should not be encoded by 9 bytes encoding' % number)
return number
if number == 126:
if pos + 2 > len(self._data):
raise ValueError('Invalid number field')
self._read_position += 2
number = struct.unpack('!H', self._data[pos:pos+2])[0]
if number <= 125:
raise ValueError(
'%d should not be encoded by 3 bytes encoding' % number)
return number
def _read_size_and_contents(self):
"""Reads data that consists of followings:
- the size of the contents encoded the same way as payload length
of the WebSocket Protocol with 1 bit padding at the head.
- the contents.
"""
try:
size = self._read_number()
except ValueError, e:
raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
str(e))
pos = self._read_position
if pos + size > len(self._data):
raise PhysicalConnectionError(
_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
'Cannot read %d bytes data' % size)
self._read_position += size
return self._data[pos:pos+size]
def _read_add_channel_request(self, first_byte, control_block):
reserved = (first_byte >> 2) & 0x7
if reserved != 0:
raise PhysicalConnectionError(
_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
'Reserved bits must be unset')
# Invalid encoding will be handled by MuxHandler.
encoding = first_byte & 0x3
try:
control_block.channel_id = self.read_channel_id()
except ValueError, e:
raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK)
control_block.encoding = encoding
encoded_handshake = self._read_size_and_contents()
control_block.encoded_handshake = encoded_handshake
return control_block
def _read_add_channel_response(self, first_byte, control_block):
reserved = (first_byte >> 2) & 0x3
if reserved != 0:
raise PhysicalConnectionError(
_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
'Reserved bits must be unset')
control_block.accepted = (first_byte >> 4) & 1
control_block.encoding = first_byte & 0x3
try:
control_block.channel_id = self.read_channel_id()
except ValueError, e:
raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK)
control_block.encoded_handshake = self._read_size_and_contents()
return control_block
def _read_flow_control(self, first_byte, control_block):
reserved = first_byte & 0x1f
if reserved != 0:
raise PhysicalConnectionError(
_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
'Reserved bits must be unset')
try:
control_block.channel_id = self.read_channel_id()
control_block.send_quota = self._read_number()
except ValueError, e:
raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
str(e))
return control_block
def _read_drop_channel(self, first_byte, control_block):
reserved = first_byte & 0x1f
if reserved != 0:
raise PhysicalConnectionError(
_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
'Reserved bits must be unset')
try:
control_block.channel_id = self.read_channel_id()
except ValueError, e:
raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK)
reason = self._read_size_and_contents()
if len(reason) == 0:
control_block.drop_code = None
control_block.drop_message = ''
elif len(reason) >= 2:
control_block.drop_code = struct.unpack('!H', reason[:2])[0]
control_block.drop_message = reason[2:]
else:
raise PhysicalConnectionError(
_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
'Received DropChannel that conains only 1-byte reason')
return control_block
def _read_new_channel_slot(self, first_byte, control_block):
reserved = first_byte & 0x1e
if reserved != 0:
raise PhysicalConnectionError(
_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
'Reserved bits must be unset')
control_block.fallback = first_byte & 1
try:
control_block.slots = self._read_number()
control_block.send_quota = self._read_number()
except ValueError, e:
raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
str(e))
return control_block
def read_control_blocks(self):
"""Reads control block(s).
Raises:
PhysicalConnectionError: when the payload contains invalid control
block(s).
StopIteration: when no control blocks left.
"""
while self._read_position < len(self._data):
first_byte = ord(self._data[self._read_position])
self._read_position += 1
opcode = (first_byte >> 5) & 0x7
control_block = _ControlBlock(opcode=opcode)
if opcode == _MUX_OPCODE_ADD_CHANNEL_REQUEST:
yield self._read_add_channel_request(first_byte, control_block)
elif opcode == _MUX_OPCODE_ADD_CHANNEL_RESPONSE:
yield self._read_add_channel_response(
first_byte, control_block)
elif opcode == _MUX_OPCODE_FLOW_CONTROL:
yield self._read_flow_control(first_byte, control_block)
elif opcode == _MUX_OPCODE_DROP_CHANNEL:
yield self._read_drop_channel(first_byte, control_block)
elif opcode == _MUX_OPCODE_NEW_CHANNEL_SLOT:
yield self._read_new_channel_slot(first_byte, control_block)
else:
raise PhysicalConnectionError(
_DROP_CODE_UNKNOWN_MUX_OPCODE,
'Invalid opcode %d' % opcode)
assert self._read_position == len(self._data)
raise StopIteration
def remaining_data(self):
"""Returns remaining data."""
return self._data[self._read_position:]
class _LogicalRequest(object):
"""Mimics mod_python request."""
def __init__(self, channel_id, command, path, protocol, headers,
connection):
"""Constructs an instance.
Args:
channel_id: the channel id of the logical channel.
command: HTTP request command.
path: HTTP request path.
headers: HTTP headers.
connection: _LogicalConnection instance.
"""
self.channel_id = channel_id
self.method = command
self.uri = path
self.protocol = protocol
self.headers_in = headers
self.connection = connection
self.server_terminated = False
self.client_terminated = False
def is_https(self):
"""Mimics request.is_https(). Returns False because this method is
used only by old protocols (hixie and hybi00).
"""
return False
class _LogicalConnection(object):
"""Mimics mod_python mp_conn."""
# For details, see the comment of set_read_state().
STATE_ACTIVE = 1
STATE_GRACEFULLY_CLOSED = 2
STATE_TERMINATED = 3
def __init__(self, mux_handler, channel_id):
"""Constructs an instance.
Args:
mux_handler: _MuxHandler instance.
channel_id: channel id of this connection.
"""
self._mux_handler = mux_handler
self._channel_id = channel_id
self._incoming_data = ''
# - Protects _waiting_write_completion
# - Signals the thread waiting for completion of write by mux handler
self._write_condition = threading.Condition()
self._waiting_write_completion = False
self._read_condition = threading.Condition()
self._read_state = self.STATE_ACTIVE
def get_local_addr(self):
"""Getter to mimic mp_conn.local_addr."""
return self._mux_handler.physical_connection.get_local_addr()
local_addr = property(get_local_addr)
def get_remote_addr(self):
"""Getter to mimic mp_conn.remote_addr."""
return self._mux_handler.physical_connection.get_remote_addr()
remote_addr = property(get_remote_addr)
def get_memorized_lines(self):
"""Gets memorized lines. Not supported."""
raise MuxUnexpectedException('_LogicalConnection does not support '
'get_memorized_lines')
def write(self, data):
"""Writes data. mux_handler sends data asynchronously. The caller will
be suspended until write done.
Args:
data: data to be written.
Raises:
MuxUnexpectedException: when called before finishing the previous
write.
"""
try:
self._write_condition.acquire()
if self._waiting_write_completion:
raise MuxUnexpectedException(
'Logical connection %d is already waiting the completion '
'of write' % self._channel_id)
self._waiting_write_completion = True
self._mux_handler.send_data(self._channel_id, data)
self._write_condition.wait()
# TODO(tyoshino): Raise an exception if woke up by on_writer_done.
finally:
self._write_condition.release()
def write_control_data(self, data):
"""Writes data via the control channel. Don't wait finishing write
because this method can be called by mux dispatcher.
Args:
data: data to be written.
"""
self._mux_handler.send_control_data(data)
def on_write_data_done(self):
"""Called when sending data is completed."""
try:
self._write_condition.acquire()
if not self._waiting_write_completion:
raise MuxUnexpectedException(
'Invalid call of on_write_data_done for logical '
'connection %d' % self._channel_id)
self._waiting_write_completion = False
self._write_condition.notify()
finally:
self._write_condition.release()
def on_writer_done(self):
"""Called by the mux handler when the writer thread has finished."""
try:
self._write_condition.acquire()
self._waiting_write_completion = False
self._write_condition.notify()
finally:
self._write_condition.release()
def append_frame_data(self, frame_data):
"""Appends incoming frame data. Called when mux_handler dispatches
frame data to the corresponding application.
Args:
frame_data: incoming frame data.
"""
self._read_condition.acquire()
self._incoming_data += frame_data
self._read_condition.notify()
self._read_condition.release()
def read(self, length):
"""Reads data. Blocks until enough data has arrived via physical
connection.
Args:
length: length of data to be read.
Raises:
LogicalConnectionClosedException: when closing handshake for this
logical channel has been received.
ConnectionTerminatedException: when the physical connection has
closed, or an error is caused on the reader thread.
"""
self._read_condition.acquire()
while (self._read_state == self.STATE_ACTIVE and
len(self._incoming_data) < length):
self._read_condition.wait()
try:
if self._read_state == self.STATE_GRACEFULLY_CLOSED:
raise LogicalConnectionClosedException(
'Logical channel %d has closed.' % self._channel_id)
elif self._read_state == self.STATE_TERMINATED:
raise ConnectionTerminatedException(
'Receiving %d byte failed. Logical channel (%d) closed' %
(length, self._channel_id))
value = self._incoming_data[:length]
self._incoming_data = self._incoming_data[length:]
finally:
self._read_condition.release()
return value
def set_read_state(self, new_state):
"""Sets the state of this connection. Called when an event for this
connection has occurred.
Args:
new_state: state to be set. new_state must be one of followings:
- STATE_GRACEFULLY_CLOSED: when closing handshake for this
connection has been received.
- STATE_TERMINATED: when the physical connection has closed or
DropChannel of this connection has received.
"""
self._read_condition.acquire()
self._read_state = new_state
self._read_condition.notify()
self._read_condition.release()
class _InnerMessage(object):
"""Holds the result of _InnerMessageBuilder.build().
"""
def __init__(self, opcode, payload):
self.opcode = opcode
self.payload = payload
class _InnerMessageBuilder(object):
"""A class that holds the context of inner message fragmentation and
builds a message from fragmented inner frame(s).
"""
def __init__(self):
self._control_opcode = None
self._pending_control_fragments = []
self._message_opcode = None
self._pending_message_fragments = []
self._frame_handler = self._handle_first
def _handle_first(self, frame):
if frame.opcode == common.OPCODE_CONTINUATION:
raise InvalidFrameException('Sending invalid continuation opcode')
if common.is_control_opcode(frame.opcode):
return self._process_first_fragmented_control(frame)
else:
return self._process_first_fragmented_message(frame)
def _process_first_fragmented_control(self, frame):
self._control_opcode = frame.opcode
self._pending_control_fragments.append(frame.payload)
if not frame.fin:
self._frame_handler = self._handle_fragmented_control
return None
return self._reassemble_fragmented_control()
def _process_first_fragmented_message(self, frame):
self._message_opcode = frame.opcode
self._pending_message_fragments.append(frame.payload)
if not frame.fin:
self._frame_handler = self._handle_fragmented_message
return None
return self._reassemble_fragmented_message()
def _handle_fragmented_control(self, frame):
if frame.opcode != common.OPCODE_CONTINUATION:
raise InvalidFrameException(
'Sending invalid opcode %d while sending fragmented control '
'message' % frame.opcode)
self._pending_control_fragments.append(frame.payload)
if not frame.fin:
return None
return self._reassemble_fragmented_control()
def _reassemble_fragmented_control(self):
opcode = self._control_opcode
payload = ''.join(self._pending_control_fragments)
self._control_opcode = None
self._pending_control_fragments = []
if self._message_opcode is not None:
self._frame_handler = self._handle_fragmented_message
else:
self._frame_handler = self._handle_first
return _InnerMessage(opcode, payload)
def _handle_fragmented_message(self, frame):
# Sender can interleave a control message while sending fragmented
# messages.
if common.is_control_opcode(frame.opcode):
if self._control_opcode is not None:
raise MuxUnexpectedException(
'Should not reach here(Bug in builder)')
return self._process_first_fragmented_control(frame)
if frame.opcode != common.OPCODE_CONTINUATION:
raise InvalidFrameException(
'Sending invalid opcode %d while sending fragmented message' %
frame.opcode)
self._pending_message_fragments.append(frame.payload)
if not frame.fin:
return None
return self._reassemble_fragmented_message()
def _reassemble_fragmented_message(self):
opcode = self._message_opcode
payload = ''.join(self._pending_message_fragments)
self._message_opcode = None
self._pending_message_fragments = []
self._frame_handler = self._handle_first
return _InnerMessage(opcode, payload)
def build(self, frame):
"""Build an inner message. Returns an _InnerMessage instance when
the given frame is the last fragmented frame. Returns None otherwise.
Args:
frame: an inner frame.
Raises:
InvalidFrameException: when received invalid opcode. (e.g.
receiving non continuation data opcode but the fin flag of
the previous inner frame was not set.)
"""
return self._frame_handler(frame)
class _LogicalStream(Stream):
"""Mimics the Stream class. This class interprets multiplexed WebSocket
frames.
"""
def __init__(self, request, stream_options, send_quota, receive_quota):
"""Constructs an instance.
Args:
request: _LogicalRequest instance.
stream_options: StreamOptions instance.
send_quota: Initial send quota.
receive_quota: Initial receive quota.
"""
# Physical stream is responsible for masking.
stream_options.unmask_receive = False
Stream.__init__(self, request, stream_options)
self._send_closed = False
self._send_quota = send_quota
# - Protects _send_closed and _send_quota
# - Signals the thread waiting for send quota replenished
self._send_condition = threading.Condition()
# The opcode of the first frame in messages.
self._message_opcode = common.OPCODE_TEXT
# True when the last message was fragmented.
self._last_message_was_fragmented = False
self._receive_quota = receive_quota
self._write_inner_frame_semaphore = threading.Semaphore()
self._inner_message_builder = _InnerMessageBuilder()
def _create_inner_frame(self, opcode, payload, end=True):
frame = Frame(fin=end, opcode=opcode, payload=payload)
for frame_filter in self._options.outgoing_frame_filters:
frame_filter.filter(frame)
if len(payload) != len(frame.payload):
raise MuxUnexpectedException(
'Mux extension must not be used after extensions which change '
' frame boundary')
first_byte = ((frame.fin << 7) | (frame.rsv1 << 6) |
(frame.rsv2 << 5) | (frame.rsv3 << 4) | frame.opcode)
return chr(first_byte) + frame.payload
def _write_inner_frame(self, opcode, payload, end=True):
payload_length = len(payload)
write_position = 0
try:
# An inner frame will be fragmented if there is no enough send
# quota. This semaphore ensures that fragmented inner frames are
# sent in order on the logical channel.
# Note that frames that come from other logical channels or
# multiplexing control blocks can be inserted between fragmented
# inner frames on the physical channel.
self._write_inner_frame_semaphore.acquire()
# Consume an octet quota when this is the first fragmented frame.
if opcode != common.OPCODE_CONTINUATION:
try:
self._send_condition.acquire()
while (not self._send_closed) and self._send_quota == 0:
self._send_condition.wait()
if self._send_closed:
raise BadOperationException(
'Logical connection %d is closed' %
self._request.channel_id)
self._send_quota -= 1
finally:
self._send_condition.release()
while write_position < payload_length:
try:
self._send_condition.acquire()
while (not self._send_closed) and self._send_quota == 0:
self._logger.debug(
'No quota. Waiting FlowControl message for %d.' %
self._request.channel_id)
self._send_condition.wait()
if self._send_closed:
raise BadOperationException(
'Logical connection %d is closed' %
self.request._channel_id)
remaining = payload_length - write_position
write_length = min(self._send_quota, remaining)
inner_frame_end = (
end and
(write_position + write_length == payload_length))
inner_frame = self._create_inner_frame(
opcode,
payload[write_position:write_position+write_length],
inner_frame_end)
self._send_quota -= write_length
self._logger.debug('Consumed quota=%d, remaining=%d' %
(write_length, self._send_quota))
finally:
self._send_condition.release()
# Writing data will block the worker so we need to release
# _send_condition before writing.
self._logger.debug('Sending inner frame: %r' % inner_frame)
self._request.connection.write(inner_frame)
write_position += write_length
opcode = common.OPCODE_CONTINUATION
except ValueError, e:
raise BadOperationException(e)
finally:
self._write_inner_frame_semaphore.release()
def replenish_send_quota(self, send_quota):
"""Replenish send quota."""
try:
self._send_condition.acquire()
if self._send_quota + send_quota > 0x7FFFFFFFFFFFFFFF:
self._send_quota = 0
raise LogicalChannelError(
self._request.channel_id, _DROP_CODE_SEND_QUOTA_OVERFLOW)
self._send_quota += send_quota
self._logger.debug('Replenished send quota for channel id %d: %d' %
(self._request.channel_id, self._send_quota))
finally:
self._send_condition.notify()
self._send_condition.release()
def consume_receive_quota(self, amount):
"""Consumes receive quota. Returns False on failure."""
if self._receive_quota < amount:
self._logger.debug('Violate quota on channel id %d: %d < %d' %
(self._request.channel_id,
self._receive_quota, amount))
return False
self._receive_quota -= amount
return True
def send_message(self, message, end=True, binary=False):
"""Override Stream.send_message."""
if self._request.server_terminated:
raise BadOperationException(
'Requested send_message after sending out a closing handshake')
if binary and isinstance(message, unicode):
raise BadOperationException(
'Message for binary frame must be instance of str')
if binary:
opcode = common.OPCODE_BINARY
else:
opcode = common.OPCODE_TEXT
message = message.encode('utf-8')
for message_filter in self._options.outgoing_message_filters:
message = message_filter.filter(message, end, binary)
if self._last_message_was_fragmented:
if opcode != self._message_opcode:
raise BadOperationException('Message types are different in '
'frames for the same message')
opcode = common.OPCODE_CONTINUATION
else:
self._message_opcode = opcode
self._write_inner_frame(opcode, message, end)
self._last_message_was_fragmented = not end
def _receive_frame(self):
"""Overrides Stream._receive_frame.
In addition to call Stream._receive_frame, this method adds the amount
of payload to receiving quota and sends FlowControl to the client.
We need to do it here because Stream.receive_message() handles
control frames internally.
"""
opcode, payload, fin, rsv1, rsv2, rsv3 = Stream._receive_frame(self)
amount = len(payload)
# Replenish extra one octet when receiving the first fragmented frame.
if opcode != common.OPCODE_CONTINUATION:
amount += 1
self._receive_quota += amount
frame_data = _create_flow_control(self._request.channel_id,
amount)
self._logger.debug('Sending flow control for %d, replenished=%d' %
(self._request.channel_id, amount))
self._request.connection.write_control_data(frame_data)
return opcode, payload, fin, rsv1, rsv2, rsv3
def _get_message_from_frame(self, frame):
"""Overrides Stream._get_message_from_frame.
"""
try:
inner_message = self._inner_message_builder.build(frame)
except InvalidFrameException:
raise LogicalChannelError(
self._request.channel_id, _DROP_CODE_BAD_FRAGMENTATION)
if inner_message is None:
return None
self._original_opcode = inner_message.opcode
return inner_message.payload
def receive_message(self):
"""Overrides Stream.receive_message."""
# Just call Stream.receive_message(), but catch
# LogicalConnectionClosedException, which is raised when the logical
# connection has closed gracefully.
try:
return Stream.receive_message(self)
except LogicalConnectionClosedException, e:
self._logger.debug('%s', e)
return None
def _send_closing_handshake(self, code, reason):
"""Overrides Stream._send_closing_handshake."""
body = create_closing_handshake_body(code, reason)
self._logger.debug('Sending closing handshake for %d: (%r, %r)' %
(self._request.channel_id, code, reason))
self._write_inner_frame(common.OPCODE_CLOSE, body, end=True)
self._request.server_terminated = True
def send_ping(self, body=''):
"""Overrides Stream.send_ping"""
self._logger.debug('Sending ping on logical channel %d: %r' %
(self._request.channel_id, body))
self._write_inner_frame(common.OPCODE_PING, body, end=True)
self._ping_queue.append(body)
def _send_pong(self, body):
"""Overrides Stream._send_pong"""
self._logger.debug('Sending pong on logical channel %d: %r' %
(self._request.channel_id, body))
self._write_inner_frame(common.OPCODE_PONG, body, end=True)
def close_connection(self, code=common.STATUS_NORMAL_CLOSURE, reason=''):
"""Overrides Stream.close_connection."""
# TODO(bashi): Implement
self._logger.debug('Closing logical connection %d' %
self._request.channel_id)
self._request.server_terminated = True
def stop_sending(self):
"""Stops accepting new send operation (_write_inner_frame)."""
self._send_condition.acquire()
self._send_closed = True
self._send_condition.notify()
self._send_condition.release()
class _OutgoingData(object):
"""A structure that holds data to be sent via physical connection and
origin of the data.
"""
def __init__(self, channel_id, data):
self.channel_id = channel_id
self.data = data
class _PhysicalConnectionWriter(threading.Thread):
"""A thread that is responsible for writing data to physical connection.
TODO(bashi): Make sure there is no thread-safety problem when the reader
thread reads data from the same socket at a time.
"""
def __init__(self, mux_handler):
"""Constructs an instance.
Args:
mux_handler: _MuxHandler instance.
"""
threading.Thread.__init__(self)
self._logger = util.get_class_logger(self)
self._mux_handler = mux_handler
self.setDaemon(True)
# When set, make this thread stop accepting new data, flush pending
# data and exit.
self._stop_requested = False
# The close code of the physical connection.
self._close_code = common.STATUS_NORMAL_CLOSURE
# Deque for passing write data. It's protected by _deque_condition
# until _stop_requested is set.
self._deque = collections.deque()
# - Protects _deque, _stop_requested and _close_code
# - Signals threads waiting for them to be available
self._deque_condition = threading.Condition()
def put_outgoing_data(self, data):
"""Puts outgoing data.
Args:
data: _OutgoingData instance.
Raises:
BadOperationException: when the thread has been requested to
terminate.
"""
try:
self._deque_condition.acquire()
if self._stop_requested:
raise BadOperationException('Cannot write data anymore')
self._deque.append(data)
self._deque_condition.notify()
finally:
self._deque_condition.release()
def _write_data(self, outgoing_data):
message = (_encode_channel_id(outgoing_data.channel_id) +
outgoing_data.data)
try:
self._mux_handler.physical_stream.send_message(
message=message, end=True, binary=True)
except Exception, e:
util.prepend_message_to_exception(
'Failed to send message to %r: ' %
(self._mux_handler.physical_connection.remote_addr,), e)
raise
# TODO(bashi): It would be better to block the thread that sends
# control data as well.
if outgoing_data.channel_id != _CONTROL_CHANNEL_ID:
self._mux_handler.notify_write_data_done(outgoing_data.channel_id)
def run(self):
try:
self._deque_condition.acquire()
while not self._stop_requested:
if len(self._deque) == 0:
self._deque_condition.wait()
continue
outgoing_data = self._deque.popleft()
self._deque_condition.release()
self._write_data(outgoing_data)
self._deque_condition.acquire()
# Flush deque.
#
# At this point, self._deque_condition is always acquired.
try:
while len(self._deque) > 0:
outgoing_data = self._deque.popleft()
self._write_data(outgoing_data)
finally:
self._deque_condition.release()
# Close physical connection.
try:
# Don't wait the response here. The response will be read
# by the reader thread.
self._mux_handler.physical_stream.close_connection(
self._close_code, wait_response=False)
except Exception, e:
util.prepend_message_to_exception(
'Failed to close the physical connection: %r' % e)
raise
finally:
self._mux_handler.notify_writer_done()
def stop(self, close_code=common.STATUS_NORMAL_CLOSURE):
"""Stops the writer thread."""
self._deque_condition.acquire()
self._stop_requested = True
self._close_code = close_code
self._deque_condition.notify()
self._deque_condition.release()
class _PhysicalConnectionReader(threading.Thread):
"""A thread that is responsible for reading data from physical connection.
"""
def __init__(self, mux_handler):
"""Constructs an instance.
Args:
mux_handler: _MuxHandler instance.
"""
threading.Thread.__init__(self)
self._logger = util.get_class_logger(self)
self._mux_handler = mux_handler
self.setDaemon(True)
def run(self):
while True:
try:
physical_stream = self._mux_handler.physical_stream
message = physical_stream.receive_message()
if message is None:
break
# Below happens only when a data message is received.
opcode = physical_stream.get_last_received_opcode()
if opcode != common.OPCODE_BINARY:
self._mux_handler.fail_physical_connection(
_DROP_CODE_INVALID_ENCAPSULATING_MESSAGE,
'Received a text message on physical connection')
break
except ConnectionTerminatedException, e:
self._logger.debug('%s', e)
break
try:
self._mux_handler.dispatch_message(message)
except PhysicalConnectionError, e:
self._mux_handler.fail_physical_connection(
e.drop_code, e.message)
break
except LogicalChannelError, e:
self._mux_handler.fail_logical_channel(
e.channel_id, e.drop_code, e.message)
except Exception, e:
self._logger.debug(traceback.format_exc())
break
self._mux_handler.notify_reader_done()
class _Worker(threading.Thread):
"""A thread that is responsible for running the corresponding application
handler.
"""
def __init__(self, mux_handler, request):
"""Constructs an instance.
Args:
mux_handler: _MuxHandler instance.
request: _LogicalRequest instance.
"""
threading.Thread.__init__(self)
self._logger = util.get_class_logger(self)
self._mux_handler = mux_handler
self._request = request
self.setDaemon(True)
def run(self):
self._logger.debug('Logical channel worker started. (id=%d)' %
self._request.channel_id)
try:
# Non-critical exceptions will be handled by dispatcher.
self._mux_handler.dispatcher.transfer_data(self._request)
except LogicalChannelError, e:
self._mux_handler.fail_logical_channel(
e.channel_id, e.drop_code, e.message)
finally:
self._mux_handler.notify_worker_done(self._request.channel_id)
class _MuxHandshaker(hybi.Handshaker):
"""Opening handshake processor for multiplexing."""
_DUMMY_WEBSOCKET_KEY = 'dGhlIHNhbXBsZSBub25jZQ=='
def __init__(self, request, dispatcher, send_quota, receive_quota):
"""Constructs an instance.
Args:
request: _LogicalRequest instance.
dispatcher: Dispatcher instance (dispatch.Dispatcher).
send_quota: Initial send quota.
receive_quota: Initial receive quota.
"""
hybi.Handshaker.__init__(self, request, dispatcher)
self._send_quota = send_quota
self._receive_quota = receive_quota
# Append headers which should not be included in handshake field of
# AddChannelRequest.
# TODO(bashi): Make sure whether we should raise exception when
# these headers are included already.
request.headers_in[common.UPGRADE_HEADER] = (
common.WEBSOCKET_UPGRADE_TYPE)
request.headers_in[common.SEC_WEBSOCKET_VERSION_HEADER] = (
str(common.VERSION_HYBI_LATEST))
request.headers_in[common.SEC_WEBSOCKET_KEY_HEADER] = (
self._DUMMY_WEBSOCKET_KEY)
def _create_stream(self, stream_options):
"""Override hybi.Handshaker._create_stream."""
self._logger.debug('Creating logical stream for %d' %
self._request.channel_id)
return _LogicalStream(
self._request, stream_options, self._send_quota,
self._receive_quota)
def _create_handshake_response(self, accept):
"""Override hybi._create_handshake_response."""
response = []
response.append('HTTP/1.1 101 Switching Protocols\r\n')
# Upgrade and Sec-WebSocket-Accept should be excluded.
response.append('%s: %s\r\n' % (
common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE))
if self._request.ws_protocol is not None:
response.append('%s: %s\r\n' % (
common.SEC_WEBSOCKET_PROTOCOL_HEADER,
self._request.ws_protocol))
if (self._request.ws_extensions is not None and
len(self._request.ws_extensions) != 0):
response.append('%s: %s\r\n' % (
common.SEC_WEBSOCKET_EXTENSIONS_HEADER,
common.format_extensions(self._request.ws_extensions)))
response.append('\r\n')
return ''.join(response)
def _send_handshake(self, accept):
"""Override hybi.Handshaker._send_handshake."""
# Don't send handshake response for the default channel
if self._request.channel_id == _DEFAULT_CHANNEL_ID:
return
handshake_response = self._create_handshake_response(accept)
frame_data = _create_add_channel_response(
self._request.channel_id,
handshake_response)
self._logger.debug('Sending handshake response for %d: %r' %
(self._request.channel_id, frame_data))
self._request.connection.write_control_data(frame_data)
class _LogicalChannelData(object):
"""A structure that holds information about logical channel.
"""
def __init__(self, request, worker):
self.request = request
self.worker = worker
self.drop_code = _DROP_CODE_NORMAL_CLOSURE
self.drop_message = ''
class _HandshakeDeltaBase(object):
"""A class that holds information for delta-encoded handshake."""
def __init__(self, headers):
self._headers = headers
def create_headers(self, delta=None):
"""Creates request headers for an AddChannelRequest that has
delta-encoded handshake.
Args:
delta: headers should be overridden.
"""
headers = copy.copy(self._headers)
if delta:
for key, value in delta.items():
# The spec requires that a header with an empty value is
# removed from the delta base.
if len(value) == 0 and headers.has_key(key):
del headers[key]
else:
headers[key] = value
return headers
class _MuxHandler(object):
"""Multiplexing handler. When a handler starts, it launches three
threads; the reader thread, the writer thread, and a worker thread.
The reader thread reads data from the physical stream, i.e., the
ws_stream object of the underlying websocket connection. The reader
thread interprets multiplexed frames and dispatches them to logical
channels. Methods of this class are mostly called by the reader thread.
The writer thread sends multiplexed frames which are created by
logical channels via the physical connection.
The worker thread launched at the starting point handles the
"Implicitly Opened Connection". If multiplexing handler receives
an AddChannelRequest and accepts it, the handler will launch a new worker
thread and dispatch the request to it.
"""
def __init__(self, request, dispatcher):
"""Constructs an instance.
Args:
request: mod_python request of the physical connection.
dispatcher: Dispatcher instance (dispatch.Dispatcher).
"""
self.original_request = request
self.dispatcher = dispatcher
self.physical_connection = request.connection
self.physical_stream = request.ws_stream
self._logger = util.get_class_logger(self)
self._logical_channels = {}
self._logical_channels_condition = threading.Condition()
# Holds client's initial quota
self._channel_slots = collections.deque()
self._handshake_base = None
self._worker_done_notify_received = False
self._reader = None
self._writer = None
def start(self):
"""Starts the handler.
Raises:
MuxUnexpectedException: when the handler already started, or when
opening handshake of the default channel fails.
"""
if self._reader or self._writer:
raise MuxUnexpectedException('MuxHandler already started')
self._reader = _PhysicalConnectionReader(self)
self._writer = _PhysicalConnectionWriter(self)
self._reader.start()
self._writer.start()
# Create "Implicitly Opened Connection".
logical_connection = _LogicalConnection(self, _DEFAULT_CHANNEL_ID)
headers = copy.copy(self.original_request.headers_in)
# Add extensions for logical channel.
headers[common.SEC_WEBSOCKET_EXTENSIONS_HEADER] = (
common.format_extensions(
self.original_request.mux_processor.extensions()))
self._handshake_base = _HandshakeDeltaBase(headers)
logical_request = _LogicalRequest(
_DEFAULT_CHANNEL_ID,
self.original_request.method,
self.original_request.uri,
self.original_request.protocol,
self._handshake_base.create_headers(),
logical_connection)
# Client's send quota for the implicitly opened connection is zero,
# but we will send FlowControl later so set the initial quota to
# _INITIAL_QUOTA_FOR_CLIENT.
self._channel_slots.append(_INITIAL_QUOTA_FOR_CLIENT)
send_quota = self.original_request.mux_processor.quota()
if not self._do_handshake_for_logical_request(
logical_request, send_quota=send_quota):
raise MuxUnexpectedException(
'Failed handshake on the default channel id')
self._add_logical_channel(logical_request)
# Send FlowControl for the implicitly opened connection.
frame_data = _create_flow_control(_DEFAULT_CHANNEL_ID,
_INITIAL_QUOTA_FOR_CLIENT)
logical_request.connection.write_control_data(frame_data)
def add_channel_slots(self, slots, send_quota):
"""Adds channel slots.
Args:
slots: number of slots to be added.
send_quota: initial send quota for slots.
"""
self._channel_slots.extend([send_quota] * slots)
# Send NewChannelSlot to client.
frame_data = _create_new_channel_slot(slots, send_quota)
self.send_control_data(frame_data)
def wait_until_done(self, timeout=None):
"""Waits until all workers are done. Returns False when timeout has
occurred. Returns True on success.
Args:
timeout: timeout in sec.
"""
self._logical_channels_condition.acquire()
try:
while len(self._logical_channels) > 0:
self._logger.debug('Waiting workers(%d)...' %
len(self._logical_channels))
self._worker_done_notify_received = False
self._logical_channels_condition.wait(timeout)
if not self._worker_done_notify_received:
self._logger.debug('Waiting worker(s) timed out')
return False
finally:
self._logical_channels_condition.release()
# Flush pending outgoing data
self._writer.stop()
self._writer.join()
return True
def notify_write_data_done(self, channel_id):
"""Called by the writer thread when a write operation has done.
Args:
channel_id: objective channel id.
"""
try:
self._logical_channels_condition.acquire()
if channel_id in self._logical_channels:
channel_data = self._logical_channels[channel_id]
channel_data.request.connection.on_write_data_done()
else:
self._logger.debug('Seems that logical channel for %d has gone'
% channel_id)
finally:
self._logical_channels_condition.release()
def send_control_data(self, data):
"""Sends data via the control channel.
Args:
data: data to be sent.
"""
self._writer.put_outgoing_data(_OutgoingData(
channel_id=_CONTROL_CHANNEL_ID, data=data))
def send_data(self, channel_id, data):
"""Sends data via given logical channel. This method is called by
worker threads.
Args:
channel_id: objective channel id.
data: data to be sent.
"""
self._writer.put_outgoing_data(_OutgoingData(
channel_id=channel_id, data=data))
def _send_drop_channel(self, channel_id, code=None, message=''):
frame_data = _create_drop_channel(channel_id, code, message)
self._logger.debug(
'Sending drop channel for channel id %d' % channel_id)
self.send_control_data(frame_data)
def _send_error_add_channel_response(self, channel_id, status=None):
if status is None:
status = common.HTTP_STATUS_BAD_REQUEST
if status in _HTTP_BAD_RESPONSE_MESSAGES:
message = _HTTP_BAD_RESPONSE_MESSAGES[status]
else:
self._logger.debug('Response message for %d is not found' % status)
message = '???'
response = 'HTTP/1.1 %d %s\r\n\r\n' % (status, message)
frame_data = _create_add_channel_response(channel_id,
encoded_handshake=response,
encoding=0, rejected=True)
self.send_control_data(frame_data)
def _create_logical_request(self, block):
if block.channel_id == _CONTROL_CHANNEL_ID:
# TODO(bashi): Raise PhysicalConnectionError with code 2006
# instead of MuxUnexpectedException.
raise MuxUnexpectedException(
'Received the control channel id (0) as objective channel '
'id for AddChannel')
if block.encoding > _HANDSHAKE_ENCODING_DELTA:
raise PhysicalConnectionError(
_DROP_CODE_UNKNOWN_REQUEST_ENCODING)
method, path, version, headers = _parse_request_text(
block.encoded_handshake)
if block.encoding == _HANDSHAKE_ENCODING_DELTA:
headers = self._handshake_base.create_headers(headers)
connection = _LogicalConnection(self, block.channel_id)
request = _LogicalRequest(block.channel_id, method, path, version,
headers, connection)
return request
def _do_handshake_for_logical_request(self, request, send_quota=0):
try:
receive_quota = self._channel_slots.popleft()
except IndexError:
raise LogicalChannelError(
request.channel_id, _DROP_CODE_NEW_CHANNEL_SLOT_VIOLATION)
handshaker = _MuxHandshaker(request, self.dispatcher,
send_quota, receive_quota)
try:
handshaker.do_handshake()
except handshake.VersionException, e:
self._logger.info('%s', e)
self._send_error_add_channel_response(
request.channel_id, status=common.HTTP_STATUS_BAD_REQUEST)
return False
except handshake.HandshakeException, e:
# TODO(bashi): Should we _Fail the Logical Channel_ with 3001
# instead?
self._logger.info('%s', e)
self._send_error_add_channel_response(request.channel_id,
status=e.status)
return False
except handshake.AbortedByUserException, e:
self._logger.info('%s', e)
self._send_error_add_channel_response(request.channel_id)
return False
return True
def _add_logical_channel(self, logical_request):
try:
self._logical_channels_condition.acquire()
if logical_request.channel_id in self._logical_channels:
self._logger.debug('Channel id %d already exists' %
logical_request.channel_id)
raise PhysicalConnectionError(
_DROP_CODE_CHANNEL_ALREADY_EXISTS,
'Channel id %d already exists' %
logical_request.channel_id)
worker = _Worker(self, logical_request)
channel_data = _LogicalChannelData(logical_request, worker)
self._logical_channels[logical_request.channel_id] = channel_data
worker.start()
finally:
self._logical_channels_condition.release()
def _process_add_channel_request(self, block):
try:
logical_request = self._create_logical_request(block)
except ValueError, e:
self._logger.debug('Failed to create logical request: %r' % e)
self._send_error_add_channel_response(
block.channel_id, status=common.HTTP_STATUS_BAD_REQUEST)
return
if self._do_handshake_for_logical_request(logical_request):
if block.encoding == _HANDSHAKE_ENCODING_IDENTITY:
# Update handshake base.
# TODO(bashi): Make sure this is the right place to update
# handshake base.
self._handshake_base = _HandshakeDeltaBase(
logical_request.headers_in)
self._add_logical_channel(logical_request)
else:
self._send_error_add_channel_response(
block.channel_id, status=common.HTTP_STATUS_BAD_REQUEST)
def _process_flow_control(self, block):
try:
self._logical_channels_condition.acquire()
if not block.channel_id in self._logical_channels:
return
channel_data = self._logical_channels[block.channel_id]
channel_data.request.ws_stream.replenish_send_quota(
block.send_quota)
finally:
self._logical_channels_condition.release()
def _process_drop_channel(self, block):
self._logger.debug(
'DropChannel received for %d: code=%r, reason=%r' %
(block.channel_id, block.drop_code, block.drop_message))
try:
self._logical_channels_condition.acquire()
if not block.channel_id in self._logical_channels:
return
channel_data = self._logical_channels[block.channel_id]
channel_data.drop_code = _DROP_CODE_ACKNOWLEDGED
# Close the logical channel
channel_data.request.connection.set_read_state(
_LogicalConnection.STATE_TERMINATED)
channel_data.request.ws_stream.stop_sending()
finally:
self._logical_channels_condition.release()
def _process_control_blocks(self, parser):
for control_block in parser.read_control_blocks():
opcode = control_block.opcode
self._logger.debug('control block received, opcode: %d' % opcode)
if opcode == _MUX_OPCODE_ADD_CHANNEL_REQUEST:
self._process_add_channel_request(control_block)
elif opcode == _MUX_OPCODE_ADD_CHANNEL_RESPONSE:
raise PhysicalConnectionError(
_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
'Received AddChannelResponse')
elif opcode == _MUX_OPCODE_FLOW_CONTROL:
self._process_flow_control(control_block)
elif opcode == _MUX_OPCODE_DROP_CHANNEL:
self._process_drop_channel(control_block)
elif opcode == _MUX_OPCODE_NEW_CHANNEL_SLOT:
raise PhysicalConnectionError(
_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
'Received NewChannelSlot')
else:
raise MuxUnexpectedException(
'Unexpected opcode %r' % opcode)
def _process_logical_frame(self, channel_id, parser):
self._logger.debug('Received a frame. channel id=%d' % channel_id)
try:
self._logical_channels_condition.acquire()
if not channel_id in self._logical_channels:
# We must ignore the message for an inactive channel.
return
channel_data = self._logical_channels[channel_id]
fin, rsv1, rsv2, rsv3, opcode, payload = parser.read_inner_frame()
consuming_byte = len(payload)
if opcode != common.OPCODE_CONTINUATION:
consuming_byte += 1
if not channel_data.request.ws_stream.consume_receive_quota(
consuming_byte):
# The client violates quota. Close logical channel.
raise LogicalChannelError(
channel_id, _DROP_CODE_SEND_QUOTA_VIOLATION)
header = create_header(opcode, len(payload), fin, rsv1, rsv2, rsv3,
mask=False)
frame_data = header + payload
channel_data.request.connection.append_frame_data(frame_data)
finally:
self._logical_channels_condition.release()
def dispatch_message(self, message):
"""Dispatches message. The reader thread calls this method.
Args:
message: a message that contains encapsulated frame.
Raises:
PhysicalConnectionError: if the message contains physical
connection level errors.
LogicalChannelError: if the message contains logical channel
level errors.
"""
parser = _MuxFramePayloadParser(message)
try:
channel_id = parser.read_channel_id()
except ValueError, e:
raise PhysicalConnectionError(_DROP_CODE_CHANNEL_ID_TRUNCATED)
if channel_id == _CONTROL_CHANNEL_ID:
self._process_control_blocks(parser)
else:
self._process_logical_frame(channel_id, parser)
def notify_worker_done(self, channel_id):
"""Called when a worker has finished.
Args:
channel_id: channel id corresponded with the worker.
"""
self._logger.debug('Worker for channel id %d terminated' % channel_id)
try:
self._logical_channels_condition.acquire()
if not channel_id in self._logical_channels:
raise MuxUnexpectedException(
'Channel id %d not found' % channel_id)
channel_data = self._logical_channels.pop(channel_id)
finally:
self._worker_done_notify_received = True
self._logical_channels_condition.notify()
self._logical_channels_condition.release()
if not channel_data.request.server_terminated:
self._send_drop_channel(
channel_id, code=channel_data.drop_code,
message=channel_data.drop_message)
def notify_reader_done(self):
"""This method is called by the reader thread when the reader has
finished.
"""
self._logger.debug(
'Termiating all logical connections waiting for incoming data '
'...')
self._logical_channels_condition.acquire()
for channel_data in self._logical_channels.values():
try:
channel_data.request.connection.set_read_state(
_LogicalConnection.STATE_TERMINATED)
except Exception:
self._logger.debug(traceback.format_exc())
self._logical_channels_condition.release()
def notify_writer_done(self):
"""This method is called by the writer thread when the writer has
finished.
"""
self._logger.debug(
'Termiating all logical connections waiting for write '
'completion ...')
self._logical_channels_condition.acquire()
for channel_data in self._logical_channels.values():
try:
channel_data.request.connection.on_writer_done()
except Exception:
self._logger.debug(traceback.format_exc())
self._logical_channels_condition.release()
def fail_physical_connection(self, code, message):
"""Fail the physical connection.
Args:
code: drop reason code.
message: drop message.
"""
self._logger.debug('Failing the physical connection...')
self._send_drop_channel(_CONTROL_CHANNEL_ID, code, message)
self._writer.stop(common.STATUS_INTERNAL_ENDPOINT_ERROR)
def fail_logical_channel(self, channel_id, code, message):
"""Fail a logical channel.
Args:
channel_id: channel id.
code: drop reason code.
message: drop message.
"""
self._logger.debug('Failing logical channel %d...' % channel_id)
try:
self._logical_channels_condition.acquire()
if channel_id in self._logical_channels:
channel_data = self._logical_channels[channel_id]
# Close the logical channel. notify_worker_done() will be
# called later and it will send DropChannel.
channel_data.drop_code = code
channel_data.drop_message = message
channel_data.request.connection.set_read_state(
_LogicalConnection.STATE_TERMINATED)
channel_data.request.ws_stream.stop_sending()
else:
self._send_drop_channel(channel_id, code, message)
finally:
self._logical_channels_condition.release()
def use_mux(request):
return hasattr(request, 'mux_processor') and (
request.mux_processor.is_active())
def start(request, dispatcher):
mux_handler = _MuxHandler(request, dispatcher)
mux_handler.start()
mux_handler.add_channel_slots(_INITIAL_NUMBER_OF_CHANNEL_SLOTS,
_INITIAL_QUOTA_FOR_CLIENT)
mux_handler.wait_until_done()
# vi:sts=4 sw=4 et