diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 38ce0ca42e..55460d26a3 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -32,6 +32,7 @@ from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline from QEfficient.diffusers.pipelines.wan.pipeline_wan import QEffWanPipeline from QEfficient.diffusers.pipelines.wan.pipeline_wan_i2v import QEffWanImageToVideoPipeline +from QEfficient.diffusers.pipelines.qwen_image.pipeline_qwenimage import QEffQwenImagePipeline from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv from QEfficient.peft import QEffAutoPeftModelForCausalLM @@ -61,6 +62,7 @@ "QEffFluxPipeline", "QEffWanPipeline", "QEffWanImageToVideoPipeline", + "QEffQwenImagePipeline", ] diff --git a/QEfficient/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/QEfficient/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py new file mode 100644 index 0000000000..ce0653fea0 --- /dev/null +++ b/QEfficient/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -0,0 +1,252 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Qeff modeling changes + - Changed upsampling mode from "nearest-exact" to "nearest" for ONNX compatibility. + - Used max(0, x.shape[2] - CACHE_T) instead of CACHE_T because x.shape[2] is either 1 or 4, + - CACHE_T = 2. This ensures the value never goes negative +""" + +import torch +from diffusers.models.autoencoders.autoencoder_kl_qwenimage import ( + QwenImageDecoder3d, + QwenImageEncoder3d, + QwenImageResample, + QwenImageResidualBlock, + QwenImageUpsample, +) + +CACHE_T = 2 + + +class QEffQwenImageResample(QwenImageResample): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __qeff_init__(self): + # Changed upsampling mode from "nearest-exact" to "nearest" for ONNX compatibility. + # Since the scale factor is an integer, both modes behave the + if self.mode in ("upsample2d", "upsample3d"): + self.resample[0] = QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest") + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class QEffQwenImageResidualBlock(QwenImageResidualBlock): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class QEffQwenImageEncoder3d(QwenImageEncoder3d): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class QEffQwenImageDecoder3d(QwenImageDecoder3d): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py index e0681b5bd6..d1194b6251 100644 --- a/QEfficient/diffusers/models/pytorch_transforms.py +++ b/QEfficient/diffusers/models/pytorch_transforms.py @@ -5,6 +5,13 @@ # # ----------------------------------------------------------------------------- +from diffusers.models.attention_processor import Attention +from diffusers.models.autoencoders.autoencoder_kl_qwenimage import ( + QwenImageDecoder3d, + QwenImageEncoder3d, + QwenImageResample, + QwenImageResidualBlock, +) from diffusers.models.autoencoders.autoencoder_kl_wan import ( AutoencoderKLWan, WanDecoder3d, @@ -20,11 +27,21 @@ FluxTransformer2DModel, FluxTransformerBlock, ) +from diffusers.models.transformers.transformer_qwenimage import ( + QwenImageTransformer2DModel, + QwenImageTransformerBlock, +) from diffusers.models.transformers.transformer_wan import WanAttention, WanAttnProcessor, WanTransformer3DModel from torch import nn from QEfficient.base.pytorch_transforms import ModuleMappingTransform from QEfficient.customop.rms_norm import CustomRMSNormAIC +from QEfficient.diffusers.models.autoencoders.autoencoder_kl_qwenimage import ( + QEffQwenImageDecoder3d, + QEffQwenImageEncoder3d, + QEffQwenImageResample, + QEffQwenImageResidualBlock, +) from QEfficient.diffusers.models.autoencoders.autoencoder_kl_wan import ( QEffAutoencoderKLWan, QEffWanDecoder3d, @@ -44,6 +61,11 @@ QEffFluxTransformer2DModel, QEffFluxTransformerBlock, ) +from QEfficient.diffusers.models.transformers.transformer_qwenimage import ( + QEffQwenImageAttention, + QEffQwenImageTransformer2DModel, + QEffQwenImageTransformerBlock, +) from QEfficient.diffusers.models.transformers.transformer_wan import ( QEffWanAttention, QEffWanAttnProcessor, @@ -69,10 +91,17 @@ class AttentionTransform(ModuleMappingTransform): WanAttention: QEffWanAttention, WanTransformer3DModel: QEffWanTransformer3DModel, AutoencoderKLWan: QEffAutoencoderKLWan, + QwenImageTransformer2DModel: QEffQwenImageTransformer2DModel, + QwenImageTransformerBlock: QEffQwenImageTransformerBlock, + Attention: QEffQwenImageAttention, WanDecoder3d: QEffWanDecoder3d, WanEncoder3d: QEffWanEncoder3d, WanResidualBlock: QEffWanResidualBlock, WanResample: QEffWanResample, + QwenImageResample: QEffQwenImageResample, + QwenImageResidualBlock: QEffQwenImageResidualBlock, + QwenImageEncoder3d: QEffQwenImageEncoder3d, + QwenImageDecoder3d: QEffQwenImageDecoder3d, } diff --git a/QEfficient/diffusers/models/transformers/transformer_qwenimage.py b/QEfficient/diffusers/models/transformers/transformer_qwenimage.py new file mode 100644 index 0000000000..08a99a0ea0 --- /dev/null +++ b/QEfficient/diffusers/models/transformers/transformer_qwenimage.py @@ -0,0 +1,554 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput +from diffusers.models.transformers.transformer_qwenimage import ( + QwenDoubleStreamAttnProcessor2_0, + QwenImageTransformer2DModel, + QwenImageTransformerBlock, +) + + +def qeff_apply_rotary_emb_qwen(x, freqs_cos, freqs_sin): + """ + Apply rotary embeddings to query/key tensors using cosine and sine tables. + + Args: + x (`torch.Tensor`): + Query or key tensor with shape `[B, S, H, D]`. + freqs_cos (`torch.Tensor`): + Cosine frequencies with shape `[S, D/2]`. + freqs_sin (`torch.Tensor`): + Sine frequencies with shape `[S, D/2]`. + + Returns: + `torch.Tensor`: + Tensor with rotary embedding applied, with the same shape and dtype as `x`. + """ + B, C, H, D = x.shape + x = x.float() + x_reshaped = x.reshape(B, -1, H, D // 2, 2) + x1 = x_reshaped[..., 0] # [B, S, H, D//2] + x2 = x_reshaped[..., 1] # [B, S, H, D//2] + + # Reshape for broadcasting: [S, D//2] -> [1, S, 1, D//2] + freqs_cos = freqs_cos.unsqueeze(0).unsqueeze(2) + freqs_sin = freqs_sin.unsqueeze(0).unsqueeze(2) + + # Apply rotation + x_out1 = x1 * freqs_cos - x2 * freqs_sin # Real part + x_out2 = x1 * freqs_sin + x2 * freqs_cos # Imaginary part + + # Stack and reshape back + x_out = torch.stack([x_out1, x_out2], dim=-1) # [B, S, H, D//2, 2] + x_out = x_out.flatten(-2) # [B, S, H, D] + return x_out.type_as(x) + + +class QEffQwenEmbedRope(nn.Module): + """ + Rotary embedding helper for Qwen Image video/text positional encodings. + + The module precomputes positive and negative frequency tables and returns + per-sample image and text RoPE tensors expected by the Qwen transformer. + """ + + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + """ + Initialize RoPE frequency caches. + + Args: + theta (`int`): + Base frequency used in rotary embedding computation. + axes_dim (`List[int]`): + RoPE dimensions for `(frame, height, width)` axes. + scale_rope (`bool`, *optional*, defaults to `False`): + Enables centered/negative indexing strategy for spatial axes. + """ + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.scale_rope = scale_rope + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + + # Store cos and sin separately instead of complex numbers + pos_freqs_list = [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ] + self.pos_freqs_cos = torch.cat([f[0] for f in pos_freqs_list], dim=1) + self.pos_freqs_sin = torch.cat([f[1] for f in pos_freqs_list], dim=1) + + neg_freqs_list = [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ] + self.neg_freqs_cos = torch.cat([f[0] for f in neg_freqs_list], dim=1) + self.neg_freqs_sin = torch.cat([f[1] for f in neg_freqs_list], dim=1) + + self.rope_cache = {} + + def rope_params(self, index, dim, theta=10000): + """ + Compute cosine/sine RoPE parameters for a single axis. + + Args: + index (`torch.Tensor`): + 1D position indices for the axis. + dim (`int`): + Rotary dimension for the axis; must be even. + theta (`int`, *optional*, defaults to `10000`): + Base frequency used for geometric progression. + + Returns: + Tuple[`torch.Tensor`, `torch.Tensor`]: + Cosine and sine tables for `index`. + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + # Return cos and sin separately instead of complex tensor + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + return freqs_cos, freqs_sin + + def _compute_video_freqs(self, frame, height, width, idx=0): + """ + Compute image RoPE cosine/sine tables for one `(frame, height, width)` tuple. + + Args: + frame (`int`): Number of latent frames. + height (`int`): Latent height. + width (`int`): Latent width. + idx (`int`, *optional*): Offset used for frame indexing. + + Returns: + Tuple[`torch.Tensor`, `torch.Tensor`]: + Cosine and sine tables with shape `[frame*height*width, rope_dim/2]`. + """ + seq_lens = frame * height * width + freqs_pos_cos = self.pos_freqs_cos.split([x // 2 for x in self.axes_dim], dim=1) + freqs_pos_sin = self.pos_freqs_sin.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg_cos = self.neg_freqs_cos.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg_sin = self.neg_freqs_sin.split([x // 2 for x in self.axes_dim], dim=1) + + # Frame dimension + freqs_frame_cos = freqs_pos_cos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + freqs_frame_sin = freqs_pos_sin[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + + if self.scale_rope: + freqs_height_cos = torch.cat( + [freqs_neg_cos[1][-(height - height // 2) :], freqs_pos_cos[1][: height // 2]], dim=0 + ) + freqs_height_sin = torch.cat( + [freqs_neg_sin[1][-(height - height // 2) :], freqs_pos_sin[1][: height // 2]], dim=0 + ) + freqs_height_cos = freqs_height_cos.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_height_sin = freqs_height_sin.view(1, height, 1, -1).expand(frame, height, width, -1) + + freqs_width_cos = torch.cat( + [freqs_neg_cos[2][-(width - width // 2) :], freqs_pos_cos[2][: width // 2]], dim=0 + ) + freqs_width_sin = torch.cat( + [freqs_neg_sin[2][-(width - width // 2) :], freqs_pos_sin[2][: width // 2]], dim=0 + ) + freqs_width_cos = freqs_width_cos.view(1, 1, width, -1).expand(frame, height, width, -1) + freqs_width_sin = freqs_width_sin.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height_cos = freqs_pos_cos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_height_sin = freqs_pos_sin[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width_cos = freqs_pos_cos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + freqs_width_sin = freqs_pos_sin[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs_cos = torch.cat([freqs_frame_cos, freqs_height_cos, freqs_width_cos], dim=-1).reshape(seq_lens, -1) + freqs_sin = torch.cat([freqs_frame_sin, freqs_height_sin, freqs_width_sin], dim=-1).reshape(seq_lens, -1) + + return freqs_cos.clone().contiguous(), freqs_sin.clone().contiguous() + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: + video_fhw: + Video latent shape description. Supports one or many + `(frame, height, width)` tuples. + txt_seq_lens: + Text sequence lengths for each sample in the batch. + device: + Target device for returned frequency tensors. + + Returns: + Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`, `torch.Tensor`]: + `(img_cos, img_sin, txt_cos, txt_sin)` RoPE frequency tables. + """ + if self.pos_freqs_cos.device != device: + self.pos_freqs_cos = self.pos_freqs_cos.to(device) + self.pos_freqs_sin = self.pos_freqs_sin.to(device) + self.neg_freqs_cos = self.neg_freqs_cos.to(device) + self.neg_freqs_sin = self.neg_freqs_sin.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs_cos_list = [] + vid_freqs_sin_list = [] + max_vid_index = 0 + + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + if not torch.compiler.is_compiling(): + if rope_key not in self.rope_cache: + self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) + video_freq_cos, video_freq_sin = self.rope_cache[rope_key] + else: + video_freq_cos, video_freq_sin = self._compute_video_freqs(frame, height, width, idx) + + video_freq_cos = video_freq_cos.to(device) + video_freq_sin = video_freq_sin.to(device) + vid_freqs_cos_list.append(video_freq_cos) + vid_freqs_sin_list.append(video_freq_sin) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs_cos = self.pos_freqs_cos[max_vid_index : max_vid_index + max_len, ...] + txt_freqs_sin = self.pos_freqs_sin[max_vid_index : max_vid_index + max_len, ...] + + vid_freqs_cos = torch.cat(vid_freqs_cos_list, dim=0) + vid_freqs_sin = torch.cat(vid_freqs_sin_list, dim=0) + + return vid_freqs_cos, vid_freqs_sin, txt_freqs_cos, txt_freqs_sin + + +class QEffQwenDoubleStreamAttnProcessor2_0(QwenDoubleStreamAttnProcessor2_0): + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, # Image stream + encoder_hidden_states: torch.FloatTensor = None, # Text stream + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + img_rotary_emb: torch.FloatTensor = None, + txt_rotary_emb: torch.FloatTensor = None, + ) -> torch.FloatTensor: + """ + Run joint text-image attention for a Qwen dual-stream block. + + Args: + attn (`Attention`): Attention module instance. + hidden_states (`torch.FloatTensor`): Image-stream states `[B, S_img, C]`. + encoder_hidden_states (`torch.FloatTensor`, *optional*): + Text-stream states `[B, S_txt, C]`. + encoder_hidden_states_mask (`torch.FloatTensor`, *optional*): + Text mask passed through the attention API. + attention_mask (`torch.FloatTensor`, *optional*): + Additional attention mask. + img_rotary_emb (`torch.Tensor`, *optional*): + Image RoPE tensor where last dim packs `(cos, sin)`. + txt_rotary_emb (`torch.Tensor`, *optional*): + Text RoPE tensor where last dim packs `(cos, sin)`. + + Returns: + Tuple[`torch.FloatTensor`, `torch.FloatTensor`]: + Image and text attention outputs. + """ + + if encoder_hidden_states is None: + raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") + + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream (sample projections) + img_query = attn.to_q(hidden_states) + img_key = attn.to_k(hidden_states) + img_value = attn.to_v(hidden_states) + + # Compute QKV for text stream (context projections) + encoder_hidden_states = encoder_hidden_states / 2 + txt_query = attn.add_q_proj(encoder_hidden_states) + txt_key = attn.add_k_proj(encoder_hidden_states) + txt_value = attn.add_v_proj(encoder_hidden_states) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (attn.heads, -1)) + img_key = img_key.unflatten(-1, (attn.heads, -1)) + img_value = img_value.unflatten(-1, (attn.heads, -1)) + + txt_query = txt_query.unflatten(-1, (attn.heads, -1)) + txt_key = txt_key.unflatten(-1, (attn.heads, -1)) + txt_value = txt_value.unflatten(-1, (attn.heads, -1)) + txt_value = txt_value * 2 + # Apply QK normalization + if attn.norm_q is not None: + img_query = attn.norm_q(img_query) + if attn.norm_k is not None: + img_key = attn.norm_k(img_key) + if attn.norm_added_q is not None: + txt_query = txt_query * 2 # FP32 #INP #MUL + txt_query = attn.norm_added_q(txt_query) + if attn.norm_added_k is not None: + txt_key = txt_key * 2 # FP32 #INP #MUL + txt_key = attn.norm_added_k(txt_key) + + # Apply RoPE + if img_rotary_emb is not None and txt_rotary_emb is not None: + # Unpack the 4 tensors (cos and sin for both img and txt) + img_freqs_cos, img_freqs_sin = torch.chunk(img_rotary_emb, 2, dim=-1) # [6032,64] each + txt_freqs_cos, txt_freqs_sin = torch.chunk(txt_rotary_emb, 2, dim=-1) # [126,64] each + + img_query = qeff_apply_rotary_emb_qwen(img_query, img_freqs_cos, img_freqs_sin) + img_key = qeff_apply_rotary_emb_qwen(img_key, img_freqs_cos, img_freqs_sin) + txt_query = qeff_apply_rotary_emb_qwen(txt_query, txt_freqs_cos, txt_freqs_sin) + txt_key = qeff_apply_rotary_emb_qwen(txt_key, txt_freqs_cos, txt_freqs_sin) + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + # Compute joint attention + joint_hidden_states = dispatch_attention_fn( + joint_query, + joint_key, + joint_value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + ) + + # Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + + img_attn_output = img_attn_output / 4 + txt_attn_output = txt_attn_output / 64 + + # Apply output projections + img_attn_output = attn.to_out[0](img_attn_output) + if len(attn.to_out) > 1: + img_attn_output = attn.to_out[1](img_attn_output) # dropout + + txt_attn_output = attn.to_add_out(txt_attn_output) + # FP32 #INP and MUL + img_attn_output = img_attn_output * 4 + txt_attn_output = txt_attn_output * 32 + return img_attn_output, txt_attn_output + + +class QEffQwenImageAttention(Attention): + def __qeff_init__(self): + self.processor = QEffQwenDoubleStreamAttnProcessor2_0() + + +class QEffQwenImageTransformerBlock(QwenImageTransformerBlock): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor, + temb: torch.Tensor, + img_rotary_emb: torch.Tensor = None, + txt_rotary_emb: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for one Qwen dual-stream transformer block. + + Args: + hidden_states (`torch.Tensor`): Image-stream hidden states. + encoder_hidden_states (`torch.Tensor`): Text-stream hidden states. + encoder_hidden_states_mask (`torch.Tensor`): Text attention mask. + temb (`torch.Tensor`): Timestep embedding. + img_rotary_emb (`torch.Tensor`, *optional*): Image RoPE frequencies. + txt_rotary_emb (`torch.Tensor`, *optional*): Text RoPE frequencies. + joint_attention_kwargs (`Dict[str, Any]`, *optional*): + Additional kwargs forwarded to the attention processor. + + Returns: + Tuple[`torch.Tensor`, `torch.Tensor`]: + Updated `(encoder_hidden_states, hidden_states)`. + """ + global sf_value + img_mod_params = self.img_mod(temb) # [B, 6*dim] + txt_mod_params = self.txt_mod(temb) # [B, 6*dim] + + # Split modulation parameters for norm1 and norm2 + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + + # Process image stream - norm1 + modulation + img_normed = self.img_norm1(hidden_states) # FP32 #INP #OUTPUT + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) # FP32 #INP + + # Process text stream - norm1 + modulation + txt_normed = self.txt_norm1(encoder_hidden_states) # FP32 #INP #OUTPUT + txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + + # Use QwenAttnProcessor2_0 for joint attention computation + # This directly implements the DoubleStreamLayerMegatron logic: + # 1. Computes QKV for both streams + # 2. Applies QK normalization and RoPE + # 3. Concatenates and runs joint attention + # 4. Splits results back to separate streams + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=img_modulated, # Image stream (will be processed as "sample") + encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") + encoder_hidden_states_mask=encoder_hidden_states_mask, + img_rotary_emb=img_rotary_emb, + txt_rotary_emb=txt_rotary_emb, + **joint_attention_kwargs, + ) + + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided + img_attn_output, txt_attn_output = attn_output + + # Apply attention gates and add residual (like in Megatron) + img_attn_output = img_attn_output / (sf_value * sf_value * 4) # FP32 + hidden_states = hidden_states / (sf_value * sf_value * 4) # FP32 + + hidden_states = hidden_states + img_gate1 * img_attn_output + + txt_attn_output = txt_attn_output / (sf_value * sf_value * 64) # FP32 + encoder_hidden_states = encoder_hidden_states / (sf_value * sf_value * 64) # FP32 + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + # Process image stream - norm2 + MLP + hidden_states = hidden_states * (sf_value * sf_value * 4) # FP32 + img_normed2 = self.img_norm2(hidden_states) # FP32 #INP #OUT + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) # FP32 #INP #OUT + img_modulated2 = img_modulated2 / (sf_value) # FP32 + img_mlp_output = self.img_mlp(img_modulated2) # FP16 + img_mlp_output = img_mlp_output / (sf_value * 4) # FP16 + hidden_states = hidden_states / (sf_value * sf_value * 4) # FP32 + hidden_states = hidden_states + img_gate2 * img_mlp_output + + # Process text stream - norm2 + MLP + encoder_hidden_states = encoder_hidden_states * (sf_value * sf_value * 64) # FP32 + txt_normed2 = self.txt_norm2(encoder_hidden_states) + txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + txt_modulated2 = txt_modulated2 / (sf_value) # FP32 + txt_mlp_output = self.txt_mlp(txt_modulated2) # FP16 + txt_mlp_output = txt_mlp_output / (sf_value * 64) # FP16 + encoder_hidden_states = encoder_hidden_states / (sf_value * sf_value * 64) # FP32 + encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output + + hidden_states = hidden_states * (sf_value * sf_value * 4) # FP32 + encoder_hidden_states = encoder_hidden_states * (sf_value * sf_value * 64) # FP32 + + return encoder_hidden_states, hidden_states + + +class QEffQwenImageTransformer2DModel(QwenImageTransformer2DModel): + def __qeff_init__(self): + self.pos_embed = QEffQwenEmbedRope(theta=10000, axes_dim=list(self.axes_dims_rope), scale_rope=True) + + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {QEffQwenImageTransformerBlock} + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + timestep: torch.LongTensor = None, + txt_seq_lens: torch.Tensor = None, + img_rotary_emb: torch.Tensor = None, + txt_rotary_emb: torch.Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`QwenTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): + Mask of the input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # Convert scalar tensors to Python integers and create img_shapes list + global sf_value + # Convert txt_seq_lens to list if it's a tensor + if isinstance(txt_seq_lens, torch.Tensor): + txt_seq_lens = txt_seq_lens.tolist() + + hidden_states = self.img_in(hidden_states) + + timestep = timestep.to(hidden_states.dtype) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + temb = self.time_text_embed(timestep, hidden_states) + + for index_block, block in enumerate(self.transformer_blocks): + if index_block < 59: + sf_value = 32 + else: + sf_value = 256 + + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + img_rotary_emb=img_rotary_emb, + txt_rotary_emb=txt_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + + encoder_hidden_states = encoder_hidden_states / (sf_value * sf_value * 64) + hidden_states = hidden_states / (sf_value * sf_value * 4) + + # Use only the image part (hidden_states) from the dual-stream blocks + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/QEfficient/diffusers/pipelines/configs/qwen_config.json b/QEfficient/diffusers/pipelines/configs/qwen_config.json new file mode 100644 index 0000000000..1fc052ad09 --- /dev/null +++ b/QEfficient/diffusers/pipelines/configs/qwen_config.json @@ -0,0 +1,51 @@ +{ + "description": "Default configuration for QWEN image", + "modules": { + "transformer": { + "specializations": { + "batch_size":"1", + "seq_length":"128" + + }, + "compilation": { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 4, + "convert_to_fp16": true, + "compile_only":true, + "aic_num_cores": 16, + "mos": 1, + "mdts_mos": 1, + "mxfp6_matmul": true, + "node_precision_info": "QEfficient/diffusers/pipelines/configs/qwen_image.yaml" + }, + "execute": { + "device_ids": null, + "qpc_path" : null + } + }, + "vae_decoder":{ + "specializations":{ + "batch_size": 1, + "num_channels": 16 + }, + "compilation": { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 4, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "aic-enable-depth-first": true, + "compile_only":true, + "mos": 1, + "mdts_mos": 1 + }, + "execute": + { + "device_ids": null, + "qpc_path" : null + } + } + } +} \ No newline at end of file diff --git a/QEfficient/diffusers/pipelines/configs/qwen_image.yaml b/QEfficient/diffusers/pipelines/configs/qwen_image.yaml new file mode 100644 index 0000000000..d1b1c975af --- /dev/null +++ b/QEfficient/diffusers/pipelines/configs/qwen_image.yaml @@ -0,0 +1 @@ +FP32NodeInstanceNames: ['/transformer_blocks.0/Div_5_output_0', '/transformer_blocks.0/Add_8_output_0', '/txt_norm/CustomRMSNorm_output_0', '/transformer_blocks.0/Div_7_output_0', '/transformer_blocks.0/img_norm1/LayerNormalization_output_0', '/transformer_blocks.0/img_norm1/LayerNormalization_output_0', '/transformer_blocks.0/Div_5_output_0', '/transformer_blocks.0/Add_8_output_0', '/transformer_blocks.0/Mul_7_output_0', '/transformer_blocks.0/Add_4_output_0', '/transformer_blocks.0/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.0/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.0/Div_7_output_0', '/transformer_blocks.0/Add_9_output_0', '/transformer_blocks.0/Mul_11_output_0', '/transformer_blocks.0/Add_7_output_0', '/transformer_blocks.0/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.0/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.0/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.0/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.0/attn/Mul_1_output_0', '/transformer_blocks.0/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.0/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.0/attn/Mul_2_output_0', '/transformer_blocks.0/attn/Mul_21_output_0', '/transformer_blocks.0/Div_4_output_0', '/transformer_blocks.0/Mul_12_output_0', '/transformer_blocks.0/attn/Mul_22_output_0', '/transformer_blocks.0/Div_6_output_0', '/transformer_blocks.0/Mul_13_output_0', '/transformer_blocks.0/img_norm2/LayerNormalization_output_0', '/transformer_blocks.0/img_norm2/LayerNormalization_output_0', '/transformer_blocks.0/Mul_14_output_0', '/transformer_blocks.0/Div_11_output_0', '/transformer_blocks.0/Add_13_output_0', '/transformer_blocks.0/Mul_18_output_0', '/transformer_blocks.0/Add_12_output_0', '/transformer_blocks.0/Div_9_output_0', '/transformer_blocks.0/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.0/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.0/Mul_20_output_0', '/transformer_blocks.0/Div_15_output_0', '/transformer_blocks.0/Add_17_output_0', '/transformer_blocks.0/Mul_24_output_0', '/transformer_blocks.0/Add_16_output_0', '/transformer_blocks.0/Div_13_output_0', '/transformer_blocks.1/img_norm1/LayerNormalization_output_0', '/transformer_blocks.1/img_norm1/LayerNormalization_output_0', '/transformer_blocks.0/Mul_26_output_0', '/transformer_blocks.1/Div_5_output_0', '/transformer_blocks.1/Add_8_output_0', '/transformer_blocks.1/Mul_7_output_0', '/transformer_blocks.1/Add_4_output_0', '/transformer_blocks.1/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.1/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.0/Mul_27_output_0', '/transformer_blocks.1/Div_7_output_0', '/transformer_blocks.1/Add_9_output_0', '/transformer_blocks.1/Mul_11_output_0', '/transformer_blocks.1/Add_7_output_0', '/transformer_blocks.1/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.1/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.1/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.1/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.1/attn/Mul_1_output_0', '/transformer_blocks.1/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.1/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.1/attn/Mul_2_output_0', '/transformer_blocks.1/attn/Mul_21_output_0', '/transformer_blocks.1/Div_4_output_0', '/transformer_blocks.1/Mul_12_output_0', '/transformer_blocks.1/attn/Mul_22_output_0', '/transformer_blocks.1/Div_6_output_0', '/transformer_blocks.1/Mul_13_output_0', '/transformer_blocks.1/img_norm2/LayerNormalization_output_0', '/transformer_blocks.1/img_norm2/LayerNormalization_output_0', '/transformer_blocks.1/Mul_14_output_0', '/transformer_blocks.1/Div_11_output_0', '/transformer_blocks.1/Add_13_output_0', '/transformer_blocks.1/Mul_18_output_0', '/transformer_blocks.1/Add_12_output_0', '/transformer_blocks.1/Div_9_output_0', '/transformer_blocks.1/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.1/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.1/Mul_20_output_0', '/transformer_blocks.1/Div_15_output_0', '/transformer_blocks.1/Add_17_output_0', '/transformer_blocks.1/Mul_24_output_0', '/transformer_blocks.1/Add_16_output_0', '/transformer_blocks.1/Div_13_output_0', '/transformer_blocks.2/img_norm1/LayerNormalization_output_0', '/transformer_blocks.2/img_norm1/LayerNormalization_output_0', '/transformer_blocks.1/Mul_26_output_0', '/transformer_blocks.2/Div_5_output_0', '/transformer_blocks.2/Add_8_output_0', '/transformer_blocks.2/Mul_7_output_0', '/transformer_blocks.2/Add_4_output_0', '/transformer_blocks.2/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.2/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.1/Mul_27_output_0', '/transformer_blocks.2/Div_7_output_0', '/transformer_blocks.2/Add_9_output_0', '/transformer_blocks.2/Mul_11_output_0', '/transformer_blocks.2/Add_7_output_0', '/transformer_blocks.2/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.2/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.2/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.2/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.2/attn/Mul_1_output_0', '/transformer_blocks.2/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.2/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.2/attn/Mul_2_output_0', '/transformer_blocks.2/attn/Mul_21_output_0', '/transformer_blocks.2/Div_4_output_0', '/transformer_blocks.2/Mul_12_output_0', '/transformer_blocks.2/attn/Mul_22_output_0', '/transformer_blocks.2/Div_6_output_0', '/transformer_blocks.2/Mul_13_output_0', '/transformer_blocks.2/img_norm2/LayerNormalization_output_0', '/transformer_blocks.2/img_norm2/LayerNormalization_output_0', '/transformer_blocks.2/Mul_14_output_0', '/transformer_blocks.2/Div_11_output_0', '/transformer_blocks.2/Add_13_output_0', '/transformer_blocks.2/Mul_18_output_0', '/transformer_blocks.2/Add_12_output_0', '/transformer_blocks.2/Div_9_output_0', '/transformer_blocks.2/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.2/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.2/Mul_20_output_0', '/transformer_blocks.2/Div_15_output_0', '/transformer_blocks.2/Add_17_output_0', '/transformer_blocks.2/Mul_24_output_0', '/transformer_blocks.2/Add_16_output_0', '/transformer_blocks.2/Div_13_output_0', '/transformer_blocks.3/img_norm1/LayerNormalization_output_0', '/transformer_blocks.3/img_norm1/LayerNormalization_output_0', '/transformer_blocks.2/Mul_26_output_0', '/transformer_blocks.3/Div_5_output_0', '/transformer_blocks.3/Add_8_output_0', '/transformer_blocks.3/Mul_7_output_0', '/transformer_blocks.3/Add_4_output_0', '/transformer_blocks.3/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.3/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.2/Mul_27_output_0', '/transformer_blocks.3/Div_7_output_0', '/transformer_blocks.3/Add_9_output_0', '/transformer_blocks.3/Mul_11_output_0', '/transformer_blocks.3/Add_7_output_0', '/transformer_blocks.3/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.3/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.3/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.3/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.3/attn/Mul_1_output_0', '/transformer_blocks.3/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.3/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.3/attn/Mul_2_output_0', '/transformer_blocks.3/attn/Mul_21_output_0', '/transformer_blocks.3/Div_4_output_0', '/transformer_blocks.3/Mul_12_output_0', '/transformer_blocks.3/attn/Mul_22_output_0', '/transformer_blocks.3/Div_6_output_0', '/transformer_blocks.3/Mul_13_output_0', '/transformer_blocks.3/img_norm2/LayerNormalization_output_0', '/transformer_blocks.3/img_norm2/LayerNormalization_output_0', '/transformer_blocks.3/Mul_14_output_0', '/transformer_blocks.3/Div_11_output_0', '/transformer_blocks.3/Add_13_output_0', '/transformer_blocks.3/Mul_18_output_0', '/transformer_blocks.3/Add_12_output_0', '/transformer_blocks.3/Div_9_output_0', '/transformer_blocks.3/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.3/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.3/Mul_20_output_0', '/transformer_blocks.3/Div_15_output_0', '/transformer_blocks.3/Add_17_output_0', '/transformer_blocks.3/Mul_24_output_0', '/transformer_blocks.3/Add_16_output_0', '/transformer_blocks.3/Div_13_output_0', '/transformer_blocks.4/img_norm1/LayerNormalization_output_0', '/transformer_blocks.4/img_norm1/LayerNormalization_output_0', '/transformer_blocks.3/Mul_26_output_0', '/transformer_blocks.4/Div_5_output_0', '/transformer_blocks.4/Add_8_output_0', '/transformer_blocks.4/Mul_7_output_0', '/transformer_blocks.4/Add_4_output_0', '/transformer_blocks.4/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.4/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.3/Mul_27_output_0', '/transformer_blocks.4/Div_7_output_0', '/transformer_blocks.4/Add_9_output_0', '/transformer_blocks.4/Mul_11_output_0', '/transformer_blocks.4/Add_7_output_0', '/transformer_blocks.4/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.4/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.4/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.4/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.4/attn/Mul_1_output_0', '/transformer_blocks.4/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.4/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.4/attn/Mul_2_output_0', '/transformer_blocks.4/attn/Mul_21_output_0', '/transformer_blocks.4/Div_4_output_0', '/transformer_blocks.4/Mul_12_output_0', '/transformer_blocks.4/attn/Mul_22_output_0', '/transformer_blocks.4/Div_6_output_0', '/transformer_blocks.4/Mul_13_output_0', '/transformer_blocks.4/img_norm2/LayerNormalization_output_0', '/transformer_blocks.4/img_norm2/LayerNormalization_output_0', '/transformer_blocks.4/Mul_14_output_0', '/transformer_blocks.4/Div_11_output_0', '/transformer_blocks.4/Add_13_output_0', '/transformer_blocks.4/Mul_18_output_0', '/transformer_blocks.4/Add_12_output_0', '/transformer_blocks.4/Div_9_output_0', '/transformer_blocks.4/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.4/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.4/Mul_20_output_0', '/transformer_blocks.4/Div_15_output_0', '/transformer_blocks.4/Add_17_output_0', '/transformer_blocks.4/Mul_24_output_0', '/transformer_blocks.4/Add_16_output_0', '/transformer_blocks.4/Div_13_output_0', '/transformer_blocks.5/img_norm1/LayerNormalization_output_0', '/transformer_blocks.5/img_norm1/LayerNormalization_output_0', '/transformer_blocks.4/Mul_26_output_0', '/transformer_blocks.5/Div_5_output_0', '/transformer_blocks.5/Add_8_output_0', '/transformer_blocks.5/Mul_7_output_0', '/transformer_blocks.5/Add_4_output_0', '/transformer_blocks.5/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.5/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.4/Mul_27_output_0', '/transformer_blocks.5/Div_7_output_0', '/transformer_blocks.5/Add_9_output_0', '/transformer_blocks.5/Mul_11_output_0', '/transformer_blocks.5/Add_7_output_0', '/transformer_blocks.5/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.5/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.5/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.5/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.5/attn/Mul_1_output_0', '/transformer_blocks.5/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.5/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.5/attn/Mul_2_output_0', '/transformer_blocks.5/attn/Mul_21_output_0', '/transformer_blocks.5/Div_4_output_0', '/transformer_blocks.5/Mul_12_output_0', '/transformer_blocks.5/attn/Mul_22_output_0', '/transformer_blocks.5/Div_6_output_0', '/transformer_blocks.5/Mul_13_output_0', '/transformer_blocks.5/img_norm2/LayerNormalization_output_0', '/transformer_blocks.5/img_norm2/LayerNormalization_output_0', '/transformer_blocks.5/Mul_14_output_0', '/transformer_blocks.5/Div_11_output_0', '/transformer_blocks.5/Add_13_output_0', '/transformer_blocks.5/Mul_18_output_0', '/transformer_blocks.5/Add_12_output_0', '/transformer_blocks.5/Div_9_output_0', '/transformer_blocks.5/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.5/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.5/Mul_20_output_0', '/transformer_blocks.5/Div_15_output_0', '/transformer_blocks.5/Add_17_output_0', '/transformer_blocks.5/Mul_24_output_0', '/transformer_blocks.5/Add_16_output_0', '/transformer_blocks.5/Div_13_output_0', '/transformer_blocks.6/img_norm1/LayerNormalization_output_0', '/transformer_blocks.6/img_norm1/LayerNormalization_output_0', '/transformer_blocks.5/Mul_26_output_0', '/transformer_blocks.6/Div_5_output_0', '/transformer_blocks.6/Add_8_output_0', '/transformer_blocks.6/Mul_7_output_0', '/transformer_blocks.6/Add_4_output_0', '/transformer_blocks.6/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.6/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.5/Mul_27_output_0', '/transformer_blocks.6/Div_7_output_0', '/transformer_blocks.6/Add_9_output_0', '/transformer_blocks.6/Mul_11_output_0', '/transformer_blocks.6/Add_7_output_0', '/transformer_blocks.6/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.6/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.6/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.6/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.6/attn/Mul_1_output_0', '/transformer_blocks.6/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.6/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.6/attn/Mul_2_output_0', '/transformer_blocks.6/attn/Mul_21_output_0', '/transformer_blocks.6/Div_4_output_0', '/transformer_blocks.6/Mul_12_output_0', '/transformer_blocks.6/attn/Mul_22_output_0', '/transformer_blocks.6/Div_6_output_0', '/transformer_blocks.6/Mul_13_output_0', '/transformer_blocks.6/img_norm2/LayerNormalization_output_0', '/transformer_blocks.6/img_norm2/LayerNormalization_output_0', '/transformer_blocks.6/Mul_14_output_0', '/transformer_blocks.6/Div_11_output_0', '/transformer_blocks.6/Add_13_output_0', '/transformer_blocks.6/Mul_18_output_0', '/transformer_blocks.6/Add_12_output_0', '/transformer_blocks.6/Div_9_output_0', '/transformer_blocks.6/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.6/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.6/Mul_20_output_0', '/transformer_blocks.6/Div_15_output_0', '/transformer_blocks.6/Add_17_output_0', '/transformer_blocks.6/Mul_24_output_0', '/transformer_blocks.6/Add_16_output_0', '/transformer_blocks.6/Div_13_output_0', '/transformer_blocks.7/img_norm1/LayerNormalization_output_0', '/transformer_blocks.7/img_norm1/LayerNormalization_output_0', '/transformer_blocks.6/Mul_26_output_0', '/transformer_blocks.7/Div_5_output_0', '/transformer_blocks.7/Add_8_output_0', '/transformer_blocks.7/Mul_7_output_0', '/transformer_blocks.7/Add_4_output_0', '/transformer_blocks.7/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.7/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.6/Mul_27_output_0', '/transformer_blocks.7/Div_7_output_0', '/transformer_blocks.7/Add_9_output_0', '/transformer_blocks.7/Mul_11_output_0', '/transformer_blocks.7/Add_7_output_0', '/transformer_blocks.7/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.7/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.7/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.7/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.7/attn/Mul_1_output_0', '/transformer_blocks.7/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.7/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.7/attn/Mul_2_output_0', '/transformer_blocks.7/attn/Mul_21_output_0', '/transformer_blocks.7/Div_4_output_0', '/transformer_blocks.7/Mul_12_output_0', '/transformer_blocks.7/attn/Mul_22_output_0', '/transformer_blocks.7/Div_6_output_0', '/transformer_blocks.7/Mul_13_output_0', '/transformer_blocks.7/img_norm2/LayerNormalization_output_0', '/transformer_blocks.7/img_norm2/LayerNormalization_output_0', '/transformer_blocks.7/Mul_14_output_0', '/transformer_blocks.7/Div_11_output_0', '/transformer_blocks.7/Add_13_output_0', '/transformer_blocks.7/Mul_18_output_0', '/transformer_blocks.7/Add_12_output_0', '/transformer_blocks.7/Div_9_output_0', '/transformer_blocks.7/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.7/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.7/Mul_20_output_0', '/transformer_blocks.7/Div_15_output_0', '/transformer_blocks.7/Add_17_output_0', '/transformer_blocks.7/Mul_24_output_0', '/transformer_blocks.7/Add_16_output_0', '/transformer_blocks.7/Div_13_output_0', '/transformer_blocks.8/img_norm1/LayerNormalization_output_0', '/transformer_blocks.8/img_norm1/LayerNormalization_output_0', '/transformer_blocks.7/Mul_26_output_0', '/transformer_blocks.8/Div_5_output_0', '/transformer_blocks.8/Add_8_output_0', '/transformer_blocks.8/Mul_7_output_0', '/transformer_blocks.8/Add_4_output_0', '/transformer_blocks.8/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.8/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.7/Mul_27_output_0', '/transformer_blocks.8/Div_7_output_0', '/transformer_blocks.8/Add_9_output_0', '/transformer_blocks.8/Mul_11_output_0', '/transformer_blocks.8/Add_7_output_0', '/transformer_blocks.8/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.8/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.8/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.8/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.8/attn/Mul_1_output_0', '/transformer_blocks.8/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.8/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.8/attn/Mul_2_output_0', '/transformer_blocks.8/attn/Mul_21_output_0', '/transformer_blocks.8/Div_4_output_0', '/transformer_blocks.8/Mul_12_output_0', '/transformer_blocks.8/attn/Mul_22_output_0', '/transformer_blocks.8/Div_6_output_0', '/transformer_blocks.8/Mul_13_output_0', '/transformer_blocks.8/img_norm2/LayerNormalization_output_0', '/transformer_blocks.8/img_norm2/LayerNormalization_output_0', '/transformer_blocks.8/Mul_14_output_0', '/transformer_blocks.8/Div_11_output_0', '/transformer_blocks.8/Add_13_output_0', '/transformer_blocks.8/Mul_18_output_0', '/transformer_blocks.8/Add_12_output_0', '/transformer_blocks.8/Div_9_output_0', '/transformer_blocks.8/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.8/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.8/Mul_20_output_0', '/transformer_blocks.8/Div_15_output_0', '/transformer_blocks.8/Add_17_output_0', '/transformer_blocks.8/Mul_24_output_0', '/transformer_blocks.8/Add_16_output_0', '/transformer_blocks.8/Div_13_output_0', '/transformer_blocks.9/img_norm1/LayerNormalization_output_0', '/transformer_blocks.9/img_norm1/LayerNormalization_output_0', '/transformer_blocks.8/Mul_26_output_0', '/transformer_blocks.9/Div_5_output_0', '/transformer_blocks.9/Add_8_output_0', '/transformer_blocks.9/Mul_7_output_0', '/transformer_blocks.9/Add_4_output_0', '/transformer_blocks.9/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.9/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.8/Mul_27_output_0', '/transformer_blocks.9/Div_7_output_0', '/transformer_blocks.9/Add_9_output_0', '/transformer_blocks.9/Mul_11_output_0', '/transformer_blocks.9/Add_7_output_0', '/transformer_blocks.9/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.9/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.9/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.9/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.9/attn/Mul_1_output_0', '/transformer_blocks.9/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.9/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.9/attn/Mul_2_output_0', '/transformer_blocks.9/attn/Mul_21_output_0', '/transformer_blocks.9/Div_4_output_0', '/transformer_blocks.9/Mul_12_output_0', '/transformer_blocks.9/attn/Mul_22_output_0', '/transformer_blocks.9/Div_6_output_0', '/transformer_blocks.9/Mul_13_output_0', '/transformer_blocks.9/img_norm2/LayerNormalization_output_0', '/transformer_blocks.9/img_norm2/LayerNormalization_output_0', '/transformer_blocks.9/Mul_14_output_0', '/transformer_blocks.9/Div_11_output_0', '/transformer_blocks.9/Add_13_output_0', '/transformer_blocks.9/Mul_18_output_0', '/transformer_blocks.9/Add_12_output_0', '/transformer_blocks.9/Div_9_output_0', '/transformer_blocks.9/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.9/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.9/Mul_20_output_0', '/transformer_blocks.9/Div_15_output_0', '/transformer_blocks.9/Add_17_output_0', '/transformer_blocks.9/Mul_24_output_0', '/transformer_blocks.9/Add_16_output_0', '/transformer_blocks.9/Div_13_output_0', '/transformer_blocks.10/img_norm1/LayerNormalization_output_0', '/transformer_blocks.10/img_norm1/LayerNormalization_output_0', '/transformer_blocks.9/Mul_26_output_0', '/transformer_blocks.10/Div_5_output_0', '/transformer_blocks.10/Add_8_output_0', '/transformer_blocks.10/Mul_7_output_0', '/transformer_blocks.10/Add_4_output_0', '/transformer_blocks.10/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.10/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.9/Mul_27_output_0', '/transformer_blocks.10/Div_7_output_0', '/transformer_blocks.10/Add_9_output_0', '/transformer_blocks.10/Mul_11_output_0', '/transformer_blocks.10/Add_7_output_0', '/transformer_blocks.10/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.10/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.10/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.10/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.10/attn/Mul_1_output_0', '/transformer_blocks.10/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.10/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.10/attn/Mul_2_output_0', '/transformer_blocks.10/attn/Mul_21_output_0', '/transformer_blocks.10/Div_4_output_0', '/transformer_blocks.10/Mul_12_output_0', '/transformer_blocks.10/attn/Mul_22_output_0', '/transformer_blocks.10/Div_6_output_0', '/transformer_blocks.10/Mul_13_output_0', '/transformer_blocks.10/img_norm2/LayerNormalization_output_0', '/transformer_blocks.10/img_norm2/LayerNormalization_output_0', '/transformer_blocks.10/Mul_14_output_0', '/transformer_blocks.10/Div_11_output_0', '/transformer_blocks.10/Add_13_output_0', '/transformer_blocks.10/Mul_18_output_0', '/transformer_blocks.10/Add_12_output_0', '/transformer_blocks.10/Div_9_output_0', '/transformer_blocks.10/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.10/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.10/Mul_20_output_0', '/transformer_blocks.10/Div_15_output_0', '/transformer_blocks.10/Add_17_output_0', '/transformer_blocks.10/Mul_24_output_0', '/transformer_blocks.10/Add_16_output_0', '/transformer_blocks.10/Div_13_output_0', '/transformer_blocks.11/img_norm1/LayerNormalization_output_0', '/transformer_blocks.11/img_norm1/LayerNormalization_output_0', '/transformer_blocks.10/Mul_26_output_0', '/transformer_blocks.11/Div_5_output_0', '/transformer_blocks.11/Add_8_output_0', '/transformer_blocks.11/Mul_7_output_0', '/transformer_blocks.11/Add_4_output_0', '/transformer_blocks.11/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.11/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.10/Mul_27_output_0', '/transformer_blocks.11/Div_7_output_0', '/transformer_blocks.11/Add_9_output_0', '/transformer_blocks.11/Mul_11_output_0', '/transformer_blocks.11/Add_7_output_0', '/transformer_blocks.11/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.11/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.11/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.11/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.11/attn/Mul_1_output_0', '/transformer_blocks.11/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.11/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.11/attn/Mul_2_output_0', '/transformer_blocks.11/attn/Mul_21_output_0', '/transformer_blocks.11/Div_4_output_0', '/transformer_blocks.11/Mul_12_output_0', '/transformer_blocks.11/attn/Mul_22_output_0', '/transformer_blocks.11/Div_6_output_0', '/transformer_blocks.11/Mul_13_output_0', '/transformer_blocks.11/img_norm2/LayerNormalization_output_0', '/transformer_blocks.11/img_norm2/LayerNormalization_output_0', '/transformer_blocks.11/Mul_14_output_0', '/transformer_blocks.11/Div_11_output_0', '/transformer_blocks.11/Add_13_output_0', '/transformer_blocks.11/Mul_18_output_0', '/transformer_blocks.11/Add_12_output_0', '/transformer_blocks.11/Div_9_output_0', '/transformer_blocks.11/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.11/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.11/Mul_20_output_0', '/transformer_blocks.11/Div_15_output_0', '/transformer_blocks.11/Add_17_output_0', '/transformer_blocks.11/Mul_24_output_0', '/transformer_blocks.11/Add_16_output_0', '/transformer_blocks.11/Div_13_output_0', '/transformer_blocks.12/img_norm1/LayerNormalization_output_0', '/transformer_blocks.12/img_norm1/LayerNormalization_output_0', '/transformer_blocks.11/Mul_26_output_0', '/transformer_blocks.12/Div_5_output_0', '/transformer_blocks.12/Add_8_output_0', '/transformer_blocks.12/Mul_7_output_0', '/transformer_blocks.12/Add_4_output_0', '/transformer_blocks.12/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.12/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.11/Mul_27_output_0', '/transformer_blocks.12/Div_7_output_0', '/transformer_blocks.12/Add_9_output_0', '/transformer_blocks.12/Mul_11_output_0', '/transformer_blocks.12/Add_7_output_0', '/transformer_blocks.12/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.12/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.12/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.12/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.12/attn/Mul_1_output_0', '/transformer_blocks.12/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.12/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.12/attn/Mul_2_output_0', '/transformer_blocks.12/attn/Mul_21_output_0', '/transformer_blocks.12/Div_4_output_0', '/transformer_blocks.12/Mul_12_output_0', '/transformer_blocks.12/attn/Mul_22_output_0', '/transformer_blocks.12/Div_6_output_0', '/transformer_blocks.12/Mul_13_output_0', '/transformer_blocks.12/img_norm2/LayerNormalization_output_0', '/transformer_blocks.12/img_norm2/LayerNormalization_output_0', '/transformer_blocks.12/Mul_14_output_0', '/transformer_blocks.12/Div_11_output_0', '/transformer_blocks.12/Add_13_output_0', '/transformer_blocks.12/Mul_18_output_0', '/transformer_blocks.12/Add_12_output_0', '/transformer_blocks.12/Div_9_output_0', '/transformer_blocks.12/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.12/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.12/Mul_20_output_0', '/transformer_blocks.12/Div_15_output_0', '/transformer_blocks.12/Add_17_output_0', '/transformer_blocks.12/Mul_24_output_0', '/transformer_blocks.12/Add_16_output_0', '/transformer_blocks.12/Div_13_output_0', '/transformer_blocks.13/img_norm1/LayerNormalization_output_0', '/transformer_blocks.13/img_norm1/LayerNormalization_output_0', '/transformer_blocks.12/Mul_26_output_0', '/transformer_blocks.13/Div_5_output_0', '/transformer_blocks.13/Add_8_output_0', '/transformer_blocks.13/Mul_7_output_0', '/transformer_blocks.13/Add_4_output_0', '/transformer_blocks.13/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.13/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.12/Mul_27_output_0', '/transformer_blocks.13/Div_7_output_0', '/transformer_blocks.13/Add_9_output_0', '/transformer_blocks.13/Mul_11_output_0', '/transformer_blocks.13/Add_7_output_0', '/transformer_blocks.13/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.13/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.13/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.13/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.13/attn/Mul_1_output_0', '/transformer_blocks.13/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.13/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.13/attn/Mul_2_output_0', '/transformer_blocks.13/attn/Mul_21_output_0', '/transformer_blocks.13/Div_4_output_0', '/transformer_blocks.13/Mul_12_output_0', '/transformer_blocks.13/attn/Mul_22_output_0', '/transformer_blocks.13/Div_6_output_0', '/transformer_blocks.13/Mul_13_output_0', '/transformer_blocks.13/img_norm2/LayerNormalization_output_0', '/transformer_blocks.13/img_norm2/LayerNormalization_output_0', '/transformer_blocks.13/Mul_14_output_0', '/transformer_blocks.13/Div_11_output_0', '/transformer_blocks.13/Add_13_output_0', '/transformer_blocks.13/Mul_18_output_0', '/transformer_blocks.13/Add_12_output_0', '/transformer_blocks.13/Div_9_output_0', '/transformer_blocks.13/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.13/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.13/Mul_20_output_0', '/transformer_blocks.13/Div_15_output_0', '/transformer_blocks.13/Add_17_output_0', '/transformer_blocks.13/Mul_24_output_0', '/transformer_blocks.13/Add_16_output_0', '/transformer_blocks.13/Div_13_output_0', '/transformer_blocks.14/img_norm1/LayerNormalization_output_0', '/transformer_blocks.14/img_norm1/LayerNormalization_output_0', '/transformer_blocks.13/Mul_26_output_0', '/transformer_blocks.14/Div_5_output_0', '/transformer_blocks.14/Add_8_output_0', '/transformer_blocks.14/Mul_7_output_0', '/transformer_blocks.14/Add_4_output_0', '/transformer_blocks.14/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.14/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.13/Mul_27_output_0', '/transformer_blocks.14/Div_7_output_0', '/transformer_blocks.14/Add_9_output_0', '/transformer_blocks.14/Mul_11_output_0', '/transformer_blocks.14/Add_7_output_0', '/transformer_blocks.14/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.14/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.14/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.14/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.14/attn/Mul_1_output_0', '/transformer_blocks.14/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.14/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.14/attn/Mul_2_output_0', '/transformer_blocks.14/attn/Mul_21_output_0', '/transformer_blocks.14/Div_4_output_0', '/transformer_blocks.14/Mul_12_output_0', '/transformer_blocks.14/attn/Mul_22_output_0', '/transformer_blocks.14/Div_6_output_0', '/transformer_blocks.14/Mul_13_output_0', '/transformer_blocks.14/img_norm2/LayerNormalization_output_0', '/transformer_blocks.14/img_norm2/LayerNormalization_output_0', '/transformer_blocks.14/Mul_14_output_0', '/transformer_blocks.14/Div_11_output_0', '/transformer_blocks.14/Add_13_output_0', '/transformer_blocks.14/Mul_18_output_0', '/transformer_blocks.14/Add_12_output_0', '/transformer_blocks.14/Div_9_output_0', '/transformer_blocks.14/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.14/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.14/Mul_20_output_0', '/transformer_blocks.14/Div_15_output_0', '/transformer_blocks.14/Add_17_output_0', '/transformer_blocks.14/Mul_24_output_0', '/transformer_blocks.14/Add_16_output_0', '/transformer_blocks.14/Div_13_output_0', '/transformer_blocks.15/img_norm1/LayerNormalization_output_0', '/transformer_blocks.15/img_norm1/LayerNormalization_output_0', '/transformer_blocks.14/Mul_26_output_0', '/transformer_blocks.15/Div_5_output_0', '/transformer_blocks.15/Add_8_output_0', '/transformer_blocks.15/Mul_7_output_0', '/transformer_blocks.15/Add_4_output_0', '/transformer_blocks.15/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.15/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.14/Mul_27_output_0', '/transformer_blocks.15/Div_7_output_0', '/transformer_blocks.15/Add_9_output_0', '/transformer_blocks.15/Mul_11_output_0', '/transformer_blocks.15/Add_7_output_0', '/transformer_blocks.15/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.15/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.15/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.15/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.15/attn/Mul_1_output_0', '/transformer_blocks.15/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.15/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.15/attn/Mul_2_output_0', '/transformer_blocks.15/attn/Mul_21_output_0', '/transformer_blocks.15/Div_4_output_0', '/transformer_blocks.15/Mul_12_output_0', '/transformer_blocks.15/attn/Mul_22_output_0', '/transformer_blocks.15/Div_6_output_0', '/transformer_blocks.15/Mul_13_output_0', '/transformer_blocks.15/img_norm2/LayerNormalization_output_0', '/transformer_blocks.15/img_norm2/LayerNormalization_output_0', '/transformer_blocks.15/Mul_14_output_0', '/transformer_blocks.15/Div_11_output_0', '/transformer_blocks.15/Add_13_output_0', '/transformer_blocks.15/Mul_18_output_0', '/transformer_blocks.15/Add_12_output_0', '/transformer_blocks.15/Div_9_output_0', '/transformer_blocks.15/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.15/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.15/Mul_20_output_0', '/transformer_blocks.15/Div_15_output_0', '/transformer_blocks.15/Add_17_output_0', '/transformer_blocks.15/Mul_24_output_0', '/transformer_blocks.15/Add_16_output_0', '/transformer_blocks.15/Div_13_output_0', '/transformer_blocks.16/img_norm1/LayerNormalization_output_0', '/transformer_blocks.16/img_norm1/LayerNormalization_output_0', '/transformer_blocks.15/Mul_26_output_0', '/transformer_blocks.16/Div_5_output_0', '/transformer_blocks.16/Add_8_output_0', '/transformer_blocks.16/Mul_7_output_0', '/transformer_blocks.16/Add_4_output_0', '/transformer_blocks.16/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.16/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.15/Mul_27_output_0', '/transformer_blocks.16/Div_7_output_0', '/transformer_blocks.16/Add_9_output_0', '/transformer_blocks.16/Mul_11_output_0', '/transformer_blocks.16/Add_7_output_0', '/transformer_blocks.16/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.16/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.16/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.16/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.16/attn/Mul_1_output_0', '/transformer_blocks.16/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.16/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.16/attn/Mul_2_output_0', '/transformer_blocks.16/attn/Mul_21_output_0', '/transformer_blocks.16/Div_4_output_0', '/transformer_blocks.16/Mul_12_output_0', '/transformer_blocks.16/attn/Mul_22_output_0', '/transformer_blocks.16/Div_6_output_0', '/transformer_blocks.16/Mul_13_output_0', '/transformer_blocks.16/img_norm2/LayerNormalization_output_0', '/transformer_blocks.16/img_norm2/LayerNormalization_output_0', '/transformer_blocks.16/Mul_14_output_0', '/transformer_blocks.16/Div_11_output_0', '/transformer_blocks.16/Add_13_output_0', '/transformer_blocks.16/Mul_18_output_0', '/transformer_blocks.16/Add_12_output_0', '/transformer_blocks.16/Div_9_output_0', '/transformer_blocks.16/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.16/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.16/Mul_20_output_0', '/transformer_blocks.16/Div_15_output_0', '/transformer_blocks.16/Add_17_output_0', '/transformer_blocks.16/Mul_24_output_0', '/transformer_blocks.16/Add_16_output_0', '/transformer_blocks.16/Div_13_output_0', '/transformer_blocks.17/img_norm1/LayerNormalization_output_0', '/transformer_blocks.17/img_norm1/LayerNormalization_output_0', '/transformer_blocks.16/Mul_26_output_0', '/transformer_blocks.17/Div_5_output_0', '/transformer_blocks.17/Add_8_output_0', '/transformer_blocks.17/Mul_7_output_0', '/transformer_blocks.17/Add_4_output_0', '/transformer_blocks.17/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.17/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.16/Mul_27_output_0', '/transformer_blocks.17/Div_7_output_0', '/transformer_blocks.17/Add_9_output_0', '/transformer_blocks.17/Mul_11_output_0', '/transformer_blocks.17/Add_7_output_0', '/transformer_blocks.17/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.17/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.17/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.17/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.17/attn/Mul_1_output_0', '/transformer_blocks.17/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.17/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.17/attn/Mul_2_output_0', '/transformer_blocks.17/attn/Mul_21_output_0', '/transformer_blocks.17/Div_4_output_0', '/transformer_blocks.17/Mul_12_output_0', '/transformer_blocks.17/attn/Mul_22_output_0', '/transformer_blocks.17/Div_6_output_0', '/transformer_blocks.17/Mul_13_output_0', '/transformer_blocks.17/img_norm2/LayerNormalization_output_0', '/transformer_blocks.17/img_norm2/LayerNormalization_output_0', '/transformer_blocks.17/Mul_14_output_0', '/transformer_blocks.17/Div_11_output_0', '/transformer_blocks.17/Add_13_output_0', '/transformer_blocks.17/Mul_18_output_0', '/transformer_blocks.17/Add_12_output_0', '/transformer_blocks.17/Div_9_output_0', '/transformer_blocks.17/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.17/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.17/Mul_20_output_0', '/transformer_blocks.17/Div_15_output_0', '/transformer_blocks.17/Add_17_output_0', '/transformer_blocks.17/Mul_24_output_0', '/transformer_blocks.17/Add_16_output_0', '/transformer_blocks.17/Div_13_output_0', '/transformer_blocks.18/img_norm1/LayerNormalization_output_0', '/transformer_blocks.18/img_norm1/LayerNormalization_output_0', '/transformer_blocks.17/Mul_26_output_0', '/transformer_blocks.18/Div_5_output_0', '/transformer_blocks.18/Add_8_output_0', '/transformer_blocks.18/Mul_7_output_0', '/transformer_blocks.18/Add_4_output_0', '/transformer_blocks.18/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.18/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.17/Mul_27_output_0', '/transformer_blocks.18/Div_7_output_0', '/transformer_blocks.18/Add_9_output_0', '/transformer_blocks.18/Mul_11_output_0', '/transformer_blocks.18/Add_7_output_0', '/transformer_blocks.18/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.18/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.18/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.18/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.18/attn/Mul_1_output_0', '/transformer_blocks.18/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.18/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.18/attn/Mul_2_output_0', '/transformer_blocks.18/attn/Mul_21_output_0', '/transformer_blocks.18/Div_4_output_0', '/transformer_blocks.18/Mul_12_output_0', '/transformer_blocks.18/attn/Mul_22_output_0', '/transformer_blocks.18/Div_6_output_0', '/transformer_blocks.18/Mul_13_output_0', '/transformer_blocks.18/img_norm2/LayerNormalization_output_0', '/transformer_blocks.18/img_norm2/LayerNormalization_output_0', '/transformer_blocks.18/Mul_14_output_0', '/transformer_blocks.18/Div_11_output_0', '/transformer_blocks.18/Add_13_output_0', '/transformer_blocks.18/Mul_18_output_0', '/transformer_blocks.18/Add_12_output_0', '/transformer_blocks.18/Div_9_output_0', '/transformer_blocks.18/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.18/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.18/Mul_20_output_0', '/transformer_blocks.18/Div_15_output_0', '/transformer_blocks.18/Add_17_output_0', '/transformer_blocks.18/Mul_24_output_0', '/transformer_blocks.18/Add_16_output_0', '/transformer_blocks.18/Div_13_output_0', '/transformer_blocks.19/img_norm1/LayerNormalization_output_0', '/transformer_blocks.19/img_norm1/LayerNormalization_output_0', '/transformer_blocks.18/Mul_26_output_0', '/transformer_blocks.19/Div_5_output_0', '/transformer_blocks.19/Add_8_output_0', '/transformer_blocks.19/Mul_7_output_0', '/transformer_blocks.19/Add_4_output_0', '/transformer_blocks.19/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.19/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.18/Mul_27_output_0', '/transformer_blocks.19/Div_7_output_0', '/transformer_blocks.19/Add_9_output_0', '/transformer_blocks.19/Mul_11_output_0', '/transformer_blocks.19/Add_7_output_0', '/transformer_blocks.19/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.19/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.19/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.19/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.19/attn/Mul_1_output_0', '/transformer_blocks.19/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.19/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.19/attn/Mul_2_output_0', '/transformer_blocks.19/attn/Mul_21_output_0', '/transformer_blocks.19/Div_4_output_0', '/transformer_blocks.19/Mul_12_output_0', '/transformer_blocks.19/attn/Mul_22_output_0', '/transformer_blocks.19/Div_6_output_0', '/transformer_blocks.19/Mul_13_output_0', '/transformer_blocks.19/img_norm2/LayerNormalization_output_0', '/transformer_blocks.19/img_norm2/LayerNormalization_output_0', '/transformer_blocks.19/Mul_14_output_0', '/transformer_blocks.19/Div_11_output_0', '/transformer_blocks.19/Add_13_output_0', '/transformer_blocks.19/Mul_18_output_0', '/transformer_blocks.19/Add_12_output_0', '/transformer_blocks.19/Div_9_output_0', '/transformer_blocks.19/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.19/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.19/Mul_20_output_0', '/transformer_blocks.19/Div_15_output_0', '/transformer_blocks.19/Add_17_output_0', '/transformer_blocks.19/Mul_24_output_0', '/transformer_blocks.19/Add_16_output_0', '/transformer_blocks.19/Div_13_output_0', '/transformer_blocks.20/img_norm1/LayerNormalization_output_0', '/transformer_blocks.20/img_norm1/LayerNormalization_output_0', '/transformer_blocks.19/Mul_26_output_0', '/transformer_blocks.20/Div_5_output_0', '/transformer_blocks.20/Add_8_output_0', '/transformer_blocks.20/Mul_7_output_0', '/transformer_blocks.20/Add_4_output_0', '/transformer_blocks.20/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.20/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.19/Mul_27_output_0', '/transformer_blocks.20/Div_7_output_0', '/transformer_blocks.20/Add_9_output_0', '/transformer_blocks.20/Mul_11_output_0', '/transformer_blocks.20/Add_7_output_0', '/transformer_blocks.20/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.20/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.20/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.20/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.20/attn/Mul_1_output_0', '/transformer_blocks.20/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.20/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.20/attn/Mul_2_output_0', '/transformer_blocks.20/attn/Mul_21_output_0', '/transformer_blocks.20/Div_4_output_0', '/transformer_blocks.20/Mul_12_output_0', '/transformer_blocks.20/attn/Mul_22_output_0', '/transformer_blocks.20/Div_6_output_0', '/transformer_blocks.20/Mul_13_output_0', '/transformer_blocks.20/img_norm2/LayerNormalization_output_0', '/transformer_blocks.20/img_norm2/LayerNormalization_output_0', '/transformer_blocks.20/Mul_14_output_0', '/transformer_blocks.20/Div_11_output_0', '/transformer_blocks.20/Add_13_output_0', '/transformer_blocks.20/Mul_18_output_0', '/transformer_blocks.20/Add_12_output_0', '/transformer_blocks.20/Div_9_output_0', '/transformer_blocks.20/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.20/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.20/Mul_20_output_0', '/transformer_blocks.20/Div_15_output_0', '/transformer_blocks.20/Add_17_output_0', '/transformer_blocks.20/Mul_24_output_0', '/transformer_blocks.20/Add_16_output_0', '/transformer_blocks.20/Div_13_output_0', '/transformer_blocks.21/img_norm1/LayerNormalization_output_0', '/transformer_blocks.21/img_norm1/LayerNormalization_output_0', '/transformer_blocks.20/Mul_26_output_0', '/transformer_blocks.21/Div_5_output_0', '/transformer_blocks.21/Add_8_output_0', '/transformer_blocks.21/Mul_7_output_0', '/transformer_blocks.21/Add_4_output_0', '/transformer_blocks.21/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.21/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.20/Mul_27_output_0', '/transformer_blocks.21/Div_7_output_0', '/transformer_blocks.21/Add_9_output_0', '/transformer_blocks.21/Mul_11_output_0', '/transformer_blocks.21/Add_7_output_0', '/transformer_blocks.21/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.21/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.21/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.21/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.21/attn/Mul_1_output_0', '/transformer_blocks.21/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.21/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.21/attn/Mul_2_output_0', '/transformer_blocks.21/attn/Mul_21_output_0', '/transformer_blocks.21/Div_4_output_0', '/transformer_blocks.21/Mul_12_output_0', '/transformer_blocks.21/attn/Mul_22_output_0', '/transformer_blocks.21/Div_6_output_0', '/transformer_blocks.21/Mul_13_output_0', '/transformer_blocks.21/img_norm2/LayerNormalization_output_0', '/transformer_blocks.21/img_norm2/LayerNormalization_output_0', '/transformer_blocks.21/Mul_14_output_0', '/transformer_blocks.21/Div_11_output_0', '/transformer_blocks.21/Add_13_output_0', '/transformer_blocks.21/Mul_18_output_0', '/transformer_blocks.21/Add_12_output_0', '/transformer_blocks.21/Div_9_output_0', '/transformer_blocks.21/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.21/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.21/Mul_20_output_0', '/transformer_blocks.21/Div_15_output_0', '/transformer_blocks.21/Add_17_output_0', '/transformer_blocks.21/Mul_24_output_0', '/transformer_blocks.21/Add_16_output_0', '/transformer_blocks.21/Div_13_output_0', '/transformer_blocks.22/img_norm1/LayerNormalization_output_0', '/transformer_blocks.22/img_norm1/LayerNormalization_output_0', '/transformer_blocks.21/Mul_26_output_0', '/transformer_blocks.22/Div_5_output_0', '/transformer_blocks.22/Add_8_output_0', '/transformer_blocks.22/Mul_7_output_0', '/transformer_blocks.22/Add_4_output_0', '/transformer_blocks.22/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.22/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.21/Mul_27_output_0', '/transformer_blocks.22/Div_7_output_0', '/transformer_blocks.22/Add_9_output_0', '/transformer_blocks.22/Mul_11_output_0', '/transformer_blocks.22/Add_7_output_0', '/transformer_blocks.22/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.22/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.22/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.22/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.22/attn/Mul_1_output_0', '/transformer_blocks.22/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.22/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.22/attn/Mul_2_output_0', '/transformer_blocks.22/attn/Mul_21_output_0', '/transformer_blocks.22/Div_4_output_0', '/transformer_blocks.22/Mul_12_output_0', '/transformer_blocks.22/attn/Mul_22_output_0', '/transformer_blocks.22/Div_6_output_0', '/transformer_blocks.22/Mul_13_output_0', '/transformer_blocks.22/img_norm2/LayerNormalization_output_0', '/transformer_blocks.22/img_norm2/LayerNormalization_output_0', '/transformer_blocks.22/Mul_14_output_0', '/transformer_blocks.22/Div_11_output_0', '/transformer_blocks.22/Add_13_output_0', '/transformer_blocks.22/Mul_18_output_0', '/transformer_blocks.22/Add_12_output_0', '/transformer_blocks.22/Div_9_output_0', '/transformer_blocks.22/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.22/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.22/Mul_20_output_0', '/transformer_blocks.22/Div_15_output_0', '/transformer_blocks.22/Add_17_output_0', '/transformer_blocks.22/Mul_24_output_0', '/transformer_blocks.22/Add_16_output_0', '/transformer_blocks.22/Div_13_output_0', '/transformer_blocks.23/img_norm1/LayerNormalization_output_0', '/transformer_blocks.23/img_norm1/LayerNormalization_output_0', '/transformer_blocks.22/Mul_26_output_0', '/transformer_blocks.23/Div_5_output_0', '/transformer_blocks.23/Add_8_output_0', '/transformer_blocks.23/Mul_7_output_0', '/transformer_blocks.23/Add_4_output_0', '/transformer_blocks.23/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.23/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.22/Mul_27_output_0', '/transformer_blocks.23/Div_7_output_0', '/transformer_blocks.23/Add_9_output_0', '/transformer_blocks.23/Mul_11_output_0', '/transformer_blocks.23/Add_7_output_0', '/transformer_blocks.23/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.23/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.23/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.23/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.23/attn/Mul_1_output_0', '/transformer_blocks.23/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.23/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.23/attn/Mul_2_output_0', '/transformer_blocks.23/attn/Mul_21_output_0', '/transformer_blocks.23/Div_4_output_0', '/transformer_blocks.23/Mul_12_output_0', '/transformer_blocks.23/attn/Mul_22_output_0', '/transformer_blocks.23/Div_6_output_0', '/transformer_blocks.23/Mul_13_output_0', '/transformer_blocks.23/img_norm2/LayerNormalization_output_0', '/transformer_blocks.23/img_norm2/LayerNormalization_output_0', '/transformer_blocks.23/Mul_14_output_0', '/transformer_blocks.23/Div_11_output_0', '/transformer_blocks.23/Add_13_output_0', '/transformer_blocks.23/Mul_18_output_0', '/transformer_blocks.23/Add_12_output_0', '/transformer_blocks.23/Div_9_output_0', '/transformer_blocks.23/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.23/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.23/Mul_20_output_0', '/transformer_blocks.23/Div_15_output_0', '/transformer_blocks.23/Add_17_output_0', '/transformer_blocks.23/Mul_24_output_0', '/transformer_blocks.23/Add_16_output_0', '/transformer_blocks.23/Div_13_output_0', '/transformer_blocks.24/img_norm1/LayerNormalization_output_0', '/transformer_blocks.24/img_norm1/LayerNormalization_output_0', '/transformer_blocks.23/Mul_26_output_0', '/transformer_blocks.24/Div_5_output_0', '/transformer_blocks.24/Add_8_output_0', '/transformer_blocks.24/Mul_7_output_0', '/transformer_blocks.24/Add_4_output_0', '/transformer_blocks.24/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.24/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.23/Mul_27_output_0', '/transformer_blocks.24/Div_7_output_0', '/transformer_blocks.24/Add_9_output_0', '/transformer_blocks.24/Mul_11_output_0', '/transformer_blocks.24/Add_7_output_0', '/transformer_blocks.24/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.24/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.24/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.24/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.24/attn/Mul_1_output_0', '/transformer_blocks.24/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.24/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.24/attn/Mul_2_output_0', '/transformer_blocks.24/attn/Mul_21_output_0', '/transformer_blocks.24/Div_4_output_0', '/transformer_blocks.24/Mul_12_output_0', '/transformer_blocks.24/attn/Mul_22_output_0', '/transformer_blocks.24/Div_6_output_0', '/transformer_blocks.24/Mul_13_output_0', '/transformer_blocks.24/img_norm2/LayerNormalization_output_0', '/transformer_blocks.24/img_norm2/LayerNormalization_output_0', '/transformer_blocks.24/Mul_14_output_0', '/transformer_blocks.24/Div_11_output_0', '/transformer_blocks.24/Add_13_output_0', '/transformer_blocks.24/Mul_18_output_0', '/transformer_blocks.24/Add_12_output_0', '/transformer_blocks.24/Div_9_output_0', '/transformer_blocks.24/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.24/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.24/Mul_20_output_0', '/transformer_blocks.24/Div_15_output_0', '/transformer_blocks.24/Add_17_output_0', '/transformer_blocks.24/Mul_24_output_0', '/transformer_blocks.24/Add_16_output_0', '/transformer_blocks.24/Div_13_output_0', '/transformer_blocks.25/img_norm1/LayerNormalization_output_0', '/transformer_blocks.25/img_norm1/LayerNormalization_output_0', '/transformer_blocks.24/Mul_26_output_0', '/transformer_blocks.25/Div_5_output_0', '/transformer_blocks.25/Add_8_output_0', '/transformer_blocks.25/Mul_7_output_0', '/transformer_blocks.25/Add_4_output_0', '/transformer_blocks.25/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.25/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.24/Mul_27_output_0', '/transformer_blocks.25/Div_7_output_0', '/transformer_blocks.25/Add_9_output_0', '/transformer_blocks.25/Mul_11_output_0', '/transformer_blocks.25/Add_7_output_0', '/transformer_blocks.25/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.25/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.25/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.25/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.25/attn/Mul_1_output_0', '/transformer_blocks.25/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.25/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.25/attn/Mul_2_output_0', '/transformer_blocks.25/attn/Mul_21_output_0', '/transformer_blocks.25/Div_4_output_0', '/transformer_blocks.25/Mul_12_output_0', '/transformer_blocks.25/attn/Mul_22_output_0', '/transformer_blocks.25/Div_6_output_0', '/transformer_blocks.25/Mul_13_output_0', '/transformer_blocks.25/img_norm2/LayerNormalization_output_0', '/transformer_blocks.25/img_norm2/LayerNormalization_output_0', '/transformer_blocks.25/Mul_14_output_0', '/transformer_blocks.25/Div_11_output_0', '/transformer_blocks.25/Add_13_output_0', '/transformer_blocks.25/Mul_18_output_0', '/transformer_blocks.25/Add_12_output_0', '/transformer_blocks.25/Div_9_output_0', '/transformer_blocks.25/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.25/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.25/Mul_20_output_0', '/transformer_blocks.25/Div_15_output_0', '/transformer_blocks.25/Add_17_output_0', '/transformer_blocks.25/Mul_24_output_0', '/transformer_blocks.25/Add_16_output_0', '/transformer_blocks.25/Div_13_output_0', '/transformer_blocks.26/img_norm1/LayerNormalization_output_0', '/transformer_blocks.26/img_norm1/LayerNormalization_output_0', '/transformer_blocks.25/Mul_26_output_0', '/transformer_blocks.26/Div_5_output_0', '/transformer_blocks.26/Add_8_output_0', '/transformer_blocks.26/Mul_7_output_0', '/transformer_blocks.26/Add_4_output_0', '/transformer_blocks.26/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.26/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.25/Mul_27_output_0', '/transformer_blocks.26/Div_7_output_0', '/transformer_blocks.26/Add_9_output_0', '/transformer_blocks.26/Mul_11_output_0', '/transformer_blocks.26/Add_7_output_0', '/transformer_blocks.26/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.26/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.26/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.26/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.26/attn/Mul_1_output_0', '/transformer_blocks.26/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.26/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.26/attn/Mul_2_output_0', '/transformer_blocks.26/attn/Mul_21_output_0', '/transformer_blocks.26/Div_4_output_0', '/transformer_blocks.26/Mul_12_output_0', '/transformer_blocks.26/attn/Mul_22_output_0', '/transformer_blocks.26/Div_6_output_0', '/transformer_blocks.26/Mul_13_output_0', '/transformer_blocks.26/img_norm2/LayerNormalization_output_0', '/transformer_blocks.26/img_norm2/LayerNormalization_output_0', '/transformer_blocks.26/Mul_14_output_0', '/transformer_blocks.26/Div_11_output_0', '/transformer_blocks.26/Add_13_output_0', '/transformer_blocks.26/Mul_18_output_0', '/transformer_blocks.26/Add_12_output_0', '/transformer_blocks.26/Div_9_output_0', '/transformer_blocks.26/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.26/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.26/Mul_20_output_0', '/transformer_blocks.26/Div_15_output_0', '/transformer_blocks.26/Add_17_output_0', '/transformer_blocks.26/Mul_24_output_0', '/transformer_blocks.26/Add_16_output_0', '/transformer_blocks.26/Div_13_output_0', '/transformer_blocks.27/img_norm1/LayerNormalization_output_0', '/transformer_blocks.27/img_norm1/LayerNormalization_output_0', '/transformer_blocks.26/Mul_26_output_0', '/transformer_blocks.27/Div_5_output_0', '/transformer_blocks.27/Add_8_output_0', '/transformer_blocks.27/Mul_7_output_0', '/transformer_blocks.27/Add_4_output_0', '/transformer_blocks.27/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.27/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.26/Mul_27_output_0', '/transformer_blocks.27/Div_7_output_0', '/transformer_blocks.27/Add_9_output_0', '/transformer_blocks.27/Mul_11_output_0', '/transformer_blocks.27/Add_7_output_0', '/transformer_blocks.27/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.27/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.27/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.27/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.27/attn/Mul_1_output_0', '/transformer_blocks.27/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.27/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.27/attn/Mul_2_output_0', '/transformer_blocks.27/attn/Mul_21_output_0', '/transformer_blocks.27/Div_4_output_0', '/transformer_blocks.27/Mul_12_output_0', '/transformer_blocks.27/attn/Mul_22_output_0', '/transformer_blocks.27/Div_6_output_0', '/transformer_blocks.27/Mul_13_output_0', '/transformer_blocks.27/img_norm2/LayerNormalization_output_0', '/transformer_blocks.27/img_norm2/LayerNormalization_output_0', '/transformer_blocks.27/Mul_14_output_0', '/transformer_blocks.27/Div_11_output_0', '/transformer_blocks.27/Add_13_output_0', '/transformer_blocks.27/Mul_18_output_0', '/transformer_blocks.27/Add_12_output_0', '/transformer_blocks.27/Div_9_output_0', '/transformer_blocks.27/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.27/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.27/Mul_20_output_0', '/transformer_blocks.27/Div_15_output_0', '/transformer_blocks.27/Add_17_output_0', '/transformer_blocks.27/Mul_24_output_0', '/transformer_blocks.27/Add_16_output_0', '/transformer_blocks.27/Div_13_output_0', '/transformer_blocks.28/img_norm1/LayerNormalization_output_0', '/transformer_blocks.28/img_norm1/LayerNormalization_output_0', '/transformer_blocks.27/Mul_26_output_0', '/transformer_blocks.28/Div_5_output_0', '/transformer_blocks.28/Add_8_output_0', '/transformer_blocks.28/Mul_7_output_0', '/transformer_blocks.28/Add_4_output_0', '/transformer_blocks.28/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.28/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.27/Mul_27_output_0', '/transformer_blocks.28/Div_7_output_0', '/transformer_blocks.28/Add_9_output_0', '/transformer_blocks.28/Mul_11_output_0', '/transformer_blocks.28/Add_7_output_0', '/transformer_blocks.28/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.28/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.28/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.28/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.28/attn/Mul_1_output_0', '/transformer_blocks.28/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.28/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.28/attn/Mul_2_output_0', '/transformer_blocks.28/attn/Mul_21_output_0', '/transformer_blocks.28/Div_4_output_0', '/transformer_blocks.28/Mul_12_output_0', '/transformer_blocks.28/attn/Mul_22_output_0', '/transformer_blocks.28/Div_6_output_0', '/transformer_blocks.28/Mul_13_output_0', '/transformer_blocks.28/img_norm2/LayerNormalization_output_0', '/transformer_blocks.28/img_norm2/LayerNormalization_output_0', '/transformer_blocks.28/Mul_14_output_0', '/transformer_blocks.28/Div_11_output_0', '/transformer_blocks.28/Add_13_output_0', '/transformer_blocks.28/Mul_18_output_0', '/transformer_blocks.28/Add_12_output_0', '/transformer_blocks.28/Div_9_output_0', '/transformer_blocks.28/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.28/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.28/Mul_20_output_0', '/transformer_blocks.28/Div_15_output_0', '/transformer_blocks.28/Add_17_output_0', '/transformer_blocks.28/Mul_24_output_0', '/transformer_blocks.28/Add_16_output_0', '/transformer_blocks.28/Div_13_output_0', '/transformer_blocks.29/img_norm1/LayerNormalization_output_0', '/transformer_blocks.29/img_norm1/LayerNormalization_output_0', '/transformer_blocks.28/Mul_26_output_0', '/transformer_blocks.29/Div_5_output_0', '/transformer_blocks.29/Add_8_output_0', '/transformer_blocks.29/Mul_7_output_0', '/transformer_blocks.29/Add_4_output_0', '/transformer_blocks.29/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.29/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.28/Mul_27_output_0', '/transformer_blocks.29/Div_7_output_0', '/transformer_blocks.29/Add_9_output_0', '/transformer_blocks.29/Mul_11_output_0', '/transformer_blocks.29/Add_7_output_0', '/transformer_blocks.29/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.29/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.29/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.29/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.29/attn/Mul_1_output_0', '/transformer_blocks.29/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.29/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.29/attn/Mul_2_output_0', '/transformer_blocks.29/attn/Mul_21_output_0', '/transformer_blocks.29/Div_4_output_0', '/transformer_blocks.29/Mul_12_output_0', '/transformer_blocks.29/attn/Mul_22_output_0', '/transformer_blocks.29/Div_6_output_0', '/transformer_blocks.29/Mul_13_output_0', '/transformer_blocks.29/img_norm2/LayerNormalization_output_0', '/transformer_blocks.29/img_norm2/LayerNormalization_output_0', '/transformer_blocks.29/Mul_14_output_0', '/transformer_blocks.29/Div_11_output_0', '/transformer_blocks.29/Add_13_output_0', '/transformer_blocks.29/Mul_18_output_0', '/transformer_blocks.29/Add_12_output_0', '/transformer_blocks.29/Div_9_output_0', '/transformer_blocks.29/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.29/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.29/Mul_20_output_0', '/transformer_blocks.29/Div_15_output_0', '/transformer_blocks.29/Add_17_output_0', '/transformer_blocks.29/Mul_24_output_0', '/transformer_blocks.29/Add_16_output_0', '/transformer_blocks.29/Div_13_output_0', '/transformer_blocks.30/img_norm1/LayerNormalization_output_0', '/transformer_blocks.30/img_norm1/LayerNormalization_output_0', '/transformer_blocks.29/Mul_26_output_0', '/transformer_blocks.30/Div_5_output_0', '/transformer_blocks.30/Add_8_output_0', '/transformer_blocks.30/Mul_7_output_0', '/transformer_blocks.30/Add_4_output_0', '/transformer_blocks.30/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.30/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.29/Mul_27_output_0', '/transformer_blocks.30/Div_7_output_0', '/transformer_blocks.30/Add_9_output_0', '/transformer_blocks.30/Mul_11_output_0', '/transformer_blocks.30/Add_7_output_0', '/transformer_blocks.30/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.30/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.30/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.30/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.30/attn/Mul_1_output_0', '/transformer_blocks.30/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.30/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.30/attn/Mul_2_output_0', '/transformer_blocks.30/attn/Mul_21_output_0', '/transformer_blocks.30/Div_4_output_0', '/transformer_blocks.30/Mul_12_output_0', '/transformer_blocks.30/attn/Mul_22_output_0', '/transformer_blocks.30/Div_6_output_0', '/transformer_blocks.30/Mul_13_output_0', '/transformer_blocks.30/img_norm2/LayerNormalization_output_0', '/transformer_blocks.30/img_norm2/LayerNormalization_output_0', '/transformer_blocks.30/Mul_14_output_0', '/transformer_blocks.30/Div_11_output_0', '/transformer_blocks.30/Add_13_output_0', '/transformer_blocks.30/Mul_18_output_0', '/transformer_blocks.30/Add_12_output_0', '/transformer_blocks.30/Div_9_output_0', '/transformer_blocks.30/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.30/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.30/Mul_20_output_0', '/transformer_blocks.30/Div_15_output_0', '/transformer_blocks.30/Add_17_output_0', '/transformer_blocks.30/Mul_24_output_0', '/transformer_blocks.30/Add_16_output_0', '/transformer_blocks.30/Div_13_output_0', '/transformer_blocks.31/img_norm1/LayerNormalization_output_0', '/transformer_blocks.31/img_norm1/LayerNormalization_output_0', '/transformer_blocks.30/Mul_26_output_0', '/transformer_blocks.31/Div_5_output_0', '/transformer_blocks.31/Add_8_output_0', '/transformer_blocks.31/Mul_7_output_0', '/transformer_blocks.31/Add_4_output_0', '/transformer_blocks.31/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.31/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.30/Mul_27_output_0', '/transformer_blocks.31/Div_7_output_0', '/transformer_blocks.31/Add_9_output_0', '/transformer_blocks.31/Mul_11_output_0', '/transformer_blocks.31/Add_7_output_0', '/transformer_blocks.31/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.31/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.31/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.31/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.31/attn/Mul_1_output_0', '/transformer_blocks.31/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.31/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.31/attn/Mul_2_output_0', '/transformer_blocks.31/attn/Mul_21_output_0', '/transformer_blocks.31/Div_4_output_0', '/transformer_blocks.31/Mul_12_output_0', '/transformer_blocks.31/attn/Mul_22_output_0', '/transformer_blocks.31/Div_6_output_0', '/transformer_blocks.31/Mul_13_output_0', '/transformer_blocks.31/img_norm2/LayerNormalization_output_0', '/transformer_blocks.31/img_norm2/LayerNormalization_output_0', '/transformer_blocks.31/Mul_14_output_0', '/transformer_blocks.31/Div_11_output_0', '/transformer_blocks.31/Add_13_output_0', '/transformer_blocks.31/Mul_18_output_0', '/transformer_blocks.31/Add_12_output_0', '/transformer_blocks.31/Div_9_output_0', '/transformer_blocks.31/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.31/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.31/Mul_20_output_0', '/transformer_blocks.31/Div_15_output_0', '/transformer_blocks.31/Add_17_output_0', '/transformer_blocks.31/Mul_24_output_0', '/transformer_blocks.31/Add_16_output_0', '/transformer_blocks.31/Div_13_output_0', '/transformer_blocks.32/img_norm1/LayerNormalization_output_0', '/transformer_blocks.32/img_norm1/LayerNormalization_output_0', '/transformer_blocks.31/Mul_26_output_0', '/transformer_blocks.32/Div_5_output_0', '/transformer_blocks.32/Add_8_output_0', '/transformer_blocks.32/Mul_7_output_0', '/transformer_blocks.32/Add_4_output_0', '/transformer_blocks.32/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.32/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.31/Mul_27_output_0', '/transformer_blocks.32/Div_7_output_0', '/transformer_blocks.32/Add_9_output_0', '/transformer_blocks.32/Mul_11_output_0', '/transformer_blocks.32/Add_7_output_0', '/transformer_blocks.32/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.32/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.32/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.32/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.32/attn/Mul_1_output_0', '/transformer_blocks.32/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.32/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.32/attn/Mul_2_output_0', '/transformer_blocks.32/attn/Mul_21_output_0', '/transformer_blocks.32/Div_4_output_0', '/transformer_blocks.32/Mul_12_output_0', '/transformer_blocks.32/attn/Mul_22_output_0', '/transformer_blocks.32/Div_6_output_0', '/transformer_blocks.32/Mul_13_output_0', '/transformer_blocks.32/img_norm2/LayerNormalization_output_0', '/transformer_blocks.32/img_norm2/LayerNormalization_output_0', '/transformer_blocks.32/Mul_14_output_0', '/transformer_blocks.32/Div_11_output_0', '/transformer_blocks.32/Add_13_output_0', '/transformer_blocks.32/Mul_18_output_0', '/transformer_blocks.32/Add_12_output_0', '/transformer_blocks.32/Div_9_output_0', '/transformer_blocks.32/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.32/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.32/Mul_20_output_0', '/transformer_blocks.32/Div_15_output_0', '/transformer_blocks.32/Add_17_output_0', '/transformer_blocks.32/Mul_24_output_0', '/transformer_blocks.32/Add_16_output_0', '/transformer_blocks.32/Div_13_output_0', '/transformer_blocks.33/img_norm1/LayerNormalization_output_0', '/transformer_blocks.33/img_norm1/LayerNormalization_output_0', '/transformer_blocks.32/Mul_26_output_0', '/transformer_blocks.33/Div_5_output_0', '/transformer_blocks.33/Add_8_output_0', '/transformer_blocks.33/Mul_7_output_0', '/transformer_blocks.33/Add_4_output_0', '/transformer_blocks.33/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.33/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.32/Mul_27_output_0', '/transformer_blocks.33/Div_7_output_0', '/transformer_blocks.33/Add_9_output_0', '/transformer_blocks.33/Mul_11_output_0', '/transformer_blocks.33/Add_7_output_0', '/transformer_blocks.33/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.33/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.33/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.33/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.33/attn/Mul_1_output_0', '/transformer_blocks.33/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.33/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.33/attn/Mul_2_output_0', '/transformer_blocks.33/attn/Mul_21_output_0', '/transformer_blocks.33/Div_4_output_0', '/transformer_blocks.33/Mul_12_output_0', '/transformer_blocks.33/attn/Mul_22_output_0', '/transformer_blocks.33/Div_6_output_0', '/transformer_blocks.33/Mul_13_output_0', '/transformer_blocks.33/img_norm2/LayerNormalization_output_0', '/transformer_blocks.33/img_norm2/LayerNormalization_output_0', '/transformer_blocks.33/Mul_14_output_0', '/transformer_blocks.33/Div_11_output_0', '/transformer_blocks.33/Add_13_output_0', '/transformer_blocks.33/Mul_18_output_0', '/transformer_blocks.33/Add_12_output_0', '/transformer_blocks.33/Div_9_output_0', '/transformer_blocks.33/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.33/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.33/Mul_20_output_0', '/transformer_blocks.33/Div_15_output_0', '/transformer_blocks.33/Add_17_output_0', '/transformer_blocks.33/Mul_24_output_0', '/transformer_blocks.33/Add_16_output_0', '/transformer_blocks.33/Div_13_output_0', '/transformer_blocks.34/img_norm1/LayerNormalization_output_0', '/transformer_blocks.34/img_norm1/LayerNormalization_output_0', '/transformer_blocks.33/Mul_26_output_0', '/transformer_blocks.34/Div_5_output_0', '/transformer_blocks.34/Add_8_output_0', '/transformer_blocks.34/Mul_7_output_0', '/transformer_blocks.34/Add_4_output_0', '/transformer_blocks.34/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.34/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.33/Mul_27_output_0', '/transformer_blocks.34/Div_7_output_0', '/transformer_blocks.34/Add_9_output_0', '/transformer_blocks.34/Mul_11_output_0', '/transformer_blocks.34/Add_7_output_0', '/transformer_blocks.34/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.34/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.34/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.34/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.34/attn/Mul_1_output_0', '/transformer_blocks.34/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.34/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.34/attn/Mul_2_output_0', '/transformer_blocks.34/attn/Mul_21_output_0', '/transformer_blocks.34/Div_4_output_0', '/transformer_blocks.34/Mul_12_output_0', '/transformer_blocks.34/attn/Mul_22_output_0', '/transformer_blocks.34/Div_6_output_0', '/transformer_blocks.34/Mul_13_output_0', '/transformer_blocks.34/img_norm2/LayerNormalization_output_0', '/transformer_blocks.34/img_norm2/LayerNormalization_output_0', '/transformer_blocks.34/Mul_14_output_0', '/transformer_blocks.34/Div_11_output_0', '/transformer_blocks.34/Add_13_output_0', '/transformer_blocks.34/Mul_18_output_0', '/transformer_blocks.34/Add_12_output_0', '/transformer_blocks.34/Div_9_output_0', '/transformer_blocks.34/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.34/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.34/Mul_20_output_0', '/transformer_blocks.34/Div_15_output_0', '/transformer_blocks.34/Add_17_output_0', '/transformer_blocks.34/Mul_24_output_0', '/transformer_blocks.34/Add_16_output_0', '/transformer_blocks.34/Div_13_output_0', '/transformer_blocks.35/img_norm1/LayerNormalization_output_0', '/transformer_blocks.35/img_norm1/LayerNormalization_output_0', '/transformer_blocks.34/Mul_26_output_0', '/transformer_blocks.35/Div_5_output_0', '/transformer_blocks.35/Add_8_output_0', '/transformer_blocks.35/Mul_7_output_0', '/transformer_blocks.35/Add_4_output_0', '/transformer_blocks.35/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.35/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.34/Mul_27_output_0', '/transformer_blocks.35/Div_7_output_0', '/transformer_blocks.35/Add_9_output_0', '/transformer_blocks.35/Mul_11_output_0', '/transformer_blocks.35/Add_7_output_0', '/transformer_blocks.35/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.35/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.35/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.35/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.35/attn/Mul_1_output_0', '/transformer_blocks.35/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.35/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.35/attn/Mul_2_output_0', '/transformer_blocks.35/attn/Mul_21_output_0', '/transformer_blocks.35/Div_4_output_0', '/transformer_blocks.35/Mul_12_output_0', '/transformer_blocks.35/attn/Mul_22_output_0', '/transformer_blocks.35/Div_6_output_0', '/transformer_blocks.35/Mul_13_output_0', '/transformer_blocks.35/img_norm2/LayerNormalization_output_0', '/transformer_blocks.35/img_norm2/LayerNormalization_output_0', '/transformer_blocks.35/Mul_14_output_0', '/transformer_blocks.35/Div_11_output_0', '/transformer_blocks.35/Add_13_output_0', '/transformer_blocks.35/Mul_18_output_0', '/transformer_blocks.35/Add_12_output_0', '/transformer_blocks.35/Div_9_output_0', '/transformer_blocks.35/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.35/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.35/Mul_20_output_0', '/transformer_blocks.35/Div_15_output_0', '/transformer_blocks.35/Add_17_output_0', '/transformer_blocks.35/Mul_24_output_0', '/transformer_blocks.35/Add_16_output_0', '/transformer_blocks.35/Div_13_output_0', '/transformer_blocks.36/img_norm1/LayerNormalization_output_0', '/transformer_blocks.36/img_norm1/LayerNormalization_output_0', '/transformer_blocks.35/Mul_26_output_0', '/transformer_blocks.36/Div_5_output_0', '/transformer_blocks.36/Add_8_output_0', '/transformer_blocks.36/Mul_7_output_0', '/transformer_blocks.36/Add_4_output_0', '/transformer_blocks.36/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.36/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.35/Mul_27_output_0', '/transformer_blocks.36/Div_7_output_0', '/transformer_blocks.36/Add_9_output_0', '/transformer_blocks.36/Mul_11_output_0', '/transformer_blocks.36/Add_7_output_0', '/transformer_blocks.36/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.36/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.36/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.36/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.36/attn/Mul_1_output_0', '/transformer_blocks.36/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.36/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.36/attn/Mul_2_output_0', '/transformer_blocks.36/attn/Mul_21_output_0', '/transformer_blocks.36/Div_4_output_0', '/transformer_blocks.36/Mul_12_output_0', '/transformer_blocks.36/attn/Mul_22_output_0', '/transformer_blocks.36/Div_6_output_0', '/transformer_blocks.36/Mul_13_output_0', '/transformer_blocks.36/img_norm2/LayerNormalization_output_0', '/transformer_blocks.36/img_norm2/LayerNormalization_output_0', '/transformer_blocks.36/Mul_14_output_0', '/transformer_blocks.36/Div_11_output_0', '/transformer_blocks.36/Add_13_output_0', '/transformer_blocks.36/Mul_18_output_0', '/transformer_blocks.36/Add_12_output_0', '/transformer_blocks.36/Div_9_output_0', '/transformer_blocks.36/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.36/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.36/Mul_20_output_0', '/transformer_blocks.36/Div_15_output_0', '/transformer_blocks.36/Add_17_output_0', '/transformer_blocks.36/Mul_24_output_0', '/transformer_blocks.36/Add_16_output_0', '/transformer_blocks.36/Div_13_output_0', '/transformer_blocks.37/img_norm1/LayerNormalization_output_0', '/transformer_blocks.37/img_norm1/LayerNormalization_output_0', '/transformer_blocks.36/Mul_26_output_0', '/transformer_blocks.37/Div_5_output_0', '/transformer_blocks.37/Add_8_output_0', '/transformer_blocks.37/Mul_7_output_0', '/transformer_blocks.37/Add_4_output_0', '/transformer_blocks.37/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.37/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.36/Mul_27_output_0', '/transformer_blocks.37/Div_7_output_0', '/transformer_blocks.37/Add_9_output_0', '/transformer_blocks.37/Mul_11_output_0', '/transformer_blocks.37/Add_7_output_0', '/transformer_blocks.37/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.37/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.37/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.37/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.37/attn/Mul_1_output_0', '/transformer_blocks.37/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.37/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.37/attn/Mul_2_output_0', '/transformer_blocks.37/attn/Mul_21_output_0', '/transformer_blocks.37/Div_4_output_0', '/transformer_blocks.37/Mul_12_output_0', '/transformer_blocks.37/attn/Mul_22_output_0', '/transformer_blocks.37/Div_6_output_0', '/transformer_blocks.37/Mul_13_output_0', '/transformer_blocks.37/img_norm2/LayerNormalization_output_0', '/transformer_blocks.37/img_norm2/LayerNormalization_output_0', '/transformer_blocks.37/Mul_14_output_0', '/transformer_blocks.37/Div_11_output_0', '/transformer_blocks.37/Add_13_output_0', '/transformer_blocks.37/Mul_18_output_0', '/transformer_blocks.37/Add_12_output_0', '/transformer_blocks.37/Div_9_output_0', '/transformer_blocks.37/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.37/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.37/Mul_20_output_0', '/transformer_blocks.37/Div_15_output_0', '/transformer_blocks.37/Add_17_output_0', '/transformer_blocks.37/Mul_24_output_0', '/transformer_blocks.37/Add_16_output_0', '/transformer_blocks.37/Div_13_output_0', '/transformer_blocks.38/img_norm1/LayerNormalization_output_0', '/transformer_blocks.38/img_norm1/LayerNormalization_output_0', '/transformer_blocks.37/Mul_26_output_0', '/transformer_blocks.38/Div_5_output_0', '/transformer_blocks.38/Add_8_output_0', '/transformer_blocks.38/Mul_7_output_0', '/transformer_blocks.38/Add_4_output_0', '/transformer_blocks.38/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.38/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.37/Mul_27_output_0', '/transformer_blocks.38/Div_7_output_0', '/transformer_blocks.38/Add_9_output_0', '/transformer_blocks.38/Mul_11_output_0', '/transformer_blocks.38/Add_7_output_0', '/transformer_blocks.38/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.38/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.38/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.38/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.38/attn/Mul_1_output_0', '/transformer_blocks.38/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.38/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.38/attn/Mul_2_output_0', '/transformer_blocks.38/attn/Mul_21_output_0', '/transformer_blocks.38/Div_4_output_0', '/transformer_blocks.38/Mul_12_output_0', '/transformer_blocks.38/attn/Mul_22_output_0', '/transformer_blocks.38/Div_6_output_0', '/transformer_blocks.38/Mul_13_output_0', '/transformer_blocks.38/img_norm2/LayerNormalization_output_0', '/transformer_blocks.38/img_norm2/LayerNormalization_output_0', '/transformer_blocks.38/Mul_14_output_0', '/transformer_blocks.38/Div_11_output_0', '/transformer_blocks.38/Add_13_output_0', '/transformer_blocks.38/Mul_18_output_0', '/transformer_blocks.38/Add_12_output_0', '/transformer_blocks.38/Div_9_output_0', '/transformer_blocks.38/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.38/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.38/Mul_20_output_0', '/transformer_blocks.38/Div_15_output_0', '/transformer_blocks.38/Add_17_output_0', '/transformer_blocks.38/Mul_24_output_0', '/transformer_blocks.38/Add_16_output_0', '/transformer_blocks.38/Div_13_output_0', '/transformer_blocks.39/img_norm1/LayerNormalization_output_0', '/transformer_blocks.39/img_norm1/LayerNormalization_output_0', '/transformer_blocks.38/Mul_26_output_0', '/transformer_blocks.39/Div_5_output_0', '/transformer_blocks.39/Add_8_output_0', '/transformer_blocks.39/Mul_7_output_0', '/transformer_blocks.39/Add_4_output_0', '/transformer_blocks.39/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.39/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.38/Mul_27_output_0', '/transformer_blocks.39/Div_7_output_0', '/transformer_blocks.39/Add_9_output_0', '/transformer_blocks.39/Mul_11_output_0', '/transformer_blocks.39/Add_7_output_0', '/transformer_blocks.39/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.39/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.39/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.39/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.39/attn/Mul_1_output_0', '/transformer_blocks.39/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.39/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.39/attn/Mul_2_output_0', '/transformer_blocks.39/attn/Mul_21_output_0', '/transformer_blocks.39/Div_4_output_0', '/transformer_blocks.39/Mul_12_output_0', '/transformer_blocks.39/attn/Mul_22_output_0', '/transformer_blocks.39/Div_6_output_0', '/transformer_blocks.39/Mul_13_output_0', '/transformer_blocks.39/img_norm2/LayerNormalization_output_0', '/transformer_blocks.39/img_norm2/LayerNormalization_output_0', '/transformer_blocks.39/Mul_14_output_0', '/transformer_blocks.39/Div_11_output_0', '/transformer_blocks.39/Add_13_output_0', '/transformer_blocks.39/Mul_18_output_0', '/transformer_blocks.39/Add_12_output_0', '/transformer_blocks.39/Div_9_output_0', '/transformer_blocks.39/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.39/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.39/Mul_20_output_0', '/transformer_blocks.39/Div_15_output_0', '/transformer_blocks.39/Add_17_output_0', '/transformer_blocks.39/Mul_24_output_0', '/transformer_blocks.39/Add_16_output_0', '/transformer_blocks.39/Div_13_output_0', '/transformer_blocks.40/img_norm1/LayerNormalization_output_0', '/transformer_blocks.40/img_norm1/LayerNormalization_output_0', '/transformer_blocks.39/Mul_26_output_0', '/transformer_blocks.40/Div_5_output_0', '/transformer_blocks.40/Add_8_output_0', '/transformer_blocks.40/Mul_7_output_0', '/transformer_blocks.40/Add_4_output_0', '/transformer_blocks.40/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.40/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.39/Mul_27_output_0', '/transformer_blocks.40/Div_7_output_0', '/transformer_blocks.40/Add_9_output_0', '/transformer_blocks.40/Mul_11_output_0', '/transformer_blocks.40/Add_7_output_0', '/transformer_blocks.40/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.40/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.40/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.40/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.40/attn/Mul_1_output_0', '/transformer_blocks.40/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.40/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.40/attn/Mul_2_output_0', '/transformer_blocks.40/attn/Mul_21_output_0', '/transformer_blocks.40/Div_4_output_0', '/transformer_blocks.40/Mul_12_output_0', '/transformer_blocks.40/attn/Mul_22_output_0', '/transformer_blocks.40/Div_6_output_0', '/transformer_blocks.40/Mul_13_output_0', '/transformer_blocks.40/img_norm2/LayerNormalization_output_0', '/transformer_blocks.40/img_norm2/LayerNormalization_output_0', '/transformer_blocks.40/Mul_14_output_0', '/transformer_blocks.40/Div_11_output_0', '/transformer_blocks.40/Add_13_output_0', '/transformer_blocks.40/Mul_18_output_0', '/transformer_blocks.40/Add_12_output_0', '/transformer_blocks.40/Div_9_output_0', '/transformer_blocks.40/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.40/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.40/Mul_20_output_0', '/transformer_blocks.40/Div_15_output_0', '/transformer_blocks.40/Add_17_output_0', '/transformer_blocks.40/Mul_24_output_0', '/transformer_blocks.40/Add_16_output_0', '/transformer_blocks.40/Div_13_output_0', '/transformer_blocks.41/img_norm1/LayerNormalization_output_0', '/transformer_blocks.41/img_norm1/LayerNormalization_output_0', '/transformer_blocks.40/Mul_26_output_0', '/transformer_blocks.41/Div_5_output_0', '/transformer_blocks.41/Add_8_output_0', '/transformer_blocks.41/Mul_7_output_0', '/transformer_blocks.41/Add_4_output_0', '/transformer_blocks.41/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.41/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.40/Mul_27_output_0', '/transformer_blocks.41/Div_7_output_0', '/transformer_blocks.41/Add_9_output_0', '/transformer_blocks.41/Mul_11_output_0', '/transformer_blocks.41/Add_7_output_0', '/transformer_blocks.41/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.41/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.41/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.41/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.41/attn/Mul_1_output_0', '/transformer_blocks.41/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.41/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.41/attn/Mul_2_output_0', '/transformer_blocks.41/attn/Mul_21_output_0', '/transformer_blocks.41/Div_4_output_0', '/transformer_blocks.41/Mul_12_output_0', '/transformer_blocks.41/attn/Mul_22_output_0', '/transformer_blocks.41/Div_6_output_0', '/transformer_blocks.41/Mul_13_output_0', '/transformer_blocks.41/img_norm2/LayerNormalization_output_0', '/transformer_blocks.41/img_norm2/LayerNormalization_output_0', '/transformer_blocks.41/Mul_14_output_0', '/transformer_blocks.41/Div_11_output_0', '/transformer_blocks.41/Add_13_output_0', '/transformer_blocks.41/Mul_18_output_0', '/transformer_blocks.41/Add_12_output_0', '/transformer_blocks.41/Div_9_output_0', '/transformer_blocks.41/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.41/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.41/Mul_20_output_0', '/transformer_blocks.41/Div_15_output_0', '/transformer_blocks.41/Add_17_output_0', '/transformer_blocks.41/Mul_24_output_0', '/transformer_blocks.41/Add_16_output_0', '/transformer_blocks.41/Div_13_output_0', '/transformer_blocks.42/img_norm1/LayerNormalization_output_0', '/transformer_blocks.42/img_norm1/LayerNormalization_output_0', '/transformer_blocks.41/Mul_26_output_0', '/transformer_blocks.42/Div_5_output_0', '/transformer_blocks.42/Add_8_output_0', '/transformer_blocks.42/Mul_7_output_0', '/transformer_blocks.42/Add_4_output_0', '/transformer_blocks.42/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.42/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.41/Mul_27_output_0', '/transformer_blocks.42/Div_7_output_0', '/transformer_blocks.42/Add_9_output_0', '/transformer_blocks.42/Mul_11_output_0', '/transformer_blocks.42/Add_7_output_0', '/transformer_blocks.42/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.42/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.42/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.42/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.42/attn/Mul_1_output_0', '/transformer_blocks.42/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.42/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.42/attn/Mul_2_output_0', '/transformer_blocks.42/attn/Mul_21_output_0', '/transformer_blocks.42/Div_4_output_0', '/transformer_blocks.42/Mul_12_output_0', '/transformer_blocks.42/attn/Mul_22_output_0', '/transformer_blocks.42/Div_6_output_0', '/transformer_blocks.42/Mul_13_output_0', '/transformer_blocks.42/img_norm2/LayerNormalization_output_0', '/transformer_blocks.42/img_norm2/LayerNormalization_output_0', '/transformer_blocks.42/Mul_14_output_0', '/transformer_blocks.42/Div_11_output_0', '/transformer_blocks.42/Add_13_output_0', '/transformer_blocks.42/Mul_18_output_0', '/transformer_blocks.42/Add_12_output_0', '/transformer_blocks.42/Div_9_output_0', '/transformer_blocks.42/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.42/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.42/Mul_20_output_0', '/transformer_blocks.42/Div_15_output_0', '/transformer_blocks.42/Add_17_output_0', '/transformer_blocks.42/Mul_24_output_0', '/transformer_blocks.42/Add_16_output_0', '/transformer_blocks.42/Div_13_output_0', '/transformer_blocks.43/img_norm1/LayerNormalization_output_0', '/transformer_blocks.43/img_norm1/LayerNormalization_output_0', '/transformer_blocks.42/Mul_26_output_0', '/transformer_blocks.43/Div_5_output_0', '/transformer_blocks.43/Add_8_output_0', '/transformer_blocks.43/Mul_7_output_0', '/transformer_blocks.43/Add_4_output_0', '/transformer_blocks.43/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.43/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.42/Mul_27_output_0', '/transformer_blocks.43/Div_7_output_0', '/transformer_blocks.43/Add_9_output_0', '/transformer_blocks.43/Mul_11_output_0', '/transformer_blocks.43/Add_7_output_0', '/transformer_blocks.43/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.43/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.43/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.43/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.43/attn/Mul_1_output_0', '/transformer_blocks.43/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.43/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.43/attn/Mul_2_output_0', '/transformer_blocks.43/attn/Mul_21_output_0', '/transformer_blocks.43/Div_4_output_0', '/transformer_blocks.43/Mul_12_output_0', '/transformer_blocks.43/attn/Mul_22_output_0', '/transformer_blocks.43/Div_6_output_0', '/transformer_blocks.43/Mul_13_output_0', '/transformer_blocks.43/img_norm2/LayerNormalization_output_0', '/transformer_blocks.43/img_norm2/LayerNormalization_output_0', '/transformer_blocks.43/Mul_14_output_0', '/transformer_blocks.43/Div_11_output_0', '/transformer_blocks.43/Add_13_output_0', '/transformer_blocks.43/Mul_18_output_0', '/transformer_blocks.43/Add_12_output_0', '/transformer_blocks.43/Div_9_output_0', '/transformer_blocks.43/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.43/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.43/Mul_20_output_0', '/transformer_blocks.43/Div_15_output_0', '/transformer_blocks.43/Add_17_output_0', '/transformer_blocks.43/Mul_24_output_0', '/transformer_blocks.43/Add_16_output_0', '/transformer_blocks.43/Div_13_output_0', '/transformer_blocks.44/img_norm1/LayerNormalization_output_0', '/transformer_blocks.44/img_norm1/LayerNormalization_output_0', '/transformer_blocks.43/Mul_26_output_0', '/transformer_blocks.44/Div_5_output_0', '/transformer_blocks.44/Add_8_output_0', '/transformer_blocks.44/Mul_7_output_0', '/transformer_blocks.44/Add_4_output_0', '/transformer_blocks.44/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.44/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.43/Mul_27_output_0', '/transformer_blocks.44/Div_7_output_0', '/transformer_blocks.44/Add_9_output_0', '/transformer_blocks.44/Mul_11_output_0', '/transformer_blocks.44/Add_7_output_0', '/transformer_blocks.44/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.44/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.44/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.44/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.44/attn/Mul_1_output_0', '/transformer_blocks.44/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.44/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.44/attn/Mul_2_output_0', '/transformer_blocks.44/attn/Mul_21_output_0', '/transformer_blocks.44/Div_4_output_0', '/transformer_blocks.44/Mul_12_output_0', '/transformer_blocks.44/attn/Mul_22_output_0', '/transformer_blocks.44/Div_6_output_0', '/transformer_blocks.44/Mul_13_output_0', '/transformer_blocks.44/img_norm2/LayerNormalization_output_0', '/transformer_blocks.44/img_norm2/LayerNormalization_output_0', '/transformer_blocks.44/Mul_14_output_0', '/transformer_blocks.44/Div_11_output_0', '/transformer_blocks.44/Add_13_output_0', '/transformer_blocks.44/Mul_18_output_0', '/transformer_blocks.44/Add_12_output_0', '/transformer_blocks.44/Div_9_output_0', '/transformer_blocks.44/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.44/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.44/Mul_20_output_0', '/transformer_blocks.44/Div_15_output_0', '/transformer_blocks.44/Add_17_output_0', '/transformer_blocks.44/Mul_24_output_0', '/transformer_blocks.44/Add_16_output_0', '/transformer_blocks.44/Div_13_output_0', '/transformer_blocks.45/img_norm1/LayerNormalization_output_0', '/transformer_blocks.45/img_norm1/LayerNormalization_output_0', '/transformer_blocks.44/Mul_26_output_0', '/transformer_blocks.45/Div_5_output_0', '/transformer_blocks.45/Add_8_output_0', '/transformer_blocks.45/Mul_7_output_0', '/transformer_blocks.45/Add_4_output_0', '/transformer_blocks.45/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.45/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.44/Mul_27_output_0', '/transformer_blocks.45/Div_7_output_0', '/transformer_blocks.45/Add_9_output_0', '/transformer_blocks.45/Mul_11_output_0', '/transformer_blocks.45/Add_7_output_0', '/transformer_blocks.45/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.45/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.45/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.45/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.45/attn/Mul_1_output_0', '/transformer_blocks.45/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.45/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.45/attn/Mul_2_output_0', '/transformer_blocks.45/attn/Mul_21_output_0', '/transformer_blocks.45/Div_4_output_0', '/transformer_blocks.45/Mul_12_output_0', '/transformer_blocks.45/attn/Mul_22_output_0', '/transformer_blocks.45/Div_6_output_0', '/transformer_blocks.45/Mul_13_output_0', '/transformer_blocks.45/img_norm2/LayerNormalization_output_0', '/transformer_blocks.45/img_norm2/LayerNormalization_output_0', '/transformer_blocks.45/Mul_14_output_0', '/transformer_blocks.45/Div_11_output_0', '/transformer_blocks.45/Add_13_output_0', '/transformer_blocks.45/Mul_18_output_0', '/transformer_blocks.45/Add_12_output_0', '/transformer_blocks.45/Div_9_output_0', '/transformer_blocks.45/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.45/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.45/Mul_20_output_0', '/transformer_blocks.45/Div_15_output_0', '/transformer_blocks.45/Add_17_output_0', '/transformer_blocks.45/Mul_24_output_0', '/transformer_blocks.45/Add_16_output_0', '/transformer_blocks.45/Div_13_output_0', '/transformer_blocks.46/img_norm1/LayerNormalization_output_0', '/transformer_blocks.46/img_norm1/LayerNormalization_output_0', '/transformer_blocks.45/Mul_26_output_0', '/transformer_blocks.46/Div_5_output_0', '/transformer_blocks.46/Add_8_output_0', '/transformer_blocks.46/Mul_7_output_0', '/transformer_blocks.46/Add_4_output_0', '/transformer_blocks.46/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.46/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.45/Mul_27_output_0', '/transformer_blocks.46/Div_7_output_0', '/transformer_blocks.46/Add_9_output_0', '/transformer_blocks.46/Mul_11_output_0', '/transformer_blocks.46/Add_7_output_0', '/transformer_blocks.46/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.46/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.46/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.46/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.46/attn/Mul_1_output_0', '/transformer_blocks.46/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.46/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.46/attn/Mul_2_output_0', '/transformer_blocks.46/attn/Mul_21_output_0', '/transformer_blocks.46/Div_4_output_0', '/transformer_blocks.46/Mul_12_output_0', '/transformer_blocks.46/attn/Mul_22_output_0', '/transformer_blocks.46/Div_6_output_0', '/transformer_blocks.46/Mul_13_output_0', '/transformer_blocks.46/img_norm2/LayerNormalization_output_0', '/transformer_blocks.46/img_norm2/LayerNormalization_output_0', '/transformer_blocks.46/Mul_14_output_0', '/transformer_blocks.46/Div_11_output_0', '/transformer_blocks.46/Add_13_output_0', '/transformer_blocks.46/Mul_18_output_0', '/transformer_blocks.46/Add_12_output_0', '/transformer_blocks.46/Div_9_output_0', '/transformer_blocks.46/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.46/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.46/Mul_20_output_0', '/transformer_blocks.46/Div_15_output_0', '/transformer_blocks.46/Add_17_output_0', '/transformer_blocks.46/Mul_24_output_0', '/transformer_blocks.46/Add_16_output_0', '/transformer_blocks.46/Div_13_output_0', '/transformer_blocks.47/img_norm1/LayerNormalization_output_0', '/transformer_blocks.47/img_norm1/LayerNormalization_output_0', '/transformer_blocks.46/Mul_26_output_0', '/transformer_blocks.47/Div_5_output_0', '/transformer_blocks.47/Add_8_output_0', '/transformer_blocks.47/Mul_7_output_0', '/transformer_blocks.47/Add_4_output_0', '/transformer_blocks.47/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.47/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.46/Mul_27_output_0', '/transformer_blocks.47/Div_7_output_0', '/transformer_blocks.47/Add_9_output_0', '/transformer_blocks.47/Mul_11_output_0', '/transformer_blocks.47/Add_7_output_0', '/transformer_blocks.47/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.47/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.47/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.47/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.47/attn/Mul_1_output_0', '/transformer_blocks.47/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.47/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.47/attn/Mul_2_output_0', '/transformer_blocks.47/attn/Mul_21_output_0', '/transformer_blocks.47/Div_4_output_0', '/transformer_blocks.47/Mul_12_output_0', '/transformer_blocks.47/attn/Mul_22_output_0', '/transformer_blocks.47/Div_6_output_0', '/transformer_blocks.47/Mul_13_output_0', '/transformer_blocks.47/img_norm2/LayerNormalization_output_0', '/transformer_blocks.47/img_norm2/LayerNormalization_output_0', '/transformer_blocks.47/Mul_14_output_0', '/transformer_blocks.47/Div_11_output_0', '/transformer_blocks.47/Add_13_output_0', '/transformer_blocks.47/Mul_18_output_0', '/transformer_blocks.47/Add_12_output_0', '/transformer_blocks.47/Div_9_output_0', '/transformer_blocks.47/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.47/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.47/Mul_20_output_0', '/transformer_blocks.47/Div_15_output_0', '/transformer_blocks.47/Add_17_output_0', '/transformer_blocks.47/Mul_24_output_0', '/transformer_blocks.47/Add_16_output_0', '/transformer_blocks.47/Div_13_output_0', '/transformer_blocks.48/img_norm1/LayerNormalization_output_0', '/transformer_blocks.48/img_norm1/LayerNormalization_output_0', '/transformer_blocks.47/Mul_26_output_0', '/transformer_blocks.48/Div_5_output_0', '/transformer_blocks.48/Add_8_output_0', '/transformer_blocks.48/Mul_7_output_0', '/transformer_blocks.48/Add_4_output_0', '/transformer_blocks.48/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.48/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.47/Mul_27_output_0', '/transformer_blocks.48/Div_7_output_0', '/transformer_blocks.48/Add_9_output_0', '/transformer_blocks.48/Mul_11_output_0', '/transformer_blocks.48/Add_7_output_0', '/transformer_blocks.48/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.48/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.48/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.48/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.48/attn/Mul_1_output_0', '/transformer_blocks.48/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.48/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.48/attn/Mul_2_output_0', '/transformer_blocks.48/attn/Mul_21_output_0', '/transformer_blocks.48/Div_4_output_0', '/transformer_blocks.48/Mul_12_output_0', '/transformer_blocks.48/attn/Mul_22_output_0', '/transformer_blocks.48/Div_6_output_0', '/transformer_blocks.48/Mul_13_output_0', '/transformer_blocks.48/img_norm2/LayerNormalization_output_0', '/transformer_blocks.48/img_norm2/LayerNormalization_output_0', '/transformer_blocks.48/Mul_14_output_0', '/transformer_blocks.48/Div_11_output_0', '/transformer_blocks.48/Add_13_output_0', '/transformer_blocks.48/Mul_18_output_0', '/transformer_blocks.48/Add_12_output_0', '/transformer_blocks.48/Div_9_output_0', '/transformer_blocks.48/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.48/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.48/Mul_20_output_0', '/transformer_blocks.48/Div_15_output_0', '/transformer_blocks.48/Add_17_output_0', '/transformer_blocks.48/Mul_24_output_0', '/transformer_blocks.48/Add_16_output_0', '/transformer_blocks.48/Div_13_output_0', '/transformer_blocks.49/img_norm1/LayerNormalization_output_0', '/transformer_blocks.49/img_norm1/LayerNormalization_output_0', '/transformer_blocks.48/Mul_26_output_0', '/transformer_blocks.49/Div_5_output_0', '/transformer_blocks.49/Add_8_output_0', '/transformer_blocks.49/Mul_7_output_0', '/transformer_blocks.49/Add_4_output_0', '/transformer_blocks.49/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.49/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.48/Mul_27_output_0', '/transformer_blocks.49/Div_7_output_0', '/transformer_blocks.49/Add_9_output_0', '/transformer_blocks.49/Mul_11_output_0', '/transformer_blocks.49/Add_7_output_0', '/transformer_blocks.49/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.49/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.49/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.49/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.49/attn/Mul_1_output_0', '/transformer_blocks.49/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.49/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.49/attn/Mul_2_output_0', '/transformer_blocks.49/attn/Mul_21_output_0', '/transformer_blocks.49/Div_4_output_0', '/transformer_blocks.49/Mul_12_output_0', '/transformer_blocks.49/attn/Mul_22_output_0', '/transformer_blocks.49/Div_6_output_0', '/transformer_blocks.49/Mul_13_output_0', '/transformer_blocks.49/img_norm2/LayerNormalization_output_0', '/transformer_blocks.49/img_norm2/LayerNormalization_output_0', '/transformer_blocks.49/Mul_14_output_0', '/transformer_blocks.49/Div_11_output_0', '/transformer_blocks.49/Add_13_output_0', '/transformer_blocks.49/Mul_18_output_0', '/transformer_blocks.49/Add_12_output_0', '/transformer_blocks.49/Div_9_output_0', '/transformer_blocks.49/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.49/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.49/Mul_20_output_0', '/transformer_blocks.49/Div_15_output_0', '/transformer_blocks.49/Add_17_output_0', '/transformer_blocks.49/Mul_24_output_0', '/transformer_blocks.49/Add_16_output_0', '/transformer_blocks.49/Div_13_output_0', '/transformer_blocks.50/img_norm1/LayerNormalization_output_0', '/transformer_blocks.50/img_norm1/LayerNormalization_output_0', '/transformer_blocks.49/Mul_26_output_0', '/transformer_blocks.50/Div_5_output_0', '/transformer_blocks.50/Add_8_output_0', '/transformer_blocks.50/Mul_7_output_0', '/transformer_blocks.50/Add_4_output_0', '/transformer_blocks.50/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.50/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.49/Mul_27_output_0', '/transformer_blocks.50/Div_7_output_0', '/transformer_blocks.50/Add_9_output_0', '/transformer_blocks.50/Mul_11_output_0', '/transformer_blocks.50/Add_7_output_0', '/transformer_blocks.50/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.50/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.50/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.50/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.50/attn/Mul_1_output_0', '/transformer_blocks.50/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.50/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.50/attn/Mul_2_output_0', '/transformer_blocks.50/attn/Mul_21_output_0', '/transformer_blocks.50/Div_4_output_0', '/transformer_blocks.50/Mul_12_output_0', '/transformer_blocks.50/attn/Mul_22_output_0', '/transformer_blocks.50/Div_6_output_0', '/transformer_blocks.50/Mul_13_output_0', '/transformer_blocks.50/img_norm2/LayerNormalization_output_0', '/transformer_blocks.50/img_norm2/LayerNormalization_output_0', '/transformer_blocks.50/Mul_14_output_0', '/transformer_blocks.50/Div_11_output_0', '/transformer_blocks.50/Add_13_output_0', '/transformer_blocks.50/Mul_18_output_0', '/transformer_blocks.50/Add_12_output_0', '/transformer_blocks.50/Div_9_output_0', '/transformer_blocks.50/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.50/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.50/Mul_20_output_0', '/transformer_blocks.50/Div_15_output_0', '/transformer_blocks.50/Add_17_output_0', '/transformer_blocks.50/Mul_24_output_0', '/transformer_blocks.50/Add_16_output_0', '/transformer_blocks.50/Div_13_output_0', '/transformer_blocks.51/img_norm1/LayerNormalization_output_0', '/transformer_blocks.51/img_norm1/LayerNormalization_output_0', '/transformer_blocks.50/Mul_26_output_0', '/transformer_blocks.51/Div_5_output_0', '/transformer_blocks.51/Add_8_output_0', '/transformer_blocks.51/Mul_7_output_0', '/transformer_blocks.51/Add_4_output_0', '/transformer_blocks.51/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.51/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.50/Mul_27_output_0', '/transformer_blocks.51/Div_7_output_0', '/transformer_blocks.51/Add_9_output_0', '/transformer_blocks.51/Mul_11_output_0', '/transformer_blocks.51/Add_7_output_0', '/transformer_blocks.51/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.51/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.51/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.51/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.51/attn/Mul_1_output_0', '/transformer_blocks.51/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.51/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.51/attn/Mul_2_output_0', '/transformer_blocks.51/attn/Mul_21_output_0', '/transformer_blocks.51/Div_4_output_0', '/transformer_blocks.51/Mul_12_output_0', '/transformer_blocks.51/attn/Mul_22_output_0', '/transformer_blocks.51/Div_6_output_0', '/transformer_blocks.51/Mul_13_output_0', '/transformer_blocks.51/img_norm2/LayerNormalization_output_0', '/transformer_blocks.51/img_norm2/LayerNormalization_output_0', '/transformer_blocks.51/Mul_14_output_0', '/transformer_blocks.51/Div_11_output_0', '/transformer_blocks.51/Add_13_output_0', '/transformer_blocks.51/Mul_18_output_0', '/transformer_blocks.51/Add_12_output_0', '/transformer_blocks.51/Div_9_output_0', '/transformer_blocks.51/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.51/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.51/Mul_20_output_0', '/transformer_blocks.51/Div_15_output_0', '/transformer_blocks.51/Add_17_output_0', '/transformer_blocks.51/Mul_24_output_0', '/transformer_blocks.51/Add_16_output_0', '/transformer_blocks.51/Div_13_output_0', '/transformer_blocks.52/img_norm1/LayerNormalization_output_0', '/transformer_blocks.52/img_norm1/LayerNormalization_output_0', '/transformer_blocks.51/Mul_26_output_0', '/transformer_blocks.52/Div_5_output_0', '/transformer_blocks.52/Add_8_output_0', '/transformer_blocks.52/Mul_7_output_0', '/transformer_blocks.52/Add_4_output_0', '/transformer_blocks.52/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.52/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.51/Mul_27_output_0', '/transformer_blocks.52/Div_7_output_0', '/transformer_blocks.52/Add_9_output_0', '/transformer_blocks.52/Mul_11_output_0', '/transformer_blocks.52/Add_7_output_0', '/transformer_blocks.52/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.52/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.52/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.52/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.52/attn/Mul_1_output_0', '/transformer_blocks.52/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.52/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.52/attn/Mul_2_output_0', '/transformer_blocks.52/attn/Mul_21_output_0', '/transformer_blocks.52/Div_4_output_0', '/transformer_blocks.52/Mul_12_output_0', '/transformer_blocks.52/attn/Mul_22_output_0', '/transformer_blocks.52/Div_6_output_0', '/transformer_blocks.52/Mul_13_output_0', '/transformer_blocks.52/img_norm2/LayerNormalization_output_0', '/transformer_blocks.52/img_norm2/LayerNormalization_output_0', '/transformer_blocks.52/Mul_14_output_0', '/transformer_blocks.52/Div_11_output_0', '/transformer_blocks.52/Add_13_output_0', '/transformer_blocks.52/Mul_18_output_0', '/transformer_blocks.52/Add_12_output_0', '/transformer_blocks.52/Div_9_output_0', '/transformer_blocks.52/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.52/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.52/Mul_20_output_0', '/transformer_blocks.52/Div_15_output_0', '/transformer_blocks.52/Add_17_output_0', '/transformer_blocks.52/Mul_24_output_0', '/transformer_blocks.52/Add_16_output_0', '/transformer_blocks.52/Div_13_output_0', '/transformer_blocks.53/img_norm1/LayerNormalization_output_0', '/transformer_blocks.53/img_norm1/LayerNormalization_output_0', '/transformer_blocks.52/Mul_26_output_0', '/transformer_blocks.53/Div_5_output_0', '/transformer_blocks.53/Add_8_output_0', '/transformer_blocks.53/Mul_7_output_0', '/transformer_blocks.53/Add_4_output_0', '/transformer_blocks.53/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.53/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.52/Mul_27_output_0', '/transformer_blocks.53/Div_7_output_0', '/transformer_blocks.53/Add_9_output_0', '/transformer_blocks.53/Mul_11_output_0', '/transformer_blocks.53/Add_7_output_0', '/transformer_blocks.53/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.53/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.53/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.53/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.53/attn/Mul_1_output_0', '/transformer_blocks.53/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.53/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.53/attn/Mul_2_output_0', '/transformer_blocks.53/attn/Mul_21_output_0', '/transformer_blocks.53/Div_4_output_0', '/transformer_blocks.53/Mul_12_output_0', '/transformer_blocks.53/attn/Mul_22_output_0', '/transformer_blocks.53/Div_6_output_0', '/transformer_blocks.53/Mul_13_output_0', '/transformer_blocks.53/img_norm2/LayerNormalization_output_0', '/transformer_blocks.53/img_norm2/LayerNormalization_output_0', '/transformer_blocks.53/Mul_14_output_0', '/transformer_blocks.53/Div_11_output_0', '/transformer_blocks.53/Add_13_output_0', '/transformer_blocks.53/Mul_18_output_0', '/transformer_blocks.53/Add_12_output_0', '/transformer_blocks.53/Div_9_output_0', '/transformer_blocks.53/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.53/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.53/Mul_20_output_0', '/transformer_blocks.53/Div_15_output_0', '/transformer_blocks.53/Add_17_output_0', '/transformer_blocks.53/Mul_24_output_0', '/transformer_blocks.53/Add_16_output_0', '/transformer_blocks.53/Div_13_output_0', '/transformer_blocks.54/img_norm1/LayerNormalization_output_0', '/transformer_blocks.54/img_norm1/LayerNormalization_output_0', '/transformer_blocks.53/Mul_26_output_0', '/transformer_blocks.54/Div_5_output_0', '/transformer_blocks.54/Add_8_output_0', '/transformer_blocks.54/Mul_7_output_0', '/transformer_blocks.54/Add_4_output_0', '/transformer_blocks.54/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.54/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.53/Mul_27_output_0', '/transformer_blocks.54/Div_7_output_0', '/transformer_blocks.54/Add_9_output_0', '/transformer_blocks.54/Mul_11_output_0', '/transformer_blocks.54/Add_7_output_0', '/transformer_blocks.54/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.54/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.54/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.54/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.54/attn/Mul_1_output_0', '/transformer_blocks.54/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.54/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.54/attn/Mul_2_output_0', '/transformer_blocks.54/attn/Mul_21_output_0', '/transformer_blocks.54/Div_4_output_0', '/transformer_blocks.54/Mul_12_output_0', '/transformer_blocks.54/attn/Mul_22_output_0', '/transformer_blocks.54/Div_6_output_0', '/transformer_blocks.54/Mul_13_output_0', '/transformer_blocks.54/img_norm2/LayerNormalization_output_0', '/transformer_blocks.54/img_norm2/LayerNormalization_output_0', '/transformer_blocks.54/Mul_14_output_0', '/transformer_blocks.54/Div_11_output_0', '/transformer_blocks.54/Add_13_output_0', '/transformer_blocks.54/Mul_18_output_0', '/transformer_blocks.54/Add_12_output_0', '/transformer_blocks.54/Div_9_output_0', '/transformer_blocks.54/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.54/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.54/Mul_20_output_0', '/transformer_blocks.54/Div_15_output_0', '/transformer_blocks.54/Add_17_output_0', '/transformer_blocks.54/Mul_24_output_0', '/transformer_blocks.54/Add_16_output_0', '/transformer_blocks.54/Div_13_output_0', '/transformer_blocks.55/img_norm1/LayerNormalization_output_0', '/transformer_blocks.55/img_norm1/LayerNormalization_output_0', '/transformer_blocks.54/Mul_26_output_0', '/transformer_blocks.55/Div_5_output_0', '/transformer_blocks.55/Add_8_output_0', '/transformer_blocks.55/Mul_7_output_0', '/transformer_blocks.55/Add_4_output_0', '/transformer_blocks.55/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.55/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.54/Mul_27_output_0', '/transformer_blocks.55/Div_7_output_0', '/transformer_blocks.55/Add_9_output_0', '/transformer_blocks.55/Mul_11_output_0', '/transformer_blocks.55/Add_7_output_0', '/transformer_blocks.55/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.55/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.55/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.55/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.55/attn/Mul_1_output_0', '/transformer_blocks.55/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.55/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.55/attn/Mul_2_output_0', '/transformer_blocks.55/attn/Mul_21_output_0', '/transformer_blocks.55/Div_4_output_0', '/transformer_blocks.55/Mul_12_output_0', '/transformer_blocks.55/attn/Mul_22_output_0', '/transformer_blocks.55/Div_6_output_0', '/transformer_blocks.55/Mul_13_output_0', '/transformer_blocks.55/img_norm2/LayerNormalization_output_0', '/transformer_blocks.55/img_norm2/LayerNormalization_output_0', '/transformer_blocks.55/Mul_14_output_0', '/transformer_blocks.55/Div_11_output_0', '/transformer_blocks.55/Add_13_output_0', '/transformer_blocks.55/Mul_18_output_0', '/transformer_blocks.55/Add_12_output_0', '/transformer_blocks.55/Div_9_output_0', '/transformer_blocks.55/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.55/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.55/Mul_20_output_0', '/transformer_blocks.55/Div_15_output_0', '/transformer_blocks.55/Add_17_output_0', '/transformer_blocks.55/Mul_24_output_0', '/transformer_blocks.55/Add_16_output_0', '/transformer_blocks.55/Div_13_output_0', '/transformer_blocks.56/img_norm1/LayerNormalization_output_0', '/transformer_blocks.56/img_norm1/LayerNormalization_output_0', '/transformer_blocks.55/Mul_26_output_0', '/transformer_blocks.56/Div_5_output_0', '/transformer_blocks.56/Add_8_output_0', '/transformer_blocks.56/Mul_7_output_0', '/transformer_blocks.56/Add_4_output_0', '/transformer_blocks.56/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.56/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.55/Mul_27_output_0', '/transformer_blocks.56/Div_7_output_0', '/transformer_blocks.56/Add_9_output_0', '/transformer_blocks.56/Mul_11_output_0', '/transformer_blocks.56/Add_7_output_0', '/transformer_blocks.56/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.56/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.56/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.56/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.56/attn/Mul_1_output_0', '/transformer_blocks.56/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.56/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.56/attn/Mul_2_output_0', '/transformer_blocks.56/attn/Mul_21_output_0', '/transformer_blocks.56/Div_4_output_0', '/transformer_blocks.56/Mul_12_output_0', '/transformer_blocks.56/attn/Mul_22_output_0', '/transformer_blocks.56/Div_6_output_0', '/transformer_blocks.56/Mul_13_output_0', '/transformer_blocks.56/img_norm2/LayerNormalization_output_0', '/transformer_blocks.56/img_norm2/LayerNormalization_output_0', '/transformer_blocks.56/Mul_14_output_0', '/transformer_blocks.56/Div_11_output_0', '/transformer_blocks.56/Add_13_output_0', '/transformer_blocks.56/Mul_18_output_0', '/transformer_blocks.56/Add_12_output_0', '/transformer_blocks.56/Div_9_output_0', '/transformer_blocks.56/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.56/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.56/Mul_20_output_0', '/transformer_blocks.56/Div_15_output_0', '/transformer_blocks.56/Add_17_output_0', '/transformer_blocks.56/Mul_24_output_0', '/transformer_blocks.56/Add_16_output_0', '/transformer_blocks.56/Div_13_output_0', '/transformer_blocks.57/img_norm1/LayerNormalization_output_0', '/transformer_blocks.57/img_norm1/LayerNormalization_output_0', '/transformer_blocks.56/Mul_26_output_0', '/transformer_blocks.57/Div_5_output_0', '/transformer_blocks.57/Add_8_output_0', '/transformer_blocks.57/Mul_7_output_0', '/transformer_blocks.57/Add_4_output_0', '/transformer_blocks.57/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.57/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.56/Mul_27_output_0', '/transformer_blocks.57/Div_7_output_0', '/transformer_blocks.57/Add_9_output_0', '/transformer_blocks.57/Mul_11_output_0', '/transformer_blocks.57/Add_7_output_0', '/transformer_blocks.57/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.57/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.57/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.57/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.57/attn/Mul_1_output_0', '/transformer_blocks.57/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.57/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.57/attn/Mul_2_output_0', '/transformer_blocks.57/attn/Mul_21_output_0', '/transformer_blocks.57/Div_4_output_0', '/transformer_blocks.57/Mul_12_output_0', '/transformer_blocks.57/attn/Mul_22_output_0', '/transformer_blocks.57/Div_6_output_0', '/transformer_blocks.57/Mul_13_output_0', '/transformer_blocks.57/img_norm2/LayerNormalization_output_0', '/transformer_blocks.57/img_norm2/LayerNormalization_output_0', '/transformer_blocks.57/Mul_14_output_0', '/transformer_blocks.57/Div_11_output_0', '/transformer_blocks.57/Add_13_output_0', '/transformer_blocks.57/Mul_18_output_0', '/transformer_blocks.57/Add_12_output_0', '/transformer_blocks.57/Div_9_output_0', '/transformer_blocks.57/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.57/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.57/Mul_20_output_0', '/transformer_blocks.57/Div_15_output_0', '/transformer_blocks.57/Add_17_output_0', '/transformer_blocks.57/Mul_24_output_0', '/transformer_blocks.57/Add_16_output_0', '/transformer_blocks.57/Div_13_output_0', '/transformer_blocks.58/img_norm1/LayerNormalization_output_0', '/transformer_blocks.58/img_norm1/LayerNormalization_output_0', '/transformer_blocks.57/Mul_26_output_0', '/transformer_blocks.58/Div_5_output_0', '/transformer_blocks.58/Add_8_output_0', '/transformer_blocks.58/Mul_7_output_0', '/transformer_blocks.58/Add_4_output_0', '/transformer_blocks.58/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.58/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.57/Mul_27_output_0', '/transformer_blocks.58/Div_7_output_0', '/transformer_blocks.58/Add_9_output_0', '/transformer_blocks.58/Mul_11_output_0', '/transformer_blocks.58/Add_7_output_0', '/transformer_blocks.58/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.58/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.58/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.58/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.58/attn/Mul_1_output_0', '/transformer_blocks.58/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.58/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.58/attn/Mul_2_output_0', '/transformer_blocks.58/attn/Mul_21_output_0', '/transformer_blocks.58/Div_4_output_0', '/transformer_blocks.58/Mul_12_output_0', '/transformer_blocks.58/attn/Mul_22_output_0', '/transformer_blocks.58/Div_6_output_0', '/transformer_blocks.58/Mul_13_output_0', '/transformer_blocks.58/img_norm2/LayerNormalization_output_0', '/transformer_blocks.58/img_norm2/LayerNormalization_output_0', '/transformer_blocks.58/Mul_14_output_0', '/transformer_blocks.58/Div_11_output_0', '/transformer_blocks.58/Add_13_output_0', '/transformer_blocks.58/Mul_18_output_0', '/transformer_blocks.58/Add_12_output_0', '/transformer_blocks.58/Div_9_output_0', '/transformer_blocks.58/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.58/txt_norm2/LayerNormalization_output_0', '/transformer_blocks.58/Mul_20_output_0', '/transformer_blocks.58/Div_15_output_0', '/transformer_blocks.58/Add_17_output_0', '/transformer_blocks.58/Mul_24_output_0', '/transformer_blocks.58/Add_16_output_0', '/transformer_blocks.58/Div_13_output_0', '/transformer_blocks.59/img_norm1/LayerNormalization_output_0', '/transformer_blocks.59/img_norm1/LayerNormalization_output_0', '/transformer_blocks.58/Mul_26_output_0', '/transformer_blocks.59/Div_5_output_0', '/transformer_blocks.59/Add_8_output_0', '/transformer_blocks.59/Mul_6_output_0', '/transformer_blocks.59/Add_4_output_0', '/transformer_blocks.59/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.59/txt_norm1/LayerNormalization_output_0', '/transformer_blocks.58/Mul_27_output_0', '/transformer_blocks.59/Mul_9_output_0', '/transformer_blocks.59/Add_7_output_0', '/transformer_blocks.59/attn/norm_q/CustomRMSNorm_output_0', '/transformer_blocks.59/attn/norm_k/CustomRMSNorm_output_0', '/transformer_blocks.59/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.59/attn/norm_added_q/CustomRMSNorm_output_0', '/transformer_blocks.59/attn/Mul_1_output_0', '/transformer_blocks.59/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.59/attn/norm_added_k/CustomRMSNorm_output_0', '/transformer_blocks.59/attn/Mul_2_output_0', '/transformer_blocks.59/attn/Mul_21_output_0', '/transformer_blocks.59/Div_4_output_0', '/transformer_blocks.59/Mul_10_output_0', '/transformer_blocks.59/img_norm2/LayerNormalization_output_0', '/transformer_blocks.59/img_norm2/LayerNormalization_output_0', '/transformer_blocks.59/Mul_11_output_0', '/transformer_blocks.59/Div_9_output_0', '/transformer_blocks.59/Add_12_output_0', '/transformer_blocks.59/Mul_15_output_0', '/transformer_blocks.59/Add_11_output_0', '/transformer_blocks.59/Div_7_output_0', '/norm_out/norm/LayerNormalization_output_0', '/Div_output_0', '/transformer_blocks.59/Mul_17_output_0', '/norm_out/Mul_2_output_0', '/norm_out/Add_2_output_0'] \ No newline at end of file diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index 9fa4d5a328..117691e5d1 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -265,20 +265,27 @@ def get_onnx_params(self, latent_height: int = 32, latent_width: int = 32) -> Tu - output_names (List[str]): Names of model outputs """ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - - # VAE decoder takes latent representation as input - example_inputs = { - "latent_sample": torch.randn(bs, 16, latent_height, latent_width), - "return_dict": False, - } - + if self.model.__class__.__name__ == "AutoencoderKLQwenImage": + # for models with latent frame + example_inputs = { + "latent_sample": torch.randn(bs, 16, 1, latent_height, latent_width), + "return_dict": False, + } + dynamic_axes = { + "latent_sample": {0: "batch_size", 3: "latent_height", 4: "latent_width"}, + } + else: + # VAE decoder takes latent representation as input + example_inputs = { + "latent_sample": torch.randn(bs, 16, latent_height, latent_width), + "return_dict": False, + } + # All dimensions except channels can be dynamic + dynamic_axes = { + "latent_sample": {0: "batch_size", 1: "channels", 2: "latent_height", 3: "latent_width"}, + } output_names = ["sample"] - # All dimensions except channels can be dynamic - dynamic_axes = { - "latent_sample": {0: "batch_size", 1: "channels", 2: "latent_height", 3: "latent_width"}, - } - return example_inputs, dynamic_axes, output_names def get_img_encoder_onnx_params(self) -> Tuple[Dict, Dict, List[str]]: @@ -686,3 +693,116 @@ def compile(self, specializations, **compiler_options) -> None: **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) """ self._compile(specializations=specializations, **compiler_options) + + +class QEffQwenImageTransformer2DModel(QEFFBaseModel): + """ + QEffQwenImageTransformer2DModel is a wrapper class for QwenImage Transformer2D models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle QwenImage Transformer2D models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. It is designed for the QwenImage architecture that uses + transformer-based diffusion models with unique latent packing and attention mechanisms. + """ + + _pytorch_transforms = [AttentionTransform, CustomOpsTransform] + _onnx_transforms = [SplitTensorsTransform] # No FP16 clip, to preserve scale factors changes in modeling + + def __init__(self, model: nn.Module): + super().__init__(model) + + def get_onnx_params(self): + """ + Build representative inputs and dynamic-axis mappings for ONNX export. + + Returns: + Tuple[Dict, Dict, List[str]]: + A tuple containing: + - example model inputs + - ONNX dynamic axis configuration + - ONNX output names + """ + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + + # For testing purpose I have set this to constant values from the original models + cl = constants.QWEN_IMAGE_CL + seq_length = constants.QWEN_IMAGE_SL + rot_dim = sum(self.model.config.axes_dims_rope) + example_inputs = { + "hidden_states": torch.randn(bs, cl, self.model.config.in_channels, dtype=torch.float32), + "encoder_hidden_states": torch.randn( + bs, seq_length, self.model.config.joint_attention_dim, dtype=torch.float32 + ), + "encoder_hidden_states_mask": torch.ones(bs, seq_length, dtype=torch.int64), + "txt_seq_lens": torch.tensor([seq_length], dtype=torch.int64), + "img_rotary_emb": torch.randn(cl, rot_dim, dtype=torch.float32), + "txt_rotary_emb": torch.randn(seq_length, rot_dim, dtype=torch.float32), + "timestep": torch.tensor([1.0], dtype=torch.float32), + } + + output_names = ["output"] + + dynamic_axes = { + "hidden_states": {0: "batch_size", 1: "cl"}, + "encoder_hidden_states": {0: "batch_size", 1: "seq_length"}, + "encoder_hidden_states_mask": {0: "batch_size", 1: "seq_length"}, + "img_rotary_emb": {0: "cl"}, + "txt_rotary_emb": {0: "seq_length"}, + } + + return example_inputs, dynamic_axes, output_names + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + use_onnx_subfunctions: bool = False, + ) -> str: + """ + Export the Qwen Image transformer to ONNX. + + Args: + inputs (`Dict`): + Example input tensors used for tracing during export. + output_names (`List[str]`): + Ordered list of ONNX output tensor names. + dynamic_axes (`Dict`): + Dynamic axis configuration for ONNX inputs/outputs. + export_dir (`str`, *optional*): + Directory to write the ONNX model into. + use_onnx_subfunctions (`bool`, *optional*, defaults to `False`): + Enables exporting transformer blocks as ONNX subfunctions. + + Returns: + `str`: Path to the exported ONNX model. + """ + + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + offload_pt_weights=True, + use_onnx_subfunctions=use_onnx_subfunctions, + ) + + def compile(self, specializations: List[Dict], **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) + """ + self._compile(specializations=specializations, **compiler_options) + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying Qwen transformer model + """ + return self.model.config.__dict__ diff --git a/QEfficient/diffusers/pipelines/qwen_image/__init__.py b/QEfficient/diffusers/pipelines/qwen_image/__init__.py new file mode 100644 index 0000000000..75daf1953a --- /dev/null +++ b/QEfficient/diffusers/pipelines/qwen_image/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/qwen_image/magcache.py b/QEfficient/diffusers/pipelines/qwen_image/magcache.py new file mode 100644 index 0000000000..7ff0a3d8f3 --- /dev/null +++ b/QEfficient/diffusers/pipelines/qwen_image/magcache.py @@ -0,0 +1,286 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +"""Runtime MagCache helpers for Qwen-Image pipelines. + +This module implements a pipeline-level (graph-agnostic) MagCache controller. +It does not modify ONNX/QPC graph signatures. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Dict, Optional, Sequence + +import numpy as np +import torch + +# Qwen-Image mag ratios from MagCache4QwenImage. +DEFAULT_QWEN_IMAGE_MAG_RATIOS = [ + 1.64062, + 1.64062, + 1.45312, + 1.45312, + 1.1875, + 1.1875, + 1.17188, + 1.17188, + 1.05469, + 1.05469, + 1.21094, + 1.21094, + 1.10938, + 1.10938, + 1.11719, + 1.11719, + 1.13281, + 1.13281, + 1.11719, + 1.11719, + 1.07812, + 1.07812, + 1.07031, + 1.07031, + 1.08594, + 1.08594, + 1.08594, + 1.08594, + 1.07812, + 1.07812, + 1.03906, + 1.03906, + 1.04688, + 1.04688, + 1.07812, + 1.07812, + 1.07031, + 1.07031, + 1.03125, + 1.03125, + 1.07812, + 1.07812, + 1.04688, + 1.04688, + 1.04688, + 1.04688, + 1.04688, + 1.04688, + 1.03906, + 1.03906, + 1.01562, + 1.01562, + 1.03125, + 1.03125, + 1.02344, + 1.02344, + 1.02344, + 1.02344, + 1.02344, + 1.02344, + 1.03906, + 1.03906, + 1.0, + 1.0, + 1.01562, + 1.01562, + 1.0, + 1.0, + 0.99219, + 0.99219, + 1.00781, + 1.00781, + 0.98047, + 0.98047, + 0.95703, + 0.95703, + 0.96875, + 0.96875, + 0.99219, + 0.99219, + 0.92578, + 0.92578, + 0.92578, + 0.92578, + 0.90625, + 0.90625, + 0.85938, + 0.85938, + 0.80469, + 0.80469, + 0.87891, + 0.87891, + 0.75, + 0.75, + 0.60938, + 0.60938, + 0.55078, + 0.55078, +] + + +def nearest_interp(src_array: np.ndarray, target_length: int) -> np.ndarray: + """Nearest-neighbor interpolation used by the upstream MagCache scripts.""" + src_length = len(src_array) + if target_length == 1: + return np.array([src_array[-1]], dtype=np.float32) + + scale = (src_length - 1) / (target_length - 1) + mapped_indices = np.round(np.arange(target_length) * scale).astype(int) + return src_array[mapped_indices].astype(np.float32) + + +@dataclass +class _StreamState: + cached_residual: Optional[torch.Tensor] = None + accumulated_ratio: float = 1.0 + accumulated_err: float = 0.0 + accumulated_steps: int = 0 + + def reset_accumulators(self) -> None: + self.accumulated_ratio = 1.0 + self.accumulated_err = 0.0 + self.accumulated_steps = 0 + + def reset_all(self) -> None: + self.cached_residual = None + self.reset_accumulators() + + +@dataclass +class QwenImageMagCacheRuntime: + """Runtime state machine for Qwen-Image MagCache.""" + + num_inference_steps: int + do_classifier_free_guidance: bool + threshold: float + max_skip_steps: int + retention_ratio: float + ratios: Optional[Sequence[float]] = None + verbose: bool = False + + call_index: int = 0 + skipped_calls: int = 0 + executed_calls: int = 0 + stream_states: Dict[str, _StreamState] = field(default_factory=dict) + + def _debug_print(self, message: str) -> None: + if self.verbose: + print(message) + + def __post_init__(self) -> None: + if self.threshold < 0: + raise ValueError(f"`magcache_thresh` must be >= 0, got {self.threshold}.") + if self.max_skip_steps < 0: + raise ValueError(f"`magcache_K` must be >= 0, got {self.max_skip_steps}.") + if not 0.0 <= self.retention_ratio <= 1.0: + raise ValueError(f"`magcache_retention_ratio` must be in [0, 1], got {self.retention_ratio}.") + + self.calls_per_step = 2 if self.do_classifier_free_guidance else 1 + self.total_calls = self.num_inference_steps * self.calls_per_step + self._prepared_ratios = self._prepare_ratios( + self.ratios, + num_steps=self.num_inference_steps, + calls_per_step=self.calls_per_step, + ) + + self.stream_states = {"cond": _StreamState()} + if self.do_classifier_free_guidance: + self.stream_states["uncond"] = _StreamState() + + @staticmethod + def _prepare_ratios( + ratios: Optional[Sequence[float]], + num_steps: int, + calls_per_step: int, + ) -> np.ndarray: + raw = np.asarray( + DEFAULT_QWEN_IMAGE_MAG_RATIOS if ratios is None else list(ratios), + dtype=np.float32, + ) + + if calls_per_step == 1: + if raw.size % 2 == 0 and raw.size > 0: + raw = raw[0::2] + prepared = np.concatenate([np.array([1.0], dtype=np.float32), raw]) + if len(prepared) != num_steps: + prepared = nearest_interp(prepared, num_steps) + return prepared + + prepared = np.concatenate([np.array([1.0, 1.0], dtype=np.float32), raw]) + if len(prepared) != num_steps * 2: + cond_ratios = nearest_interp(prepared[0::2], num_steps) + uncond_ratios = nearest_interp(prepared[1::2], num_steps) + prepared = np.empty(num_steps * 2, dtype=np.float32) + prepared[0::2] = cond_ratios + prepared[1::2] = uncond_ratios + return prepared + + def _cache_allowed_for_call(self, call_index: int) -> bool: + return call_index >= int(self.total_calls * self.retention_ratio) + + def should_skip(self, stream_name: str) -> bool: + state = self.stream_states[stream_name] + + if not self._cache_allowed_for_call(self.call_index): + self._debug_print( + f"[MagCache] call={self.call_index} stream={stream_name} diff=N/A " + f"thresh={self.threshold:.6f} decision=run (retention window)" + ) + return False + if state.cached_residual is None: + self._debug_print( + f"[MagCache] call={self.call_index} stream={stream_name} diff=N/A " + f"thresh={self.threshold:.6f} decision=run (cache cold start)" + ) + return False + + ratio = float(self._prepared_ratios[self.call_index]) + state.accumulated_ratio *= ratio + state.accumulated_steps += 1 + state.accumulated_err += abs(1.0 - state.accumulated_ratio) + + should_skip = state.accumulated_err < self.threshold and state.accumulated_steps <= self.max_skip_steps + self._debug_print( + f"[MagCache] call={self.call_index} stream={stream_name} diff={state.accumulated_err:.6f} " + f"thresh={self.threshold:.6f} k={state.accumulated_steps}/{self.max_skip_steps} " + f"decision={'skip' if should_skip else 'run'}" + ) + + if should_skip: + self.skipped_calls += 1 + self._debug_print(f"[MagCache] stream={stream_name} diff<{self.threshold:.6f}; skipping this step for now.") + return True + + state.reset_accumulators() + return False + + def get_cached_residual(self, stream_name: str) -> torch.Tensor: + cached = self.stream_states[stream_name].cached_residual + if cached is None: + raise RuntimeError(f"MagCache residual is empty for stream '{stream_name}'.") + return cached + + def complete_call(self, stream_name: str, residual: torch.Tensor) -> None: + state = self.stream_states[stream_name] + state.cached_residual = residual.detach() + self.executed_calls += 1 + + self.call_index += 1 + if self.call_index >= self.total_calls: + self._reset_for_next_image() + + def complete_skip(self, stream_name: str) -> None: + if stream_name not in self.stream_states: + raise KeyError(f"Unknown stream name '{stream_name}'.") + + self.call_index += 1 + if self.call_index >= self.total_calls: + self._reset_for_next_image() + + def _reset_for_next_image(self) -> None: + self.call_index = 0 + for state in self.stream_states.values(): + state.reset_all() diff --git a/QEfficient/diffusers/pipelines/qwen_image/pipeline_qwenimage.py b/QEfficient/diffusers/pipelines/qwen_image/pipeline_qwenimage.py new file mode 100644 index 0000000000..89ee3e55d1 --- /dev/null +++ b/QEfficient/diffusers/pipelines/qwen_image/pipeline_qwenimage.py @@ -0,0 +1,737 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +""" +QEfficient Qwen Image pipeline implementation. + +This module wraps `diffusers.QwenImagePipeline` with QEfficient export, +compilation, and runtime integrations for QAIC deployment. +TODO: 1. Update Qwen text encoder to Qaic; present running on cpu +""" + +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from diffusers import QwenImagePipeline +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.qwenimage.pipeline_qwenimage import calculate_shift +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps +from tqdm import tqdm + +from QEfficient.diffusers.pipelines.pipeline_module import ( + QEffQwenImageTransformer2DModel, + QEffVAE, +) +from QEfficient.diffusers.pipelines.pipeline_utils import ( + ONNX_SUBFUNCTION_MODULE, + ModulePerf, + QEffPipelineOutput, + calculate_compressed_latent_dimension, + compile_modules_parallel, + compile_modules_sequential, + config_manager, + set_execute_params, +) +from QEfficient.diffusers.pipelines.qwen_image.magcache import QwenImageMagCacheRuntime +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils.logging_utils import logger + + +class QEffQwenImagePipeline: + """ + QEfficient-optimized Qwen Image text-to-image pipeline. + + This wrapper integrates QEfficient model components for ONNX export, QAIC + compilation, and inference while preserving the familiar Diffusers API. + """ + + _hf_auto_class = QwenImagePipeline + + def __init__(self, model, **kwargs): + """ + Initialize the QEfficient Qwen Image pipeline wrapper. + + Args: + model (`QwenImagePipeline`): + The underlying Diffusers Qwen Image pipeline instance. + **kwargs: + Additional metadata passed from `from_pretrained`. + """ + self.model = model + self.kwargs = kwargs + + self.text_encoder = model.text_encoder # TODO: Text encoder on QAIC + self.transformer = QEffQwenImageTransformer2DModel(model.transformer) + self.vae_decoder = QEffVAE(model.vae, "decoder") + + # Store all modules in a dictionary for easy iteration during export/compile + self.modules = { + "transformer": self.transformer, + "vae_decoder": self.vae_decoder, + } + + # Copy tokenizers and scheduler from the original model + self.tokenizer = model.tokenizer + self.tokenizer_max_length = model.tokenizer_max_length + self.scheduler = model.scheduler + + self.prompt_template_encode = model.prompt_template_encode + self.prompt_template_encode_start_idx = model.prompt_template_encode_start_idx + + self.vae_decoder.model.forward = lambda latent_sample, return_dict: self.vae_decoder.model.decode( + latent_sample, return_dict + ) + + self.vae_scale_factor = 2 ** len(model.vae.temperal_downsample) if getattr(model, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = model.default_sample_size + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + """ + Instantiate a QEffQwenImagePipeline from a pretrained checkpoint. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Local path or model identifier. + **kwargs: + Additional keyword arguments passed to + `QwenImagePipeline.from_pretrained`. + + Returns: + `QEffQwenImagePipeline`: Initialized QEff pipeline instance. + """ + model = cls._hf_auto_class.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch.float32, + device_map="cpu", + **kwargs, + ) + return cls( + model=model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, + ) + + def export( + self, + export_dir: Optional[str] = None, + use_onnx_subfunctions: bool = False, + ) -> str: + """ + Export all pipeline modules to ONNX format for deployment preparation. + + This method exports the Qwen Image transformer and VAE decoder with + their module-specific dynamic axes and export settings. + + Args: + export_dir (`str`, *optional*): + Target directory for ONNX artifacts. If `None`, the default + export structure is used. + use_onnx_subfunctions (`bool`, *optional*, defaults to `False`): + Enables ONNX subfunction export for supported modules. + + Returns: + `str`: Absolute path to the export directory. + + Raises: + RuntimeError: If ONNX export fails for any module + OSError: If there are issues creating the export directory or writing files + ValueError: If module configurations are invalid + + Example: + >>> pipeline = QEffQwenImagePipeline.from_pretrained("Qwen/Qwen-Image") + >>> export_path = pipeline.export( + ... export_dir="/path/to/export", + ... use_onnx_subfunctions=True + ... ) + """ + # Export each module with video-specific parameters + for module_name, module_obj in tqdm(self.modules.items(), desc="Exporting modules", unit="module"): + # Get ONNX export configuration with video dimensions + example_inputs, dynamic_axes, output_names = module_obj.get_onnx_params() + + # Prepare export parameters + export_params = { + "inputs": example_inputs, + "output_names": output_names, + "dynamic_axes": dynamic_axes, + "export_dir": export_dir, + } + + # Enable ONNX subfunctions for supported modules if requested + if use_onnx_subfunctions and module_name in ONNX_SUBFUNCTION_MODULE: + export_params["use_onnx_subfunctions"] = True + + if module_obj.qpc_path is None: + module_obj.export(**export_params) + + @staticmethod + def get_default_config_path(): + """ + Get the default configuration file path for the Qwen Image pipeline. + + Returns: + `str`: Path to the default Qwen Image configuration JSON file. + """ + return os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs/qwen_config.json") + + def compile( + self, + compile_config: Optional[str] = None, + parallel: bool = False, + height: int = 1024, + width: int = 1024, + use_onnx_subfunctions: bool = False, + ) -> None: + """ + Compile ONNX models into optimized QPC binaries for QAIC. + + Args: + compile_config (`str`, *optional*): + Path to a JSON compilation config. If `None`, the default + config is used. + parallel (`bool`, *optional*, defaults to `False`): + If `True`, compile modules in parallel. + height (`int`, *optional*): + Target image height in pixel space. + width (`int`, *optional*): + Target image width in pixel space. + use_onnx_subfunctions (`bool`, *optional*, defaults to `False`): + Re-export modules with ONNX subfunctions before compile when needed. + + Raises: + RuntimeError: If compilation fails for any module. + FileNotFoundError: If required files/configuration are missing. + ValueError: If configuration parameters are invalid. + OSError: If file I/O fails. + + Example: + >>> pipeline = QEffQwenImagePipeline.from_pretrained("Qwen/Qwen-Image") + >>> pipeline.compile( + ... compile_config="/path/to/custom_config.json", + ... parallel=True, + ... height=480, + ... width=832, + ... ) + """ + # Load compilation configuration + config_manager(self, config_source=compile_config, use_onnx_subfunctions=use_onnx_subfunctions) + + # Set device IDs, qpc path if precompiled qpc exist + set_execute_params(self) + + # Skip export/compile entirely when all modules already have precompiled QPCs. + if all(module_obj.qpc_path is not None for module_obj in self.modules.values()): + logger.info("All module `qpc_path`s are provided; skipping export and compile.") + return + + # Ensure modules that still need compilation are exported to ONNX. + if any(module_obj.qpc_path is None and module_obj.onnx_path is None for module_obj in self.modules.values()): + self.export(use_onnx_subfunctions=use_onnx_subfunctions) + + # Calculate compressed latent dimension using utility function + cl, latent_height, latent_width = calculate_compressed_latent_dimension( + height, width, self.model.vae_scale_factor + ) + + # Prepare dynamic specialization updates based on video dimensions + specialization_updates = { + "transformer": { + "cl": cl, # Compressed latent dimension + }, + "vae_decoder": { + "latent_frames": 1, + "latent_height": latent_height, + "latent_width": latent_width, + }, + } + + # Use generic utility functions for compilation + if parallel: + compile_modules_parallel(self.modules, self.custom_config, specialization_updates) + else: + compile_modules_sequential(self.modules, self.custom_config, specialization_updates) + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + Tokenize prompt text and return encoder hidden states plus attention mask. + + Args: + prompt (`str` or `List[str]`, *optional*): + Input prompt(s) to encode. + device (`torch.device`, *optional*): + Device used for tokenization and text encoding. + dtype (`torch.dtype`, *optional*): + Target dtype for returned prompt embeddings. + + Returns: + Tuple[`torch.Tensor`, `torch.Tensor`]: + Prompt embeddings and corresponding attention mask, both trimmed + to remove template prefix tokens. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt = txt[0] # batch size is 1 + txt_tokens = self.tokenizer( + txt, + max_length=self.tokenizer_max_length + drop_idx, + padding="max_length", + truncation=True, + return_tensors="pt", + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + atten_mask = txt_tokens.attention_mask[:, drop_idx:] + prompt_embeds = hidden_states[:, drop_idx:, :] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds, atten_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + """ + Prepare prompt embeddings and masks for image generation. + + Args: + prompt (`str` or `List[str]`): + Prompt text to encode. + device (`torch.device`, *optional*): + Torch device used for embedding computation. + num_images_per_prompt (`int`): + Number of images to generate per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Precomputed prompt embeddings. If omitted, embeddings are + computed from `prompt`. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask aligned with `prompt_embeds`. + max_sequence_length (`int`, *optional*): + Maximum text sequence length retained from the embeddings. + + Returns: + Tuple[`torch.Tensor`, `torch.Tensor`]: + Prompt embeddings and masks repeated for + `num_images_per_prompt`. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + custom_config_path: Optional[str] = None, + parallel_compile: bool = False, + use_onnx_subfunctions: bool = False, + use_magcache: bool = False, + magcache_thresh: float = 0.06, + magcache_K: int = 2, + magcache_retention_ratio: float = 0.2, + magcache_ratios: Optional[List[float]] = None, + magcache_verbose: bool = False, + ): + """ + Generate images from text prompts with the QEfficient Qwen Image pipeline. + This method performs text-to-image generation by encoding the input prompts through the + Qwen text encoder, running the diffusion process with the transformer model, and decoding + the final latents to images using the VAE decoder. All components are optimized for + Qualcomm AI hardware. + + Args: + prompt (Union[str, List[str]], optional): The text prompt(s) to guide image generation. + negative_prompt (Union[str, List[str]], optional): Negative prompt(s) for true CFG. + true_cfg_scale (float, defaults to 4.0): Scale for true classifier-free guidance. + height (Optional[int], optional): Height of the generated image in pixels. + width (Optional[int], optional): Width of the generated image in pixels. + num_inference_steps (int, defaults to 50): Number of denoising steps. + sigmas (Optional[List[float]], optional): Custom sigmas for the denoising process. + guidance_scale (float, defaults to 1.0): Guidance scale (for future guidance-distilled models). + num_images_per_prompt (int, defaults to 1): Number of images to generate per prompt. + generator (Optional[Union[torch.Generator, List[torch.Generator]]], optional): Random generator(s). + latents (Optional[torch.Tensor], optional): Pre-generated noisy latents. + prompt_embeds (Optional[torch.Tensor], optional): Pre-generated text embeddings. + prompt_embeds_mask (Optional[torch.Tensor], optional): Pre-generated text embeddings mask. + negative_prompt_embeds (Optional[torch.Tensor], optional): Pre-generated negative text embeddings. + negative_prompt_embeds_mask (Optional[torch.Tensor], optional): Pre-generated negative text embeddings mask. + output_type (Optional[str], defaults to "pil"): Output format ("pil", "np", "pt", or "latent"). + return_dict (bool, defaults to True): Whether to return a QwenImagePipelineOutput. + attention_kwargs (Optional[Dict[str, Any]], optional): Additional attention kwargs. + callback_on_step_end (Optional[Callable], optional): Callback function at end of each step. + callback_on_step_end_tensor_inputs (List[str], defaults to ["latents"]): Tensor inputs for callback. + max_sequence_length (int, defaults to 512): Maximum sequence length for text encoder. + custom_config_path (`str`, *optional*): JSON config path used by compile/config manager. + parallel_compile (`bool`, *optional*, defaults to `False`): Whether to compile modules in parallel. + use_onnx_subfunctions (`bool`, *optional*, defaults to `False`): Whether to enable ONNX subfunction export. + use_magcache (bool, optional): Enable Qwen Image MagCache skip/reuse logic. Default: False. + magcache_thresh (float, optional): MagCache accumulated error threshold. Default: 0.06. + magcache_K (int, optional): Maximum consecutive skipped calls per stream. Default: 2. + magcache_retention_ratio (float, optional): Retention ratio in [0, 1]. Default: 0.2. + magcache_ratios (Optional[List[float]], optional): Optional custom MagCache ratio profile. + magcache_verbose (bool, optional): Emit per-call MagCache decisions. Default: False. + + Returns: + Union[QwenImagePipelineOutput, Tuple]: + Generated images and per-module performance metrics. + + Examples: + >>> from QEfficient import QEffQwenImagePipeline + >>> pipeline = QEffQwenImagePipeline.from_pretrained("Qwen/Qwen-Image") + >>> output = pipeline( + ... prompt="A cat holding a sign that says hello world", + ... negative_prompt="", + ... width=1664, + ... height=928, + ... num_inference_steps=50, + ... true_cfg_scale=4.0, + ... generator=torch.Generator(device="cpu").manual_seed(42), + ... parallel_compile=True, + ... max_sequence_length=128, + ... use_onnx_subfunctions=True, + ... ) + >>> image = output.images[0] + >>> image.save("output.png") + """ + device = self.model._execution_device + + height = height + width = width + + # Compile models with custom configuration if needed + self.compile( + compile_config=custom_config_path, + parallel=parallel_compile, + use_onnx_subfunctions=use_onnx_subfunctions, + height=height, + width=width, + ) + + # 1. Check inputs + self.model.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + if isinstance(use_magcache, str): + lowered = use_magcache.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + use_magcache = True + elif lowered in {"0", "false", "no", "off"}: + use_magcache = False + else: + raise ValueError( + f"Invalid string value for `use_magcache`: {use_magcache!r}. " + "Use one of: true/false, 1/0, yes/no, on/off." + ) + elif not isinstance(use_magcache, bool): + use_magcache = bool(use_magcache) + logger.warning(f"Coerced non-bool `use_magcache` to {use_magcache}.") + + if not use_magcache and ( + magcache_verbose + or magcache_ratios is not None + or magcache_thresh != 0.06 + or magcache_K != 2 + or magcache_retention_ratio != 0.2 + ): + logger.warning("Ignoring MagCache knobs because `use_magcache=False`.") + + # 3: Encode prompts with both text encoders + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # Encode negative prompts if using true classifier-free guidance + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.model.config.in_channels // 4 + latents = self.model.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + txt_seq_lens = [max_sequence_length] + + magcache_runtime = None + if use_magcache: + magcache_runtime = QwenImageMagCacheRuntime( + num_inference_steps=num_inference_steps, + do_classifier_free_guidance=do_true_cfg, + threshold=magcache_thresh, + max_skip_steps=magcache_K, + retention_ratio=magcache_retention_ratio, + ratios=magcache_ratios, + verbose=magcache_verbose, + ) + + # # Initialize transformer session + if self.transformer.qpc_session is None: + self.transformer.qpc_session = QAICInferenceSession( + str(self.transformer.qpc_path), device_ids=self.transformer.device_ids + ) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + transformer_perf = [] + cfg_perf = [] + + # rotary emb + qaic_image_rotary_emb = self.transformer.model.pos_embed(img_shapes, txt_seq_lens, device="cpu") + qaic_img_freqs_cos, qaic_img_freqs_sin, qaic_txt_freqs_cos, qaic_txt_freqs_sin = qaic_image_rotary_emb + + img_rotary_emb = torch.cat([qaic_img_freqs_cos, qaic_img_freqs_sin], dim=-1) # [6032, 128] + txt_rotary_emb = torch.cat([qaic_txt_freqs_cos, qaic_txt_freqs_sin], dim=-1) # [126, 128] + with self.model.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self._interrupt: + continue + self._current_timestep = t + timestep = (t.expand(latents.shape[0]) / 1000).detach().numpy().astype(np.float32) + + # Conditional pass + transformer_inputs = { + "hidden_states": latents.detach().numpy().astype(np.float32), + "encoder_hidden_states": prompt_embeds.detach().numpy().astype(np.float32), + "encoder_hidden_states_mask": prompt_embeds_mask.detach().numpy().astype(np.int64), + "img_rotary_emb": img_rotary_emb.detach().numpy().astype(np.float32), + "txt_rotary_emb": txt_rotary_emb.detach().numpy().astype(np.float32), + "timestep": timestep, + } + + def _run_qwen_step( + stream_name: str, + inputs: Dict[str, np.ndarray], + perf_bucket: List[float], + ) -> torch.Tensor: + if magcache_runtime is not None and magcache_runtime.should_skip(stream_name): + cached_residual = magcache_runtime.get_cached_residual(stream_name) + magcache_runtime.complete_skip(stream_name) + return latents.to(cached_residual.dtype) + cached_residual + + start_step_time = time.perf_counter() + outputs = self.transformer.qpc_session.run(inputs) + end_step_time = time.perf_counter() + perf_bucket.append(end_step_time - start_step_time) + noise_pred_step = torch.from_numpy(outputs["output"]) + + if magcache_runtime is not None: + residual = noise_pred_step - latents.to(noise_pred_step.dtype) + magcache_runtime.complete_call(stream_name, residual) + return noise_pred_step + + noise_pred = _run_qwen_step("cond", transformer_inputs, transformer_perf) + + if do_true_cfg: + # Unconditional pass + transformer_inputs_uncond = { + "hidden_states": latents.detach().numpy().astype(np.float32), + "encoder_hidden_states": negative_prompt_embeds.detach().numpy().astype(np.float32), + "encoder_hidden_states_mask": negative_prompt_embeds_mask.detach().numpy().astype(np.int64), + "img_rotary_emb": img_rotary_emb.detach().numpy().astype(np.float32), + "txt_rotary_emb": txt_rotary_emb.detach().numpy().astype(np.float32), + "timestep": timestep, + } + neg_noise_pred = _run_qwen_step("uncond", transformer_inputs_uncond, cfg_perf) + + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + self._current_timestep = None + vae_decoder_perf = 0.0 + if output_type == "latent": + image = latents + else: + latents = self.model._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae_decoder.model.dtype) + latents_mean = ( + torch.tensor(self.vae_decoder.model.config.latents_mean) + .view(1, self.vae_decoder.model.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae_decoder.model.config.latents_std).view( + 1, self.vae_decoder.model.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + ########## QAIC + # Initialize VAE decoder inference session + if self.vae_decoder.qpc_session is None: + self.vae_decoder.qpc_session = QAICInferenceSession( + str(self.vae_decoder.qpc_path), device_ids=self.vae_decoder.device_ids + ) + + # Allocate output buffer for VAE decoder + output_buffer = {"sample": np.random.rand(batch_size, 3, 1, height, width).astype(np.int32)} + self.vae_decoder.qpc_session.set_buffers(output_buffer) + + # Run VAE decoder inference and measure time + inputs = {"latent_sample": latents.numpy()} + start_decode_time = time.perf_counter() + image = self.vae_decoder.qpc_session.run(inputs) + end_decode_time = time.perf_counter() + vae_decoder_perf = end_decode_time - start_decode_time + + image_tensor = torch.from_numpy(image["sample"]) + image_tensor = image_tensor[:, :, 0] + image = self.image_processor.postprocess(image_tensor, output_type=output_type) + + if not return_dict: + return (image,) + + # Build performance metrics + perf_metrics = [ + ModulePerf(module_name="transformer", perf=transformer_perf), + ModulePerf(module_name="vae_decoder", perf=vae_decoder_perf), + ] + + return QEffPipelineOutput( + pipeline_module=perf_metrics, + images=image, + ) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 339e4f4dac..426e928f7d 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -191,6 +191,10 @@ def get_models_dir(): # WAN I2V WAN_DIT_I2V_IMG_LATENT_CHANNELS = 32 +# QWEN Image +QWEN_IMAGE_SL = 128 +QWEN_IMAGE_CL = 256 + # For the purpose of automatic CCL lists generation, to limit the number of elements in CCL list, the starting point will be calculated based on context length CCL_START_MAP = { 32768: (4096, 4000), diff --git a/examples/diffusers/qwen_image/README.md b/examples/diffusers/qwen_image/README.md new file mode 100644 index 0000000000..11f8f506fa --- /dev/null +++ b/examples/diffusers/qwen_image/README.md @@ -0,0 +1,202 @@ +# Qwen Image Generation Examples + +This directory contains examples showing how to use `QEffQwenImagePipeline` to generate images with Qwen Image models. + +## Overview + +Qwen Image is a text-to-image diffusion model. These examples demonstrate end-to-end image generation with QEfficient export, compile, and execution flow for Qualcomm Cloud AI 100. + +## Files + +- **`qwen_image_example.py`** - Basic image generation example +- **`qwen_image_custom.py`** - Advanced example with customization options +- **`qwen_image_magcache.py`** - Image generation with optional MagCache runtime +- **`qwen_config.json`** - Configuration file for transformer and VAE modules + +## Quick Start + +### Basic Usage + +```python +from QEfficient import QEffQwenImagePipeline +import torch + +pipeline = QEffQwenImagePipeline.from_pretrained("Qwen/Qwen-Image") + +output = pipeline( + prompt="A cinematic photo of a coffee shop street in rain", + negative_prompt="low quality, blurry", + width=1664, + height=928, + num_inference_steps=50, + true_cfg_scale=4.0, + generator=torch.Generator(device="cpu").manual_seed(42), + parallel_compile=True, + max_sequence_length=128, +) + +output.images[0].save("qwen_output.png") +``` + +Run the basic example: + +```bash +python qwen_image_example.py +``` + +Run with MagCache: + +```bash +python qwen_image_magcache.py +``` + +Run with MagCache disabled: + +```bash +# Edit qwen_image_magcache.py and set: +# use_magcache = False +python qwen_image_magcache.py +``` + +If you already have compiled QPCs, set their paths in `examples/diffusers/qwen_image/qwen_config.json` +and keep `custom_config_path=\"examples/diffusers/qwen_image/qwen_config.json\"` in the script; +the pipeline will skip export/compile. + +## Advanced Customization + +`qwen_image_custom.py` includes common customizations. + +### 1. Custom model components + +```python +pipeline = QEffQwenImagePipeline.from_pretrained( + "Qwen/Qwen-Image", + transformer=custom_transformer, + vae=custom_vae, + text_encoder=custom_text_encoder, + tokenizer=custom_tokenizer, +) +``` + +### 2. Custom scheduler + +```python +pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config) +``` + +### 3. Reduce layers for faster iteration + +```python +original_blocks = pipeline.transformer.model.transformer_blocks +pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0], original_blocks[1]]) +pipeline.transformer.model.config.num_layers = 2 +``` + +### 4. Compile with custom configuration + +```python +pipeline.compile( + compile_config="examples/diffusers/qwen_image/qwen_config.json", + parallel=True, + height=928, + width=1664, + use_onnx_subfunctions=False, +) +``` + +### 5. Skip export/compile if precompiled QPC exists + +Update `qwen_config.json` with prebuilt `qpc_path` in `execute` for each module: + +```json +"execute": { + "device_ids": null, + "qpc_path": "" +} +``` + +### 6. Runtime custom config + +```python +output = pipeline( + prompt="A modern storefront at golden hour", + negative_prompt="low quality", + width=1664, + height=928, + num_inference_steps=50, + true_cfg_scale=4.0, + generator=torch.Generator(device="cpu").manual_seed(42), + custom_config_path="examples/diffusers/qwen_image/qwen_config.json", + parallel_compile=True, + max_sequence_length=128, + use_onnx_subfunctions=True, +) +``` + +Run the advanced example: + +```bash +python qwen_image_custom.py +``` + +## Configuration File + +`qwen_config.json` controls specialization, compile, and execute settings for: + +- `transformer` +- `vae_decoder` + +### Common parameter groups + +#### Specializations +- `batch_size` +- `cl` (image token length for transformer) +- `seq_length` (text sequence length) +- `latent_height`, `latent_width` (VAE decode shape) + +#### Compilation +- `onnx_path`: Path to pre-exported ONNX model (null for auto-export) +- `compile_dir`: Directory for compiled artifacts (null for auto-generation) +- `mdp_ts_num_devices`: Number of devices for model data parallelism +- `mxfp6_matmul`: Enable MXFP6 quantization for matrix multiplication +- `convert_to_fp16`: Convert model to FP16 precision +- `aic_num_cores`: Number of AI cores to use +- `mos`: Multi-output streaming +- `mdts-mos`: Multi-device tensor slicing with MOS (transformer only) +- `aic-enable-depth-first`: Enable depth-first compilation + +#### Execute +- `device_ids`: List of device IDs to use (null for auto-selection) +- `qpc_path` : compiled qpc path, to skip recompilation (null by default) + +## Key Generation Parameters + +- **`prompt`**: Positive prompt string +- **`negative_prompt`**: Negative prompt string +- **`width`**, **`height`**: Output image size +- **`num_inference_steps`**: Number of denoising steps +- **`true_cfg_scale`**: Classifier-free guidance scale +- **`max_sequence_length`**: Text sequence length limit +- **`generator`**: Seeded torch generator for reproducibility +- **`parallel_compile`**: Compile multiple modules in parallel +- **`use_onnx_subfunctions`**: Enable ONNX modular export (experimental) +- **`use_magcache`**: Enable runtime MagCache skip/reuse logic +- **`magcache_thresh`**: Accumulated skip-error threshold +- **`magcache_K`**: Maximum consecutive skipped calls per stream +- **`magcache_retention_ratio`**: Warmup/retention ratio before skipping is considered +- **`magcache_verbose`**: Print per-call diff/decision logs (`skipping this step for now` when skipped) + +## Output + +Pipeline output contains generated images and performance metadata. + +```python +print(output) +image = output.images[0] +image.save("output.png") +``` + +## References + +- [Qwen Image Model Card](https://huggingface.co/Qwen/Qwen-Image) +- [QEfficient Documentation](../../../README.md) diff --git a/examples/diffusers/qwen_image/qwen_config.json b/examples/diffusers/qwen_image/qwen_config.json new file mode 100644 index 0000000000..1512586a66 --- /dev/null +++ b/examples/diffusers/qwen_image/qwen_config.json @@ -0,0 +1,49 @@ +{ + "description": "Default configuration for QWEN image", + "modules": { + "transformer": { + "specializations": { + "batch_size":"1", + "seq_length":"128" + }, + "compilation": { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 4, + "convert_to_fp16": true, + "compile_only":true, + "aic_num_cores": 16, + "mos": 1, + "mdts_mos": 1, + "node_precision_info": "QEfficient/diffusers/pipelines/configs/qwen_image.yaml" + }, + "execute": { + "device_ids": null, + "qpc_path" : null + } + }, + "vae_decoder":{ + "specializations":{ + "batch_size": 1, + "num_channels": 16 + }, + "compilation": { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 4, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "aic-enable-depth-first": true, + "compile_only":true, + "mos": 1, + "mdts_mos": 1 + }, + "execute": + { + "device_ids": null, + "qpc_path" : null + } + } + } +} \ No newline at end of file diff --git a/examples/diffusers/qwen_image/qwen_image_custom.py b/examples/diffusers/qwen_image/qwen_image_custom.py new file mode 100644 index 0000000000..ef0a70cb5d --- /dev/null +++ b/examples/diffusers/qwen_image/qwen_image_custom.py @@ -0,0 +1,113 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +""" +Qwen Image Custom Configuration Example + +This example demonstrates how to customize the Qwen Image model with various options: +1. Custom image dimensions and aspect ratios +2. Optional scheduler customization +3. Optional reduced transformer layers for faster iteration +4. Optional explicit compilation with custom config +5. Runtime config usage via JSON config file + +Use this example as a starting point for your own Qwen Image workflow. +""" + +import torch + +from QEfficient import QEffQwenImagePipeline + +# ============================================================================ +# PIPELINE INITIALIZATION WITH CUSTOM PARAMETERS +# ============================================================================ + +# Option 1: Basic initialization +pipeline = QEffQwenImagePipeline.from_pretrained("Qwen/Qwen-Image") + +# Option 2: Advanced initialization with custom modules (example) +# pipeline = QEffQwenImagePipeline.from_pretrained( +# "Qwen/Qwen-Image", +# transformer=custom_transformer, +# vae=custom_vae, +# text_encoder=custom_text_encoder, +# tokenizer=custom_tokenizer, +# ) + +# ============================================================================ +# OPTIONAL: CUSTOM SCHEDULER CONFIGURATION +# ============================================================================ +# pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config) + +# ============================================================================ +# OPTIONAL: REDUCE MODEL LAYERS FOR FASTER INFERENCE +# ============================================================================ +# Trade-off: faster generation with possible quality drop +original_blocks = pipeline.transformer.model.transformer_blocks +pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0], original_blocks[1]]) +pipeline.transformer.model.config["num_layers"] = 2 + +# ============================================================================ +# OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION +# ============================================================================ +# NOTE-1: If compile_config is not specified, default qwen config is used. +# pipeline.compile( +# compile_config="examples/diffusers/qwen_image/qwen_config.json", +# parallel=True, +# height=928, +# width=1664, +# use_onnx_subfunctions=True, +# ) + +# ============================================================================ +# IMAGE GENERATION WITH CUSTOM RUNTIME CONFIGURATION +# ============================================================================ +# Generate an image using the configured pipeline. +# +# Note: Use of custom_config_path provides flexibility to set device_ids for each +# module, so you can skip the separate pipeline.compile() step. + +positive_magic = { + "en": ", Ultra HD, 4K, cinematic composition.", +} + +prompt = ( + "A coffee shop entrance with a chalkboard sign reading 'Qwen Coffee $2 per cup', " + "warm ambient lighting, realistic details" +) +negative_prompt = "low quality, blurry, distorted" + +# Common Qwen image aspect ratios +aspect_ratios = { + "1:1": (1328, 1328), + "16:9": (1664, 928), + "9:16": (928, 1664), + "4:3": (1472, 1140), + "3:4": (1140, 1472), + "3:2": (1584, 1056), + "2:3": (1056, 1584), +} + +width, height = aspect_ratios["16:9"] + +output = pipeline( + prompt=prompt + positive_magic["en"], + negative_prompt=negative_prompt, + width=width, + height=height, + num_inference_steps=50, + true_cfg_scale=4.0, + generator=torch.Generator(device="cpu").manual_seed(42), + custom_config_path="examples/diffusers/qwen_image/qwen_config.json", + parallel_compile=True, + max_sequence_length=128, + use_onnx_subfunctions=False, +) + +image = output.images[0] +image.save("qwen_image_custom.png") +print(output) diff --git a/examples/diffusers/qwen_image/qwen_image_example.py b/examples/diffusers/qwen_image/qwen_image_example.py new file mode 100644 index 0000000000..75286d3768 --- /dev/null +++ b/examples/diffusers/qwen_image/qwen_image_example.py @@ -0,0 +1,58 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +""" +Qwen-Image Image Generation Example + +This example demonstrates how to use the QEffQwenImagePipeline to generate images using Qwen-Image model. +""" + +import torch + +from QEfficient import QEffQwenImagePipeline + +# Initialize the Qwen Image pipeline from pretrained weights +pipe = QEffQwenImagePipeline.from_pretrained("Qwen/Qwen-Image") + +positive_magic = { + "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt +} + +prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197".""" +negative_prompt = "" + +# Generate with different aspect ratios +aspect_ratios = { + "1:1": (1328, 1328), + "16:9": (1664, 928), + "9:16": (928, 1664), + "4:3": (1472, 1140), + "3:4": (1140, 1472), + "3:2": (1584, 1056), + "2:3": (1056, 1584), +} + +width, height = aspect_ratios["16:9"] + +output = pipe( + prompt=prompt + positive_magic["en"], + negative_prompt=negative_prompt, + width=width, + height=height, + num_inference_steps=50, + true_cfg_scale=4.0, + generator=torch.Generator(device="cpu").manual_seed(42), + parallel_compile=True, + max_sequence_length=128, + # use_onnx_subfunctions=True, +) + +# Extract the generated image from the output +image = output.images[0] + +# Save the generated image to disk +image.save("qwen_image_example.png") +print(output) diff --git a/examples/diffusers/qwen_image/qwen_image_magcache.py b/examples/diffusers/qwen_image/qwen_image_magcache.py new file mode 100644 index 0000000000..76a1b5d0e0 --- /dev/null +++ b/examples/diffusers/qwen_image/qwen_image_magcache.py @@ -0,0 +1,68 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +""" +Qwen-Image MagCache Example + +This example demonstrates how to enable/disable runtime MagCache for Qwen-Image +while using a custom config file that can point to precompiled QPC paths. +""" + +import torch + +from QEfficient import QEffQwenImagePipeline + +# Initialize the Qwen Image pipeline from pretrained weights +pipe = QEffQwenImagePipeline.from_pretrained("Qwen/Qwen-Image") + +positive_magic = { + "en": ", Ultra HD, 4K, cinematic composition.", +} + +prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197".""" +negative_prompt = "" + +# Common Qwen image aspect ratios +aspect_ratios = { + "1:1": (1328, 1328), + "16:9": (1664, 928), + "9:16": (928, 1664), + "4:3": (1472, 1140), + "3:4": (1140, 1472), + "3:2": (1584, 1056), + "2:3": (1056, 1584), +} + +width, height = aspect_ratios["16:9"] + +# MagCache knobs +use_magcache = True +magcache_thresh = 0.06 +magcache_K = 2 +magcache_retention_ratio = 0.2 +magcache_verbose = True + +output = pipe( + prompt=prompt + positive_magic["en"], + negative_prompt=negative_prompt, + width=width, + height=height, + num_inference_steps=50, + true_cfg_scale=4.0, + generator=torch.Generator(device="cpu").manual_seed(42), + custom_config_path="examples/diffusers/qwen_image/qwen_config.json", + parallel_compile=True, + max_sequence_length=128, + use_magcache=use_magcache, + magcache_thresh=magcache_thresh, + magcache_K=magcache_K, + magcache_retention_ratio=magcache_retention_ratio, + magcache_verbose=magcache_verbose, +) + +image = output.images[0] +image.save("qwen_image_magcache.png") +print(output) diff --git a/pyproject.toml b/pyproject.toml index 9a3a639381..2719547bf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ markers = [ "diffusion_models: marks tests for diffusion models", "wan: marks tests for WAN model", "flux: marks tests for Flux model", + "qwen_image: marks tests for Qwen Image model", "regular: marks regular tests", "nightly: marks nightly tests", "multimodal: marks multimodal tests", diff --git a/tests/diffusers/qwen_image_test_config.json b/tests/diffusers/qwen_image_test_config.json new file mode 100644 index 0000000000..8a58246ac2 --- /dev/null +++ b/tests/diffusers/qwen_image_test_config.json @@ -0,0 +1,76 @@ +{ + "model_setup": { + "height": 64, + "width": 64, + "num_transformer_layers": 2, + "attention_head_dim": 8, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "max_sequence_length": 128 + }, + "mad_validation": { + "tolerances": { + "transformer": 3.5, + "vae_decoder": 0.7 + } + }, + "pipeline_params": { + "test_prompt": "A cat holding a sign that says hello world", + "num_inference_steps": 2, + "guidance_scale": 1.0, + "true_cfg_scale": 1.0, + "validate_gen_img": true, + "min_image_variance": 1.0, + "use_onnx_subfunctions": true + }, + "validation_checks": { + "image_generation": true, + "onnx_export": true, + "compilation": true + }, + "modules": { + "transformer": { + "specializations": { + "batch_size": "1", + "seq_length": "128" + }, + "compilation": { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "use_onnx_subfunctions": false, + "compile_only": true, + "aic_num_cores": 16, + "mos": 1, + "mdts_mos": 1 + }, + "execute": { + "device_ids": null, + "qpc_path": null + } + }, + "vae_decoder": { + "specializations": { + "batch_size": "1", + "num_channels": "16" + }, + "compilation": { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "compile_only": true, + "aic_num_cores": 16, + "mos": 1, + "mdts_mos": 1 + }, + "execute": { + "device_ids": null, + "qpc_path": null + } + } + } +} diff --git a/tests/diffusers/test_qwen_image.py b/tests/diffusers/test_qwen_image.py new file mode 100644 index 0000000000..53da7cb73f --- /dev/null +++ b/tests/diffusers/test_qwen_image.py @@ -0,0 +1,455 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import copy +import time +from typing import Dict, List, Optional, Union + +import numpy as np +import pytest +import torch +from diffusers import ( + AutoencoderKLQwenImage, + FlowMatchEulerDiscreteScheduler, + QwenImagePipeline, + QwenImageTransformer2DModel, +) +from diffusers.pipelines.qwenimage.pipeline_qwenimage import calculate_shift +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps + +from QEfficient import QEffQwenImagePipeline +from QEfficient.diffusers.pipelines.pipeline_utils import ModulePerf, QEffPipelineOutput +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils._utils import load_json +from tests.diffusers.diffusers_utils import DiffusersTestUtils, MADValidator + +CONFIG_PATH = "tests/diffusers/qwen_image_test_config.json" +INITIAL_TEST_CONFIG = load_json(CONFIG_PATH) + + +class DummyTokenizer: + model_max_length = 1024 + + +class DummyTextEncoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.dtype = torch.float32 + + +def qwen_pipeline_call_with_mad_validation( + pipeline, + pytorch_pipeline, + height: int = 64, + width: int = 64, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 1.0, + num_inference_steps: int = 2, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + max_sequence_length: int = 128, + custom_config_path: Optional[str] = None, + parallel_compile: bool = True, + use_onnx_subfunctions: bool = False, + mad_tolerances: Dict[str, float] = None, +): + """ + Replicate QEffQwenImagePipeline.__call__ flow and validate MAD for transformer and VAE. + """ + mad_validator = MADValidator(tolerances=mad_tolerances) + device = "cpu" + + # Step 1: Compile/export + pipeline.compile( + compile_config=custom_config_path, + parallel=parallel_compile, + height=height, + width=width, + use_onnx_subfunctions=use_onnx_subfunctions, + ) + + # Step 2: Validate input contract + pipeline.model.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=["latents"], + max_sequence_length=max_sequence_length, + ) + + pipeline._guidance_scale = guidance_scale + pipeline._current_timestep = None + pipeline._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = 1 + + has_neg_prompt = negative_prompt is not None + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + # Step 3: Use deterministic synthetic prompt embeddings for both QAIC and torch reference. + torch.manual_seed(123) + embed_dim = pipeline.transformer.model.config.joint_attention_dim + qaic_prompt_embeds = torch.randn( + batch_size * num_images_per_prompt, max_sequence_length, embed_dim, dtype=torch.float32 + ) + qaic_prompt_embeds_mask = torch.ones(batch_size * num_images_per_prompt, max_sequence_length, dtype=torch.int64) + + torch_prompt_embeds = qaic_prompt_embeds.clone() + torch_prompt_embeds_mask = qaic_prompt_embeds_mask.clone() + + if do_true_cfg: + qaic_negative_prompt_embeds = torch.randn( + batch_size * num_images_per_prompt, + max_sequence_length, + embed_dim, + dtype=torch.float32, + ) + qaic_negative_prompt_embeds_mask = torch.ones( + batch_size * num_images_per_prompt, max_sequence_length, dtype=torch.int64 + ) + else: + qaic_negative_prompt_embeds = None + qaic_negative_prompt_embeds_mask = None + + # Step 4: Latents and timesteps + num_channels_latents = pipeline.transformer.model.config.in_channels // 4 + latents = pipeline.model.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + qaic_prompt_embeds.dtype, + device, + generator, + latents, + ) + + # QEff path currently uses nested shape format for pos_embed utility. + img_shapes = [[(1, height // pipeline.vae_scale_factor // 2, width // pipeline.vae_scale_factor // 2)]] * batch_size + # Diffusers PyTorch transformer expects List[Tuple[int, int, int]]. + torch_img_shapes = [ + (1, height // pipeline.vae_scale_factor // 2, width // pipeline.vae_scale_factor // 2) + ] * batch_size + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + pipeline.scheduler.config.get("base_image_seq_len", 256), + pipeline.scheduler.config.get("max_image_seq_len", 4096), + pipeline.scheduler.config.get("base_shift", 0.5), + pipeline.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + pipeline.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) + pipeline._num_timesteps = len(timesteps) + + txt_seq_lens = [max_sequence_length] + + if pipeline.transformer.qpc_session is None: + pipeline.transformer.qpc_session = QAICInferenceSession(str(pipeline.transformer.qpc_path)) + + pipeline.scheduler.set_begin_index(0) + transformer_perf = [] + + qaic_image_rotary_emb = pipeline.transformer.model.pos_embed(img_shapes, txt_seq_lens, device="cpu") + qaic_img_freqs_cos, qaic_img_freqs_sin, qaic_txt_freqs_cos, qaic_txt_freqs_sin = qaic_image_rotary_emb + + img_rotary_emb = torch.cat([qaic_img_freqs_cos, qaic_img_freqs_sin], dim=-1) + txt_rotary_emb = torch.cat([qaic_txt_freqs_cos, qaic_txt_freqs_sin], dim=-1) + + # Step 5: Denoising + transformer MAD + with pipeline.model.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if pipeline._interrupt: + continue + + timestep = (t.expand(latents.shape[0]) / 1000).detach().numpy().astype(np.float32) + + transformer_inputs = { + "hidden_states": latents.detach().numpy().astype(np.float32), + "encoder_hidden_states": qaic_prompt_embeds.detach().numpy().astype(np.float32), + "encoder_hidden_states_mask": qaic_prompt_embeds_mask.detach().numpy().astype(np.int64), + "img_rotary_emb": img_rotary_emb.detach().numpy().astype(np.float32), + "txt_rotary_emb": txt_rotary_emb.detach().numpy().astype(np.float32), + "timestep": timestep, + } + + timestep_torch = torch.from_numpy(timestep).to(dtype=torch.float32) + noise_pred_torch = pytorch_pipeline.transformer( + hidden_states=latents, + encoder_hidden_states=torch_prompt_embeds, + encoder_hidden_states_mask=torch_prompt_embeds_mask, + timestep=timestep_torch, + img_shapes=torch_img_shapes, + txt_seq_lens=txt_seq_lens, + return_dict=False, + )[0] + + start_transformer_step_time = time.perf_counter() + outputs = pipeline.transformer.qpc_session.run(transformer_inputs) + end_transformer_step_time = time.perf_counter() + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + + mad_validator.validate_module_mad( + noise_pred_torch.detach().cpu().numpy(), + outputs["output"], + module_name="transformer", + step_info=f"step {i} (t={t.item():.6f})", + ) + + noise_pred = torch.from_numpy(outputs["output"]) + + if do_true_cfg: + transformer_inputs_uncond = { + "hidden_states": latents.detach().numpy().astype(np.float32), + "encoder_hidden_states": qaic_negative_prompt_embeds.detach().numpy().astype(np.float32), + "encoder_hidden_states_mask": qaic_negative_prompt_embeds_mask.detach().numpy().astype(np.int64), + "img_rotary_emb": img_rotary_emb.detach().numpy().astype(np.float32), + "txt_rotary_emb": txt_rotary_emb.detach().numpy().astype(np.float32), + "timestep": timestep, + } + neg_noise_pred = pipeline.transformer.qpc_session.run(transformer_inputs_uncond) + neg_noise_pred = torch.from_numpy(neg_noise_pred["output"]) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + latents_dtype = latents.dtype + latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype and torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + progress_bar.update() + + pipeline.transformer.qpc_session.deactivate() + + # Step 6: VAE decode + MAD + latents = pipeline.model._unpack_latents(latents, height, width, pipeline.vae_scale_factor) + latents = latents.to(pipeline.vae_decoder.model.dtype) + + latents_mean = ( + torch.tensor(pipeline.vae_decoder.model.config.latents_mean) + .view(1, pipeline.vae_decoder.model.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(pipeline.vae_decoder.model.config.latents_std).view( + 1, pipeline.vae_decoder.model.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + if pipeline.vae_decoder.qpc_session is None: + pipeline.vae_decoder.qpc_session = QAICInferenceSession( + str(pipeline.vae_decoder.qpc_path), device_ids=pipeline.vae_decoder.device_ids + ) + + output_buffer = {"sample": np.random.rand(batch_size, 3, 1, height, width).astype(np.int32)} + pipeline.vae_decoder.qpc_session.set_buffers(output_buffer) + + image_torch = pytorch_pipeline.vae.decode(latents, return_dict=False)[0] + + inputs = {"latent_sample": latents.numpy()} + start_decode_time = time.perf_counter() + image = pipeline.vae_decoder.qpc_session.run(inputs) + end_decode_time = time.perf_counter() + vae_decoder_perf = end_decode_time - start_decode_time + + mad_validator.validate_module_mad( + image_torch.detach().cpu().numpy(), + image["sample"], + module_name="vae_decoder", + ) + + pipeline.vae_decoder.qpc_session.deactivate() + + image_tensor = torch.from_numpy(image["sample"]) + image_tensor = image_tensor[:, :, 0] + image = pipeline.image_processor.postprocess(image_tensor, output_type=output_type) + + perf_metrics = [ + ModulePerf(module_name="transformer", perf=transformer_perf), + ModulePerf(module_name="vae_decoder", perf=vae_decoder_perf), + ] + + return QEffPipelineOutput(pipeline_module=perf_metrics, images=image) + + +@pytest.fixture(scope="session") +def qwen_image_pipeline(): + """Setup tiny random-init Qwen Image pipelines for QAIC vs PyTorch validation.""" + torch.manual_seed(42) + np.random.seed(42) + config = INITIAL_TEST_CONFIG["model_setup"] + + transformer = QwenImageTransformer2DModel( + patch_size=2, + in_channels=64, + out_channels=16, + num_layers=config["num_transformer_layers"], + attention_head_dim=config["attention_head_dim"], + num_attention_heads=config["num_attention_heads"], + joint_attention_dim=config["joint_attention_dim"], + axes_dims_rope=(2, 2, 4), + ) + + vae = AutoencoderKLQwenImage( + base_dim=16, + z_dim=16, + dim_mult=[1, 2], + num_res_blocks=1, + temperal_downsample=[False], + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + pytorch_pipeline = QwenImagePipeline( + scheduler=scheduler, + vae=vae, + text_encoder=DummyTextEncoder(), + tokenizer=DummyTokenizer(), + transformer=transformer, + ) + + pipeline = QEffQwenImagePipeline(copy.deepcopy(pytorch_pipeline)) + + pytorch_pipeline.transformer.eval() + pytorch_pipeline.vae.eval() + pipeline.transformer.model.eval() + pipeline.vae_decoder.model.eval() + + # Align export inputs with tiny test model dims; the production helper uses + # hardcoded Qwen-Image sizes (e.g., encoder_hidden_dim=3584), which break + # for this reduced config. + def _test_get_onnx_params(self): + bs = 1 + cl = 256 + seq_length = INITIAL_TEST_CONFIG["model_setup"]["max_sequence_length"] + hidden_dim = self.model.config.in_channels + encoder_hidden_dim = self.model.config.joint_attention_dim + rot_dim = sum(self.model.config.axes_dims_rope) + + example_inputs = { + "hidden_states": torch.randn(bs, cl, hidden_dim, dtype=torch.float32), + "encoder_hidden_states": torch.randn(bs, seq_length, encoder_hidden_dim, dtype=torch.float32), + "encoder_hidden_states_mask": torch.ones(bs, seq_length, dtype=torch.int64), + "txt_seq_lens": torch.tensor([seq_length], dtype=torch.int64), + "img_rotary_emb": torch.randn(cl, rot_dim, dtype=torch.float32), + "txt_rotary_emb": torch.randn(seq_length, rot_dim, dtype=torch.float32), + "timestep": torch.tensor([1.0], dtype=torch.float32), + } + + dynamic_axes = { + "hidden_states": {0: "batch_size", 1: "cl"}, + "encoder_hidden_states": {0: "batch_size", 1: "seq_length"}, + "encoder_hidden_states_mask": {0: "batch_size", 1: "seq_length"}, + "img_rotary_emb": {0: "cl"}, + "txt_rotary_emb": {0: "seq_length"}, + } + + return example_inputs, dynamic_axes, ["output"] + + pipeline.transformer.get_onnx_params = _test_get_onnx_params.__get__( + pipeline.transformer, type(pipeline.transformer) + ) + return pipeline, pytorch_pipeline + + +@pytest.mark.qwen_image +@pytest.mark.diffusion_models +@pytest.mark.on_qaic +def test_qwen_image_pipeline(qwen_image_pipeline): + """Qwen Image pipeline test with transformer and VAE MAD validation.""" + pipeline, pytorch_pipeline = qwen_image_pipeline + config = INITIAL_TEST_CONFIG + + DiffusersTestUtils.print_test_header( + f"QWEN IMAGE PIPELINE TEST - {config['model_setup']['height']}x{config['model_setup']['width']} Resolution", + config, + ) + + generator = torch.manual_seed(42) + start_time = time.time() + + result = qwen_pipeline_call_with_mad_validation( + pipeline=pipeline, + pytorch_pipeline=pytorch_pipeline, + height=config["model_setup"]["height"], + width=config["model_setup"]["width"], + prompt=config["pipeline_params"]["test_prompt"], + guidance_scale=config["pipeline_params"]["guidance_scale"], + true_cfg_scale=config["pipeline_params"]["true_cfg_scale"], + num_inference_steps=config["pipeline_params"]["num_inference_steps"], + max_sequence_length=config["model_setup"]["max_sequence_length"], + custom_config_path=CONFIG_PATH, + generator=generator, + mad_tolerances=config["mad_validation"]["tolerances"], + use_onnx_subfunctions=config["pipeline_params"]["use_onnx_subfunctions"], + parallel_compile=True, + ) + + execution_time = time.time() - start_time + + if config["validation_checks"]["image_generation"]: + assert result is not None, "Pipeline returned None" + assert hasattr(result, "images"), "Result missing 'images' attribute" + assert len(result.images) > 0, "No images generated" + + generated_image = result.images[0] + expected_size = (config["model_setup"]["width"], config["model_setup"]["height"]) + image_validation = DiffusersTestUtils.validate_image_generation( + generated_image, expected_size, config["pipeline_params"]["min_image_variance"] + ) + + print("\n IMAGE VALIDATION PASSED") + print(f" - Size: {image_validation['size']}") + print(f" - Mode: {image_validation['mode']}") + print(f" - Variance: {image_validation['variance']:.2f}") + print(f" - Mean pixel value: {image_validation['mean_pixel_value']:.2f}") + + if config["validation_checks"]["onnx_export"]: + print("\n ONNX Export Validation:") + for module_name in ["transformer", "vae_decoder"]: + module_obj = getattr(pipeline, module_name, None) + if module_obj and hasattr(module_obj, "onnx_path") and module_obj.onnx_path: + DiffusersTestUtils.check_file_exists(str(module_obj.onnx_path), f"{module_name} ONNX") + + if config["validation_checks"]["compilation"]: + print("\n Compilation Validation:") + for module_name in ["transformer", "vae_decoder"]: + module_obj = getattr(pipeline, module_name, None) + if module_obj and hasattr(module_obj, "qpc_path") and module_obj.qpc_path: + DiffusersTestUtils.check_file_exists(str(module_obj.qpc_path), f"{module_name} QPC") + + print(f"\nTotal execution time: {execution_time:.4f}s") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s", "-m", "qwen_image"]) diff --git a/tests/diffusers/test_qwen_image_magcache.py b/tests/diffusers/test_qwen_image_magcache.py new file mode 100644 index 0000000000..71de2a7f1c --- /dev/null +++ b/tests/diffusers/test_qwen_image_magcache.py @@ -0,0 +1,92 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import numpy as np +import pytest +import torch + +from QEfficient.diffusers.pipelines.qwen_image.magcache import QwenImageMagCacheRuntime, nearest_interp + + +@pytest.mark.diffusers +def test_nearest_interp_target_length_one_uses_last_value(): + src = np.asarray([0.1, 0.2, 0.9], dtype=np.float32) + out = nearest_interp(src, 1) + assert out.shape == (1,) + assert np.isclose(out[0], src[-1]) + + +@pytest.mark.diffusers +def test_prepare_ratios_cfg_and_non_cfg_lengths(): + ratios = [0.99, 0.98, 0.97, 0.96] + + cfg_runtime = QwenImageMagCacheRuntime( + num_inference_steps=5, + do_classifier_free_guidance=True, + threshold=0.1, + max_skip_steps=2, + retention_ratio=0.0, + ratios=ratios, + ) + assert len(cfg_runtime._prepared_ratios) == 10 + + non_cfg_runtime = QwenImageMagCacheRuntime( + num_inference_steps=5, + do_classifier_free_guidance=False, + threshold=0.1, + max_skip_steps=2, + retention_ratio=0.0, + ratios=ratios, + ) + assert len(non_cfg_runtime._prepared_ratios) == 5 + + +@pytest.mark.diffusers +def test_retention_window_behavior(): + runtime = QwenImageMagCacheRuntime( + num_inference_steps=5, + do_classifier_free_guidance=False, + threshold=0.1, + max_skip_steps=2, + retention_ratio=0.4, + ratios=[1.0] * 5, + ) + + allowed = [runtime._cache_allowed_for_call(i) for i in range(5)] + assert allowed == [False, False, True, True, True] + + +@pytest.mark.diffusers +def test_skip_path_advances_call_index_and_respects_k_limit(): + runtime = QwenImageMagCacheRuntime( + num_inference_steps=4, + do_classifier_free_guidance=False, + threshold=1.0, + max_skip_steps=2, + retention_ratio=0.0, + ratios=[1.0] * 4, + ) + + assert runtime.should_skip("cond") is False + runtime.complete_call("cond", torch.zeros(1)) + assert runtime.call_index == 1 + + assert runtime.should_skip("cond") is True + runtime.complete_skip("cond") + assert runtime.call_index == 2 + + assert runtime.should_skip("cond") is True + runtime.complete_skip("cond") + assert runtime.call_index == 3 + + # Third consecutive skip exceeds K=2 and should force execution. + assert runtime.should_skip("cond") is False + runtime.complete_call("cond", torch.zeros(1)) + + # End of denoise sequence resets runtime state for next image. + assert runtime.call_index == 0 + assert runtime.stream_states["cond"].cached_residual is None