Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/diffusers/utils/dynamic_modules_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from urllib import request

from huggingface_hub import hf_hub_download, model_info
from huggingface_hub.constants import HF_HUB_CACHE
from huggingface_hub.utils import RevisionNotFoundError, validate_hf_hub_args
from packaging import version

Expand Down Expand Up @@ -298,6 +299,7 @@ def get_cached_module_file(
"""
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
commit_hash = None

if subfolder is not None:
module_file_or_url = os.path.join(pretrained_model_name_or_path, subfolder, module_file)
Expand All @@ -306,7 +308,18 @@ def get_cached_module_file(

if os.path.isfile(module_file_or_url):
resolved_module_file = module_file_or_url
submodule = "local"
# When the local path is inside the HuggingFace Hub cache (e.g. a custom
# component downloaded as part of a whole pipeline via snapshot_download),
# extract the repo ID and commit hash so the submodule name and versioning
# match the behaviour of loading the component individually via AutoModel.
# HF cache layout: {hf_cache}/models--{org}--{repo}/snapshots/{hash}/…
hf_cache = str(cache_dir) if cache_dir is not None else HF_HUB_CACHE
hf_cache_prefix = os.path.join(hf_cache, "models--")
if module_file_or_url.startswith(hf_cache_prefix):
model_name, _, commit_hash, _ = module_file_or_url.replace(hf_cache_prefix, "").split(os.sep, 3)
submodule = os.path.join("local", model_name)
else:
submodule = "local"
elif pretrained_model_name_or_path.count("/") == 0:
available_versions = get_diffusers_versions()
# cut ".dev0"
Expand Down Expand Up @@ -395,7 +408,8 @@ def get_cached_module_file(
else:
# Get the commit hash
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
if commit_hash is None:
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha

# The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
# benefit of versioning.
Expand Down
Loading