forked from shariqfarooq123/LooseControl
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathloosecontrol.py
More file actions
135 lines (110 loc) · 5.92 KB
/
Copy pathloosecontrol.py
File metadata and controls
135 lines (110 loc) · 5.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler,
)
import torch
import PIL
import PIL.Image
from diffusers.loaders import UNet2DConditionLoadersMixin
from typing import Dict
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
import functools
from cross_frame_attention import CrossFrameAttnProcessor
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
NEGATIVE_PROMPT = "blurry, text, caption, lowquality, lowresolution, low res, grainy, ugly"
def attach_loaders_mixin(model):
# hacky way to make ControlNet work with LoRA. This may not be required in future versions of diffusers.
model.text_encoder_name = TEXT_ENCODER_NAME
model.unet_name = UNET_NAME
r"""
Attach the [`UNet2DConditionLoadersMixin`] to a model. This will add the
all the methods from the mixin 'UNet2DConditionLoadersMixin' to the model.
"""
# mixin_instance = UNet2DConditionLoadersMixin()
for attr_name, attr_value in vars(UNet2DConditionLoadersMixin).items():
# print(attr_name)
if callable(attr_value):
# setattr(model, attr_name, functools.partialmethod(attr_value, model).__get__(model, model.__class__))
setattr(model, attr_name, functools.partial(attr_value, model))
return model
def set_attn_processor(module, processor, _remove_lora=False):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in module.named_children():
fn_recursive_attn_processor(name, module, processor)
class ControlNetX(ControlNetModel, UNet2DConditionLoadersMixin):
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# This may not be required in future versions of diffusers.
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
class ControlNetPipeline:
def __init__(self, checkpoint="lllyasviel/control_v11f1p_sd15_depth", sd_checkpoint="runwayml/stable-diffusion-v1-5") -> None:
controlnet = ControlNetX.from_pretrained(checkpoint)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
sd_checkpoint, controlnet=controlnet, requires_safety_checker=False, safety_checker=None,
torch_dtype=torch.float16)
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
@torch.no_grad()
def __call__(self,
prompt: str="",
height=512,
width=512,
control_image=None,
controlnet_conditioning_scale=1.0,
num_inference_steps: int=20,
**kwargs) -> PIL.Image.Image:
out = self.pipe(prompt, control_image,
height=height, width=width,
num_inference_steps=num_inference_steps,
controlnet_conditioning_scale=controlnet_conditioning_scale,
**kwargs).images
return out[0] if len(out) == 1 else out
def to(self, *args, **kwargs):
self.pipe.to(*args, **kwargs)
return self
class LooseControlNet(ControlNetPipeline):
def __init__(self, loose_control_weights="shariqfarooq/loose-control-3dbox", cn_checkpoint="lllyasviel/control_v11f1p_sd15_depth", sd_checkpoint="runwayml/stable-diffusion-v1-5") -> None:
super().__init__(cn_checkpoint, sd_checkpoint)
self.pipe.controlnet = attach_loaders_mixin(self.pipe.controlnet)
self.pipe.controlnet.load_attn_procs(loose_control_weights)
def set_normal_attention(self):
self.pipe.unet.set_attn_processor(AttnProcessor())
def set_cf_attention(self, _remove_lora=False):
for upblocks in self.pipe.unet.up_blocks[-2:]:
set_attn_processor(upblocks, CrossFrameAttnProcessor(), _remove_lora=_remove_lora)
def edit(self, depth, depth_edit, prompt, prompt_edit=None, seed=42, seed_edit=None, negative_prompt=NEGATIVE_PROMPT, controlnet_conditioning_scale=1.0, num_inference_steps=20, **kwargs):
if prompt_edit is None:
prompt_edit = prompt
if seed_edit is None:
seed_edit = seed
seed = int(seed)
seed_edit = int(seed_edit)
control_image = [depth, depth_edit]
prompt = [prompt, prompt_edit]
generator = [torch.Generator().manual_seed(seed), torch.Generator().manual_seed(seed_edit)]
gen = self.pipe(prompt, control_image=control_image, controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator, num_inference_steps=num_inference_steps, negative_prompt=negative_prompt, **kwargs)[-1]
return gen