From 8d00f89545e055a24b2a638a241f1f0c3557a4bd Mon Sep 17 00:00:00 2001 From: eevelweezel Date: Fri, 15 May 2026 16:05:30 -0500 Subject: [PATCH] Fixed #32969 -- Improve pickling of HTTPResponse and subclasses. --- django/core/handlers/asgi.py | 20 ++++++++ django/core/handlers/base.py | 50 ++++++++++++++++++ django/core/handlers/wsgi.py | 71 +++++++++++--------------- django/core/servers/basehttp.py | 2 +- django/template/base.py | 18 +++++++ django/test/client.py | 12 ++++- django/urls/resolvers.py | 21 ++++++-- tests/requests_tests/tests.py | 3 +- tests/test_client/tests.py | 81 ++++++++++++++++++++++++++++++ tests/urlpatterns_reverse/tests.py | 18 +++++-- 10 files changed, 243 insertions(+), 53 deletions(-) diff --git a/django/core/handlers/asgi.py b/django/core/handlers/asgi.py index 7ee52088c416..e6e2b9f0d214 100644 --- a/django/core/handlers/asgi.py +++ b/django/core/handlers/asgi.py @@ -5,6 +5,7 @@ import traceback from collections import defaultdict from contextlib import aclosing, closing +from io import BytesIO from asgiref.sync import ThreadSensitiveContext, sync_to_async @@ -146,6 +147,21 @@ def close(self): super().close() self._stream.close() + def __getstate__(self): + state = self.__dict__.copy() + state["stream"] = self._stream.read() + del state["_stream"] + return state + + def __setstate__(self, state): + stream = state.pop("stream") + self.__dict__.update(state) + try: + content_length = int(self.scope.get("CONTENT_LENGTH")) + except (ValueError, TypeError): + content_length = 0 + self._stream = base.LimitedStream(BytesIO(stream), content_length) + class ASGIHandler(base.BaseHandler): """Handler for ASGI requests.""" @@ -365,3 +381,7 @@ def chunk_bytes(cls, data): (position + cls.chunk_size) >= len(data), ) position += cls.chunk_size + + def __setstate__(self, state): + self.__dict__.update(state) + self.load_middleware(is_async=True) diff --git a/django/core/handlers/base.py b/django/core/handlers/base.py index d65a5edcb6a6..c0a67f31d026 100644 --- a/django/core/handlers/base.py +++ b/django/core/handlers/base.py @@ -2,6 +2,7 @@ import logging import types from inspect import iscoroutinefunction +from io import IOBase from asgiref.sync import async_to_sync, sync_to_async @@ -18,6 +19,47 @@ logger = logging.getLogger("django.request") +class LimitedStream(IOBase): + """ + Wrap another stream to disallow reading it past a number of bytes. + + Based on the implementation from werkzeug.wsgi.LimitedStream. See: + https://github.com/pallets/werkzeug/blob/dbf78f67/src/werkzeug/wsgi.py#L828 + """ + + def __init__(self, stream, limit): + self._read = stream.read + self._readline = stream.readline + self._pos = 0 + self.limit = limit + + def read(self, size=-1, /): + _pos = self._pos + limit = self.limit + if _pos >= limit: + return b"" + if size == -1 or size is None: + size = limit - _pos + else: + size = min(size, limit - _pos) + data = self._read(size) + self._pos += len(data) + return data + + def readline(self, size=-1, /): + _pos = self._pos + limit = self.limit + if _pos >= limit: + return b"" + if size == -1 or size is None: + size = limit - _pos + else: + size = min(size, limit - _pos) + line = self._readline(size) + self._pos += len(line) + return line + + class BaseHandler: _view_middleware = None _template_response_middleware = None @@ -366,6 +408,14 @@ def process_exception_by_middleware(self, exception, request): return response return None + def __getstate__(self): + state = self.__dict__.copy() + del state["_view_middleware"] + del state["_template_response_middleware"] + del state["_exception_middleware"] + del state["_middleware_chain"] + return state + def reset_urlconf(sender, **kwargs): """Reset the URLconf after each request is finished.""" diff --git a/django/core/handlers/wsgi.py b/django/core/handlers/wsgi.py index aab9fe0c4916..ec346ea72c5f 100644 --- a/django/core/handlers/wsgi.py +++ b/django/core/handlers/wsgi.py @@ -1,4 +1,4 @@ -from io import IOBase +from io import BytesIO from django.conf import settings from django.core import signals @@ -12,47 +12,6 @@ _slashes_re = _lazy_re_compile(rb"/+") -class LimitedStream(IOBase): - """ - Wrap another stream to disallow reading it past a number of bytes. - - Based on the implementation from werkzeug.wsgi.LimitedStream. See: - https://github.com/pallets/werkzeug/blob/dbf78f67/src/werkzeug/wsgi.py#L828 - """ - - def __init__(self, stream, limit): - self._read = stream.read - self._readline = stream.readline - self._pos = 0 - self.limit = limit - - def read(self, size=-1, /): - _pos = self._pos - limit = self.limit - if _pos >= limit: - return b"" - if size == -1 or size is None: - size = limit - _pos - else: - size = min(size, limit - _pos) - data = self._read(size) - self._pos += len(data) - return data - - def readline(self, size=-1, /): - _pos = self._pos - limit = self.limit - if _pos >= limit: - return b"" - if size == -1 or size is None: - size = limit - _pos - else: - size = min(size, limit - _pos) - line = self._readline(size) - self._pos += len(line) - return line - - class WSGIRequest(HttpRequest): def __init__(self, environ): script_name = get_script_name(environ) @@ -75,7 +34,7 @@ def __init__(self, environ): content_length = int(environ.get("CONTENT_LENGTH")) except (ValueError, TypeError): content_length = 0 - self._stream = LimitedStream(self.environ["wsgi.input"], content_length) + self._stream = base.LimitedStream(self.environ["wsgi.input"], content_length) self._read_started = False self.resolver_match = None @@ -109,6 +68,28 @@ def FILES(self): POST = property(_get_post, _set_post) + def __getstate__(self): + state = self.__dict__.copy() + errb = b"" + if self.environ.get("wsgi.errors"): + errb = self.environ["wsgi.errors"].getvalue() + state["environ"]["wsgi.input"] = self._stream.read() + state["environ"]["wsgi.errors"] = errb + del state["META"] + del state["_stream"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.environ["wsgi.input"] = BytesIO(self.environ["wsgi.input"]) + self.environ["wsgi.errors"] = BytesIO(self.environ["wsgi.errors"]) + self.META = self.environ + try: + content_length = int(self.environ.get("CONTENT_LENGTH")) + except (ValueError, TypeError): + content_length = 0 + self._stream = base.LimitedStream(self.environ["wsgi.input"], content_length) + class WSGIHandler(base.BaseHandler): request_class = WSGIRequest @@ -143,6 +124,10 @@ def __call__(self, environ, start_response): ) return response + def __setstate__(self, state): + self.__dict__.update(state) + self.load_middleware(is_async=False) + def get_path_info(environ): """Return the HTTP request's PATH_INFO as a string.""" diff --git a/django/core/servers/basehttp.py b/django/core/servers/basehttp.py index d62b88d28651..4570ff62cfe9 100644 --- a/django/core/servers/basehttp.py +++ b/django/core/servers/basehttp.py @@ -15,7 +15,7 @@ from wsgiref import simple_server from django.core.exceptions import ImproperlyConfigured -from django.core.handlers.wsgi import LimitedStream +from django.core.handlers.base import LimitedStream from django.core.wsgi import get_wsgi_application from django.db import connections from django.utils.log import log_message diff --git a/django/template/base.py b/django/template/base.py index 8c6390de33fe..49c8c47f16a6 100644 --- a/django/template/base.py +++ b/django/template/base.py @@ -290,6 +290,24 @@ def get_exception_info(self, exception, token): "end": end, } + def __getstate__(self): + state = self.__dict__.copy() + state["engine"] = str(self.engine) + state["origin"] = self.origin.name + state["loader"] = str(self.origin.loader) + del state["nodelist"] + return state + + def __setstate(self, state): + from .engine import _engine_list + + self.__dict__.update(state) + self.engine = _engine_list(using=state["engine"]) + self.origin = Origin( + state["origin"], template_name=state["name"], loader=state["loader"] + ) + self.nodelist = self.compile_nodelist() + class PartialTemplate: """ diff --git a/django/test/client.py b/django/test/client.py index 0f986d5a6c81..5c5f1fdef84a 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -14,8 +14,8 @@ from django.conf import settings from django.core.handlers.asgi import ASGIRequest -from django.core.handlers.base import BaseHandler -from django.core.handlers.wsgi import LimitedStream, WSGIRequest +from django.core.handlers.base import BaseHandler, LimitedStream +from django.core.handlers.wsgi import WSGIRequest from django.core.serializers.json import DjangoJSONEncoder from django.core.signals import got_request_exception, request_finished, request_started from django.db import close_old_connections @@ -209,6 +209,10 @@ def __call__(self, environ): return response + def __setstate__(self, state): + self.__dict__.update(state) + self.load_middleware() + class AsyncClientHandler(BaseHandler): """An async version of ClientHandler.""" @@ -261,6 +265,10 @@ async def __call__(self, scope): request_finished.connect(close_old_connections) return response + def __setstate__(self, state): + self.__dict__.update(state) + self.load_middleware(is_async=True) + def store_rendered_templates(store, signal, sender, template, context, **kwargs): """ diff --git a/django/urls/resolvers.py b/django/urls/resolvers.py index 6c681f9d8d32..7de2cf98c2ed 100644 --- a/django/urls/resolvers.py +++ b/django/urls/resolvers.py @@ -11,7 +11,6 @@ import re import string from importlib import import_module -from pickle import PicklingError from urllib.parse import quote from asgiref.local import Local @@ -101,8 +100,14 @@ def __repr__(self): ) ) - def __reduce_ex__(self, protocol): - raise PicklingError(f"Cannot pickle {self.__class__.__qualname__}.") + def __getstate__(self): + state = self.__dict__.copy() + del state["func"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.func = get_callable(state["_func_path"]) def get_resolver(urlconf=None): @@ -499,6 +504,16 @@ def lookup_str(self): return callback.__module__ + "." + callback.__class__.__name__ return callback.__module__ + "." + callback.__qualname__ + def __getstate__(self): + state = self.__dict__.copy() + state["lookup_str"] = self.lookup_str + del state["callback"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.callback = get_callable(state["lookup_str"]) + class URLResolver: def __init__( diff --git a/tests/requests_tests/tests.py b/tests/requests_tests/tests.py index cf24ddd326d3..414f2ac95caa 100644 --- a/tests/requests_tests/tests.py +++ b/tests/requests_tests/tests.py @@ -7,7 +7,8 @@ from django.core.exceptions import BadRequest, DisallowedHost from django.core.files.uploadedfile import InMemoryUploadedFile from django.core.files.uploadhandler import MemoryFileUploadHandler -from django.core.handlers.wsgi import LimitedStream, WSGIRequest +from django.core.handlers.base import LimitedStream +from django.core.handlers.wsgi import WSGIRequest from django.http import ( HttpHeaders, HttpRequest, diff --git a/tests/test_client/tests.py b/tests/test_client/tests.py index 8e62c717eba0..88eb276eac3e 100644 --- a/tests/test_client/tests.py +++ b/tests/test_client/tests.py @@ -22,6 +22,7 @@ import copy import itertools +import pickle import tempfile from unittest import mock @@ -101,6 +102,86 @@ async def test_copy_response_async(self): self.assertIs(response_copy.resolver_match, response.resolver_match) self.assertIs(response_copy.asgi_request, response.asgi_request) + def test_pickle_response(self): + response = self.client.get("/cbv_view/") + pickled = pickle.dumps(response) + new_response = pickle.loads(pickled) + self.assertIsNot(new_response, response) + self.assertIsNot(new_response.resolver_match, response.resolver_match) + self.assertIsNot(new_response.wsgi_request, response.wsgi_request) + for key in [ + "headers", + "cookies", + "closed", + "_reason_phrase", + "_is_rendered", + "exc_info", + ]: + with self.subTest(key=key): + self.assertEqual(new_response.__dict__[key], response.__dict__[key]) + for key in [ + "path", + "method", + "content_type", + "content_params", + "_read_started", + "COOKIES", + ]: + with self.subTest(key=key): + self.assertEqual( + new_response.wsgi_request.__dict__[key], + response.wsgi_request.__dict__[key], + ) + for key in ["defaults", "cookies", "exc_info", "extra", "headers"]: + with self.subTest(key=key): + self.assertEqual( + new_response.client.__dict__[key], response.client.__dict__[key] + ) + self.assertEqual( + repr(response.resolver_match.__dict__["_wrapped"]), + repr(new_response.resolver_match), + ) + + async def test_pickle_response_async(self): + response = await self.async_client.get("/cbv_view/") + pickled = pickle.dumps(response) + new_response = pickle.loads(pickled) + self.assertIsNot(new_response, response) + self.assertIsNot(new_response.resolver_match, response.resolver_match) + self.assertIsNot(new_response.asgi_request, response.asgi_request) + for key in [ + "headers", + "cookies", + "closed", + "_reason_phrase", + "_is_rendered", + "exc_info", + ]: + with self.subTest(key=key): + self.assertEqual(new_response.__dict__[key], response.__dict__[key]) + for key in [ + "path", + "method", + "content_type", + "content_params", + "_read_started", + "COOKIES", + ]: + with self.subTest(key=key): + self.assertEqual( + new_response.asgi_request.__dict__[key], + response.asgi_request.__dict__[key], + ) + for key in ["defaults", "cookies", "exc_info", "extra", "headers"]: + with self.subTest(key=key): + self.assertEqual( + new_response.client.__dict__[key], response.client.__dict__[key] + ) + self.assertEqual( + repr(response.resolver_match.__dict__["_wrapped"]), + repr(new_response.resolver_match), + ) + def test_query_string_encoding(self): # WSGI requires latin-1 encoded strings. response = self.client.get("/get_view/?var=1\ufffd") diff --git a/tests/urlpatterns_reverse/tests.py b/tests/urlpatterns_reverse/tests.py index 58cd2601db5e..ed8ebac77ada 100644 --- a/tests/urlpatterns_reverse/tests.py +++ b/tests/urlpatterns_reverse/tests.py @@ -2,6 +2,7 @@ Unit tests for reverse URL lookups. """ +import copy import pickle import sys import threading @@ -1689,9 +1690,20 @@ def test_repr_functools_partial(self): @override_settings(ROOT_URLCONF="urlpatterns.path_urls") def test_pickling(self): - msg = "Cannot pickle ResolverMatch." - with self.assertRaisesMessage(pickle.PicklingError, msg): - pickle.dumps(resolve("/users/")) + resolver_match = resolve("/users/") + pickled = pickle.dumps(resolver_match) + new_match = pickle.loads(pickled) + self.assertEqual(resolver_match.func, new_match.func) + self.assertEqual(resolver_match.route, new_match.route) + self.assertEqual(len(resolver_match.tried), len(new_match.tried)) + + @override_settings(ROOT_URLCONF="urlpatterns.path_urls") + def test_copy(self): + resolver_match = resolve("/users/") + new_match = copy.copy(resolver_match) + self.assertEqual(resolver_match.__dict__, new_match.__dict__) + self.assertEqual(resolver_match.route, new_match.route) + self.assertEqual(len(resolver_match.tried), len(new_match.tried)) @override_settings(ROOT_URLCONF="urlpatterns_reverse.erroneous_urls")