Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,18 +353,20 @@ def _unwrap_model(model):


def maybe_raise_or_warn(
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
library_name,
library,
class_name,
importable_classes,
passed_class_obj,
name,
is_pipeline_module,
cache_dir,
):
"""Simple helper method to raise or warn in case incorrect module has been passed"""
if not is_pipeline_module:
library = importlib.import_module(library_name)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What error do we get without this change?

It would raise here because it's not real library but custom component/module

and as you can see from def get_class_obj_and_candidates, it's the same logic but with added condition to get_class_from_dynamic_module — I think it should've used this util from the beginning, maybe the util was added after this piece of code


# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _maybe_remap_transformers_class(class_name) or class_name

class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
class_obj, class_candidates = get_class_obj_and_candidates(
library_name, class_name, importable_classes, None, is_pipeline_module, name, cache_dir
)

expected_class_obj = None
for class_name, class_candidate in class_candidates.items():
Expand Down Expand Up @@ -668,6 +670,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
passed_class_obj,
name,
is_pipeline_module,
kwargs.get("cached_folder", None),
)
with accelerate.init_empty_weights():
loaded_sub_model = passed_class_obj[name]
Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,14 @@ def load_module(name, value):
# if the model is in a pipeline module, then we load it from the pipeline
# check that passed_class_obj has correct parent class
maybe_raise_or_warn(
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
library_name,
library,
class_name,
importable_classes,
passed_class_obj,
name,
is_pipeline_module,
cached_folder,
)

loaded_sub_model = passed_class_obj[name]
Expand Down
Loading