diff --git a/can_ada-stubs/__init__.pyi b/can_ada-stubs/__init__.pyi index 389d4fa..e0ed606 100644 --- a/can_ada-stubs/__init__.pyi +++ b/can_ada-stubs/__init__.pyi @@ -1,4 +1,5 @@ from typing import Iterator, overload +from urllib.parse import ParseResult, ParseResultBytes __version__: str @@ -65,6 +66,10 @@ class URLSearchParamsValuesIter: def __next__(self) -> str | None: ... def can_parse(input: str, base_input: str | None = ...) -> bool: ... -def idna_decode(arg0: str) -> str: ... +def idna_decode(arg0: bytes) -> str: ... def idna_encode(arg0: str) -> bytes: ... def parse(arg0: str) -> URL: ... +@overload +def parse_compat(arg0: str) -> ParseResult: ... +@overload +def parse_compat(arg0: bytes) -> ParseResultBytes: ... diff --git a/src/binding.cpp b/src/binding.cpp index eac2a68..43ad3e1 100644 --- a/src/binding.cpp +++ b/src/binding.cpp @@ -9,6 +9,17 @@ namespace py = nanobind; +struct parse_impl_result { + std::string scheme; + std::string netloc; + std::string path; + std::string params; + std::string query; + std::string fragment; +}; + +static parse_impl_result parse_compat_impl(std::string_view input); + NB_MODULE(can_ada, m) { #ifdef VERSION_INFO m.attr("__version__") = Py_STRINGIFY(VERSION_INFO); @@ -153,4 +164,88 @@ NB_MODULE(can_ada, m) { return std::move(*url); }); + auto urllib = py::module_::import_("urllib.parse"); + + auto ParseResult = py::object(urllib.attr("ParseResult")); + auto ParseResultBytes = py::object(urllib.attr("ParseResultBytes")); + + m.def("parse_compat", [ParseResult](std::string_view input) { + auto [scheme, netloc, path, params, query, fragment] = + parse_compat_impl(input); + return ParseResult(scheme, netloc, path, params, query, fragment); + }); + + m.def("parse_compat", [ParseResultBytes](py::bytes input) { + auto [scheme, netloc, path, params, query, fragment] = + parse_compat_impl(std::string_view(input.c_str(), input.size())); + return ParseResultBytes(py::bytes(scheme.data(), scheme.size()), + py::bytes(netloc.data(), netloc.size()), + py::bytes(path.data(), path.size()), + py::bytes(params.data(), params.size()), + py::bytes(query.data(), query.size()), + py::bytes(fragment.data(), fragment.size())); + }); +} + +static parse_impl_result parse_compat_impl(std::string_view input) { + auto result = ada::parse(input); + if (!result) { + throw py::value_error("URL could not be parsed."); + } + + auto url = std::move(*result); + + auto scheme = url.get_protocol(); + if (!scheme.empty() && scheme.back() == ':') { + scheme.remove_suffix(1); + } + + std::string netloc; + { + if (url.has_non_empty_username()) { + netloc.append(url.get_username()); + if (url.has_password()) { + netloc.push_back(':'); + netloc.append(url.get_password()); + } + netloc.push_back('@'); + } + + netloc.append(url.get_host()); + + if (url.has_port()) { + netloc.push_back(':'); + netloc.append(url.get_port()); + } + } + + auto raw_path = url.get_pathname(); + auto path = raw_path; + std::string_view params{}; + + auto last_slash = raw_path.rfind('/'); + auto last_segment = (last_slash != std::string_view::npos) + ? raw_path.substr(last_slash + 1) + : raw_path; + + auto semi = last_segment.find(';'); + if (semi != std::string_view::npos) { + path = (last_slash != std::string_view::npos) + ? raw_path.substr(0, last_slash + 1 + semi) + : last_segment.substr(0, semi); + params = last_segment.substr(semi + 1); + } + + auto query = url.get_search(); + if (!query.empty() && query.front() == '?') { + query.remove_prefix(1); + } + + auto fragment = url.get_hash(); + if (!fragment.empty() && fragment.front() == '#') { + fragment.remove_prefix(1); + } + + return {std::string(scheme), std::move(netloc), std::string(path), + std::string(params), std::string(query), std::string(fragment)}; } diff --git a/tests/conftest.py b/tests/conftest.py index 0769639..d9b4add 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +from pathlib import Path import pytest @@ -23,3 +24,17 @@ def pytest_collection_modifyitems(config, items): for item in items: if 'slow' in item.keywords: item.add_marker(skip_slow) + + +@pytest.fixture(scope="session") +def top100str() -> list[str]: + current_file_dir = Path(__file__).parent + with open(current_file_dir / "data" / "top100.txt", "r") as f: + return f.read().splitlines() + + +@pytest.fixture(scope="session") +def top100bytes() -> list[bytes]: + current_file_dir = Path(__file__).parent + with open(current_file_dir / "data" / "top100.txt", "rb") as f: + return f.read().splitlines() diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 639e0d3..b2f7bfc 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -1,6 +1,5 @@ -from functools import lru_cache -from pathlib import Path - +from typing import Callable, Any +from collections.abc import Iterable import pytest import urllib.parse @@ -9,63 +8,84 @@ yarl = pytest.importorskip("yarl") -@lru_cache -def data() -> list[str]: - current_file_dir = Path(__file__).parent - with open(current_file_dir / "data" / "top100.txt", "r") as f: - return f.readlines() - +@pytest.mark.slow +def test_urllib_parse(benchmark: Callable[[Any], Any], top100str: Iterable[str]): + def urllib_parse(): + for line in top100str: + urllib.parse.urlparse(line) -def urllib_parse(): - for line in data(): - urllib.parse.urlparse(line) + benchmark(urllib_parse) -def ada_python_parse(): - for line in data(): - try: - ada_url.URL(line) - except ValueError: - # There are a small number of URLs in the sample data that are - # not valid WHATWG URLs. - pass +@pytest.mark.slow +def test_ada_python_parse(benchmark: Callable[[Any], Any], top100str: Iterable[str]): + def ada_python_parse(): + for line in top100str: + try: + ada_url.URL(line) + except ValueError: + # There are a small number of URLs in the sample data that are + # not valid WHATWG URLs. + pass + benchmark(ada_python_parse) -def can_ada_parse(): - for line in data(): - try: - can_ada.parse(line) - except ValueError: - # There are a small number of URLs in the sample data that are - # not valid WHATWG URLs. - pass +@pytest.mark.slow +def test_can_ada_parse(benchmark: Callable[[Any], Any], top100str: Iterable[str]): + def can_ada_parse(): + for line in top100str: + try: + can_ada.parse(line) + except ValueError: + # There are a small number of URLs in the sample data that are + # not valid WHATWG URLs. + pass -def yarl_parse(): - for line in data(): - try: - yarl.URL(line) - except ValueError: - # There are a small number of URLs in the sample data that are - # not valid WHATWG URLs. - pass + benchmark(can_ada_parse) @pytest.mark.slow -def test_urllib_parse(benchmark): - benchmark(urllib_parse) - +def test_yarl_parse(benchmark: Callable[[Any], Any], top100str: Iterable[str]): + def yarl_parse(): + for line in top100str: + try: + yarl.URL(line) + except ValueError: + # There are a small number of URLs in the sample data that are + # not valid WHATWG URLs. + pass -@pytest.mark.slow -def test_ada_python_parse(benchmark): - benchmark(ada_python_parse) + benchmark(yarl_parse) @pytest.mark.slow -def test_can_ada_parse(benchmark): - benchmark(can_ada_parse) +def test_can_ada_parse_compat_str( + benchmark: Callable[[Any], Any], top100str: Iterable[str] +): + def can_ada_parse_compat(): + for line in top100str: + try: + can_ada.parse_compat(line) + except ValueError: + # There are a small number of URLs in the sample data that are + # not valid WHATWG URLs. + pass + + benchmark(can_ada_parse_compat) @pytest.mark.slow -def test_yarl_parse(benchmark): - benchmark(yarl_parse) +def test_can_ada_parse_compat_bytes( + benchmark: Callable[[Any], Any], top100bytes: Iterable[str] +): + def can_ada_parse_compat(): + for line in top100bytes: + try: + can_ada.parse_compat(line) + except (ValueError, UnicodeDecodeError): + # There are a small number of URLs in the sample data that are + # not valid WHATWG URLs. + pass + + benchmark(can_ada_parse_compat) diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 0000000..4fea8a2 --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,15 @@ +import pytest +import can_ada +import urllib.parse + + +@pytest.mark.xfail(reason="parse_compat is not 100% urllib-compatible yet") +def test_urllib_parse_str_matches(top100str: list[str]): + for line in top100str: + assert urllib.parse.urlparse(line) == can_ada.parse_compat(line) + + +@pytest.mark.xfail(reason="parse_compat is not 100% urllib-compatible yet") +def test_urllib_parse_bytes_matches(top100bytes: list[bytes]): + for line in top100bytes: + assert urllib.parse.urlparse(line) == can_ada.parse_compat(line)