Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions Lib/test/test_socket_reentrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Tests for re-entrant mutation vulnerabilities in socket module.

These tests verify that mutating sequences during argument parsing
via __buffer__ protocol does not cause crashes.
"""

import socket
import unittest


class ReentrantMutationTests(unittest.TestCase):
"""Tests for re-entrant mutation vulnerabilities in sendmsg/recvmsg_into."""

@unittest.skipUnless(hasattr(socket.socket, "sendmsg"),
"sendmsg not supported")
def test_sendmsg_reentrant_data_mutation(self):
# Test that sendmsg() handles re-entrant mutation of data buffers
# via __buffer__ protocol.
seq = []

class MutBuffer:
def __init__(self, data):
self._data = bytes(data)
self.tripped = False

def __buffer__(self, flags):
if not self.tripped:
self.tripped = True
seq.clear()
return memoryview(self._data)

seq[:] = [
MutBuffer(b'Hello'),
b'World',
b'Test',
]

left, right = socket.socketpair()
try:
# Should not crash
try:
left.sendmsg(seq)
except (TypeError, OSError):
pass # Expected - the important thing is no crash
finally:
left.close()
right.close()

@unittest.skipUnless(hasattr(socket.socket, "recvmsg_into"),
"recvmsg_into not supported")
def test_recvmsg_into_reentrant_buffer_mutation(self):
# Test that recvmsg_into() handles re-entrant mutation of buffers
# via __buffer__ protocol.
seq = []

class MutBuffer:
def __init__(self, data):
self._data = bytearray(data)
self.tripped = False

def __buffer__(self, flags):
if not self.tripped:
self.tripped = True
seq.clear()
return memoryview(self._data)

seq[:] = [
MutBuffer(b'x' * 100),
bytearray(100),
bytearray(100),
]

left, right = socket.socketpair()
try:
left.send(b'Hello World!')
# Should not crash
try:
right.recvmsg_into(seq)
except (TypeError, OSError):
pass # Expected - the important thing is no crash
finally:
left.close()
right.close()


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixed crashes in :meth:`socket.socket.sendmsg` and :meth:`socket.socket.recvmsg_into`
that could occur if buffer sequences are mutated re-entrantly during argument parsing
via ``__buffer__`` protocol callbacks.
22 changes: 12 additions & 10 deletions Modules/socketmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -4527,11 +4527,13 @@ sock_recvmsg_into(PyObject *self, PyObject *args)
&buffers_arg, &ancbufsize, &flags))
return NULL;

if ((fast = PySequence_Fast(buffers_arg,
"recvmsg_into() argument 1 must be an "
"iterable")) == NULL)
fast = PySequence_Tuple(buffers_arg);
if (fast == NULL) {
PyErr_SetString(PyExc_TypeError,
"recvmsg_into() argument 1 must be an iterable");
return NULL;
nitems = PySequence_Fast_GET_SIZE(fast);
}
nitems = PyTuple_GET_SIZE(fast);
if (nitems > INT_MAX) {
PyErr_SetString(PyExc_OSError, "recvmsg_into() argument 1 is too long");
goto finally;
Expand All @@ -4545,7 +4547,7 @@ sock_recvmsg_into(PyObject *self, PyObject *args)
goto finally;
}
for (; nbufs < nitems; nbufs++) {
if (!PyArg_Parse(PySequence_Fast_GET_ITEM(fast, nbufs),
if (!PyArg_Parse(PyTuple_GET_ITEM(fast, nbufs),
"w*;recvmsg_into() argument 1 must be an iterable "
"of single-segment read-write buffers",
&bufs[nbufs]))
Expand Down Expand Up @@ -4854,14 +4856,14 @@ sock_sendmsg_iovec(PySocketSockObject *s, PyObject *data_arg,

/* Fill in an iovec for each message part, and save the Py_buffer
structs to release afterwards. */
data_fast = PySequence_Fast(data_arg,
"sendmsg() argument 1 must be an "
"iterable");
data_fast = PySequence_Tuple(data_arg);
if (data_fast == NULL) {
PyErr_SetString(PyExc_TypeError,
"sendmsg() argument 1 must be an iterable");
goto finally;
}

ndataparts = PySequence_Fast_GET_SIZE(data_fast);
ndataparts = PyTuple_GET_SIZE(data_fast);
if (ndataparts > INT_MAX) {
PyErr_SetString(PyExc_OSError, "sendmsg() argument 1 is too long");
goto finally;
Expand All @@ -4883,7 +4885,7 @@ sock_sendmsg_iovec(PySocketSockObject *s, PyObject *data_arg,
}
}
for (; ndatabufs < ndataparts; ndatabufs++) {
if (PyObject_GetBuffer(PySequence_Fast_GET_ITEM(data_fast, ndatabufs),
if (PyObject_GetBuffer(PyTuple_GET_ITEM(data_fast, ndatabufs),
&databufs[ndatabufs], PyBUF_SIMPLE) < 0)
goto finally;
iovs[ndatabufs].iov_base = databufs[ndatabufs].buf;
Expand Down
Loading