diff --git a/diffusion/callbacks/log_activation_norms.py b/diffusion/callbacks/log_activation_norms.py new file mode 100644 index 00000000..6009083b --- /dev/null +++ b/diffusion/callbacks/log_activation_norms.py @@ -0,0 +1,60 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Logger for transformer activation statistics.""" + +from collections import defaultdict + +import torch +from composer import Callback, Logger, State +from torch.nn.parallel import DistributedDataParallel + + +class LogActivationStatistics(Callback): + """Logging callback for activation statistics.""" + + def __init__(self): + self.hook_handles = [] + self.activations = {} + self.activation_norms = defaultdict(float) + self.batch_counter = 0 + + def activation_hook(self, name): + + def hook_fn(module, input, output): + self.activations[name] = output + + return hook_fn + + def register_hooks(self, model): + for name, layer in model.named_modules(): + if 'autoencoder' not in name and 'adaLN' not in name and ('.attention' in name or 'linear' in name): + handle = layer.register_forward_hook(self.activation_hook(name)) + self.hook_handles.append(handle) + + def remove_hooks(self): + for handle in self.hook_handles: + handle.remove() + + def eval_start(self, state: State, logger: Logger): + if isinstance(state.model, DistributedDataParallel): + model = state.model.module + else: + model = state.model + self.register_hooks(model) + + def eval_batch_end(self, state: State, logger: Logger): + for k, v in self.activations.items(): + self.activation_norms[k] = self.batch_counter * self.activation_norms[k] / (self.batch_counter + 1) + stats = sum(torch.abs(t).mean().item() for t in v) / len(v) + self.activation_norms[k] += stats / (self.batch_counter + 1) + self.batch_counter += 1 + + def eval_end(self, state: State, logger: Logger): + norms = {} + for k, v in self.activation_norms.items(): + norms[f'activation-statistics/{k}'] = v + logger.log_metrics(norms) + self.remove_hooks() + self.activations.clear() + self.batch_counter = 0 diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 6be52f5a..6055940d 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -47,6 +47,42 @@ def _parse_latent_statistics(latent_stat: Union[float, Tuple, str]) -> Union[flo return latent_stat +def make_autoencoder(model_name: Optional[str] = None, + autoencoder_path: Optional[str] = None, + autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', + precision: torch.dtype = torch.float16) -> Tuple[torch.nn.Module, int, int]: + """Create an autoencoder for latent diffusion. + + Args: + model_name: (optional, str): Name of huggingface model to load. Default: `None`. + autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, + will use the vae from `model_name`. Default `None`. + autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. + precision: (torch.dtype): Precision to load the autoencoder in. Default: `torch.float16` + + Returns: + autoencoder (torch.nn.Module): The loaded autoencoder module. + autoencoder_channels (int): The number of channels in the autoencoder + downsample_factor (int): The autoencoder downsampling factor + """ + # Make the autoencoder + if autoencoder_path is None: + downsample_factor = 8 + autoencoder_channels = 4 + # Use the pretrained vae + try: + autoencoder = AutoencoderKL.from_pretrained(model_name, subfolder='vae', torch_dtype=precision) + except: # for handling SDXL vae fp16 fixed checkpoint + autoencoder = AutoencoderKL.from_pretrained(model_name, torch_dtype=precision) + else: + # Use a custom autoencoder + autoencoder, _ = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision) + downsample_factor = 2**(len(autoencoder.config['channel_multipliers']) - 1) + autoencoder_channels = autoencoder.config['latent_channels'] + assert isinstance(autoencoder, torch.nn.Module) + return autoencoder, autoencoder_channels, downsample_factor + + def stable_diffusion_2( model_name: str = 'stabilityai/stable-diffusion-2-base', pretrained: bool = True, @@ -885,18 +921,12 @@ def text_to_image_transformer( vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', autoencoder_path: Optional[str] = None, autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', - num_mmdit_layers: int = 28, - num_dit_layers: int = 0, - mmdit_block_group_size: int = 1, - dit_block_group_size: int = 1, - attention_implementation: Optional[str] = None, + transformer_config: Optional[dict] = None, + width_scale: float = 1.0, max_image_side: int = 1280, - conditioning_features: int = 768, - conditioning_max_sequence_length: int = 77, - num_register_tokens: int = 0, patch_size: int = 2, - latent_mean: Union[float, Tuple, str] = 0.0, - latent_std: Union[float, Tuple, str] = 7.67754318618, + latent_mean: Union[float, Tuple] = 0.0, + latent_std: Union[float, Tuple] = 7.67754318618, timestep_mean: float = 0.0, timestep_std: float = 1.0, timestep_shift: float = 1.0, @@ -915,96 +945,75 @@ def text_to_image_transformer( autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, will use the vae from `model_name`. Default `None`. autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. - num_mmdit_layers (int): Number of mmdit_layers in the transformer. Number of heads and layer width are determined by - this according to `num_features = 64 * num_layers`, and `num_heads = num_layers`. Default: `28`. - num_dit_layers (int): Number of mmdit_layers in the transformer. Number of heads and layer width are determined by - this according to `num_features = 64 * num_layers`, and `num_heads = num_layers`. Default: `0`. - mmdit_block_group_size (int): Size of MMDiT block groups. Must be a divisor of num_mmdit_layers. Default: `1`. - dit_block_group_size (int): Size of DiT block groups. Must be a divisor of num_mmdit_layers. Default: `1`. - attention_implementation (optional, str): Attention implementation. One of ('flash', 'mem_efficient', 'math'). - If not specified, will let SDPA decide. Default: 'None'. + transformer_config (optional, dict): Config for the transformer. If not specified, will default to a similar + config to SD3-medium. Default: `None`. + width_scale (float): Scaling factor for scaling with width in mu-parameterization. Ex: when scaling from `width=32` + to `width=256`, one should set `width_scale=256/32`. Default: `1.0` max_image_side (int): Maximum side length of the image. Default: `1280`. conditioning_features (int): Number of features in the conditioning transformer. Default: `768`. conditioning_max_sequence_length (int): Maximum sequence length for the conditioning transformer. Default: `77`. num_register_tokens (int): Number of additional register tokens to use. Default: `0`. + attention_implementation (optional, str): Attention implementation. One of ('flash', 'mem_efficient', 'math'). + If not specified, will let SDPA decide. Default: 'None'. patch_size (int): Patch size for the transformer. Default: `2`. - latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value, - a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder - checkpoint. Defaults to `0.0`. - latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value, - a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder - checkpoint. Defaults to `1/0.13025`. + latent_mean (float, Tuple): The mean of the autoencoder latents. Either a float for a single value, + or a tuple of means. Defaults to `0.0`. + latent_std (float, Tuple): The std. dev. of the autoencoder latents. Either a float for a single value, + or a tuple of std_devs. Defaults to `1/0.13025`. timestep_mean (float): The mean of the timesteps. Default: `0.0`. timestep_std (float): The std. dev. of the timesteps. Default: `1.0`. timestep_shift (float): The shift of the timesteps. Default: `1.0`. image_key (str): The key for the image in the batch. Default: `image`. - caption_key (str): The key for the captions in the batch. Default: `captions`. + caption_key (str): The key for the captions in the dataset. Default: `captions`. pretrained (bool): Whether to load pretrained weights. Not used. Defaults to False. """ - latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) + precision = torch.float16 + # Make the autoencoder + autoencoder, autoencoder_channels, downsample_factor = make_autoencoder( + model_name=vae_model_name, + autoencoder_path=autoencoder_path, + autoencoder_local_path=autoencoder_local_path, + precision=precision) + if isinstance(latent_mean, float): + latent_mean = tuple([latent_mean] * autoencoder_channels) + if isinstance(latent_std, float): + latent_std = tuple([latent_std] * autoencoder_channels) + + # Figure out the maximum input sequence length + input_max_sequence_length = math.ceil(max_image_side / (downsample_factor * patch_size)) if (isinstance(tokenizer_names, tuple) or isinstance(text_encoder_names, tuple)) and len(tokenizer_names) != len(text_encoder_names): raise ValueError('Number of tokenizer_names and text_encoder_names must be equal') - # Make the tokenizer and text encoder tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names) text_encoder = MultiTextEncoder(model_names=text_encoder_names, encode_latents_in_fp16=True, pretrained_sdxl=False) - precision = torch.float16 - # Make the autoencoder - if autoencoder_path is None: - if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': - raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') - downsample_factor = 8 - autoencoder_channels = 4 - # Use the pretrained vae - try: - vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=precision) - except: # for handling SDXL vae fp16 fixed checkpoint - vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=precision) - else: - # Use a custom autoencoder - vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision) - if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'): - raise ValueError( - 'Must specify latent scale when using a custom autoencoder without tracking latent statistics.') - if isinstance(latent_mean, str) and latent_mean == 'latent_statistics': - assert isinstance(latent_statistics, dict) - latent_mean = tuple(latent_statistics['latent_channel_means']) - if isinstance(latent_std, str) and latent_std == 'latent_statistics': - assert isinstance(latent_statistics, dict) - latent_std = tuple(latent_statistics['latent_channel_stds']) - downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) - autoencoder_channels = vae.config['latent_channels'] - assert isinstance(vae, torch.nn.Module) - if isinstance(latent_mean, float): - latent_mean = (latent_mean,) * autoencoder_channels - if isinstance(latent_std, float): - latent_std = (latent_std,) * autoencoder_channels - assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) - # Figure out the maximum input sequence length - input_max_sequence_length = math.ceil(max_image_side / (downsample_factor * patch_size)) # Make the transformer model - num_layers = num_mmdit_layers + num_dit_layers - transformer = DiffusionTransformer(num_features=64 * num_layers, - num_heads=num_layers, - num_mmdit_layers=num_mmdit_layers, - num_dit_layers=num_dit_layers, - attention_implementation=attention_implementation, - input_features=autoencoder_channels * (patch_size**2), - input_max_sequence_length=input_max_sequence_length, - input_dimension=2, - conditioning_features=conditioning_features, - conditioning_max_sequence_length=conditioning_max_sequence_length, - conditioning_dimension=1, - expansion_factor=4, - num_register_tokens=num_register_tokens, - mmdit_block_group_size=mmdit_block_group_size, - dit_block_group_size=dit_block_group_size) + transformer_config_dict = { + 'num_features': 64 * 24, + 'num_heads': 24, + 'num_mmdit_layers': 24, + 'num_dit_layers': 0, + 'attention_implementation': None, + 'input_features': autoencoder_channels * (patch_size**2), + 'input_max_sequence_length': input_max_sequence_length, + 'input_dimension': 2, + 'conditioning_features': 64 * 24, + 'conditioning_max_sequence_length': 77 + 512, + 'conditioning_dimension': 1, + 'expansion_factor': 4, + 'num_register_tokens': 0, + 'mmdit_block_group_size': 1, + 'dit_block_group_size': 1 + } + if transformer_config is not None: + transformer_config_dict.update(transformer_config) + transformer = DiffusionTransformer(**transformer_config_dict) + # Make the composer model model = ComposerTextToImageMMDiT(model=transformer, - autoencoder=vae, + autoencoder=autoencoder, text_encoder=text_encoder, tokenizer=tokenizer, latent_mean=latent_mean, @@ -1030,18 +1039,12 @@ def precomputed_text_latents_to_image_transformer( include_text_encoders: bool = False, text_encoder_dtype: str = 'bfloat16', cache_dir: str = '/tmp/hf_files', - num_mmdit_layers: int = 28, - num_dit_layers: int = 0, - mmdit_block_group_size: int = 1, - dit_block_group_size: int = 1, + transformer_config: Optional[dict] = None, + width_scale: float = 1.0, max_image_side: int = 1280, - conditioning_features: int = 768, - conditioning_max_sequence_length: int = 512 + 77, - num_register_tokens: int = 0, - attention_implementation: Optional[str] = None, patch_size: int = 2, - latent_mean: Union[float, Tuple, str] = 0.0, - latent_std: Union[float, Tuple, str] = 7.67754318618, + latent_mean: Union[float, Tuple] = 0.0, + latent_std: Union[float, Tuple] = 7.67754318618, timestep_mean: float = 0.0, timestep_std: float = 1.0, timestep_shift: float = 1.0, @@ -1063,12 +1066,10 @@ def precomputed_text_latents_to_image_transformer( Default: `bfloat16`. cache_dir (str): Directory to cache the model in if using `include_text_encoders`. Default: `'/tmp/hf_files'`. autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. - num_mmdit_layers (int): Number of mmdit_layers in the transformer. Number of heads and layer width are determined by - this according to `num_features = 64 * num_layers`, and `num_heads = num_layers`. Default: `28`. - num_dit_layers (int): Number of mmdit_layers in the transformer. Number of heads and layer width are determined by - this according to `num_features = 64 * num_layers`, and `num_heads = num_layers`. Default: `0`. - mmdit_block_group_size (int): Size of MMDiT block groups. Must be a divisor of num_mmdit_layers. Default: `1`. - dit_block_group_size (int): Size of DiT block groups. Must be a divisor of num_mmdit_layers. Default: `1`. + transformer_config (optional, dict): Config for the transformer. If not specified, will default to a similar + config to SD3-medium. Default: `None`. + width_scale (float): Scaling factor for scaling with width in mu-parameterization. Ex: when scaling from `width=32` + to `width=256`, one should set `width_scale=256/32`. Default: `1.0` max_image_side (int): Maximum side length of the image. Default: `1280`. conditioning_features (int): Number of features in the conditioning transformer. Default: `768`. conditioning_max_sequence_length (int): Maximum sequence length for the conditioning transformer. Default: `77`. @@ -1076,12 +1077,10 @@ def precomputed_text_latents_to_image_transformer( attention_implementation (optional, str): Attention implementation. One of ('flash', 'mem_efficient', 'math'). If not specified, will let SDPA decide. Default: 'None'. patch_size (int): Patch size for the transformer. Default: `2`. - latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value, - a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder - checkpoint. Defaults to `0.0`. - latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value, - a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder - checkpoint. Defaults to `1/0.13025`. + latent_mean (float, Tuple): The mean of the autoencoder latents. Either a float for a single value, + or a tuple of means. Defaults to `0.0`. + latent_std (float, Tuple): The std. dev. of the autoencoder latents. Either a float for a single value, + or a tuple of std_devs. Defaults to `1/0.13025`. timestep_mean (float): The mean of the timesteps. Default: `0.0`. timestep_std (float): The std. dev. of the timesteps. Default: `1.0`. timestep_shift (float): The shift of the timesteps. Default: `1.0`. @@ -1093,59 +1092,42 @@ def precomputed_text_latents_to_image_transformer( clip_pooled_key (str): The key to use for the CLIP pooled in the precomputed latents. Default: `'CLIP_POOLED'`. pretrained (bool): Whether to load pretrained weights. Not used. Defaults to False. """ - latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) - precision = torch.float16 # Make the autoencoder - if autoencoder_path is None: - if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': - raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') - downsample_factor = 8 - autoencoder_channels = 4 - # Use the pretrained vae - try: - vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=precision) - except: # for handling SDXL vae fp16 fixed checkpoint - vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=precision) - else: - # Use a custom autoencoder - vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision) - if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'): - raise ValueError( - 'Must specify latent scale when using a custom autoencoder without tracking latent statistics.') - if isinstance(latent_mean, str) and latent_mean == 'latent_statistics': - assert isinstance(latent_statistics, dict) - latent_mean = tuple(latent_statistics['latent_channel_means']) - if isinstance(latent_std, str) and latent_std == 'latent_statistics': - assert isinstance(latent_statistics, dict) - latent_std = tuple(latent_statistics['latent_channel_stds']) - downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) - autoencoder_channels = vae.config['latent_channels'] - assert isinstance(vae, torch.nn.Module) + autoencoder, autoencoder_channels, downsample_factor = make_autoencoder( + model_name=vae_model_name, + autoencoder_path=autoencoder_path, + autoencoder_local_path=autoencoder_local_path, + precision=precision) if isinstance(latent_mean, float): - latent_mean = (latent_mean,) * autoencoder_channels + latent_mean = tuple([latent_mean] * autoencoder_channels) if isinstance(latent_std, float): - latent_std = (latent_std,) * autoencoder_channels - assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) + latent_std = tuple([latent_std] * autoencoder_channels) + # Figure out the maximum input sequence length input_max_sequence_length = math.ceil(max_image_side / (downsample_factor * patch_size)) + # Make the transformer model - num_layers = num_mmdit_layers + num_dit_layers - transformer = DiffusionTransformer(num_features=64 * num_layers, - num_heads=num_layers, - num_mmdit_layers=num_mmdit_layers, - num_dit_layers=num_dit_layers, - attention_implementation=attention_implementation, - input_features=autoencoder_channels * (patch_size**2), - input_max_sequence_length=input_max_sequence_length, - input_dimension=2, - conditioning_features=64 * num_layers, - conditioning_max_sequence_length=conditioning_max_sequence_length, - conditioning_dimension=1, - expansion_factor=4, - num_register_tokens=num_register_tokens, - mmdit_block_group_size=mmdit_block_group_size, - dit_block_group_size=dit_block_group_size) + transformer_config_dict = { + 'num_features': 64 * 24, + 'num_heads': 24, + 'num_mmdit_layers': 24, + 'num_dit_layers': 0, + 'attention_implementation': None, + 'input_features': autoencoder_channels * (patch_size**2), + 'input_max_sequence_length': input_max_sequence_length, + 'input_dimension': 2, + 'conditioning_features': 64 * 24, + 'conditioning_max_sequence_length': 77 + 512, + 'conditioning_dimension': 1, + 'expansion_factor': 4, + 'num_register_tokens': 0, + 'mmdit_block_group_size': 1, + 'dit_block_group_size': 1 + } + if transformer_config is not None: + transformer_config_dict.update(transformer_config) + transformer = DiffusionTransformer(**transformer_config_dict) # Optionally load the tokenizers and text encoders t5_tokenizer, t5_encoder, clip_tokenizer, clip_encoder = None, None, None, None @@ -1169,7 +1151,7 @@ def precomputed_text_latents_to_image_transformer( # Make the composer model model = ComposerPrecomputedTextLatentsToImageMMDiT(model=transformer, - autoencoder=vae, + autoencoder=autoencoder, t5_tokenizer=t5_tokenizer, t5_encoder=t5_encoder, clip_tokenizer=clip_tokenizer, @@ -1185,7 +1167,8 @@ def precomputed_text_latents_to_image_transformer( image_key=image_key, t5_latent_key=t5_latent_key, clip_latent_key=clip_latent_key, - clip_pooled_key=clip_pooled_key) + clip_pooled_key=clip_pooled_key, + width_scale=width_scale) if torch.cuda.is_available(): model = DeviceGPU().module_to_device(model) diff --git a/diffusion/models/t2i_transformer.py b/diffusion/models/t2i_transformer.py index 309be415..40e2790b 100644 --- a/diffusion/models/t2i_transformer.py +++ b/diffusion/models/t2i_transformer.py @@ -13,7 +13,7 @@ from tqdm.auto import tqdm from transformers import PreTrainedTokenizer -from diffusion.models.transformer import DiffusionTransformer, VectorEmbedding +from diffusion.models.transformer import DiffusionTransformer, FP32LayerNorm, MuInputLinear, VectorEmbedding def _duplicate_tensor(tensor, num_images_per_prompt): @@ -219,7 +219,7 @@ def flops_per_batch(self, batch) -> int: def encode_image(self, image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Encode an image tensor with the autoencoder and patchify the latents.""" with torch.amp.autocast('cuda', enabled=False): - latents = self.autoencoder.encode(image.half())['latent_dist'].sample().data + latents = self.autoencoder.encode(image.half())['latent_dist'].mean.data # Scale and patchify the latents latents = (latents - self.latent_mean) / self.latent_std latent_patches, latent_coords = patchify(latents, self.patch_size) @@ -491,6 +491,8 @@ class ComposerPrecomputedTextLatentsToImageMMDiT(ComposerModel): clip_latent_key (str): The key in the batch dict that contains the CLIP latents. Default: `'CLIP_LATENTS'`. clip_pooled_key (str): The key in the batch dict that contains the CLIP pooled embeddings. Default: `'CLIP_POOLED'`. pooled_embedding_features (int): The number of features in the pooled text embeddings. Default: `768`. + width_scale (float): Scaling factor for scaling with width in mu-parameterization. Ex: when scaling from `width=32` + to `width=256`, one should set `width_scale=256/32`. Default: `1.0` """ def __init__( @@ -515,6 +517,7 @@ def __init__( clip_latent_key: str = 'CLIP_LATENTS', clip_pooled_key: str = 'CLIP_POOLED', pooled_embedding_features: int = 768, + width_scale: float = 1.0, ): super().__init__() self.model = model @@ -541,12 +544,13 @@ def __init__( self.clip_latent_key = clip_latent_key self.clip_pooled_key = clip_pooled_key self.pooled_embedding_features = pooled_embedding_features + self.width_scale = width_scale # Embedding MLPs and norms for the pooled text embeddings - self.t5_proj = torch.nn.Linear(4096, model.num_features) - self.t5_ln = torch.nn.LayerNorm(model.num_features) - self.clip_proj = torch.nn.Linear(768, model.num_features) - self.clip_ln = torch.nn.LayerNorm(model.num_features) + self.t5_ln = FP32LayerNorm(4096) + self.t5_proj_linear = MuInputLinear(4096, model.num_features) + self.clip_ln = FP32LayerNorm(768) + self.clip_proj_linear = MuInputLinear(768, model.num_features) self.pooled_embedding_mlp = VectorEmbedding(pooled_embedding_features, model.num_features) # freeze text_encoder during diffusion training and use half precision self.autoencoder.requires_grad_(False) @@ -611,7 +615,7 @@ def flops_per_batch(self, batch) -> int: def encode_image(self, image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Encode an image tensor with the autoencoder and patchify the latents.""" with torch.amp.autocast('cuda', enabled=False): - latents = self.autoencoder.encode(image.half())['latent_dist'].sample().data + latents = self.autoencoder.encode(image.half())['latent_dist'].mean.data # Scale and patchify the latents latents = (latents - self.latent_mean) / self.latent_std latent_patches, latent_coords = patchify(latents, self.patch_size) @@ -665,11 +669,12 @@ def prepare_text_embeddings(self, t5_embed: torch.Tensor, clip_embed: torch.Tens t5_embed = t5_embed[:, :self.max_seq_len] if clip_embed.shape[1] > self.max_seq_len: clip_embed = clip_embed[:, :self.max_seq_len] - t5_embed = self.t5_proj(t5_embed) - clip_embed = self.clip_proj(clip_embed) # Apply layernorms t5_embed = self.t5_ln(t5_embed) clip_embed = self.clip_ln(clip_embed) + # Embed to shared dimensionality + t5_embed = self.t5_proj_linear(t5_embed) + clip_embed = self.clip_proj_linear(clip_embed) # Concatenate the text embeddings text_embeds = torch.cat([t5_embed, clip_embed], dim=1) return text_embeds diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index 87736de9..79c2fa93 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -46,6 +46,113 @@ def get_multidimensional_position_embeddings(position_embeddings: torch.Tensor, return embeddings +class MuInputLinear(nn.Module): + """Linear input layer with the mu parameterization. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): Whether or not to use a bias. Default: `True`. + """ + + def __init__(self, in_features, out_features, bias=True): + super(MuInputLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.mu_input_linear = nn.Linear(in_features, out_features, bias) + self.mu_init() + + def mu_init(self): + """Initializes a linear layer according to mu-parameterizaion.""" + scale = 1 / math.sqrt(self.in_features) + if self.mu_input_linear.bias is not None: + nn.init.zeros_(self.mu_input_linear.bias) + nn.init.normal_(self.mu_input_linear.weight, std=scale) + + def forward(self, x): + return self.mu_input_linear(x) + + +class MuLinear(nn.Module): + """Linear layer with the mu parameterization. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): Whether or not to use a bias. Default: `True`. + """ + + def __init__(self, in_features, out_features, bias=True): + super(MuLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.mu_linear = nn.Linear(in_features, out_features, bias) + self.mu_init() + + def mu_init(self): + """Initializes a linear layer according to mu-parameterizaion.""" + scale = 1 / math.sqrt(self.in_features) + if self.mu_linear.bias is not None: + nn.init.zeros_(self.mu_linear.bias) + nn.init.normal_(self.mu_linear.weight, std=scale) + + def forward(self, x): + return self.mu_linear(x) + + +class MuOutputLinear(nn.Module): + """Linear outpus layer with the mu parameterization. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): Whether or not to use a bias. Default: `True`. + """ + + def __init__(self, in_features, out_features, bias=True): + super(MuOutputLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.mu_output_linear = nn.Linear(in_features, out_features, bias) + self.mu_init() + + def mu_init(self): + """Initializes a linear layer according to mu-parameterizaion.""" + scale = 1 / self.in_features + if self.mu_output_linear.bias is not None: + nn.init.zeros_(self.mu_output_linear.bias) + nn.init.normal_(self.mu_output_linear.weight, std=scale) + + def rescale_init(self, scale): + rescale = math.sqrt(1 / (self.in_features * scale)) + nn.init.normal_(self.mu_output_linear.weight, std=rescale) + + def forward(self, x): + return self.mu_output_linear(x) + + +class FP32LayerNorm(nn.Module): + """LayerNorm in FP32. + + Args: + normalized_shape (int): input shape from an expected input of size (..., normalized_shape) + eps (float): a value added to the denominator for numerical stability. Default: `1e-5` + elementwise_affine (bool): a boolean value that when set to True, this module has learnable + per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: `True`. + """ + + def __init__(self, normalized_shape: int, eps: float = 1e-5, elementwise_affine: bool = True): + super().__init__() + self.layer_norm = nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + def forward(self, x): + original_dtype = x.dtype + x = x.to(dtype=torch.float32) + x = self.layer_norm(x) + return x.to(dtype=original_dtype) + + class AdaptiveLayerNorm(nn.Module): """Adaptive LayerNorm. @@ -60,16 +167,16 @@ def __init__(self, num_features: int): self.num_features = num_features # MLP for computing modulations. # Initialized to zero so modulation acts as identity at initialization. - self.adaLN_mlp_linear_shift = nn.Linear(self.num_features, self.num_features, bias=True) - self.adaLN_mlp_linear_scale = nn.Linear(self.num_features, self.num_features, bias=True) - nn.init.zeros_(self.adaLN_mlp_linear_shift.weight) - nn.init.zeros_(self.adaLN_mlp_linear_scale.weight) - nn.init.zeros_(self.adaLN_mlp_linear_shift.bias) - nn.init.zeros_(self.adaLN_mlp_linear_scale.bias) + self.adaLN_mlp_linear_shift = MuLinear(self.num_features, self.num_features, bias=True) + self.adaLN_mlp_linear_scale = MuLinear(self.num_features, self.num_features, bias=True) + nn.init.zeros_(self.adaLN_mlp_linear_shift.mu_linear.weight) + nn.init.zeros_(self.adaLN_mlp_linear_scale.mu_linear.weight) + nn.init.zeros_(self.adaLN_mlp_linear_shift.mu_linear.bias) + nn.init.zeros_(self.adaLN_mlp_linear_scale.mu_linear.bias) self.adaLN_mlp_shift = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear_shift) self.adaLN_mlp_scale = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear_scale) # LayerNorm - self.layernorm = nn.LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) + self.layernorm = FP32LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) @torch.compile() def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: @@ -94,9 +201,9 @@ def __init__(self, num_features: int): self.num_features = num_features # MLP for computing modulation. # Initialized to zero so modulation starts off at zero. - self.adaLN_mlp_linear = nn.Linear(self.num_features, self.num_features, bias=True) - nn.init.zeros_(self.adaLN_mlp_linear.weight) - nn.init.zeros_(self.adaLN_mlp_linear.bias) + self.adaLN_mlp_linear = MuLinear(self.num_features, self.num_features, bias=True) + nn.init.zeros_(self.adaLN_mlp_linear.mu_linear.weight) + nn.init.zeros_(self.adaLN_mlp_linear.mu_linear.bias) self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) @torch.compile() @@ -114,19 +221,19 @@ class ScalarEmbedding(nn.Module): Args: num_features (int): The size of the output vector. sinusoidal_embedding_dim (int): The size of the intermediate sinusoidal embedding. Default: `256`. - max_period (int): The maximum period of the sinusoidal embedding. Default: `10000`. + max_period (float): The maximum period of the sinusoidal embedding. Default: `10000.0`. Returns: torch.Tensor: The embedded scalar """ - def __init__(self, num_features: int, sinusoidal_embedding_dim: int = 256, max_period: int = 10000): + def __init__(self, num_features: int, sinusoidal_embedding_dim: int = 256, max_period: float = 10000.0): super().__init__() self.num_features = num_features self.sinusoidal_embedding_dim = sinusoidal_embedding_dim self.max_period = max_period - self.linear_1 = nn.Linear(self.sinusoidal_embedding_dim, self.num_features) - self.linear_2 = nn.Linear(self.num_features, self.num_features) + self.linear_1 = MuInputLinear(self.sinusoidal_embedding_dim, self.num_features) + self.linear_2 = MuLinear(self.num_features, self.num_features) self.mlp = nn.Sequential(self.linear_1, nn.SiLU(), self.linear_2) # Make the freqs half_dim = self.sinusoidal_embedding_dim // 2 @@ -170,8 +277,8 @@ def __init__(self, input_features: int, num_features: int): super().__init__() self.input_features = input_features self.num_features = num_features - self.linear_1 = nn.Linear(self.input_features, self.num_features) - self.linear_2 = nn.Linear(self.num_features, self.num_features) + self.linear_1 = MuInputLinear(self.input_features, self.num_features) + self.linear_2 = MuLinear(self.num_features, self.num_features) self.mlp = nn.Sequential(self.linear_1, nn.SiLU(), self.linear_2) @torch.compile() @@ -194,17 +301,12 @@ def __init__(self, num_features: int): # Adaptive layernorm self.adaptive_layernorm = AdaptiveLayerNorm(self.num_features) # Linear layer to get q, k, and v - self.q_proj = nn.Linear(self.num_features, self.num_features) - self.k_proj = nn.Linear(self.num_features, self.num_features) - self.v_proj = nn.Linear(self.num_features, self.num_features) + self.q_proj = MuLinear(self.num_features, self.num_features, bias=False) + self.k_proj = MuLinear(self.num_features, self.num_features, bias=False) + self.v_proj = MuLinear(self.num_features, self.num_features, bias=False) # QK layernorms. Original MMDiT used RMSNorm here. - self.q_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) - self.k_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) - for l in [self.q_proj, self.k_proj, self.v_proj]: - # Initialize all biases to zero - nn.init.zeros_(l.bias) - # Init the standard deviation of the weights to 0.02 as is tradition - nn.init.normal_(l.weight, std=0.02) + self.q_norm = FP32LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) + self.k_norm = FP32LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) @torch.compile() def forward(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -229,6 +331,7 @@ def __init__(self, num_features: int, num_heads: int, attention_implementation: self.num_features = num_features self.num_heads = num_heads self.head_dim = num_features // num_heads + self.attn_scale = 1 / self.head_dim assert self.num_features % self.num_heads == 0, 'num_features must be divisible by num_heads' if attention_implementation is not None: assert attention_implementation in ('flash', 'mem_efficient', 'math'), ( @@ -265,10 +368,19 @@ def forward(self, # Attention with selectable implementation if self.attention_implementation is None: - attention_out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) + attention_out = F.scaled_dot_product_attention(q.float(), + k.float(), + v.float(), + attn_mask=mask, + scale=self.attn_scale) else: with sdpa_kernel(self.sdp_backends): - attention_out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) + attention_out = F.scaled_dot_product_attention(q.float(), + k.float(), + v.float(), + attn_mask=mask, + scale=self.attn_scale) + attention_out = attention_out.to(dtype=v.dtype) # Reshape back to (B, T, C) attention_out = attention_out.transpose(1, 2).reshape(B, T, C) @@ -292,16 +404,13 @@ def __init__(self, num_features: int, expansion_factor: int = 4): # Input modulation self.modulate_v = ModulationLayer(self.num_features) # Linear layer to process v - self.linear_v = nn.Linear(self.num_features, self.num_features) + self.linear_v = MuLinear(self.num_features, self.num_features, bias=False) # Layernorm for the output self.output_norm = AdaptiveLayerNorm(self.num_features) # Transformer style MLP layers - self.linear_1 = nn.Linear(self.num_features, self.expansion_factor * self.num_features) + self.linear_1 = MuLinear(self.num_features, self.expansion_factor * self.num_features) self.nonlinearity = nn.GELU(approximate='tanh') - self.linear_2 = nn.Linear(self.expansion_factor * self.num_features, self.num_features) - # Initialize all biases to zero - nn.init.zeros_(self.linear_1.bias) - nn.init.zeros_(self.linear_2.bias) + self.linear_2 = MuLinear(self.expansion_factor * self.num_features, self.num_features) # Output MLP self.output_mlp = nn.Sequential(self.linear_1, self.nonlinearity, self.linear_2) # Output modulation @@ -567,15 +676,15 @@ def __init__(self, self.conditioning_max_sequence_length = conditioning_max_sequence_length self.num_register_tokens = num_register_tokens # Embedding block for the timestep - self.timestep_embedding = ScalarEmbedding(self.num_features) + self.timestep_embedding = ScalarEmbedding(self.num_features, max_period=1 / (2 * math.pi)) # Projection layer for the input sequence - self.input_embedding = nn.Linear(self.input_features, self.num_features) + self.input_embedding = MuLinear(self.input_features, self.num_features) # Embedding layer for the input sequence input_position_embedding = torch.randn(self.input_dimension, self.input_max_sequence_length, self.num_features) input_position_embedding /= math.sqrt(self.num_features) self.input_position_embedding = torch.nn.Parameter(input_position_embedding, requires_grad=True) # Projection layer for the conditioning sequence - self.conditioning_embedding = nn.Linear(self.conditioning_features, self.num_features) + self.conditioning_embedding = MuLinear(self.conditioning_features, self.num_features) # Embedding layer for the conditioning sequence conditioning_position_embedding = torch.randn(self.conditioning_dimension, self.conditioning_max_sequence_length, self.num_features) @@ -618,10 +727,7 @@ def __init__(self, attention_implementation=self.attention_implementation)) # Output projection layer self.final_norm = AdaptiveLayerNorm(self.num_features) - self.final_linear = nn.Linear(self.num_features, self.input_features) - # Init the output layer to zero - nn.init.zeros_(self.final_linear.weight) - nn.init.zeros_(self.final_linear.bias) + self.final_linear = MuOutputLinear(self.num_features, self.input_features) def fsdp_wrap_fn(self, module: nn.Module) -> bool: if isinstance(module, (MMDiTGroup, DiTGroup)): diff --git a/diffusion/train.py b/diffusion/train.py index becff0f1..8026d032 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -20,7 +20,8 @@ from torch.optim import Optimizer from diffusion.models.autoencoder import ComposerAutoEncoder, ComposerDiffusersAutoEncoder -from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT +from diffusion.models.t2i_transformer import ComposerPrecomputedTextLatentsToImageMMDiT, ComposerTextToImageMMDiT +from diffusion.models.transformer import MuOutputLinear def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer: @@ -54,14 +55,29 @@ def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Opti def make_transformer_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer: """Configures the optimizer for use with a transformer model.""" print('Configuring optimizer for transformer') - assert isinstance(model, ComposerTextToImageMMDiT) + assert isinstance(model, (ComposerTextToImageMMDiT, ComposerPrecomputedTextLatentsToImageMMDiT)) + # Grab the width scaling factor from the model if it's been given + if hasattr(model, 'width_scale'): + width_scale = model.width_scale + else: + width_scale = 1.0 # Turn off weight decay for biases, norms, and positional embeddings. + # Also set up learning rates for mu-parameterization no_decay = ['bias', 'norm', 'position_embedding'] params_with_no_decay = [] params_with_decay = [] + mu_input_params = [] + mu_hidden_params = [] + mu_output_params = [] for name, param in model.named_parameters(): - if any(nd in name for nd in no_decay): + if 'mu_input_linear.weight' in name: + mu_input_params.append(param) + elif 'mu_hidden_linear.weight' in name: + mu_hidden_params.append(param) + elif 'mu_output_linear.weight' in name: + mu_output_params.append(param) + elif any(nd in name for nd in no_decay): params_with_no_decay.append(param) else: params_with_decay.append(param) @@ -72,7 +88,24 @@ def make_transformer_optimizer(config: DictConfig, model: ComposerModel) -> Opti decay_dict = dict(config.optimizer.items()) decay_dict['params'] = params_with_decay - optimizer = hydra.utils.instantiate(config.optimizer, [no_decay_dict, decay_dict]) + mu_input_dict = dict(config.optimizer.items()) + mu_input_dict['params'] = mu_input_params + + mu_hidden_dict = dict(config.optimizer.items()) + mu_hidden_dict['params'] = mu_hidden_params + mu_hidden_dict['lr'] *= 1 / width_scale + + mu_output_dict = dict(config.optimizer.items()) + mu_output_dict['params'] = mu_output_params + mu_output_dict['lr'] *= 1 / width_scale + + # Rescaling of output inits + for module in model.modules(): + if isinstance(module, MuOutputLinear): + module.rescale_init(width_scale) + + optimizer = hydra.utils.instantiate(config.optimizer, + [no_decay_dict, decay_dict, mu_input_dict, mu_hidden_dict, mu_output_dict]) return optimizer @@ -97,7 +130,7 @@ def train(config: DictConfig) -> None: if hasattr(model, 'autoencoder_loss'): # Check if this is training an autoencoder. If so, the optimizer needs different param groups optimizer = make_autoencoder_optimizer(config, model) - elif isinstance(model, ComposerTextToImageMMDiT): + elif isinstance(model, (ComposerTextToImageMMDiT, ComposerPrecomputedTextLatentsToImageMMDiT)): # Check if this is training a transformer. If so, the optimizer needs different param groups optimizer = make_transformer_optimizer(config, model) else: