Skip to content

Commit 09ee6d8

Browse files
committed
update __init__ files to add sana
1 parent aefb869 commit 09ee6d8

7 files changed

Lines changed: 2457 additions & 103 deletions

File tree

src/diffusers/__init__.py

Lines changed: 107 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
is_transformers_available,
2424
)
2525

26-
2726
# Lazy Import based on
2827
# https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
2928

@@ -60,7 +59,11 @@
6059
}
6160

6261
try:
63-
if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
62+
if (
63+
not is_torch_available()
64+
and not is_accelerate_available()
65+
and not is_bitsandbytes_available()
66+
):
6467
raise OptionalDependencyNotAvailable()
6568
except OptionalDependencyNotAvailable:
6669
from .utils import dummy_bitsandbytes_objects
@@ -72,7 +75,11 @@
7275
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
7376

7477
try:
75-
if not is_torch_available() and not is_accelerate_available() and not is_gguf_available():
78+
if (
79+
not is_torch_available()
80+
and not is_accelerate_available()
81+
and not is_gguf_available()
82+
):
7683
raise OptionalDependencyNotAvailable()
7784
except OptionalDependencyNotAvailable:
7885
from .utils import dummy_gguf_objects
@@ -84,7 +91,11 @@
8491
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
8592

8693
try:
87-
if not is_torch_available() and not is_accelerate_available() and not is_torchao_available():
94+
if (
95+
not is_torch_available()
96+
and not is_accelerate_available()
97+
and not is_torchao_available()
98+
):
8899
raise OptionalDependencyNotAvailable()
89100
except OptionalDependencyNotAvailable:
90101
from .utils import dummy_torchao_objects
@@ -96,7 +107,11 @@
96107
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
97108

98109
try:
99-
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
110+
if (
111+
not is_torch_available()
112+
and not is_accelerate_available()
113+
and not is_optimum_quanto_available()
114+
):
100115
raise OptionalDependencyNotAvailable()
101116
except OptionalDependencyNotAvailable:
102117
from .utils import dummy_optimum_quanto_objects
@@ -126,7 +141,9 @@
126141
except OptionalDependencyNotAvailable:
127142
from .utils import dummy_pt_objects # noqa F403
128143

129-
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
144+
_import_structure["utils.dummy_pt_objects"] = [
145+
name for name in dir(dummy_pt_objects) if not name.startswith("_")
146+
]
130147

131148
else:
132149
_import_structure["hooks"].extend(
@@ -187,6 +204,7 @@
187204
"OmniGenTransformer2DModel",
188205
"PixArtTransformer2DModel",
189206
"PriorTransformer",
207+
"SanaControlNetModel",
190208
"SanaTransformer2DModel",
191209
"SD3ControlNetModel",
192210
"SD3MultiControlNetModel",
@@ -303,11 +321,15 @@
303321
from .utils import dummy_torch_and_torchsde_objects # noqa F403
304322

305323
_import_structure["utils.dummy_torch_and_torchsde_objects"] = [
306-
name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_")
324+
name
325+
for name in dir(dummy_torch_and_torchsde_objects)
326+
if not name.startswith("_")
307327
]
308328

309329
else:
310-
_import_structure["schedulers"].extend(["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"])
330+
_import_structure["schedulers"].extend(
331+
["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"]
332+
)
311333

312334
try:
313335
if not (is_torch_available() and is_transformers_available()):
@@ -316,7 +338,9 @@
316338
from .utils import dummy_torch_and_transformers_objects # noqa F403
317339

318340
_import_structure["utils.dummy_torch_and_transformers_objects"] = [
319-
name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
341+
name
342+
for name in dir(dummy_torch_and_transformers_objects)
343+
if not name.startswith("_")
320344
]
321345

322346
else:
@@ -424,6 +448,7 @@
424448
"PixArtSigmaPAGPipeline",
425449
"PixArtSigmaPipeline",
426450
"ReduxImageEncoder",
451+
"SanaControlNetPipeline",
427452
"SanaPAGPipeline",
428453
"SanaPipeline",
429454
"SanaSprintPipeline",
@@ -517,39 +542,63 @@
517542
)
518543

519544
try:
520-
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
545+
if not (
546+
is_torch_available()
547+
and is_transformers_available()
548+
and is_k_diffusion_available()
549+
):
521550
raise OptionalDependencyNotAvailable()
522551
except OptionalDependencyNotAvailable:
523552
from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
524553

525554
_import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
526-
name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
555+
name
556+
for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects)
557+
if not name.startswith("_")
527558
]
528559

529560
else:
530-
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
561+
_import_structure["pipelines"].extend(
562+
["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"]
563+
)
531564

532565
try:
533-
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
566+
if not (
567+
is_torch_available()
568+
and is_transformers_available()
569+
and is_sentencepiece_available()
570+
):
534571
raise OptionalDependencyNotAvailable()
535572
except OptionalDependencyNotAvailable:
536-
from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
573+
from .utils import ( # noqa F403
574+
dummy_torch_and_transformers_and_sentencepiece_objects,
575+
)
537576

538-
_import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [
539-
name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_")
577+
_import_structure[
578+
"utils.dummy_torch_and_transformers_and_sentencepiece_objects"
579+
] = [
580+
name
581+
for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects)
582+
if not name.startswith("_")
540583
]
541584

542585
else:
543-
_import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"])
586+
_import_structure["pipelines"].extend(
587+
["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"]
588+
)
544589

545590
try:
546-
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
591+
if not (
592+
is_torch_available() and is_transformers_available() and is_onnx_available()
593+
):
547594
raise OptionalDependencyNotAvailable()
548595
except OptionalDependencyNotAvailable:
549596
from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
550597

551598
_import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
552-
name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
599+
name
600+
for name in dir(dummy_torch_and_transformers_and_onnx_objects)
601+
if not name.startswith("_")
553602
]
554603

555604
else:
@@ -571,20 +620,26 @@
571620
from .utils import dummy_torch_and_librosa_objects # noqa F403
572621

573622
_import_structure["utils.dummy_torch_and_librosa_objects"] = [
574-
name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_")
623+
name
624+
for name in dir(dummy_torch_and_librosa_objects)
625+
if not name.startswith("_")
575626
]
576627

577628
else:
578629
_import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"])
579630

580631
try:
581-
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
632+
if not (
633+
is_transformers_available() and is_torch_available() and is_note_seq_available()
634+
):
582635
raise OptionalDependencyNotAvailable()
583636
except OptionalDependencyNotAvailable:
584637
from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
585638

586639
_import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
587-
name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")
640+
name
641+
for name in dir(dummy_transformers_and_torch_and_note_seq_objects)
642+
if not name.startswith("_")
588643
]
589644

590645

@@ -605,7 +660,9 @@
605660
else:
606661
_import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
607662
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
608-
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
663+
_import_structure["models.unets.unet_2d_condition_flax"] = [
664+
"FlaxUNet2DConditionModel"
665+
]
609666
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
610667
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
611668
_import_structure["schedulers"].extend(
@@ -630,7 +687,9 @@
630687
from .utils import dummy_flax_and_transformers_objects # noqa F403
631688

632689
_import_structure["utils.dummy_flax_and_transformers_objects"] = [
633-
name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_")
690+
name
691+
for name in dir(dummy_flax_and_transformers_objects)
692+
if not name.startswith("_")
634693
]
635694

636695

@@ -763,6 +822,7 @@
763822
OmniGenTransformer2DModel,
764823
PixArtTransformer2DModel,
765824
PriorTransformer,
825+
SanaControlNetModel,
766826
SanaTransformer2DModel,
767827
SD3ControlNetModel,
768828
SD3MultiControlNetModel,
@@ -979,6 +1039,7 @@
9791039
PixArtSigmaPAGPipeline,
9801040
PixArtSigmaPipeline,
9811041
ReduxImageEncoder,
1042+
SanaControlNetPipeline,
9821043
SanaPAGPipeline,
9831044
SanaPipeline,
9841045
SanaSprintPipeline,
@@ -1070,22 +1131,35 @@
10701131
)
10711132

10721133
try:
1073-
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
1134+
if not (
1135+
is_torch_available()
1136+
and is_transformers_available()
1137+
and is_k_diffusion_available()
1138+
):
10741139
raise OptionalDependencyNotAvailable()
10751140
except OptionalDependencyNotAvailable:
10761141
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
10771142
else:
1078-
from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
1143+
from .pipelines import (
1144+
StableDiffusionKDiffusionPipeline,
1145+
StableDiffusionXLKDiffusionPipeline,
1146+
)
10791147

10801148
try:
1081-
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
1149+
if not (
1150+
is_torch_available()
1151+
and is_transformers_available()
1152+
and is_sentencepiece_available()
1153+
):
10821154
raise OptionalDependencyNotAvailable()
10831155
except OptionalDependencyNotAvailable:
10841156
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
10851157
else:
10861158
from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
10871159
try:
1088-
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
1160+
if not (
1161+
is_torch_available() and is_transformers_available() and is_onnx_available()
1162+
):
10891163
raise OptionalDependencyNotAvailable()
10901164
except OptionalDependencyNotAvailable:
10911165
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
@@ -1108,7 +1182,11 @@
11081182
from .pipelines import AudioDiffusionPipeline, Mel
11091183

11101184
try:
1111-
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
1185+
if not (
1186+
is_transformers_available()
1187+
and is_torch_available()
1188+
and is_note_seq_available()
1189+
):
11121190
raise OptionalDependencyNotAvailable()
11131191
except OptionalDependencyNotAvailable:
11141192
from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403

0 commit comments

Comments
 (0)