diff --git a/morgan/__init__.py b/morgan/__init__.py index 755fed9..2cf6a33 100644 --- a/morgan/__init__.py +++ b/morgan/__init__.py @@ -44,6 +44,7 @@ def __init__(self, args: argparse.Namespace): self.index_path = args.index_path self.index_url = args.index_url self.mirror_all_versions: bool = args.mirror_all_versions + self.package_type_regex: str = args.package_type_regex self.config = configparser.ConfigParser() self.config.read(args.config) self.envs = {} @@ -201,9 +202,10 @@ def _filter_files( files: Iterable[dict], ) -> Iterable[dict]: # remove files with unsupported extensions + pattern: str = rf"\.{self.package_type_regex}$" files = list( filter( - lambda file: re.search(r"\.(whl|zip|tar.gz)$", file["filename"]), files + lambda file: re.search(pattern, file["filename"]), files ) ) @@ -574,6 +576,13 @@ def my_url(arg): "Transitive dependencies still mirror only the latest matching release. " "(Default: only the latest matching release)" ), + ), + parser.add_argument( + "--package-type-regex", + dest="package_type_regex", + default=r"(whl|zip|tar\.gz)", + type=str, + help="Regular expression to filter which package file types are mirrored", ) server.add_arguments(parser) diff --git a/tests/test_init.py b/tests/test_init.py index 7a164c6..96d89b7 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -87,6 +87,7 @@ def test_mirrorer_initialization(self, temp_index_path): index_url="https://pypi.org/simple/", config=os.path.join(temp_index_path, "morgan.ini"), mirror_all_versions=False, + package_type_regex="(whl|zip|tar.gz)", ) mirrorer = Mirrorer(args) @@ -105,6 +106,7 @@ def test_server_file_copying(self, temp_index_path): index_url=PYPI_ADDRESS, config=os.path.join(temp_index_path, "morgan.ini"), mirror_all_versions=False, + package_type_regex="(whl|zip|tar.gz)", ) mirrorer = Mirrorer(args) @@ -128,6 +130,7 @@ def test_file_hashing(self, temp_index_path): index_url=PYPI_ADDRESS, config=os.path.join(temp_index_path, "morgan.ini"), mirror_all_versions=False, + package_type_regex="(whl|zip|tar.gz)", ) mirrorer = Mirrorer(args) @@ -176,6 +179,7 @@ def _make_mirrorer(mirror_all_versions): index_url="https://example.com/simple", config=os.path.join(temp_index_path, "morgan.ini"), mirror_all_versions=mirror_all_versions, + package_type_regex=r"(whl|zip|tar\.gz)" ) return Mirrorer(args) @@ -223,7 +227,9 @@ def test_filter_files_with_all_versions_mirrored( self, make_mirrorer, sample_files, version_spec, expected_versions ): """Test that file filtering correctly handles different version specifications.""" - mirrorer = make_mirrorer(mirror_all_versions=True) + mirrorer = make_mirrorer( + mirror_all_versions=True, + ) requirement = packaging.requirements.Requirement( f"sample_package{version_spec}" )