diff --git a/aiohttp/_cparser.pxd b/aiohttp/_cparser.pxd index 1b3be6d4efb..cc7ef58d664 100644 --- a/aiohttp/_cparser.pxd +++ b/aiohttp/_cparser.pxd @@ -145,6 +145,7 @@ cdef extern from "llhttp.h": int llhttp_should_keep_alive(const llhttp_t* parser) + void llhttp_resume(llhttp_t* parser) void llhttp_resume_after_upgrade(llhttp_t* parser) llhttp_errno_t llhttp_get_errno(const llhttp_t* parser) diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index 9a444be66fc..8300b35c247 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -46,7 +46,8 @@ include "_headers.pxi" from aiohttp cimport _find_header -ALLOWED_UPGRADES = frozenset({"websocket"}) + +cdef frozenset ALLOWED_UPGRADES = frozenset({"websocket"}) DEF DEFAULT_FREELIST_SIZE = 250 cdef extern from "Python.h": @@ -69,7 +70,7 @@ cdef object CONTENT_ENCODING = hdrs.CONTENT_ENCODING cdef object EMPTY_PAYLOAD = _EMPTY_PAYLOAD cdef object StreamReader = _StreamReader cdef object DeflateBuffer = _DeflateBuffer -cdef bytes EMPTY_BYTES = b"" +cdef tuple EMPTY_FEED_DATA_RESULT = ((), False, b"") # RFC 9110 singleton headers — duplicates are rejected in strict mode. # In lax mode (response parser default), the check is skipped entirely @@ -298,7 +299,7 @@ cdef class HttpParser: bint _has_value int _header_name_size - object _protocol + readonly object protocol object _loop object _timer @@ -309,6 +310,7 @@ cdef class HttpParser: bint _read_until_eof bint _lax + bytes _tail bint _started object _url bytearray _buf @@ -319,6 +321,8 @@ cdef class HttpParser: list _raw_headers bint _upgraded list _messages + bint _more_data_available + bint _paused object _payload bint _payload_error object _payload_exception @@ -359,18 +363,21 @@ cdef class HttpParser: self._cparser.data = self self._cparser.content_length = 0 - self._protocol = protocol + self.protocol = protocol self._loop = loop self._timer = timer self._buf = bytearray() + self._more_data_available = False + self._paused = False self._payload = None self._payload_error = 0 self._payload_exception = payload_exception self._messages = [] - self._raw_name = EMPTY_BYTES - self._raw_value = EMPTY_BYTES + self._raw_name = b"" + self._raw_value = b"" + self._tail = b"" self._has_value = False self._header_name_size = 0 @@ -401,7 +408,7 @@ cdef class HttpParser: cdef _process_header(self): cdef str value - if self._raw_name is not EMPTY_BYTES: + if self._raw_name is not b"": name = find_header(self._raw_name) value = self._raw_value.decode('utf-8', 'surrogateescape') @@ -426,20 +433,20 @@ cdef class HttpParser: self._has_value = False self._header_name_size = 0 self._raw_headers.append((self._raw_name, self._raw_value)) - self._raw_name = EMPTY_BYTES - self._raw_value = EMPTY_BYTES + self._raw_name = b"" + self._raw_value = b"" cdef _on_header_field(self, char* at, size_t length): if self._has_value: self._process_header() - if self._raw_name is EMPTY_BYTES: + if self._raw_name is b"": self._raw_name = at[:length] else: self._raw_name += at[:length] cdef _on_header_value(self, char* at, size_t length): - if self._raw_value is EMPTY_BYTES: + if self._raw_value is b"": self._raw_value = at[:length] else: self._raw_value += at[:length] @@ -495,7 +502,7 @@ cdef class HttpParser: self._read_until_eof) ): payload = StreamReader( - self._protocol, timer=self._timer, loop=self._loop, + self.protocol, timer=self._timer, loop=self._loop, limit=self._limit) else: payload = EMPTY_PAYLOAD @@ -535,6 +542,10 @@ cdef class HttpParser: ### Public API ### + def pause_reading(self): + assert self._payload is not None + self._paused = True + def feed_eof(self): cdef bytes desc @@ -562,6 +573,21 @@ cdef class HttpParser: char* base cdef cparser.llhttp_errno_t errno + # Proactor loop sends bytearray. + if type(data) is not bytes: + data = bytes(data) + + if self._tail: + data, self._tail = self._tail + data, b"" + + had_more_data = self._more_data_available + if self._more_data_available: + result = cb_on_body(self._cparser, b"", 0) + if result is cparser.HPE_PAUSED: + self._tail = data + return EMPTY_FEED_DATA_RESULT + # TODO: Do we need to handle error case (-1)? + PyObject_GetBuffer(data, &self.py_buf, PyBUF_SIMPLE) # Cache buffer pointer before PyBuffer_Release to avoid use-after-release. base = self.py_buf.buf @@ -574,12 +600,15 @@ cdef class HttpParser: if errno is cparser.HPE_PAUSED_UPGRADE: cparser.llhttp_resume_after_upgrade(self._cparser) - nb = cparser.llhttp_get_error_pos(self._cparser) - base + elif errno is cparser.HPE_PAUSED: + cparser.llhttp_resume(self._cparser) + pos = cparser.llhttp_get_error_pos(self._cparser) - base + self._tail = data[pos:] PyBuffer_Release(&self.py_buf) - if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED_UPGRADE): + if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED, cparser.HPE_PAUSED_UPGRADE): if self._payload_error == 0: if self._last_error is not None: ex = self._last_error @@ -603,8 +632,9 @@ cdef class HttpParser: if self._upgraded: return messages, True, data[nb:] - else: - return messages, False, b"" + if not messages: # Shortcut to reduce Python overhead + return EMPTY_FEED_DATA_RESULT + return messages, False, b"" def set_upgraded(self, val): self._upgraded = val @@ -799,19 +829,26 @@ cdef int cb_on_body(cparser.llhttp_t* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data cdef bytes body = at[:length] - try: - pyparser._payload.feed_data(body) - except BaseException as underlying_exc: - reraised_exc = underlying_exc - if pyparser._payload_exception is not None: - reraised_exc = pyparser._payload_exception(str(underlying_exc)) - - set_exception(pyparser._payload, reraised_exc, underlying_exc) - - pyparser._payload_error = 1 - return -1 - else: - return 0 + while body or pyparser._more_data_available: + try: + pyparser._more_data_available = pyparser._payload.feed_data(body) + except BaseException as underlying_exc: + reraised_exc = underlying_exc + if pyparser._payload_exception is not None: + reraised_exc = pyparser._payload_exception(str(underlying_exc)) + + set_exception(pyparser._payload, reraised_exc, underlying_exc) + + pyparser._payload_error = 1 + pyparser._paused = False + return -1 + body = b"" + + if pyparser._paused: + pyparser._paused = False + return cparser.HPE_PAUSED + pyparser._paused = False + return 0 cdef int cb_on_message_complete(cparser.llhttp_t* parser) except -1: diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py index 7f01830f4e9..179391cd3a0 100644 --- a/aiohttp/base_protocol.py +++ b/aiohttp/base_protocol.py @@ -1,26 +1,35 @@ import asyncio -from typing import cast +from typing import TYPE_CHECKING, Any, cast from .client_exceptions import ClientConnectionResetError from .helpers import set_exception from .tcp_helpers import tcp_nodelay +if TYPE_CHECKING: + from .http_parser import HttpParser + class BaseProtocol(asyncio.Protocol): __slots__ = ( "_loop", "_paused", + "_parser", "_drain_waiter", "_connection_lost", "_reading_paused", + "_upgraded", "transport", ) - def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, loop: asyncio.AbstractEventLoop, parser: "HttpParser[Any] | None" = None + ) -> None: self._loop: asyncio.AbstractEventLoop = loop self._paused = False self._drain_waiter: asyncio.Future[None] | None = None self._reading_paused = False + self._parser = parser + self._upgraded = False self.transport: asyncio.Transport | None = None @@ -48,15 +57,27 @@ def resume_writing(self) -> None: waiter.set_result(None) def pause_reading(self) -> None: - if not self._reading_paused and self.transport is not None: + self._reading_paused = True + # Parser shouldn't be paused on websockets. + if not self._upgraded: + assert self._parser is not None + self._parser.pause_reading() + if self.transport is not None: try: self.transport.pause_reading() except (AttributeError, NotImplementedError, RuntimeError): pass - self._reading_paused = True - def resume_reading(self) -> None: - if self._reading_paused and self.transport is not None: + def resume_reading(self, resume_parser: bool = True) -> None: + self._reading_paused = False + + # This will resume parsing any unprocessed data from the last pause. + if not self._upgraded and resume_parser: + self.data_received(b"") + + # Reading may have been paused again in the above call if there was a lot of + # compressed data still pending. + if not self._reading_paused and self.transport is not None: try: self.transport.resume_reading() except (AttributeError, NotImplementedError, RuntimeError): diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index bb088b6a99c..c6e78e57b42 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -32,7 +32,7 @@ class ResponseHandler(BaseProtocol, DataQueue[tuple[RawResponseMessage, StreamRe """Helper class to adapt between Protocol and StreamReader.""" def __init__(self, loop: asyncio.AbstractEventLoop) -> None: - BaseProtocol.__init__(self, loop=loop) + BaseProtocol.__init__(self, loop=loop, parser=None) DataQueue.__init__(self, loop) self._should_close = False @@ -43,10 +43,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._data_received_cb: Callable[[], None] | None = None self._timer = None - self._tail = b"" - self._upgraded = False - self._parser: HttpResponseParser | None = None self._read_timeout: float | None = None self._read_timeout_handle: asyncio.TimerHandle | None = None @@ -197,8 +194,8 @@ def pause_reading(self) -> None: super().pause_reading() self._drop_timeout() - def resume_reading(self) -> None: - super().resume_reading() + def resume_reading(self, resume_parser: bool = True) -> None: + super().resume_reading(resume_parser) self._reschedule_timeout() def set_exception( @@ -299,10 +296,10 @@ def _on_read_timeout(self) -> None: set_exception(self._payload, exc) def data_received(self, data: bytes) -> None: - self._reschedule_timeout() - - if not data: - return + # If no data, then we are resuming decompression. We haven't received + # data from the socket, so we can avoid the reschedule overhead. + if data: + self._reschedule_timeout() # custom payload parser - currently always WebSocketReader if self._payload_parser is not None: diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index 2a8818c4220..12d4e9d3a8a 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -34,7 +34,9 @@ MAX_SYNC_CHUNK_SIZE = 4096 -DEFAULT_MAX_DECOMPRESS_SIZE = 2**25 # 32MiB +# Matches the max size we receive from sockets: +# https://github.com/python/cpython/blob/1857a40807daeae3a1bf5efb682de9c9ae6df845/Lib/asyncio/selector_events.py#L766 +DEFAULT_MAX_DECOMPRESS_SIZE = 256 * 1024 # Unlimited decompression constants - different libraries use different conventions ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited @@ -53,6 +55,9 @@ def flush(self, length: int = ..., /) -> bytes: ... @property def eof(self) -> bool: ... + @property + def unconsumed_tail(self) -> bytes: ... + class ZLibBackendProtocol(Protocol): MAX_WBITS: int @@ -179,6 +184,11 @@ async def decompress( ) return self.decompress_sync(data, max_length) + @property + @abstractmethod + def data_available(self) -> bool: + """Return True if more output is available by passing b"".""" + class ZLibCompressor: def __init__( @@ -271,7 +281,9 @@ def __init__( def decompress_sync( self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED ) -> bytes: - return self._decompressor.decompress(data, max_length) + return self._decompressor.decompress( + self._decompressor.unconsumed_tail + data, max_length + ) def flush(self, length: int = 0) -> bytes: return ( @@ -280,6 +292,10 @@ def flush(self, length: int = 0) -> bytes: else self._decompressor.flush() ) + @property + def data_available(self) -> bool: + return bool(self._decompressor.unconsumed_tail) + @property def eof(self) -> bool: return self._decompressor.eof @@ -301,6 +317,7 @@ def __init__( "Please install `Brotli` module" ) self._obj = brotli.Decompressor() + self._last_empty = False super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) def decompress_sync( @@ -308,8 +325,12 @@ def decompress_sync( ) -> bytes: """Decompress the given data.""" if hasattr(self._obj, "decompress"): - return cast(bytes, self._obj.decompress(data, max_length)) - return cast(bytes, self._obj.process(data, max_length)) + result = cast(bytes, self._obj.decompress(data, max_length)) + else: + result = cast(bytes, self._obj.process(data, max_length)) + # Only way to know that brotli has no further data is checking we get no output + self._last_empty = result == b"" + return result def flush(self) -> bytes: """Flush the decompressor.""" @@ -317,6 +338,10 @@ def flush(self) -> bytes: return cast(bytes, self._obj.flush()) return b"" + @property + def data_available(self) -> bool: + return not self._obj.is_finished() and not self._last_empty + class ZSTDDecompressor(DecompressionBaseHandler): def __init__( @@ -373,3 +398,9 @@ def decompress_sync( def flush(self) -> bytes: return b"" + + @property + def data_available(self) -> bool: + return ( + not self._obj.needs_input and not self._obj.eof + ) or self._pending_unused_data is not None diff --git a/aiohttp/http_exceptions.py b/aiohttp/http_exceptions.py index cf3c05434c5..95d0d6373ae 100644 --- a/aiohttp/http_exceptions.py +++ b/aiohttp/http_exceptions.py @@ -73,10 +73,6 @@ class ContentLengthError(PayloadEncodingError): """Not enough data to satisfy content length header.""" -class DecompressSizeError(PayloadEncodingError): - """Decompressed size exceeds the configured limit.""" - - class LineTooLong(BadHttpMessage): def __init__(self, line: bytes, limit: int) -> None: super().__init__(f"Got more than {limit} bytes when reading: {line!r}.") diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index 207cf8da39e..bf87690fef3 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -35,7 +35,6 @@ BadStatusLine, ContentEncodingError, ContentLengthError, - DecompressSizeError, InvalidHeader, InvalidURLError, LineTooLong, @@ -124,6 +123,12 @@ class RawResponseMessage(NamedTuple): _MsgT = TypeVar("_MsgT", RawRequestMessage, RawResponseMessage) +class PayloadState(IntEnum): + PAYLOAD_COMPLETE = 0 + PAYLOAD_NEEDS_INPUT = 1 + PAYLOAD_HAS_PENDING_INPUT = 2 + + class ParseState(IntEnum): PARSE_NONE = 0 PARSE_LENGTH = 1 @@ -265,6 +270,7 @@ def __init__( self._upgraded = False self._payload = None self._payload_parser: HttpPayloadParser | None = None + self._payload_has_more_data = False self._auto_decompress = auto_decompress self._limit = limit self._headers_parser = HeadersParser(max_field_size, self.lax) @@ -275,6 +281,10 @@ def parse_message(self, lines: list[bytes]) -> _MsgT: ... @abc.abstractmethod def _is_chunked_te(self, te: str) -> bool: ... + def pause_reading(self) -> None: + assert self._payload_parser is not None + self._payload_parser.pause_reading() + def feed_eof(self) -> _MsgT | None: if self._payload_parser is not None: self._payload_parser.feed_eof() @@ -311,7 +321,7 @@ def feed_data( max_line_length = self.max_line_size should_close = False - while start_pos < data_len: + while start_pos < data_len or self._payload_has_more_data: # read HTTP message (request/response line + headers), \r\n\r\n # and split by lines if self._payload_parser is None and not self._upgraded: @@ -470,11 +480,13 @@ def get_content_length() -> int | None: break # feed payload - elif data and start_pos < data_len: + elif self._payload_has_more_data or (data and start_pos < data_len): assert not self._lines assert self._payload_parser is not None try: - eof, data = self._payload_parser.feed_data(data[start_pos:], SEP) + payload_state, data = self._payload_parser.feed_data( + data[start_pos:], SEP + ) except Exception as underlying_exc: reraised_exc: BaseException = underlying_exc if self.payload_exception is not None: @@ -486,18 +498,25 @@ def get_content_length() -> int | None: underlying_exc, ) - eof = True + payload_state = PayloadState.PAYLOAD_COMPLETE data = b"" if isinstance( underlying_exc, (InvalidHeader, TransferEncodingError) ): raise - if eof: - start_pos = 0 - data_len = len(data) - self._payload_parser = None - continue + self._payload_has_more_data = ( + payload_state == PayloadState.PAYLOAD_HAS_PENDING_INPUT + ) + + if payload_state is not PayloadState.PAYLOAD_COMPLETE: + # We've either consumed all available data, or we're pausing + # until the reader buffer is freed up. + break + + start_pos = 0 + data_len = len(data) + self._payload_parser = None else: break @@ -777,6 +796,7 @@ def __init__( max_trailers: int = 128, ) -> None: self._length = 0 + self._paused = False self._type = ParseState.PARSE_UNTIL_EOF self._chunk = ChunkState.PARSE_CHUNKED_SIZE self._chunk_size = 0 @@ -787,6 +807,7 @@ def __init__( self._max_line_size = max_line_size self._max_field_size = max_field_size self._max_trailers = max_trailers + self._more_data_available = False self._trailer_lines: list[bytes] = [] self.done = False @@ -815,6 +836,9 @@ def __init__( self.payload = real_payload + def pause_reading(self) -> None: + self._paused = True + def feed_eof(self) -> None: if self._type == ParseState.PARSE_UNTIL_EOF: self.payload.feed_eof() @@ -829,32 +853,52 @@ def feed_eof(self) -> None: def feed_data( self, chunk: bytes, SEP: _SEP = b"\r\n", CHUNK_EXT: bytes = b";" - ) -> tuple[bool, bytes]: + ) -> tuple[PayloadState, bytes]: + """Receive a chunk of data to process. + + Return: + PayloadState - The current state of payload processing. + This function may be called with empty bytes after returning + PAYLOAD_HAS_PENDING_INPUT to continue processing after a pause. + bytes - If payload is complete, this is the unconsumed bytes intended for the + next message/payload, b"" otherwise. + """ # Read specified amount of bytes if self._type == ParseState.PARSE_LENGTH: + if self._chunk_tail: + chunk = self._chunk_tail + chunk + self._chunk_tail = b"" + required = self._length self._length = max(required - len(chunk), 0) - self.payload.feed_data(chunk[:required]) + self._more_data_available = self.payload.feed_data(chunk[:required]) + while self._more_data_available: + if self._paused: + self._paused = False + self._chunk_tail = chunk[required:] + return PayloadState.PAYLOAD_HAS_PENDING_INPUT, b"" + self._more_data_available = self.payload.feed_data(b"") + if self._length == 0: self.payload.feed_eof() - return True, chunk[required:] - + return PayloadState.PAYLOAD_COMPLETE, chunk[required:] # Chunked transfer encoding parser elif self._type == ParseState.PARSE_CHUNKED: if self._chunk_tail: - # We should never have a tail if we're inside the payload body. - assert self._chunk != ChunkState.PARSE_CHUNKED_CHUNK - # We should check the length is sane. - max_line_length = self._max_line_size - if self._chunk == ChunkState.PARSE_TRAILERS: - max_line_length = self._max_field_size - if len(self._chunk_tail) > max_line_length: - raise LineTooLong(self._chunk_tail[:100] + b"...", max_line_length) + # We should check the length is sane when not processing payload body. + if self._chunk != ChunkState.PARSE_CHUNKED_CHUNK: + max_line_length = self._max_line_size + if self._chunk == ChunkState.PARSE_TRAILERS: + max_line_length = self._max_field_size + if len(self._chunk_tail) > max_line_length: + raise LineTooLong( + self._chunk_tail[:100] + b"...", max_line_length + ) chunk = self._chunk_tail + chunk self._chunk_tail = b"" - while chunk: + while chunk or self._more_data_available: # read next chunk size if self._chunk == ChunkState.PARSE_CHUNKED_SIZE: pos = chunk.find(SEP) @@ -894,17 +938,26 @@ def feed_data( self.payload.begin_http_chunk_receiving() else: self._chunk_tail = chunk - return False, b"" + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" # read chunk and feed buffer if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK: + if self._paused: + self._paused = False + self._chunk_tail = chunk + return PayloadState.PAYLOAD_HAS_PENDING_INPUT, b"" + required = self._chunk_size self._chunk_size = max(required - len(chunk), 0) - self.payload.feed_data(chunk[:required]) + self._more_data_available = self.payload.feed_data(chunk[:required]) + chunk = chunk[required:] + + if self._more_data_available: + continue if self._chunk_size: - return False, b"" - chunk = chunk[required:] + self._paused = False + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF self.payload.end_http_chunk_receiving() @@ -923,13 +976,13 @@ def feed_data( raise exc else: self._chunk_tail = chunk - return False, b"" + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" if self._chunk == ChunkState.PARSE_TRAILERS: pos = chunk.find(SEP) if pos < 0: # No line found self._chunk_tail = chunk - return False, b"" + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" line = chunk[:pos] chunk = chunk[pos + len(SEP) :] @@ -955,13 +1008,18 @@ def feed_data( finally: self._trailer_lines.clear() self.payload.feed_eof() - return True, chunk + return PayloadState.PAYLOAD_COMPLETE, chunk # Read all bytes until eof elif self._type == ParseState.PARSE_UNTIL_EOF: - self.payload.feed_data(chunk) + self._more_data_available = self.payload.feed_data(chunk) + while self._more_data_available: + if self._paused: + self._paused = False + return PayloadState.PAYLOAD_HAS_PENDING_INPUT, b"" + self._more_data_available = self.payload.feed_data(b"") - return False, b"" + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" class DeflateBuffer: @@ -1006,10 +1064,8 @@ def set_exception( ) -> None: set_exception(self.out, exc, exc_cause) - def feed_data(self, chunk: bytes) -> None: - if not chunk: - return - + def feed_data(self, chunk: bytes) -> bool: + """Return True if more data is available and this method should be called again with b"".""" self.size += len(chunk) self.out.total_compressed_bytes = self.size @@ -1028,9 +1084,8 @@ def feed_data(self, chunk: bytes) -> None: ) try: - # Decompress with limit + 1 so we can detect if output exceeds limit chunk = self.decompressor.decompress_sync( - chunk, max_length=self._max_decompress_size + 1 + chunk, max_length=self._max_decompress_size ) except Exception: raise ContentEncodingError( @@ -1039,15 +1094,9 @@ def feed_data(self, chunk: bytes) -> None: self._started_decoding = True - # Check if decompression limit was exceeded - if len(chunk) > self._max_decompress_size: - raise DecompressSizeError( - "Decompressed data exceeds the configured limit of %d bytes" - % self._max_decompress_size - ) - if chunk: self.out.feed_data(chunk) + return self.decompressor.data_available def feed_eof(self) -> None: chunk = self.decompressor.flush() diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index c44219e92b4..3d6c1057c75 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -1,6 +1,7 @@ import base64 import binascii import json +import math import re import sys import uuid @@ -268,6 +269,8 @@ def __init__( subtype: str = "mixed", default_charset: str | None = None, max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE, + client_max_size: int = math.inf, # type: ignore[assignment] + max_size_error_cls: type[Exception] = ValueError, ) -> None: self.headers = headers self._boundary = boundary @@ -285,6 +288,8 @@ def __init__( self._content_eof = 0 self._cache: dict[str, Any] = {} self._max_decompress_size = max_decompress_size + self._client_max_size = client_max_size + self._max_size_error_cls = max_size_error_cls def __aiter__(self) -> Self: return self @@ -313,11 +318,19 @@ async def read(self, *, decode: bool = False) -> bytes: data = bytearray() while not self._at_eof: data.extend(await self.read_chunk(self.chunk_size)) + if len(data) > self._client_max_size: + raise self._max_size_error_cls( + max_size=self._client_max_size, actual_size=len(data) + ) # https://github.com/python/mypy/issues/17537 if decode: # type: ignore[unreachable] decoded_data = bytearray() async for d in self.decode_iter(data): decoded_data.extend(d) + if len(decoded_data) > self._client_max_size: + raise self._max_size_error_cls( + max_size=self._client_max_size, actual_size=len(decoded_data) + ) return decoded_data return data @@ -559,6 +572,8 @@ async def _decode_content_async(self, data: bytes) -> AsyncIterator[bytes]: suppress_deflate_header=True, ) yield await d.decompress(data, max_length=self._max_decompress_size) + while d.data_available: + yield await d.decompress(b"", max_length=self._max_decompress_size) else: raise RuntimeError(f"unknown content encoding: {encoding}") @@ -652,8 +667,10 @@ def __init__( headers: Mapping[str, str], content: StreamReader, *, + client_max_size: int = math.inf, # type: ignore[assignment] max_field_size: int = 8190, max_headers: int = 128, + max_size_error_cls: type[Exception] = ValueError, ) -> None: self._mimetype = parse_mimetype(headers[CONTENT_TYPE]) assert self._mimetype.type == "multipart", "multipart/* content type expected" @@ -664,11 +681,13 @@ def __init__( self.headers = headers self._boundary = ("--" + self._get_boundary()).encode() + self._client_max_size = client_max_size self._content = content self._default_charset: str | None = None self._last_part: MultipartReader | BodyPartReader | None = None self._max_field_size = max_field_size self._max_headers = max_headers + self._max_size_error_cls = max_size_error_cls self._at_eof = False self._at_bof = True self._unread: list[bytes] = [] @@ -768,12 +787,21 @@ def _get_part_reader( if mimetype.type == "multipart": if self.multipart_reader_cls is None: - return type(self)(headers, self._content) + return type(self)( + headers, + self._content, + client_max_size=self._client_max_size, + max_field_size=self._max_field_size, + max_headers=self._max_headers, + max_size_error_cls=self._max_size_error_cls, + ) return self.multipart_reader_cls( headers, self._content, + client_max_size=self._client_max_size, max_field_size=self._max_field_size, max_headers=self._max_headers, + max_size_error_cls=self._max_size_error_cls, ) else: return self.part_reader_cls( @@ -782,6 +810,8 @@ def _get_part_reader( self._content, subtype=self._mimetype.subtype, default_charset=self._default_charset, + client_max_size=self._client_max_size, + max_size_error_cls=self._max_size_error_cls, ) def _get_boundary(self) -> str: diff --git a/aiohttp/streams.py b/aiohttp/streams.py index bacb810958b..afefc9f0216 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -219,8 +219,8 @@ def feed_eof(self) -> None: self._eof_waiter = None set_result(waiter, None) - if self._protocol._reading_paused: - self._protocol.resume_reading() + # At EOF the parser is done, there won't be unprocessed data. + self._protocol.resume_reading(resume_parser=False) for cb in self._eof_callbacks: try: @@ -274,11 +274,11 @@ def unread_data(self, data: bytes) -> None: self._buffer.appendleft(data) self._eof_counter = 0 - def feed_data(self, data: bytes) -> None: + def feed_data(self, data: bytes) -> bool: assert not self._eof, "feed_data after feed_eof" if not data: - return + return False data_len = len(data) self._size += data_len @@ -290,8 +290,9 @@ def feed_data(self, data: bytes) -> None: self._waiter = None set_result(waiter, None) - if self._size > self._high_water and not self._protocol._reading_paused: + if self._size > self._high_water: self._protocol.pause_reading() + return False def begin_http_chunk_receiving(self) -> None: if self._http_chunk_splits is None: @@ -328,10 +329,7 @@ def end_http_chunk_receiving(self) -> None: # If we get too many small chunks before self._high_water is reached, then any # .read() call becomes computationally expensive, and could block the event loop # for too long, hence an additional self._high_water_chunks here. - if ( - len(self._http_chunk_splits) > self._high_water_chunks - and not self._protocol._reading_paused - ): + if len(self._http_chunk_splits) > self._high_water_chunks: self._protocol.pause_reading() # wake up readchunk when end of http chunk received @@ -531,13 +529,9 @@ def _read_nowait_chunk(self, n: int) -> bytes: while chunk_splits and chunk_splits[0] < self._cursor: chunk_splits.popleft() - if ( - self._protocol._reading_paused - and self._size < self._low_water - and ( - self._http_chunk_splits is None - or len(self._http_chunk_splits) < self._low_water_chunks - ) + if self._size < self._low_water and ( + self._http_chunk_splits is None + or len(self._http_chunk_splits) < self._low_water_chunks ): self._protocol.resume_reading() return data @@ -597,8 +591,8 @@ def at_eof(self) -> bool: async def wait_eof(self) -> None: return - def feed_data(self, data: bytes) -> None: - pass + def feed_data(self, data: bytes) -> bool: + return False async def readline(self, *, max_line_length: int | None = None) -> bytes: return b"" diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 20d76408d4f..abb5f86d81c 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -22,6 +22,7 @@ HttpVersion10, RawRequestMessage, StreamWriter, + WebSocketReader, ) from .http_exceptions import BadHttpMethod from .log import access_logger, server_logger @@ -171,10 +172,8 @@ class RequestHandler(BaseProtocol, Generic[_Request]): "_handler_waiter", "_waiter", "_task_handler", - "_upgrade", "_payload_parser", "_data_received_cb", - "_request_parser", "logger", "access_log", "access_logger", @@ -207,7 +206,17 @@ def __init__( auto_decompress: bool = True, timeout_ceil_threshold: float = 5, ): - super().__init__(loop) + parser = HttpRequestParser( + self, + loop, + read_bufsize, + max_line_size=max_line_size, + max_field_size=max_field_size, + max_headers=max_headers, + payload_exception=RequestPayloadError, + auto_decompress=auto_decompress, + ) + super().__init__(loop, parser) # _request_count is the number of requests processed with the same connection. self._request_count = 0 @@ -239,19 +248,7 @@ def __init__( self._waiter: asyncio.Future[None] | None = None self._handler_waiter: asyncio.Future[None] | None = None self._task_handler: asyncio.Task[None] | None = None - - self._upgrade = False self._payload_parser: Any = None - self._request_parser: HttpRequestParser | None = HttpRequestParser( - self, - loop, - read_bufsize, - max_line_size=max_line_size, - max_field_size=max_field_size, - max_headers=max_headers, - payload_exception=RequestPayloadError, - auto_decompress=auto_decompress, - ) self._timeout_ceil_threshold: float = 5 try: @@ -392,7 +389,7 @@ def connection_lost(self, exc: BaseException | None) -> None: self._manager = None self._request_factory = None self._request_handler = None - self._request_parser = None + self._parser = None if self._keepalive_handle is not None: self._keepalive_handle.cancel() @@ -412,9 +409,10 @@ def connection_lost(self, exc: BaseException | None) -> None: self._payload_parser = None def set_parser( - self, parser: Any, data_received_cb: Callable[[], None] | None = None + self, + parser: WebSocketReader, + data_received_cb: Callable[[], None] | None = None, ) -> None: - # Actual type is WebReader assert self._payload_parser is None self._payload_parser = parser @@ -432,10 +430,10 @@ def data_received(self, data: bytes) -> None: return # parse http messages messages: Sequence[_MsgType] - if self._payload_parser is None and not self._upgrade: - assert self._request_parser is not None + if self._payload_parser is None and not self._upgraded: + assert self._parser is not None try: - messages, upgraded, tail = self._request_parser.feed_data(data) + messages, upgraded, tail = self._parser.feed_data(data) except HttpProcessingError as exc: messages = [ (_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD) @@ -452,12 +450,12 @@ def data_received(self, data: bytes) -> None: # don't set result twice waiter.set_result(None) - self._upgrade = upgraded + self._upgraded = upgraded if upgraded and tail: self._message_tail = tail # no parser, just store - elif self._payload_parser is None and self._upgrade and data: + elif self._payload_parser is None and self._upgraded and data: self._message_tail += data # feed payload @@ -719,11 +717,11 @@ async def finish_response( prematurely. """ request._finish() - if self._request_parser is not None: - self._request_parser.set_upgraded(False) - self._upgrade = False + if self._parser is not None: + self._parser.set_upgraded(False) + self._upgraded = False if self._message_tail: - self._request_parser.feed_data(self._message_tail) + self._parser.feed_data(self._message_tail) self._message_tail = b"" try: prepare_meth = resp.prepare diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 42be85e2e74..d0106e47b94 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -675,8 +675,10 @@ async def multipart(self) -> MultipartReader: return MultipartReader( self._headers, self._payload, + client_max_size=self._client_max_size, max_field_size=self._protocol.max_field_size, max_headers=self._protocol.max_headers, + max_size_error_cls=HTTPRequestEntityTooLarge, ) async def post(self) -> "MultiDictProxy[str | bytes | FileField]": diff --git a/tests/test_base_protocol.py b/tests/test_base_protocol.py index 713dba2d0c2..234e9927c02 100644 --- a/tests/test_base_protocol.py +++ b/tests/test_base_protocol.py @@ -5,6 +5,7 @@ import pytest from aiohttp.base_protocol import BaseProtocol +from aiohttp.http_parser import HttpParser async def test_loop() -> None: @@ -26,33 +27,28 @@ async def test_pause_writing() -> None: async def test_pause_reading_no_transport() -> None: loop = asyncio.get_event_loop() - pr = BaseProtocol(loop) - assert not pr._reading_paused + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) + pr = BaseProtocol(loop, parser=parser) pr.pause_reading() - assert not pr._reading_paused + parser.pause_reading.assert_called_once() async def test_pause_reading_stub_transport() -> None: loop = asyncio.get_event_loop() - pr = BaseProtocol(loop) + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) + pr = BaseProtocol(loop, parser=parser) tr = asyncio.Transport() pr.transport = tr assert not pr._reading_paused pr.pause_reading() assert pr._reading_paused - - -async def test_resume_reading_no_transport() -> None: - loop = asyncio.get_event_loop() - pr = BaseProtocol(loop) - pr._reading_paused = True - pr.resume_reading() - assert pr._reading_paused + parser.pause_reading.assert_called_once() # type: ignore[unreachable] async def test_resume_reading_stub_transport() -> None: loop = asyncio.get_event_loop() - pr = BaseProtocol(loop) + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) + pr = BaseProtocol(loop, parser=parser) tr = asyncio.Transport() pr.transport = tr pr._reading_paused = True diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 8ee45330bb5..80e95c29512 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -53,7 +53,6 @@ ) from aiohttp.client_reqrep import ClientRequest from aiohttp.compression_utils import DEFAULT_MAX_DECOMPRESS_SIZE -from aiohttp.http_exceptions import DecompressSizeError from aiohttp.payload import ( AsyncIterablePayload, BufferedReaderPayload, @@ -2407,10 +2406,9 @@ async def test_payload_decompress_size_limit(aiohttp_client: AiohttpClient) -> N When a compressed payload expands beyond the configured limit, we raise DecompressSizeError. """ - # Create a highly compressible payload that exceeds the decompression limit. - # 64MiB of repeated bytes compresses to ~32KB but expands beyond the - # 32MiB per-call limit. - original = b"A" * (64 * 2**20) + # Create a highly compressible payload. + payload_size = 64 * 2**20 + original = b"A" * payload_size compressed = zlib.compress(original) assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE @@ -2427,11 +2425,11 @@ async def handler(request: web.Request) -> web.Response: async with client.get("/") as resp: assert resp.status == 200 - with pytest.raises(aiohttp.ClientPayloadError) as exc_info: - await resp.read() + received = 0 + async for chunk in resp.content.iter_chunked(1024): + received += len(chunk) - assert isinstance(exc_info.value.__cause__, DecompressSizeError) - assert "Decompressed data exceeds" in str(exc_info.value.__cause__) + assert received == payload_size @pytest.mark.skipif(brotli is None, reason="brotli is not installed") @@ -2440,8 +2438,9 @@ async def test_payload_decompress_size_limit_brotli( ) -> None: """Test that brotli decompression size limit triggers DecompressSizeError.""" assert brotli is not None - # Create a highly compressible payload that exceeds the decompression limit. - original = b"A" * (64 * 2**20) + # Create a highly compressible payload + payload_size = 64 * 2**20 + original = b"A" * payload_size compressed = brotli.compress(original) assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE @@ -2457,11 +2456,11 @@ async def handler(request: web.Request) -> web.Response: async with client.get("/") as resp: assert resp.status == 200 - with pytest.raises(aiohttp.ClientPayloadError) as exc_info: - await resp.read() + received = 0 + async for chunk in resp.content.iter_chunked(1024): + received += len(chunk) - assert isinstance(exc_info.value.__cause__, DecompressSizeError) - assert "Decompressed data exceeds" in str(exc_info.value.__cause__) + assert received == payload_size @pytest.mark.skipif(ZstdCompressor is None, reason="backports.zstd is not installed") @@ -2470,8 +2469,9 @@ async def test_payload_decompress_size_limit_zstd( ) -> None: """Test that zstd decompression size limit triggers DecompressSizeError.""" assert ZstdCompressor is not None - # Create a highly compressible payload that exceeds the decompression limit. - original = b"A" * (64 * 2**20) + # Create a highly compressible payload. + payload_size = 64 * 2**20 + original = b"A" * payload_size compressor = ZstdCompressor() compressed = compressor.compress(original) + compressor.flush() assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE @@ -2488,11 +2488,11 @@ async def handler(request: web.Request) -> web.Response: async with client.get("/") as resp: assert resp.status == 200 - with pytest.raises(aiohttp.ClientPayloadError) as exc_info: - await resp.read() + received = 0 + async for chunk in resp.content.iter_chunked(1024): + received += len(chunk) - assert isinstance(exc_info.value.__cause__, DecompressSizeError) - assert "Decompressed data exceeds" in str(exc_info.value.__cause__) + assert received == payload_size async def test_bad_payload_chunked_encoding(aiohttp_client: AiohttpClient) -> None: diff --git a/tests/test_client_proto.py b/tests/test_client_proto.py index 49a81c8dbb3..0a26a211453 100644 --- a/tests/test_client_proto.py +++ b/tests/test_client_proto.py @@ -10,7 +10,7 @@ from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientResponse from aiohttp.helpers import TimerNoop -from aiohttp.http_parser import RawResponseMessage +from aiohttp.http_parser import HttpParser, RawResponseMessage async def test_force_close(loop: asyncio.AbstractEventLoop) -> None: @@ -35,7 +35,9 @@ async def test_oserror(loop: asyncio.AbstractEventLoop) -> None: async def test_pause_resume_on_error(loop: asyncio.AbstractEventLoop) -> None: + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) proto = ResponseHandler(loop=loop) + proto._parser = parser transport = mock.Mock() proto.connection_made(transport) diff --git a/tests/test_flowcontrol_streams.py b/tests/test_flowcontrol_streams.py index 9e21f786610..3654ba4aad2 100644 --- a/tests/test_flowcontrol_streams.py +++ b/tests/test_flowcontrol_streams.py @@ -5,6 +5,7 @@ from aiohttp import streams from aiohttp.base_protocol import BaseProtocol +from aiohttp.http_parser import HttpParser @pytest.fixture @@ -38,7 +39,6 @@ async def test_readline(self, stream: streams.StreamReader) -> None: stream.feed_data(b"d\n") res = await stream.readline() assert res == b"d\n" - assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readline_resume_paused(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = True @@ -51,7 +51,6 @@ async def test_readany(self, stream: streams.StreamReader) -> None: stream.feed_data(b"data") res = await stream.readany() assert res == b"data" - assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readany_resume_paused(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = True @@ -65,7 +64,6 @@ async def test_readchunk(self, stream: streams.StreamReader) -> None: res, end_of_http_chunk = await stream.readchunk() assert res == b"data" assert not end_of_http_chunk - assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readchunk_resume_paused(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = True @@ -120,7 +118,8 @@ async def test_resumed_on_eof(self, stream: streams.StreamReader) -> None: async def test_stream_reader_eof_when_full() -> None: loop = asyncio.get_event_loop() - protocol = BaseProtocol(loop=loop) + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) + protocol = BaseProtocol(loop=loop, parser=parser) protocol.transport = asyncio.Transport() stream = streams.StreamReader(protocol, 1024, loop=loop) diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 2c593a7589c..f5044c0572c 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -1,10 +1,11 @@ # Tests for aiohttp/protocol.py import asyncio +import platform import re import sys import zlib -from collections.abc import Iterable +from collections.abc import Iterable, Iterator from contextlib import suppress from typing import Any from unittest import mock @@ -17,6 +18,7 @@ import aiohttp from aiohttp import http_exceptions, streams from aiohttp.base_protocol import BaseProtocol +from aiohttp.client_proto import ResponseHandler from aiohttp.helpers import NO_EXTENSIONS from aiohttp.http_parser import ( DeflateBuffer, @@ -27,8 +29,12 @@ HttpRequestParserPy, HttpResponseParser, HttpResponseParserPy, + PayloadState, ) from aiohttp.http_writer import HttpVersion +from aiohttp.web_protocol import RequestHandler +from aiohttp.web_request import Request +from aiohttp.web_server import Server try: try: @@ -56,9 +62,23 @@ RESPONSE_PARSERS.append(HttpResponseParserC) +@pytest.fixture +def server() -> Any: + return mock.create_autospec( + Server, + request_factory=mock.Mock(), + request_handler=mock.AsyncMock(), + instance=True, + ) + + @pytest.fixture def protocol() -> Any: - return mock.create_autospec(BaseProtocol, spec_set=True, instance=True) + return mock.create_autospec( + BaseProtocol, + spec_set=True, + instance=True, + ) def _gen_ids(parsers: Iterable[type[HttpParser[Any]]]) -> list[str]: @@ -71,11 +91,13 @@ def _gen_ids(parsers: Iterable[type[HttpParser[Any]]]) -> list[str]: @pytest.fixture(params=REQUEST_PARSERS, ids=_gen_ids(REQUEST_PARSERS)) def parser( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, + server: Server[Request], request: pytest.FixtureRequest, -) -> HttpRequestParser: +) -> Iterator[HttpRequestParser]: + protocol = RequestHandler(server, loop=loop) + # Parser implementations - return request.param( # type: ignore[no-any-return] + parser = request.param( protocol, loop, 2**16, @@ -83,6 +105,10 @@ def parser( max_headers=128, max_field_size=8190, ) + protocol._force_close = False + protocol._parser = parser + with mock.patch.object(protocol, "transport", True): + yield parser @pytest.fixture(params=REQUEST_PARSERS, ids=_gen_ids(REQUEST_PARSERS)) @@ -94,11 +120,12 @@ def request_cls(request: pytest.FixtureRequest) -> type[HttpRequestParser]: @pytest.fixture(params=RESPONSE_PARSERS, ids=_gen_ids(RESPONSE_PARSERS)) def response( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, request: pytest.FixtureRequest, ) -> HttpResponseParser: + protocol = ResponseHandler(loop) + # Parser implementations - return request.param( # type: ignore[no-any-return] + parser = request.param( protocol, loop, 2**16, @@ -107,6 +134,8 @@ def response( max_field_size=8190, read_until_eof=True, ) + protocol._parser = parser + return parser # type: ignore[no-any-return] @pytest.fixture(params=RESPONSE_PARSERS, ids=_gen_ids(RESPONSE_PARSERS)) @@ -154,9 +183,11 @@ def test_reject_obsolete_line_folding(parser: HttpRequestParser) -> None: @pytest.mark.skipif(NO_EXTENSIONS, reason="Only tests C parser.") def test_invalid_character( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, + server: Server[Request], request: pytest.FixtureRequest, ) -> None: + protocol = RequestHandler(server, loop=loop) + parser = HttpRequestParserC( protocol, loop, @@ -164,6 +195,7 @@ def test_invalid_character( max_line_size=8190, max_field_size=8190, ) + protocol._parser = parser text = b"POST / HTTP/1.1\r\nHost: localhost:8080\r\nSet-Cookie: abc\x01def\r\n\r\n" error_detail = re.escape(r""": @@ -176,9 +208,11 @@ def test_invalid_character( @pytest.mark.skipif(NO_EXTENSIONS, reason="Only tests C parser.") def test_invalid_linebreak( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, + server: Server[Request], request: pytest.FixtureRequest, ) -> None: + protocol = RequestHandler(server, loop=loop) + parser = HttpRequestParserC( protocol, loop, @@ -186,6 +220,7 @@ def test_invalid_linebreak( max_line_size=8190, max_field_size=8190, ) + protocol._parser = parser text = b"GET /world HTTP/1.1\r\nHost: 127.0.0.1\n\r\n" error_detail = re.escape(r""": @@ -250,8 +285,10 @@ def test_ctl_host_header_bad_characters(parser: HttpRequestParser) -> None: def test_unpaired_surrogate_in_header_py( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol + loop: asyncio.AbstractEventLoop, server: Server[Request] ) -> None: + protocol = RequestHandler(server, loop=loop) + parser = HttpRequestParserPy( protocol, loop, @@ -259,6 +296,7 @@ def test_unpaired_surrogate_in_header_py( max_line_size=8190, max_field_size=8190, ) + protocol._parser = parser text = b"POST / HTTP/1.1\r\n\xff\r\n\r\n" message = None try: @@ -1013,6 +1051,113 @@ def test_max_header_value_size_under_limit(parser: HttpRequestParser) -> None: assert msg.url == URL("/test") +async def test_chunk_splits_after_pause(parser: HttpRequestParser) -> None: + text = ( + b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + + b"1\r\nb\r\n" * 50000 + + b"0\r\n\r\n" + ) + + messages, upgrade, tail = parser.feed_data(text) + payload = messages[0][-1] + # Payload should have paused reading and stopped receiving new chunks after 16k. + assert payload._http_chunk_splits is not None + assert len(payload._http_chunk_splits) == 16385 + # We should still get the full result after read(), as it will continue processing. + result = await payload.read() + assert len(result) == 50000 # Compare len first, as it's easier to debug in diff. + assert result == b"b" * 50000 + + +async def test_compressed_with_tail(response: HttpResponseParser) -> None: + """Test compressed content-length body followed by a second response. + + With 2 responses arriving in one call and the first compressed, this should + trigger decompression pausing with the second response being saved as the tail. + Verify that the second response is resumed from the tail. + """ + # Must be large enough to exceed high water mark. + original = b"x" * 1024 * 1024 + compressed = zlib.compress(original) + resp1 = ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: " + str(len(compressed)).encode() + b"\r\n" + b"Content-Encoding: deflate\r\n" + b"\r\n" + ) + compressed + resp2 = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok" + + msgs, upgrade, tail = response.feed_data(resp1 + resp2) + payload = msgs[0][-1] + result = await payload.read() + assert len(result) == len(original) + assert result == original + + payload = response.protocol._buffer[0][-1] + result = await payload.read() + assert result == b"ok" + + +async def test_compressed_chunked_with_pending(response: HttpResponseParser) -> None: + """Test chunked + compressed where the decompressor needs to resume from pause. + + We need to verify that chunked messages continue parsing correctly after + a pause and resume in the decompression. + """ + # Must be large enough to exceed high water mark. + original = b"A" * 1024 * 1024 + compressed = zlib.compress(original) + chunk_data = hex(len(compressed))[2:].encode() + b"\r\n" + compressed + b"\r\n" + headers = ( + b"HTTP/1.1 200 OK\r\n" + b"Transfer-Encoding: chunked\r\n" + b"Content-Encoding: deflate\r\n" + b"\r\n" + ) + data = headers + chunk_data + b"0\r\n\r\n" + + msgs, upgrade, tail = response.feed_data(data) + payload = msgs[0][-1] + result = await payload.read() + assert len(result) == len(original) + assert result == original + + +async def test_compressed_until_eof_with_pending(response: HttpResponseParser) -> None: + """Test read-until-eof + compressed with pause.""" + + # Must be large enough to exceed high water mark. + original = b"B" * 1024 * 1024 + compressed = zlib.compress(original) + # No Content-Length or Transfer-Encoding means the parser must parse until EOF. + headers = b"HTTP/1.1 200 OK\r\n" b"Content-Encoding: deflate\r\n" b"\r\n" + + msgs, upgrade, tail = response.feed_data(headers + compressed) + response.feed_eof() + payload = msgs[0][-1] + result = await payload.read() + assert len(result) == len(original) + assert result == original + + +async def test_compressed_256kb(response: HttpResponseParser) -> None: + original = b"x" * 256 * 1024 + compressed = zlib.compress(original) + headers = ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: " + str(len(compressed)).encode() + b"\r\n" + b"Content-Encoding: deflate\r\n" + b"\r\n" + ) + + messages, upgrade, tail = response.feed_data(headers + compressed) + assert len(messages) == 1 + payload = messages[0][-1] + result = await payload.read() + assert len(result) == len(original) + assert result == original + + @pytest.mark.parametrize("size", [40965, 8191]) def test_max_header_value_size_continuation( response: HttpResponseParser, size: int @@ -1447,8 +1592,10 @@ async def test_http_response_parser_bad_chunked_lax( @pytest.mark.dev_mode async def test_http_response_parser_bad_chunked_strict_py( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol + loop: asyncio.AbstractEventLoop, ) -> None: + protocol = ResponseHandler(loop) + response = HttpResponseParserPy( protocol, loop, @@ -1456,6 +1603,7 @@ async def test_http_response_parser_bad_chunked_strict_py( max_line_size=8190, max_field_size=8190, ) + protocol._parser = response text = ( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5 \r\nabcde\r\n0\r\n\r\n" ) @@ -1469,8 +1617,10 @@ async def test_http_response_parser_bad_chunked_strict_py( reason="C based HTTP parser not available", ) async def test_http_response_parser_bad_chunked_strict_c( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol + loop: asyncio.AbstractEventLoop, ) -> None: + protocol = ResponseHandler(loop) + response = HttpResponseParserC( protocol, loop, @@ -1478,6 +1628,7 @@ async def test_http_response_parser_bad_chunked_strict_c( max_line_size=8190, max_field_size=8190, ) + protocol._parser = response text = ( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5 \r\nabcde\r\n0\r\n\r\n" ) @@ -1628,10 +1779,12 @@ async def test_request_chunked_reject_bad_trailer(parser: HttpRequestParser) -> def test_parse_no_length_or_te_on_post( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, + server: Server[Request], request_cls: type[HttpRequestParser], ) -> None: + protocol = RequestHandler(server, loop=loop) parser = request_cls(protocol, loop, limit=2**16) + protocol._parser = parser text = b"POST /test HTTP/1.1\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] @@ -1640,10 +1793,11 @@ def test_parse_no_length_or_te_on_post( def test_parse_payload_response_without_body( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, response_cls: type[HttpResponseParser], ) -> None: + protocol = ResponseHandler(loop) parser = response_cls(protocol, loop, 2**16, response_with_body=False) + protocol._parser = parser text = b"HTTP/1.1 200 Ok\r\ncontent-length: 10\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] @@ -1904,8 +2058,10 @@ def test_parse_uri_utf8_percent_encoded(parser: HttpRequestParser) -> None: reason="C based HTTP parser not available", ) def test_parse_bad_method_for_c_parser_raises( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol + loop: asyncio.AbstractEventLoop, server: Server[Request] ) -> None: + protocol = RequestHandler(server, loop=loop) + payload = b"GET1 /test HTTP/1.1\r\n\r\n" parser = HttpRequestParserC( protocol, @@ -1915,6 +2071,7 @@ def test_parse_bad_method_for_c_parser_raises( max_headers=128, max_field_size=8190, ) + protocol._parser = parser with pytest.raises(aiohttp.http_exceptions.BadStatusLine): messages, upgrade, tail = parser.feed_data(payload) @@ -2049,8 +2206,8 @@ async def test_parse_chunked_payload_split_end_trailers4( async def test_http_payload_parser_length(self, protocol: BaseProtocol) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, length=2, headers_parser=HeadersParser()) - eof, tail = p.feed_data(b"1245") - assert eof + state, tail = p.feed_data(b"1245") + assert state is PayloadState.PAYLOAD_COMPLETE assert b"12" == out._buffer[0] assert b"45" == tail @@ -2326,6 +2483,7 @@ async def test_empty_body(self, protocol: BaseProtocol) -> None: assert buf.at_eof() + @pytest.mark.skipif(platform.python_implementation() == "PyPy", reason="Broken") @pytest.mark.parametrize( "chunk_size", [1024, 2**14, 2**16], # 1KB, 16KB, 64KB diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 52e97a993a3..98eddc82368 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -354,12 +354,17 @@ async def test_read_with_content_encoding_gzip(self) -> None: result = await obj.read(decode=True) assert b"Time to Relax!" == result + @pytest.mark.skipif(sys.version_info < (3, 11), reason="wbits not available") async def test_read_with_content_encoding_deflate(self) -> None: + content = b"A" * 1_000_000 # Large enough to exceed max_length. + compressed = ZLibBackend.compress(content, wbits=-ZLibBackend.MAX_WBITS) + h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "deflate"})) - with Stream(b"\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--") as stream: + with Stream(compressed + b"\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.read(decode=True) - assert b"Time to Relax!" == result + assert len(result) == len(content) # Simplifies diff on failure + assert result == content async def test_read_with_content_encoding_identity(self) -> None: thing = ( @@ -1721,6 +1726,35 @@ async def test_body_part_reader_payload_as_bytes() -> None: payload.decode() +@pytest.mark.skipif(sys.version_info < (3, 11), reason="No wbits parameter") +async def test_body_part_reader_payload_write() -> None: + content = b"A" * 1_000_000 # Large enough to exceed max_length. + compressed = ZLibBackend.compress(content, wbits=-ZLibBackend.MAX_WBITS) + output = b"" + + async def write(inp: bytes) -> None: + nonlocal output + output += inp + + h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "deflate"})) + if sys.version_info >= (3, 12): + writer = mock.create_autospec( + AbstractStreamWriter, write=write, spec_set=True, instance=True + ) + else: + writer = mock.create_autospec( + AbstractStreamWriter, spec_set=True, instance=True + ) + writer.write.side_effect = write + with Stream(compressed + b"\r\n--:--") as stream: + body_part = aiohttp.BodyPartReader(BOUNDARY, h, stream) + payload = BodyPartReaderPayload(body_part) + await payload.write(writer) + + assert len(output) == len(content) # Simplifies diff on failure + assert output == content + + async def test_multipart_writer_close_with_exceptions() -> None: """Test that MultipartWriter.close() continues closing all parts even if one raises.""" writer = aiohttp.MultipartWriter() diff --git a/tests/test_streams.py b/tests/test_streams.py index 93e0caaac9b..52d926f6baa 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1110,7 +1110,7 @@ async def test_empty_stream_reader() -> None: assert s.set_exception(ValueError()) is None # type: ignore[func-returns-value] assert s.exception() is None assert s.feed_eof() is None # type: ignore[func-returns-value] - assert s.feed_data(b"data") is None # type: ignore[func-returns-value] + assert s.feed_data(b"data") is False assert s.at_eof() await s.wait_eof() assert await s.read() == b"" diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 730d662ced4..7784a41ae7c 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -328,6 +328,27 @@ async def handler(request: web.Request) -> web.Response: resp.release() +async def test_multipart_client_max_size(aiohttp_client: AiohttpClient) -> None: + with multipart.MultipartWriter() as writer: + writer.append("A" * 1020) + + async def handler(request: web.Request) -> web.Response: + reader = await request.multipart() + assert isinstance(reader, multipart.MultipartReader) + + part = await reader.next() + assert isinstance(part, multipart.BodyPartReader) + await part.text() # Should raise HttpRequestEntityTooLarge + assert False + + app = web.Application(client_max_size=1000) + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + async with client.post("/", data=writer) as resp: + assert resp.status == 413 + + async def test_multipart_empty(aiohttp_client: AiohttpClient) -> None: with multipart.MultipartWriter() as writer: pass diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py index 5ae1e5dd756..9acad2f2101 100644 --- a/tests/test_web_protocol.py +++ b/tests/test_web_protocol.py @@ -1,48 +1,51 @@ import asyncio -from typing import Any, cast from unittest import mock -from aiohttp.web_protocol import RequestHandler +import pytest +from aiohttp.http import WebSocketReader +from aiohttp.web_protocol import RequestHandler +from aiohttp.web_request import BaseRequest +from aiohttp.web_server import Server -class _DummyManager: - def __init__(self) -> None: - self.request_handler = mock.Mock() - self.request_factory = mock.Mock() +@pytest.fixture +def dummy_manager() -> Server[BaseRequest]: + return mock.create_autospec(Server[BaseRequest], request_handler=mock.Mock(), request_factory=mock.Mock(), instance=True) # type: ignore[no-any-return] -class _DummyParser: - def __init__(self) -> None: - self.received: list[bytes] = [] - def feed_data(self, data: bytes) -> tuple[bool, bytes]: - self.received.append(data) - return False, b"" +@pytest.fixture +def dummy_reader() -> tuple[WebSocketReader, mock.Mock]: + m = mock.create_autospec(WebSocketReader, spec_set=True, instance=True) + m.feed_data.return_value = False, b"" + return m, m def test_set_parser_does_not_call_data_received_cb_for_tail( loop: asyncio.AbstractEventLoop, + dummy_manager: Server[BaseRequest], + dummy_reader: tuple[WebSocketReader, mock.Mock], ) -> None: - handler: RequestHandler[Any] = RequestHandler(cast(Any, _DummyManager()), loop=loop) + handler = RequestHandler(dummy_manager, loop=loop) handler._message_tail = b"tail" cb = mock.Mock() - parser = _DummyParser() - handler.set_parser(parser, data_received_cb=cb) + handler.set_parser(dummy_reader[0], data_received_cb=cb) cb.assert_not_called() - assert parser.received == [b"tail"] + dummy_reader[1].feed_data.assert_called_once_with(b"tail") def test_data_received_calls_data_received_cb( loop: asyncio.AbstractEventLoop, + dummy_manager: Server[BaseRequest], + dummy_reader: tuple[WebSocketReader, mock.Mock], ) -> None: - handler: RequestHandler[Any] = RequestHandler(cast(Any, _DummyManager()), loop=loop) + handler = RequestHandler(dummy_manager, loop=loop) cb = mock.Mock() - parser = _DummyParser() - handler.set_parser(parser, data_received_cb=cb) + handler.set_parser(dummy_reader[0], data_received_cb=cb) handler.data_received(b"x") - assert cb.call_count == 1 - assert parser.received == [b"x"] + cb.assert_called_once() + dummy_reader[1].feed_data.assert_called_once_with(b"x") diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index 26d1a275327..27dbae6630a 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -19,7 +19,7 @@ from aiohttp._websocket.reader import WebSocketDataQueue from aiohttp.base_protocol import BaseProtocol from aiohttp.compression_utils import ZLibBackend, ZLibBackendWrapper -from aiohttp.http import WebSocketError, WSCloseCode, WSMsgType +from aiohttp.http import HttpParser, WebSocketError, WSCloseCode, WSMsgType from aiohttp.http_websocket import ( WebSocketReader, WSMessageBinary, @@ -113,8 +113,9 @@ def build_close_frame( @pytest.fixture() def protocol(loop: asyncio.AbstractEventLoop) -> BaseProtocol: + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) transport = mock.Mock(spec_set=asyncio.Transport) - protocol = BaseProtocol(loop) + protocol = BaseProtocol(loop, parser=parser) protocol.connection_made(transport) return protocol