blob: a0000fa1c94f64e700abfafbcb9634e1096dc219 [file] [log] [blame]
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor,
Boston, MA 02110-1335 USA
"""
import array
import os
import struct
import six
from ._exceptions import *
from ._utils import validate_utf8
from threading import Lock
try:
if six.PY3:
import numpy
else:
numpy = None
except ImportError:
numpy = None
try:
# If wsaccel is available we use compiled routines to mask data.
if not numpy:
from wsaccel.xormask import XorMaskerSimple
def _mask(_m, _d):
return XorMaskerSimple(_m).process(_d)
except ImportError:
# wsaccel is not available, we rely on python implementations.
def _mask(_m, _d):
for i in range(len(_d)):
_d[i] ^= _m[i % 4]
if six.PY3:
return _d.tobytes()
else:
return _d.tostring()
__all__ = [
'ABNF', 'continuous_frame', 'frame_buffer',
'STATUS_NORMAL',
'STATUS_GOING_AWAY',
'STATUS_PROTOCOL_ERROR',
'STATUS_UNSUPPORTED_DATA_TYPE',
'STATUS_STATUS_NOT_AVAILABLE',
'STATUS_ABNORMAL_CLOSED',
'STATUS_INVALID_PAYLOAD',
'STATUS_POLICY_VIOLATION',
'STATUS_MESSAGE_TOO_BIG',
'STATUS_INVALID_EXTENSION',
'STATUS_UNEXPECTED_CONDITION',
'STATUS_BAD_GATEWAY',
'STATUS_TLS_HANDSHAKE_ERROR',
]
# closing frame status codes.
STATUS_NORMAL = 1000
STATUS_GOING_AWAY = 1001
STATUS_PROTOCOL_ERROR = 1002
STATUS_UNSUPPORTED_DATA_TYPE = 1003
STATUS_STATUS_NOT_AVAILABLE = 1005
STATUS_ABNORMAL_CLOSED = 1006
STATUS_INVALID_PAYLOAD = 1007
STATUS_POLICY_VIOLATION = 1008
STATUS_MESSAGE_TOO_BIG = 1009
STATUS_INVALID_EXTENSION = 1010
STATUS_UNEXPECTED_CONDITION = 1011
STATUS_BAD_GATEWAY = 1014
STATUS_TLS_HANDSHAKE_ERROR = 1015
VALID_CLOSE_STATUS = (
STATUS_NORMAL,
STATUS_GOING_AWAY,
STATUS_PROTOCOL_ERROR,
STATUS_UNSUPPORTED_DATA_TYPE,
STATUS_INVALID_PAYLOAD,
STATUS_POLICY_VIOLATION,
STATUS_MESSAGE_TOO_BIG,
STATUS_INVALID_EXTENSION,
STATUS_UNEXPECTED_CONDITION,
STATUS_BAD_GATEWAY,
)
class ABNF(object):
"""
ABNF frame class.
see http://tools.ietf.org/html/rfc5234
and http://tools.ietf.org/html/rfc6455#section-5.2
"""
# operation code values.
OPCODE_CONT = 0x0
OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2
OPCODE_CLOSE = 0x8
OPCODE_PING = 0x9
OPCODE_PONG = 0xa
# available operation code value tuple
OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE,
OPCODE_PING, OPCODE_PONG)
# opcode human readable string
OPCODE_MAP = {
OPCODE_CONT: "cont",
OPCODE_TEXT: "text",
OPCODE_BINARY: "binary",
OPCODE_CLOSE: "close",
OPCODE_PING: "ping",
OPCODE_PONG: "pong"
}
# data length threshold.
LENGTH_7 = 0x7e
LENGTH_16 = 1 << 16
LENGTH_63 = 1 << 63
def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0,
opcode=OPCODE_TEXT, mask=1, data=""):
"""
Constructor for ABNF.
please check RFC for arguments.
"""
self.fin = fin
self.rsv1 = rsv1
self.rsv2 = rsv2
self.rsv3 = rsv3
self.opcode = opcode
self.mask = mask
if data is None:
data = ""
self.data = data
self.get_mask_key = os.urandom
def validate(self, skip_utf8_validation=False):
"""
validate the ABNF frame.
skip_utf8_validation: skip utf8 validation.
"""
if self.rsv1 or self.rsv2 or self.rsv3:
raise WebSocketProtocolException("rsv is not implemented, yet")
if self.opcode not in ABNF.OPCODES:
raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
if self.opcode == ABNF.OPCODE_PING and not self.fin:
raise WebSocketProtocolException("Invalid ping frame.")
if self.opcode == ABNF.OPCODE_CLOSE:
l = len(self.data)
if not l:
return
if l == 1 or l >= 126:
raise WebSocketProtocolException("Invalid close frame.")
if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
raise WebSocketProtocolException("Invalid close frame.")
code = 256 * \
six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2])
if not self._is_valid_close_status(code):
raise WebSocketProtocolException("Invalid close opcode.")
@staticmethod
def _is_valid_close_status(code):
return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
def __str__(self):
return "fin=" + str(self.fin) \
+ " opcode=" + str(self.opcode) \
+ " data=" + str(self.data)
@staticmethod
def create_frame(data, opcode, fin=1):
"""
create frame to send text, binary and other data.
data: data to send. This is string value(byte array).
if opcode is OPCODE_TEXT and this value is unicode,
data value is converted into unicode string, automatically.
opcode: operation code. please see OPCODE_XXX.
fin: fin flag. if set to 0, create continue fragmentation.
"""
if opcode == ABNF.OPCODE_TEXT and isinstance(data, six.text_type):
data = data.encode("utf-8")
# mask must be set if send data from client
return ABNF(fin, 0, 0, 0, opcode, 1, data)
def format(self):
"""
format this object to string(byte array) to send data to server.
"""
if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
raise ValueError("not 0 or 1")
if self.opcode not in ABNF.OPCODES:
raise ValueError("Invalid OPCODE")
length = len(self.data)
if length >= ABNF.LENGTH_63:
raise ValueError("data is too long")
frame_header = chr(self.fin << 7
| self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4
| self.opcode)
if length < ABNF.LENGTH_7:
frame_header += chr(self.mask << 7 | length)
frame_header = six.b(frame_header)
elif length < ABNF.LENGTH_16:
frame_header += chr(self.mask << 7 | 0x7e)
frame_header = six.b(frame_header)
frame_header += struct.pack("!H", length)
else:
frame_header += chr(self.mask << 7 | 0x7f)
frame_header = six.b(frame_header)
frame_header += struct.pack("!Q", length)
if not self.mask:
return frame_header + self.data
else:
mask_key = self.get_mask_key(4)
return frame_header + self._get_masked(mask_key)
def _get_masked(self, mask_key):
s = ABNF.mask(mask_key, self.data)
if isinstance(mask_key, six.text_type):
mask_key = mask_key.encode('utf-8')
return mask_key + s
@staticmethod
def mask(mask_key, data):
"""
mask or unmask data. Just do xor for each byte
mask_key: 4 byte string(byte).
data: data to mask/unmask.
"""
if data is None:
data = ""
if isinstance(mask_key, six.text_type):
mask_key = six.b(mask_key)
if isinstance(data, six.text_type):
data = six.b(data)
if numpy:
origlen = len(data)
_mask_key = mask_key[3] << 24 | mask_key[2] << 16 | mask_key[1] << 8 | mask_key[0]
# We need data to be a multiple of four...
data += bytes(" " * (4 - (len(data) % 4)), "us-ascii")
a = numpy.frombuffer(data, dtype="uint32")
masked = numpy.bitwise_xor(a, [_mask_key]).astype("uint32")
if len(data) > origlen:
return masked.tobytes()[:origlen]
return masked.tobytes()
else:
_m = array.array("B", mask_key)
_d = array.array("B", data)
return _mask(_m, _d)
class frame_buffer(object):
_HEADER_MASK_INDEX = 5
_HEADER_LENGTH_INDEX = 6
def __init__(self, recv_fn, skip_utf8_validation):
self.recv = recv_fn
self.skip_utf8_validation = skip_utf8_validation
# Buffers over the packets from the layer beneath until desired amount
# bytes of bytes are received.
self.recv_buffer = []
self.clear()
self.lock = Lock()
def clear(self):
self.header = None
self.length = None
self.mask = None
def has_received_header(self):
return self.header is None
def recv_header(self):
header = self.recv_strict(2)
b1 = header[0]
if six.PY2:
b1 = ord(b1)
fin = b1 >> 7 & 1
rsv1 = b1 >> 6 & 1
rsv2 = b1 >> 5 & 1
rsv3 = b1 >> 4 & 1
opcode = b1 & 0xf
b2 = header[1]
if six.PY2:
b2 = ord(b2)
has_mask = b2 >> 7 & 1
length_bits = b2 & 0x7f
self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
def has_mask(self):
if not self.header:
return False
return self.header[frame_buffer._HEADER_MASK_INDEX]
def has_received_length(self):
return self.length is None
def recv_length(self):
bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
length_bits = bits & 0x7f
if length_bits == 0x7e:
v = self.recv_strict(2)
self.length = struct.unpack("!H", v)[0]
elif length_bits == 0x7f:
v = self.recv_strict(8)
self.length = struct.unpack("!Q", v)[0]
else:
self.length = length_bits
def has_received_mask(self):
return self.mask is None
def recv_mask(self):
self.mask = self.recv_strict(4) if self.has_mask() else ""
def recv_frame(self):
with self.lock:
# Header
if self.has_received_header():
self.recv_header()
(fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
# Frame length
if self.has_received_length():
self.recv_length()
length = self.length
# Mask
if self.has_received_mask():
self.recv_mask()
mask = self.mask
# Payload
payload = self.recv_strict(length)
if has_mask:
payload = ABNF.mask(mask, payload)
# Reset for next frame
self.clear()
frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
frame.validate(self.skip_utf8_validation)
return frame
def recv_strict(self, bufsize):
shortage = bufsize - sum(len(x) for x in self.recv_buffer)
while shortage > 0:
# Limit buffer size that we pass to socket.recv() to avoid
# fragmenting the heap -- the number of bytes recv() actually
# reads is limited by socket buffer and is relatively small,
# yet passing large numbers repeatedly causes lots of large
# buffers allocated and then shrunk, which results in
# fragmentation.
bytes_ = self.recv(min(16384, shortage))
self.recv_buffer.append(bytes_)
shortage -= len(bytes_)
unified = six.b("").join(self.recv_buffer)
if shortage == 0:
self.recv_buffer = []
return unified
else:
self.recv_buffer = [unified[bufsize:]]
return unified[:bufsize]
class continuous_frame(object):
def __init__(self, fire_cont_frame, skip_utf8_validation):
self.fire_cont_frame = fire_cont_frame
self.skip_utf8_validation = skip_utf8_validation
self.cont_data = None
self.recving_frames = None
def validate(self, frame):
if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
raise WebSocketProtocolException("Illegal frame")
if self.recving_frames and \
frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
raise WebSocketProtocolException("Illegal frame")
def add(self, frame):
if self.cont_data:
self.cont_data[1] += frame.data
else:
if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
self.recving_frames = frame.opcode
self.cont_data = [frame.opcode, frame.data]
if frame.fin:
self.recving_frames = None
def is_fire(self, frame):
return frame.fin or self.fire_cont_frame
def extract(self, frame):
data = self.cont_data
self.cont_data = None
frame.data = data[1]
if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data):
raise WebSocketPayloadException(
"cannot decode: " + repr(frame.data))
return [data[0], frame]