Skip to content

Commit

Permalink
Handle bytearray like bytes.
Browse files Browse the repository at this point in the history
Ref #478.
  • Loading branch information
aaugustin committed Nov 4, 2018
1 parent 771a5f2 commit e284fb2
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 22 deletions.
16 changes: 10 additions & 6 deletions src/websockets/framing.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,18 +237,22 @@ def check(frame):

def encode_data(data):
"""
Helper that converts :class:`str` or :class:`bytes` to :class:`bytes`.
Converts a string or byte-like object to bytes.
:class:`str` are encoded with UTF-8.
If ``data`` is a :class:`str`, return a :class:`bytes` object encoding
``data`` in UTF-8.
If ``data`` is a bytes-like object, return a :class:`bytes` object.
Raise :exc:`TypeError` for other inputs.
"""
# Expect str or bytes, return bytes.
if isinstance(data, str):
return data.encode('utf-8')
elif isinstance(data, bytes):
return data
elif isinstance(data, collections.abc.ByteString):
return bytes(data)
else:
raise TypeError("data must be bytes or str")
raise TypeError("data must be bytes-like or str")


def parse_close(data):
Expand Down
4 changes: 2 additions & 2 deletions src/websockets/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def send(self, data):
if isinstance(data, str):
yield from self.write_frame(True, OP_TEXT, data.encode('utf-8'))

elif isinstance(data, bytes):
elif isinstance(data, collections.abc.ByteString):
yield from self.write_frame(True, OP_BINARY, data)

# Fragmented message -- regular iterator.
Expand All @@ -483,7 +483,7 @@ def send(self, data):
if isinstance(data, str):
yield from self.write_frame(False, OP_TEXT, data.encode('utf-8'))
encode_data = True
elif isinstance(data, bytes):
elif isinstance(data, collections.abc.ByteString):
yield from self.write_frame(False, OP_BINARY, data)
encode_data = False
else:
Expand Down
61 changes: 57 additions & 4 deletions src/websockets/speedups.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,76 @@

static const Py_ssize_t MASK_LEN = 4;

/* Similar to PyBytes_AsStringAndSize, but accepts more types */

static int
_PyBytesLike_AsStringAndSize(PyObject *obj, char **buffer, Py_ssize_t *length)
{
if (PyBytes_Check(obj))
{
*buffer = PyBytes_AS_STRING(obj);
*length = PyBytes_GET_SIZE(obj);
}
else if (PyByteArray_Check(obj))
{
*buffer = PyByteArray_AS_STRING(obj);
*length = PyByteArray_GET_SIZE(obj);
}
else
{
PyErr_Format(
PyExc_TypeError,
"expected a bytes-like object, %.200s found",
Py_TYPE(obj)->tp_name);
return -1;
}

return 0;
}

/* C implementation of websockets.utils.apply_mask */

static PyObject *
apply_mask(PyObject *self, PyObject *args, PyObject *kwds)
{

// Inputs are treated as immutable, which causes an extra memory copy.
// In order to support bytes and bytearray, accept any Python object.

static char *kwlist[] = {"data", "mask", NULL};
const char *input;
PyObject *input_obj;
PyObject *mask_obj;

// A pointer to the underlying char * will be extracted from these inputs.

char *input;
Py_ssize_t input_len;
const char *mask;
char *mask;
Py_ssize_t mask_len;

// Initialize a PyBytesObject then get a pointer to the underlying char *
// in order to avoid an extra memory copy in PyBytes_FromStringAndSize.

PyObject *result;
char *output;

// Other variables.

Py_ssize_t i = 0;

// Parse inputs.

if (!PyArg_ParseTupleAndKeywords(
args, kwds, "y#y#", kwlist, &input, &input_len, &mask, &mask_len))
args, kwds, "OO", kwlist, &input_obj, &mask_obj))
{
return NULL;
}

if (_PyBytesLike_AsStringAndSize(input_obj, &input, &input_len) == -1)
{
return NULL;
}

if (_PyBytesLike_AsStringAndSize(mask_obj, &mask, &mask_len) == -1)
{
return NULL;
}
Expand All @@ -41,6 +90,8 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds)
return NULL;
}

// Create output.

result = PyBytes_FromStringAndSize(NULL, input_len);
if (result == NULL)
{
Expand All @@ -50,6 +101,8 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds)
// Since we juste created result, we don't need error checks.
output = PyBytes_AS_STRING(result);

// Perform the masking operation.

// Apparently GCC cannot figure out the following optimizations by itself.

// We need a new scope for MSVC 2010 (non C99 friendly)
Expand Down
7 changes: 6 additions & 1 deletion src/websockets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@

