diff --git a/tests/test_api.py b/tests/test_api.py index c5bc0e4..911b96c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -48,3 +48,39 @@ def test_retrieve(self, url: str, filename: str, is_tar: bool): file_count = len([f for f in pathlib.Path(tmpdir).iterdir() if f.is_file()]) self.assertGreater(file_count, 0) + + def _make_tarball_with_traversal(self, directory: str, filename: str, member_name: str) -> str: + """Create a tarball fixture whose sole member uses a path-traversal name.""" + tarball_path = str(pathlib.Path(directory) / filename) + inner_content = b"malicious content" + buf = io.BytesIO(inner_content) + buf.seek(0) + with tarfile.open(tarball_path, "w") as tar: + info = tarfile.TarInfo(name=member_name) + info.size = len(inner_content) + tar.addfile(info, buf) + return tarball_path + + def test_retrieve_path_traversal_raises(self): + """A tar member that escapes the target directory must raise ValueError.""" + with tempfile.TemporaryDirectory() as tmpdir: + fake_path = self._make_tarball_with_traversal( + tmpdir, "traversal.tar", "../malicious.txt" + ) + with patch( + "pyiri2016.api.update.wget.download", return_value=fake_path + ): + with self.assertRaises(ValueError): + update.retrieve("http://example.com", "traversal.tar", directory=tmpdir) + + def test_retrieve_absolute_path_in_tar_raises(self): + """A tar member with an absolute path that escapes the target directory must raise ValueError.""" + with tempfile.TemporaryDirectory() as tmpdir: + fake_path = self._make_tarball_with_traversal( + tmpdir, "absolute.tar", "/etc/passwd" + ) + with patch( + "pyiri2016.api.update.wget.download", return_value=fake_path + ): + with self.assertRaises(ValueError): + update.retrieve("http://example.com", "absolute.tar", directory=tmpdir)