Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 3 additions & 18 deletions dfetch/project/archivesubproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,14 @@

from __future__ import annotations

import os
import pathlib
import tempfile

from dfetch.log import get_logger
from dfetch.manifest.project import ProjectEntry
from dfetch.manifest.version import Version
from dfetch.project.metadata import Dependency
from dfetch.project.subproject import SubProject
from dfetch.util.util import temp_file
from dfetch.vcs.archive import (
ARCHIVE_EXTENSIONS,
ArchiveLocalRepo,
Expand Down Expand Up @@ -126,16 +125,9 @@ def _download_and_compute_hash(
"""
effective_url = url if url is not None else self.remote
remote = ArchiveRemote(effective_url) if url is not None else self._remote_repo
fd, tmp_path = tempfile.mkstemp(suffix=_suffix_for_url(effective_url))
os.close(fd)
try:
with temp_file(_suffix_for_url(effective_url)) as tmp_path:
hex_digest = remote.download(tmp_path, algorithm=algorithm)
return IntegrityHash(algorithm, hex_digest)
finally:
try:
os.remove(tmp_path)
except OSError:
pass

def _does_revision_exist(self, revision: str) -> bool: # noqa: ARG002
"""Check whether the archive URL is still reachable.
Expand Down Expand Up @@ -182,9 +174,7 @@ def _fetch_impl(self, version: Version) -> tuple[Version, list[Dependency]]:

pathlib.Path(self.local_path).mkdir(parents=True, exist_ok=True)

fd, tmp_path = tempfile.mkstemp(suffix=_suffix_for_url(self.remote))
os.close(fd)
try:
with temp_file(_suffix_for_url(self.remote)) as tmp_path:
expected = IntegrityHash.parse(revision)
if expected:
actual_hex = self._remote_repo.download(
Expand All @@ -204,11 +194,6 @@ def _fetch_impl(self, version: Version) -> tuple[Version, list[Dependency]]:
src=self.source,
ignore=self.ignore,
)
finally:
try:
os.remove(tmp_path)
except OSError:
pass

return version, []

Expand Down
14 changes: 8 additions & 6 deletions dfetch/project/gitsubproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dfetch.project.metadata import Dependency
from dfetch.project.subproject import SubProject
from dfetch.util.util import LICENSE_GLOBS, safe_rm
from dfetch.vcs.git import GitLocalRepo, GitRemote, get_git_version
from dfetch.vcs.git import CheckoutOptions, GitLocalRepo, GitRemote, get_git_version

logger = get_logger(__name__)

Expand Down Expand Up @@ -70,11 +70,13 @@ def _fetch_impl(self, version: Version) -> tuple[Version, list[Dependency]]:

local_repo = GitLocalRepo(self.local_path)
fetched_sha, submodules = local_repo.checkout_version(
remote=self.remote,
version=rev_or_branch_or_tag,
src=self.source,
must_keeps=license_globs + [".gitmodules"],
ignore=self.ignore,
CheckoutOptions(
remote=self.remote,
version=rev_or_branch_or_tag,
src=self.source,
must_keeps=license_globs + [".gitmodules"],
ignore=self.ignore,
)
)

vcs_deps = []
Expand Down
7 changes: 3 additions & 4 deletions dfetch/project/svnsubproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,14 @@ def _fetch_impl(self, version: Version) -> tuple[Version, list[Dependency]]:

if self.source:
root_branch_path = "/".join([self.remote, branch_path]).strip("/")

for file in SvnSubProject._license_files(root_branch_path):
license_files = SvnSubProject._license_files(root_branch_path)
if license_files:
dest = (
self.local_path
if os.path.isdir(self.local_path)
else os.path.dirname(self.local_path)
)
SvnRepo.export(f"{root_branch_path}/{file}", rev_arg, dest)
break
SvnRepo.export(f"{root_branch_path}/{license_files[0]}", rev_arg, dest)

if self.ignore:
self._remove_ignored_files()
Expand Down
42 changes: 42 additions & 0 deletions dfetch/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import shutil
import stat
import tempfile
from collections.abc import Generator, Iterator, Sequence
from contextlib import contextmanager
from pathlib import Path, PurePath
Expand Down Expand Up @@ -346,3 +347,44 @@ def resolve_absolute_path(path: str | Path) -> Path:
- Handles Windows drive-relative paths and expands '~'.
"""
return Path(os.path.realpath(Path(path).expanduser()))


@contextmanager
def temp_file(suffix: str = "") -> Generator[str, None, None]:
"""Create a temporary file, yield its path, and always delete it on exit."""
fd, tmp_path = tempfile.mkstemp(suffix=suffix)
os.close(fd)
try:
yield tmp_path
finally:
try:
os.remove(tmp_path)
except OSError:
pass


def unique_parent_dirs(paths: list[str]) -> list[str]:
"""Return the unique parent directories for a list of paths, preserving order.

For each path that is a directory, the path itself is used. For a file,
its parent directory is used. Root-level files (no parent) are skipped.
Duplicates are removed while preserving the first-seen order.
"""
dirs: list[str] = []
for path in paths:
if os.path.isdir(path):
dirs.append(path)
else:
parent = os.path.dirname(path)
if parent:
dirs.append(parent)
return list(dict.fromkeys(dirs))


def move_directory_contents(src_dir: str, dest_dir: str) -> None:
"""Move every entry in *src_dir* directly into *dest_dir*.

Complements :func:`copy_directory_contents`.
"""
for entry in os.listdir(src_dir):
shutil.move(os.path.join(src_dir, entry), dest_dir)
19 changes: 14 additions & 5 deletions dfetch/vcs/archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,19 @@ def download(self, dest_path: str, algorithm: str | None = None) -> str | None:

_MAX_REDIRECTS = 10

@staticmethod
def _stream_response_to_file(
resp: http.client.HTTPResponse,
dest_path: str,
hasher: hashlib._Hash | None,
) -> None:
"""Write the response body to *dest_path*, updating *hasher* for each chunk."""
with open(dest_path, "wb") as fh:
while chunk := resp.read(65536):
fh.write(chunk)
if hasher:
hasher.update(chunk)

def _http_download(
self,
parsed: urllib.parse.ParseResult,
Expand Down Expand Up @@ -232,11 +245,7 @@ def _http_download(
raise RuntimeError(
f"HTTP {resp.status} when downloading '{self.url}'"
)
with open(dest_path, "wb") as fh:
while chunk := resp.read(65536):
fh.write(chunk)
if hasher:
hasher.update(chunk)
self._stream_response_to_file(resp, dest_path, hasher)
return
except (OSError, http.client.HTTPException) as exc:
raise RuntimeError(
Expand Down
Loading
Loading