diff --git a/dfetch/project/archivesubproject.py b/dfetch/project/archivesubproject.py index 91a684b9..86336214 100644 --- a/dfetch/project/archivesubproject.py +++ b/dfetch/project/archivesubproject.py @@ -41,15 +41,14 @@ from __future__ import annotations -import os import pathlib -import tempfile from dfetch.log import get_logger from dfetch.manifest.project import ProjectEntry from dfetch.manifest.version import Version from dfetch.project.metadata import Dependency from dfetch.project.subproject import SubProject +from dfetch.util.util import temp_file from dfetch.vcs.archive import ( ARCHIVE_EXTENSIONS, ArchiveLocalRepo, @@ -126,16 +125,9 @@ def _download_and_compute_hash( """ effective_url = url if url is not None else self.remote remote = ArchiveRemote(effective_url) if url is not None else self._remote_repo - fd, tmp_path = tempfile.mkstemp(suffix=_suffix_for_url(effective_url)) - os.close(fd) - try: + with temp_file(_suffix_for_url(effective_url)) as tmp_path: hex_digest = remote.download(tmp_path, algorithm=algorithm) return IntegrityHash(algorithm, hex_digest) - finally: - try: - os.remove(tmp_path) - except OSError: - pass def _does_revision_exist(self, revision: str) -> bool: # noqa: ARG002 """Check whether the archive URL is still reachable. @@ -182,9 +174,7 @@ def _fetch_impl(self, version: Version) -> tuple[Version, list[Dependency]]: pathlib.Path(self.local_path).mkdir(parents=True, exist_ok=True) - fd, tmp_path = tempfile.mkstemp(suffix=_suffix_for_url(self.remote)) - os.close(fd) - try: + with temp_file(_suffix_for_url(self.remote)) as tmp_path: expected = IntegrityHash.parse(revision) if expected: actual_hex = self._remote_repo.download( @@ -204,11 +194,6 @@ def _fetch_impl(self, version: Version) -> tuple[Version, list[Dependency]]: src=self.source, ignore=self.ignore, ) - finally: - try: - os.remove(tmp_path) - except OSError: - pass return version, [] diff --git a/dfetch/project/gitsubproject.py b/dfetch/project/gitsubproject.py index 757027b4..21ee85d3 100644 --- a/dfetch/project/gitsubproject.py +++ b/dfetch/project/gitsubproject.py @@ -9,7 +9,7 @@ from dfetch.project.metadata import Dependency from dfetch.project.subproject import SubProject from dfetch.util.util import LICENSE_GLOBS, safe_rm -from dfetch.vcs.git import GitLocalRepo, GitRemote, get_git_version +from dfetch.vcs.git import CheckoutOptions, GitLocalRepo, GitRemote, get_git_version logger = get_logger(__name__) @@ -70,11 +70,13 @@ def _fetch_impl(self, version: Version) -> tuple[Version, list[Dependency]]: local_repo = GitLocalRepo(self.local_path) fetched_sha, submodules = local_repo.checkout_version( - remote=self.remote, - version=rev_or_branch_or_tag, - src=self.source, - must_keeps=license_globs + [".gitmodules"], - ignore=self.ignore, + CheckoutOptions( + remote=self.remote, + version=rev_or_branch_or_tag, + src=self.source, + must_keeps=license_globs + [".gitmodules"], + ignore=self.ignore, + ) ) vcs_deps = [] diff --git a/dfetch/project/svnsubproject.py b/dfetch/project/svnsubproject.py index 6333856e..1ec167f8 100644 --- a/dfetch/project/svnsubproject.py +++ b/dfetch/project/svnsubproject.py @@ -136,15 +136,14 @@ def _fetch_impl(self, version: Version) -> tuple[Version, list[Dependency]]: if self.source: root_branch_path = "/".join([self.remote, branch_path]).strip("/") - - for file in SvnSubProject._license_files(root_branch_path): + license_files = SvnSubProject._license_files(root_branch_path) + if license_files: dest = ( self.local_path if os.path.isdir(self.local_path) else os.path.dirname(self.local_path) ) - SvnRepo.export(f"{root_branch_path}/{file}", rev_arg, dest) - break + SvnRepo.export(f"{root_branch_path}/{license_files[0]}", rev_arg, dest) if self.ignore: self._remove_ignored_files() diff --git a/dfetch/util/util.py b/dfetch/util/util.py index 01452655..b3a6dfdd 100644 --- a/dfetch/util/util.py +++ b/dfetch/util/util.py @@ -5,6 +5,7 @@ import os import shutil import stat +import tempfile from collections.abc import Generator, Iterator, Sequence from contextlib import contextmanager from pathlib import Path, PurePath @@ -346,3 +347,44 @@ def resolve_absolute_path(path: str | Path) -> Path: - Handles Windows drive-relative paths and expands '~'. """ return Path(os.path.realpath(Path(path).expanduser())) + + +@contextmanager +def temp_file(suffix: str = "") -> Generator[str, None, None]: + """Create a temporary file, yield its path, and always delete it on exit.""" + fd, tmp_path = tempfile.mkstemp(suffix=suffix) + os.close(fd) + try: + yield tmp_path + finally: + try: + os.remove(tmp_path) + except OSError: + pass + + +def unique_parent_dirs(paths: list[str]) -> list[str]: + """Return the unique parent directories for a list of paths, preserving order. + + For each path that is a directory, the path itself is used. For a file, + its parent directory is used. Root-level files (no parent) are skipped. + Duplicates are removed while preserving the first-seen order. + """ + dirs: list[str] = [] + for path in paths: + if os.path.isdir(path): + dirs.append(path) + else: + parent = os.path.dirname(path) + if parent: + dirs.append(parent) + return list(dict.fromkeys(dirs)) + + +def move_directory_contents(src_dir: str, dest_dir: str) -> None: + """Move every entry in *src_dir* directly into *dest_dir*. + + Complements :func:`copy_directory_contents`. + """ + for entry in os.listdir(src_dir): + shutil.move(os.path.join(src_dir, entry), dest_dir) diff --git a/dfetch/vcs/archive.py b/dfetch/vcs/archive.py index a8caed9a..9cbe513b 100644 --- a/dfetch/vcs/archive.py +++ b/dfetch/vcs/archive.py @@ -200,6 +200,19 @@ def download(self, dest_path: str, algorithm: str | None = None) -> str | None: _MAX_REDIRECTS = 10 + @staticmethod + def _stream_response_to_file( + resp: http.client.HTTPResponse, + dest_path: str, + hasher: hashlib._Hash | None, + ) -> None: + """Write the response body to *dest_path*, updating *hasher* for each chunk.""" + with open(dest_path, "wb") as fh: + while chunk := resp.read(65536): + fh.write(chunk) + if hasher: + hasher.update(chunk) + def _http_download( self, parsed: urllib.parse.ParseResult, @@ -232,11 +245,7 @@ def _http_download( raise RuntimeError( f"HTTP {resp.status} when downloading '{self.url}'" ) - with open(dest_path, "wb") as fh: - while chunk := resp.read(65536): - fh.write(chunk) - if hasher: - hasher.update(chunk) + self._stream_response_to_file(resp, dest_path, hasher) return except (OSError, http.client.HTTPException) as exc: raise RuntimeError( diff --git a/dfetch/vcs/git.py b/dfetch/vcs/git.py index d803cc62..a9be2869 100644 --- a/dfetch/vcs/git.py +++ b/dfetch/vcs/git.py @@ -4,15 +4,21 @@ import glob import os import re -import shutil import tempfile from collections.abc import Generator, Sequence from dataclasses import dataclass -from pathlib import Path, PurePath +from pathlib import Path from dfetch.log import get_logger from dfetch.util.cmdline import SubprocessCommandError, run_on_cmdline -from dfetch.util.util import in_directory, is_license_file, safe_rm, strip_glob_prefix +from dfetch.util.util import ( + in_directory, + is_license_file, + move_directory_contents, + safe_rm, + strip_glob_prefix, + unique_parent_dirs, +) from dfetch.vcs.patch import Patch, PatchType logger = get_logger(__name__) @@ -31,6 +37,17 @@ class Submodule: tag: str +@dataclass +class CheckoutOptions: + """Options for checking out a specific version from a remote git repository.""" + + remote: str + version: str + src: str | None = None + must_keeps: list[str] | None = None + ignore: Sequence[str] | None = None + + def get_git_version() -> tuple[str, str]: """Get the name and version of git.""" result = run_on_cmdline(logger, ["git", "--version"]) @@ -270,35 +287,28 @@ def _configure_sparse_checkout( f.write("\n".join(map(str, patterns)) + "\n") - def checkout_version( # pylint: disable=too-many-arguments + def checkout_version( self, - *, - remote: str, - version: str, - src: str | None = None, - must_keeps: list[str] | None = None, - ignore: Sequence[str] | None = None, + options: CheckoutOptions, ) -> tuple[str, list[Submodule]]: """Checkout a specific version from a given remote. Args: - remote (str): Url or path to a remote git repository - version (str): A target to checkout, can be branch, tag or sha - src (Optional[str]): Optional path to subdirectory or file in repo - must_keeps (Optional[List[str]]): Optional list of glob patterns to keep - ignore (Optional[Sequence[str]]): Optional sequence of glob patterns to ignore (relative to src) + options: A :class:`CheckoutOptions` instance describing what to fetch. """ with in_directory(self._path): run_on_cmdline(logger, ["git", "init"]) - run_on_cmdline(logger, ["git", "remote", "add", "origin", remote]) + run_on_cmdline(logger, ["git", "remote", "add", "origin", options.remote]) run_on_cmdline(logger, ["git", "checkout", "-b", "dfetch-local-branch"]) - if src or ignore: - self._configure_sparse_checkout(src, must_keeps or [], ignore) + if options.src or options.ignore: + self._configure_sparse_checkout( + options.src, options.must_keeps or [], options.ignore + ) run_on_cmdline( logger, - ["git", "fetch", "--depth", "1", "origin", version], + ["git", "fetch", "--depth", "1", "origin", options.version], env=_extend_env_for_non_interactive_mode(), ) run_on_cmdline(logger, ["git", "reset", "--hard", "FETCH_HEAD"]) @@ -317,7 +327,9 @@ def checkout_version( # pylint: disable=too-many-arguments .strip() ) - submodules = self._apply_src_and_ignore(remote, src, ignore, submodules) + submodules = self._apply_src_and_ignore( + options.remote, options.src, options.ignore, submodules + ) return str(current_sha), submodules @@ -352,40 +364,50 @@ def _move_src_folder_up(remote: str, src: str) -> None: remote (str): Name of the root src (str): Src folder to move up """ - matched_paths = sorted(glob.glob(src)) - - if not matched_paths: + if os.path.isabs(src): logger.warning( - f"The 'src:' filter '{src}' didn't match any files from '{remote}'" + f"The 'src:' filter '{src}' is an absolute path; skipping for '{remote}'" ) return - dirs = [] - for src_dir_path in matched_paths: - if os.path.isdir(src_dir_path): - dirs.append(src_dir_path) + repo_root = Path(os.getcwd()).resolve() + safe_matched: list[str] = [] + for p in sorted(glob.glob(src)): + if Path(p).resolve().is_relative_to(repo_root): + safe_matched.append(p) else: - if dir_path := os.path.dirname(src_dir_path): - dirs.append(dir_path) + logger.warning( + f"The 'src:' filter '{src}' matched '{p}' outside the repo root; skipping" + ) - unique_dirs = list(dict.fromkeys(dirs)) + if not safe_matched: + logger.warning( + f"The 'src:' filter '{src}' didn't match any files from '{remote}'" + ) + return - if len(unique_dirs) > 1: + # Resolve to canonical absolute paths so downstream steps use stable paths + # regardless of any '..' components in the original glob results. + resolved_dirs = [Path(d).resolve() for d in unique_parent_dirs(safe_matched)] + + if len(resolved_dirs) > 1: + display = resolved_dirs[0].relative_to(repo_root) logger.warning( f"The 'src:' filter '{src}' matches multiple directories from '{remote}'. " - f"Only considering files in '{unique_dirs[0]}'." + f"Only considering files in '{display}'." ) - for src_dir_path in unique_dirs[:1]: + if resolved_dirs: + chosen = resolved_dirs[0] try: - for file_to_copy in os.listdir(src_dir_path): - shutil.move(src_dir_path + "/" + file_to_copy, ".") - safe_rm(PurePath(src_dir_path).parts[0]) + move_directory_contents(str(chosen), ".") + parts = chosen.relative_to(repo_root).parts + if parts: + safe_rm(repo_root / parts[0], within=repo_root) except FileNotFoundError: logger.warning( - f"The 'src:' filter '{src_dir_path}' didn't match any files from '{remote}'" + f"The 'src:' filter '{chosen}' didn't match any files from '{remote}'" ) - continue @staticmethod def _determine_ignore_paths( @@ -435,6 +457,20 @@ def get_remote_url() -> str: return decoded_result + @staticmethod + def _build_hash_args(old_hash: str | None, new_hash: str | None) -> list[str]: + """Return the SHA positional arguments for git diff (zero, one, or two hashes).""" + if not old_hash: + return [] + return [old_hash, new_hash] if new_hash else [old_hash] + + @staticmethod + def _build_ignore_args(ignore: Sequence[str] | None) -> list[str]: + """Return git-diff pathspec arguments that exclude each pattern in *ignore*.""" + if not ignore: + return [] + return ["--", "."] + [f":(exclude){p}" for p in ignore] + def create_diff( self, old_hash: str | None, @@ -456,15 +492,9 @@ def create_diff( if reverse: cmd.extend(["-R", "--src-prefix=b/", "--dst-prefix=a/"]) - if old_hash: - cmd.append(old_hash) - if new_hash: - cmd.append(new_hash) + cmd.extend(GitLocalRepo._build_hash_args(old_hash, new_hash)) + cmd.extend(GitLocalRepo._build_ignore_args(ignore)) - if ignore: - cmd.extend(["--", "."]) - for ignore_path in ignore: - cmd.append(f":(exclude){ignore_path}") result = run_on_cmdline(logger, cmd) return str(result.stdout.decode()) @@ -494,7 +524,7 @@ def ignored_files(path: str) -> Sequence[str]: @staticmethod def any_changes_or_untracked(path: str) -> bool: - """List of any changed files.""" + """Return True if the repo at *path* has any changed or untracked files.""" if not Path(path).exists(): raise RuntimeError("Path does not exist.") @@ -646,26 +676,26 @@ def find_branch_containing_sha(self, sha: str) -> str: return "" if not branches else branches[0] - def get_username(self) -> str: - """Get the username of the local git repo.""" + def _get_git_config_value(self, key: str) -> str: + """Read a single git config value from the local repo. + + Args: + key: The git config key to query (e.g. ``user.name``). + + Returns: + The stripped config value, or an empty string if the key is absent. + """ try: with in_directory(self._path): - result = run_on_cmdline( - logger, - ["git", "config", "user.name"], - ) + result = run_on_cmdline(logger, ["git", "config", key]) return str(result.stdout.decode().strip()) except SubprocessCommandError: return "" + def get_username(self) -> str: + """Get the username of the local git repo.""" + return self._get_git_config_value("user.name") + def get_useremail(self) -> str: """Get the user email of the local git repo.""" - try: - with in_directory(self._path): - result = run_on_cmdline( - logger, - ["git", "config", "user.email"], - ) - return str(result.stdout.decode().strip()) - except SubprocessCommandError: - return "" + return self._get_git_config_value("user.email") diff --git a/features/steps/git_steps.py b/features/steps/git_steps.py index b0ea4a27..7254e07d 100644 --- a/features/steps/git_steps.py +++ b/features/steps/git_steps.py @@ -26,6 +26,7 @@ def create_repo(): subprocess.check_call(["git", "config", "user.email", "you@example.com"]) subprocess.check_call(["git", "config", "user.name", "John Doe"]) + subprocess.check_call(["git", "config", "commit.gpgsign", "false"]) if os.name == "nt": # Creates zombie fsmonitor-daemon process that holds files diff --git a/tests/test_git_vcs.py b/tests/test_git_vcs.py index 0089dcd4..d5ae3203 100644 --- a/tests/test_git_vcs.py +++ b/tests/test_git_vcs.py @@ -10,12 +10,99 @@ import pytest from dfetch.util.cmdline import SubprocessCommandError +from dfetch.util.util import unique_parent_dirs from dfetch.vcs.git import ( GitLocalRepo, GitRemote, _build_git_ssh_command, ) +# --------------------------------------------------------------------------- +# unique_parent_dirs (dfetch.util.util) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "name, paths, isdir_results, expected", + [ + ("empty input", [], [], []), + ("all directories", ["a", "b"], [True, True], ["a", "b"]), + ("all files with parent", ["a/x.c", "b/y.c"], [False, False], ["a", "b"]), + ("file at root — no parent dir", ["x.c"], [False], []), + ("mixed dir and file in same parent", ["a", "a/y.c"], [True, False], ["a"]), + ("deduplication preserves order", ["a/x.c", "a/y.c"], [False, False], ["a"]), + ], +) +def test_unique_parent_dirs(name, paths, isdir_results, expected): + with patch("dfetch.util.util.os.path.isdir", side_effect=isdir_results): + assert unique_parent_dirs(paths) == expected + + +# --------------------------------------------------------------------------- +# GitLocalRepo._build_hash_args +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "name, old_hash, new_hash, expected", + [ + ("no hashes", None, None, []), + ("old hash only", "abc123", None, ["abc123"]), + ("both hashes", "abc123", "def456", ["abc123", "def456"]), + ("new hash without old hash is ignored", None, "def456", []), + ], +) +def test_build_hash_args(name, old_hash, new_hash, expected): + assert GitLocalRepo._build_hash_args(old_hash, new_hash) == expected + + +# --------------------------------------------------------------------------- +# GitLocalRepo._build_ignore_args +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "name, ignore, expected", + [ + ("None", None, []), + ("empty sequence", [], []), + ("single pattern", ["*.txt"], ["--", ".", ":(exclude)*.txt"]), + ( + "multiple patterns", + ["*.txt", "build/"], + ["--", ".", ":(exclude)*.txt", ":(exclude)build/"], + ), + ], +) +def test_build_ignore_args(name, ignore, expected): + assert GitLocalRepo._build_ignore_args(ignore) == expected + + +# --------------------------------------------------------------------------- +# GitLocalRepo._move_src_folder_up — path-traversal guards +# --------------------------------------------------------------------------- + + +def test_move_src_folder_up_rejects_absolute_src(tmp_path): + """An absolute src pattern must be rejected without touching the filesystem.""" + with patch("dfetch.vcs.git.move_directory_contents") as mock_move: + with patch("dfetch.vcs.git.os.getcwd", return_value=str(tmp_path)): + GitLocalRepo._move_src_folder_up("my-remote", "/etc") + mock_move.assert_not_called() + + +def test_move_src_folder_up_rejects_traversal_src(tmp_path): + """A src pattern that resolves outside the repo root must be skipped.""" + outside = tmp_path.parent / "outside" + outside.mkdir() + (outside / "secret.txt").write_text("data") + + with patch("dfetch.vcs.git.move_directory_contents") as mock_move: + with patch("dfetch.vcs.git.glob.glob", return_value=[str(outside)]): + with patch("dfetch.vcs.git.os.getcwd", return_value=str(tmp_path)): + GitLocalRepo._move_src_folder_up("my-remote", "../outside") + mock_move.assert_not_called() + @pytest.mark.parametrize( "name, cmd_result, expectation",