def apply_mask(data, mask):
"""
Apply masking to websocket message.
Apply masking to the data of a WebSocket message.
``data`` and ``mask`` are bytes-like objects.
Return :class:`bytes`.
"""
if len(mask) != 4:
raise ValueError("mask must contain 4 bytes")

return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask)))
9 changes: 8 additions & 1 deletion tests/test_framing.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,14 @@ def test_encode_data_str(self):
def test_encode_data_bytes(self):
self.assertEqual(encode_data(b'tea'), b'tea')

def test_encode_data_other(self):
def test_encode_data_bytearray(self):
self.assertEqual(encode_data(bytearray(b'tea')), b'tea')

def test_encode_data_list(self):
with self.assertRaises(TypeError):
encode_data([])

def test_encode_data_none(self):
with self.assertRaises(TypeError):
encode_data(None)

Expand Down
20 changes: 20 additions & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,10 @@ def test_send_binary(self):
self.loop.run_until_complete(self.protocol.send(b'tea'))
self.assertOneFrameSent(True, OP_BINARY, b'tea')

def test_send_binary_from_bytearray(self):
self.loop.run_until_complete(self.protocol.send(bytearray(b'tea')))
self.assertOneFrameSent(True, OP_BINARY, b'tea')

def test_send_type_error(self):
with self.assertRaises(TypeError):
self.loop.run_until_complete(self.protocol.send(42))
Expand All @@ -554,6 +558,14 @@ def test_send_iterable_binary(self):
(False, OP_BINARY, b'te'), (False, OP_CONT, b'a'), (True, OP_CONT, b'')
)

def test_send_iterable_binary_from_bytearray(self):
self.loop.run_until_complete(
self.protocol.send([bytearray(b'te'), bytearray(b'a')])
)
self.assertFramesSent(
(False, OP_BINARY, b'te'), (False, OP_CONT, b'a'), (True, OP_CONT, b'')
)

def test_send_empty_iterable(self):
self.loop.run_until_complete(self.protocol.send([]))
self.assertNoFrameSent()
Expand Down Expand Up @@ -616,6 +628,10 @@ def test_ping_binary(self):
self.loop.run_until_complete(self.protocol.ping(b'tea'))
self.assertOneFrameSent(True, OP_PING, b'tea')

def test_ping_binary_from_bytearray(self):
self.loop.run_until_complete(self.protocol.ping(bytearray(b'tea')))
self.assertOneFrameSent(True, OP_PING, b'tea')

def test_ping_type_error(self):
with self.assertRaises(TypeError):
self.loop.run_until_complete(self.protocol.ping(42))
Expand Down Expand Up @@ -661,6 +677,10 @@ def test_pong_binary(self):
self.loop.run_until_complete(self.protocol.pong(b'tea'))
self.assertOneFrameSent(True, OP_PONG, b'tea')

def test_pong_binary_from_bytearray(self):
self.loop.run_until_complete(self.protocol.pong(bytearray(b'tea')))
self.assertOneFrameSent(True, OP_PONG, b'tea')

def test_pong_type_error(self):
with self.assertRaises(TypeError):
self.loop.run_until_complete(self.protocol.pong(42))
Expand Down
19 changes: 11 additions & 8 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import unittest

from websockets.utils import apply_mask as py_apply_mask
Expand All @@ -9,14 +10,16 @@ def apply_mask(*args, **kwargs):
return py_apply_mask(*args, **kwargs)

def test_apply_mask(self):
for data_in, mask, data_out in [
(b'', b'1234', b''),
(b'aBcDe', b'\x00\x00\x00\x00', b'aBcDe'),
(b'abcdABCD', b'1234', b'PPPPpppp'),
(b'abcdABCD' * 10, b'1234', b'PPPPpppp' * 10),
]:
with self.subTest(data_in=data_in, mask=mask):
self.assertEqual(self.apply_mask(data_in, mask), data_out)
for data_type, mask_type in itertools.product([bytes, bytearray], repeat=2):
for data_in, mask, data_out in [
(b'', b'1234', b''),
(b'aBcDe', b'\x00\x00\x00\x00', b'aBcDe'),
(b'abcdABCD', b'1234', b'PPPPpppp'),
(b'abcdABCD' * 10, b'1234', b'PPPPpppp' * 10),
]:
data_in, mask = data_type(data_in), mask_type(mask)
with self.subTest(data_in=data_in, mask=mask):
self.assertEqual(self.apply_mask(data_in, mask), data_out)

def test_apply_mask_check_input_types(self):
for data_in, mask in [(None, None), (b'abcd', None), (None, b'abcd')]:
Expand Down

0 comments on commit e284fb2

Please sign in to comment.