diff --git a/src/httpx2/httpx2/_models.py b/src/httpx2/httpx2/_models.py index b2720dfb..12c4240d 100644 --- a/src/httpx2/httpx2/_models.py +++ b/src/httpx2/httpx2/_models.py @@ -29,7 +29,12 @@ StreamConsumed, request_context, ) -from ._multipart import get_multipart_boundary_from_content_type +from ._multipart import ( + MultipartStream, + append_boundary_to_content_type, + get_multipart_boundary_from_content_type, + is_multipart_form_data_content_type, +) from ._status_codes import codes from ._types import ( AsyncByteStream, @@ -393,15 +398,27 @@ def __init__( if stream is None: content_type: str | None = self.headers.get("content-type") + boundary = get_multipart_boundary_from_content_type( + content_type=content_type.encode(self.headers.encoding) if content_type else None + ) headers, stream = encode_request( content=content, data=data, files=files, json=json, - boundary=get_multipart_boundary_from_content_type( - content_type=content_type.encode(self.headers.encoding) if content_type else None - ), + boundary=boundary, ) + # If the user supplied a `multipart/form-data` content-type without an + # explicit boundary, inject the generated boundary so that the header + # matches the boundary actually used in the request body. + if ( + boundary is None + and content_type is not None + and is_multipart_form_data_content_type(content_type) + and isinstance(stream, MultipartStream) + ): + generated = stream.boundary.decode("ascii") + self.headers["content-type"] = append_boundary_to_content_type(content_type, generated) self._prepare(headers) self.stream = stream # Load the request body, except for streaming content. diff --git a/src/httpx2/httpx2/_multipart.py b/src/httpx2/httpx2/_multipart.py index 1a9beffb..a407bd18 100644 --- a/src/httpx2/httpx2/_multipart.py +++ b/src/httpx2/httpx2/_multipart.py @@ -49,10 +49,44 @@ def _guess_content_type(filename: str | None) -> str | None: return None +def is_multipart_form_data_content_type(content_type: str | bytes | None) -> bool: + if not content_type: + return False + if isinstance(content_type, str): + return content_type.split(";", 1)[0].strip().lower() == "multipart/form-data" + return content_type.split(b";", 1)[0].strip().lower() == b"multipart/form-data" + + +# RFC 2045 tspecials. A parameter value containing any of these characters (or +# whitespace) is not a `token` and must be transmitted as a quoted-string. +_TSPECIALS = set('()<>@,;:\\"/[]?= \t') + + +def _format_boundary_parameter_value(boundary: str) -> str: + """ + Render a boundary as a `Content-Type` parameter value. + + Per RFC 2046 section 5.1.1 a boundary may contain characters (e.g. ``:``) + that are RFC 2045 `tspecials`. Such values are not `token`s and must be + enclosed in quotes. RFC 2046 `bcharsnospace` never includes ``"`` or ``\\``, + so no character escaping inside the quoted-string is required. + """ + if boundary and not any(char in _TSPECIALS for char in boundary): + return boundary + return f'"{boundary}"' + + +def append_boundary_to_content_type(content_type: str, boundary: str) -> str: + content_type = content_type.rstrip() + while content_type.endswith(";"): + content_type = content_type[:-1].rstrip() + return f"{content_type}; boundary={_format_boundary_parameter_value(boundary)}" + + def get_multipart_boundary_from_content_type( content_type: bytes | None, ) -> bytes | None: - if not content_type or not content_type.startswith(b"multipart/form-data"): + if content_type is None or not is_multipart_form_data_content_type(content_type): return None # parse boundary according to # https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1 diff --git a/tests/httpx2/test_multipart.py b/tests/httpx2/test_multipart.py index 5bf5b746..37c5de9f 100644 --- a/tests/httpx2/test_multipart.py +++ b/tests/httpx2/test_multipart.py @@ -7,12 +7,61 @@ import pytest import httpx2 +from httpx2._multipart import append_boundary_to_content_type, is_multipart_form_data_content_type def echo_request_content(request: httpx2.Request) -> httpx2.Response: return httpx2.Response(200, content=request.content) +@pytest.mark.parametrize( + ("content_type", "expected"), + [ + ("multipart/form-data", True), + ("Multipart/Form-Data", True), + ("multipart/form-data; charset=utf-8", True), + (b"multipart/form-data; charset=utf-8", True), + (b"MULTIPART/FORM-DATA", True), + ("application/json", False), + (None, False), + ], +) +def test_is_multipart_form_data_content_type(content_type: str | bytes | None, expected: bool) -> None: + assert is_multipart_form_data_content_type(content_type) is expected + + +@pytest.mark.parametrize( + ("content_type", "boundary", "expected"), + [ + ("multipart/form-data", "abc123", "multipart/form-data; boundary=abc123"), + ( + "multipart/form-data; charset=utf-8", + "abc123", + "multipart/form-data; charset=utf-8; boundary=abc123", + ), + ( + "multipart/form-data; charset=utf-8; ", + "abc123", + "multipart/form-data; charset=utf-8; boundary=abc123", + ), + # RFC 2046 section 5.1.1 permits boundaries containing tspecials such as + # ":", which RFC 2045 requires to be transmitted as a quoted-string. + ( + "multipart/form-data", + "gc0pJq0M:08jU534c0p", + 'multipart/form-data; boundary="gc0pJq0M:08jU534c0p"', + ), + ( + "multipart/form-data", + "with space", + 'multipart/form-data; boundary="with space"', + ), + ], +) +def test_append_boundary_to_content_type(content_type: str, boundary: str, expected: str) -> None: + assert append_boundary_to_content_type(content_type, boundary) == expected + + @pytest.mark.parametrize(("value,output"), (("abc", b"abc"), (b"abc", b"abc"))) def test_multipart(value: str | bytes, output: bytes) -> None: client = httpx2.Client(transport=httpx2.MockTransport(echo_request_content)) @@ -79,8 +128,10 @@ def test_multipart_explicit_boundary(header: str) -> None: @pytest.mark.parametrize( "header", [ + "multipart/form-data", "multipart/form-data; charset=utf-8", "multipart/form-data; charset=utf-8; ", + "Multipart/Form-Data; charset=utf-8", ], ) def test_multipart_header_without_boundary(header: str) -> None: @@ -91,7 +142,28 @@ def test_multipart_header_without_boundary(header: str) -> None: response = client.post("http://127.0.0.1:8000/", files=files, headers=headers) assert response.status_code == 200 - assert response.request.headers["Content-Type"] == header + # The user-supplied content-type has no boundary, so httpx generates one and + # injects it into the header. The boundary must match the one used in the body. + content_type = response.request.headers["Content-Type"] + expected_base = header.rstrip() + while expected_base.endswith(";"): + expected_base = expected_base[:-1].rstrip() + boundary = content_type.removeprefix(f"{expected_base}; boundary=") + assert boundary + assert len(boundary) == 32 + assert all(c in "0123456789abcdef" for c in boundary) + assert content_type == f"{expected_base}; boundary={boundary}" + boundary_bytes = boundary.encode("ascii") + assert response.content == b"".join( + [ + b"--" + boundary_bytes + b"\r\n", + b'Content-Disposition: form-data; name="file"; filename="upload"\r\n', + b"Content-Type: application/octet-stream\r\n", + b"\r\n", + b"\r\n", + b"--" + boundary_bytes + b"--\r\n", + ] + ) @pytest.mark.parametrize(("key"), (b"abc", 1, 2.3, None))