Skip to content

At which stages was Regional Soft Refinement implemented? #14

@xcltql666

Description

@xcltql666

Amazing work!
I am a beginner and have some doubts about the overall framework.
This is the overall framework of RAG from the original paper, which shows that HB is implemented in the first t-r and SR is implemented in the last t - (r+1). However, it seems that SR is implemented throughout the entire process in the code?

1736996635(1)

with self.progress_bar(total=num_inference_steps) as progress_bar:
    for i, t in enumerate(timesteps):
        if self.interrupt:
            continue

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timestep = t.expand(latents.shape[0]).to(latents.dtype)

        if i <= HB_replace:
            latents = self.HB_replace_latents(latents, HB_latents_list_list[i], HB_m_offset_list, HB_n_offset_list,
                                              height, width)

        if Repainting_mask is not None and i == Repainting_HB_replace:
            Repainting_latents = self.Repainting_HB_replace_latents(latents, Repainting_HB_latents,
                                                                    Repainting_HB_m_offset, Repainting_HB_n_offset,
                                                                    Repainting, height, width)

        self._joint_attention_kwargs = {"SR_encoder_hidden_states_list": SR_prompt_embeds_list,
                                        "SR_norm_encoder_hidden_states_list": None, "SR_hidden_states_list": None,
                                        "SR_norm_hidden_states_list": None}

        if i < HB_replace:
            if Repainting_mask is None or i < Repainting_HB_replace:
                noise_pred = self.transformer(
                    hidden_states=latents,
                    timestep=timestep / 1000,
                    guidance=guidance,
                    pooled_projections=pooled_prompt_embeds,
                    encoder_hidden_states=prompt_embeds,
                    txt_ids=text_ids,
                    img_ids=latent_image_ids,
                    joint_attention_kwargs=self.joint_attention_kwargs,
                    return_dict=False,
                    HB_hidden_states_list_list=HB_hidden_states_list_list_list[i],
                    HB_m_offset_list=HB_m_offset_list,
                    HB_n_offset_list=HB_n_offset_list,
                    HB_m_scale_list=HB_m_scale_list,
                    HB_n_scale_list=HB_n_scale_list,
                    latent_h=height // 16,
                    latent_w=width // 16
                )[0]
            else:
                noise_pred, original_hidden_states_list = self.transformer(
                    hidden_states=latents,
                    timestep=timestep / 1000,
                    guidance=guidance,
                    pooled_projections=pooled_prompt_embeds,
                    encoder_hidden_states=prompt_embeds,
                    txt_ids=text_ids,
                    img_ids=latent_image_ids,
                    joint_attention_kwargs=self.joint_attention_kwargs,
                    return_dict=False,
                    HB_hidden_states_list_list=HB_hidden_states_list_list_list[i],
                    HB_m_offset_list=HB_m_offset_list,
                    HB_n_offset_list=HB_n_offset_list,
                    HB_m_scale_list=HB_m_scale_list,
                    HB_n_scale_list=HB_n_scale_list,
                    latent_h=height // 16,
                    latent_w=width // 16,
                    return_hidden_states_list=True
                )
                noise_pred = noise_pred[0]

        if i >= HB_replace:
            if Repainting_mask is None or i < Repainting_HB_replace:
                noise_pred = self.transformer(
                    hidden_states=latents,
                    timestep=timestep / 1000,
                    guidance=guidance,
                    pooled_projections=pooled_prompt_embeds,
                    encoder_hidden_states=prompt_embeds,
                    txt_ids=text_ids,
                    img_ids=latent_image_ids,
                    joint_attention_kwargs=self.joint_attention_kwargs,
                    return_dict=False,
                )[0]
            else:
                noise_pred, original_hidden_states_list = self.transformer(
                    hidden_states=latents,
                    timestep=timestep / 1000,
                    guidance=guidance,
                    pooled_projections=pooled_prompt_embeds,
                    encoder_hidden_states=prompt_embeds,
                    txt_ids=text_ids,
                    img_ids=latent_image_ids,
                    joint_attention_kwargs=self.joint_attention_kwargs,
                    return_dict=False,
                    return_hidden_states_list=True
                )
                noise_pred = noise_pred[0]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions