diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index b2564d25505e..0446d3f8e7a2 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -353,18 +353,20 @@ def _unwrap_model(model): def maybe_raise_or_warn( - library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module + library_name, + library, + class_name, + importable_classes, + passed_class_obj, + name, + is_pipeline_module, + cache_dir, ): """Simple helper method to raise or warn in case incorrect module has been passed""" if not is_pipeline_module: - library = importlib.import_module(library_name) - - # Handle deprecated Transformers classes - if library_name == "transformers": - class_name = _maybe_remap_transformers_class(class_name) or class_name - - class_obj = getattr(library, class_name) - class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + class_obj, class_candidates = get_class_obj_and_candidates( + library_name, class_name, importable_classes, None, is_pipeline_module, name, cache_dir + ) expected_class_obj = None for class_name, class_candidate in class_candidates.items(): @@ -668,6 +670,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic passed_class_obj, name, is_pipeline_module, + kwargs.get("cached_folder", None), ) with accelerate.init_empty_weights(): loaded_sub_model = passed_class_obj[name] diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d675f1de04a7..579037c53731 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1035,7 +1035,14 @@ def load_module(name, value): # if the model is in a pipeline module, then we load it from the pipeline # check that passed_class_obj has correct parent class maybe_raise_or_warn( - library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module + library_name, + library, + class_name, + importable_classes, + passed_class_obj, + name, + is_pipeline_module, + cached_folder, ) loaded_sub_model = passed_class_obj[name]