with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if i <= HB_replace:
latents = self.HB_replace_latents(latents, HB_latents_list_list[i], HB_m_offset_list, HB_n_offset_list,
height, width)
if Repainting_mask is not None and i == Repainting_HB_replace:
Repainting_latents = self.Repainting_HB_replace_latents(latents, Repainting_HB_latents,
Repainting_HB_m_offset, Repainting_HB_n_offset,
Repainting, height, width)
self._joint_attention_kwargs = {"SR_encoder_hidden_states_list": SR_prompt_embeds_list,
"SR_norm_encoder_hidden_states_list": None, "SR_hidden_states_list": None,
"SR_norm_hidden_states_list": None}
if i < HB_replace:
if Repainting_mask is None or i < Repainting_HB_replace:
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
HB_hidden_states_list_list=HB_hidden_states_list_list_list[i],
HB_m_offset_list=HB_m_offset_list,
HB_n_offset_list=HB_n_offset_list,
HB_m_scale_list=HB_m_scale_list,
HB_n_scale_list=HB_n_scale_list,
latent_h=height // 16,
latent_w=width // 16
)[0]
else:
noise_pred, original_hidden_states_list = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
HB_hidden_states_list_list=HB_hidden_states_list_list_list[i],
HB_m_offset_list=HB_m_offset_list,
HB_n_offset_list=HB_n_offset_list,
HB_m_scale_list=HB_m_scale_list,
HB_n_scale_list=HB_n_scale_list,
latent_h=height // 16,
latent_w=width // 16,
return_hidden_states_list=True
)
noise_pred = noise_pred[0]
if i >= HB_replace:
if Repainting_mask is None or i < Repainting_HB_replace:
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
else:
noise_pred, original_hidden_states_list = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
return_hidden_states_list=True
)
noise_pred = noise_pred[0]
Amazing work!
I am a beginner and have some doubts about the overall framework.
This is the overall framework of RAG from the original paper, which shows that HB is implemented in the first t-r and SR is implemented in the last t - (r+1). However, it seems that SR is implemented throughout the entire process in the code?