diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..4a15ba9 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +*/*/build* +*/*/__pycache__ +*/*/*.pyc + +.pytest_cache +.git +comm_traces +*outputs* diff --git a/3rdparty/torchtitan b/3rdparty/torchtitan index 98128cc..da6d9e5 160000 --- a/3rdparty/torchtitan +++ b/3rdparty/torchtitan @@ -1 +1 @@ -Subproject commit 98128cc298a899e65e2c2b3296b29f4a351a824f +Subproject commit da6d9e56abd20506475e7d6652a2a574e9abc286 diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e17b7e4 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,30 @@ +FROM rocm/pytorch-nightly:20260429082157-rocm7.2.2 + +RUN apt-get update && apt-get install -y \ + git-lfs \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +RUN update-pciids + +RUN pip install --no-cache-dir huggingface_hub "datasets>=3.6.0" \ + transformers tabulate wandb fsspec tyro "tokenizers>=0.15.0" safetensors \ + tensorboard pre-commit yapf pybind11 meson-python torchdata pytablewriter \ + "antlr4-python3-runtime==4.11.0" sympy math_verify more_itertools peft \ + accelerate pillow "numpy<2" opencv-python-headless scipy \ + numba huggingface-hub[cli,hf_transfer] "packaging>=24.2" \ + "setuptools>=77.0.3,<80.0.0" "setuptools-scm>=8" \ + protobuf-protoc-bin fmt && \ + pip install --no-cache-dir /opt/rocm/share/amd_smi + +RUN cd /var/lib/jenkins && \ + git clone --depth 1 https://github.com/EleutherAI/lm-evaluation-harness && \ + cd lm-evaluation-harness && \ + pip install -e . + +COPY . /var/lib/jenkins/alto + +RUN cd /var/lib/jenkins/alto/3rdparty/torchtitan && \ + pip install --no-build-isolation -e . && \ + cd /var/lib/jenkins/alto && \ + pip install -e . diff --git a/LICENSE b/LICENSE index ef6cbf6..d7d76f9 100644 --- a/LICENSE +++ b/LICENSE @@ -29,19 +29,15 @@ ALTO uses or references the following third-party projects: - License: BSD License - Repository: https://github.com/pytorch/torchtitan -2. Megatron-LM - - License: Apache License 2.0 - - Repository: https://github.com/nvidia/megatron-lm - -3. vLLM +2. vLLM - License: Apache License 2.0 - Repository: https://github.com/vllm-project/vllm -4. compressed-tensors +3. compressed-tensors - License: Apache License 2.0 - Repository: https://github.com/vllm-project/compressed-tensors -5. llm-compressor +4. llm-compressor - License: Apache License 2.0 - Repository: https://github.com/vllm-project/llm-compressor diff --git a/alto/kernels/dispatch/__init__.py b/alto/kernels/dispatch/__init__.py index 2e03bab..fdb6bea 100644 --- a/alto/kernels/dispatch/__init__.py +++ b/alto/kernels/dispatch/__init__.py @@ -4,10 +4,10 @@ from .config import TrainingOpConfig from .conversion import swap_params -from .attention import LPScaledDotProductAttentionWrapper +from .attention import LPScaledDotProductAttention __all__ = [ "TrainingOpConfig", "swap_params", - "LPScaledDotProductAttentionWrapper", + "LPScaledDotProductAttention", ] diff --git a/alto/kernels/dispatch/attention.py b/alto/kernels/dispatch/attention.py index 3372762..3274782 100644 --- a/alto/kernels/dispatch/attention.py +++ b/alto/kernels/dispatch/attention.py @@ -3,15 +3,15 @@ # SPDX-License-Identifier: MIT import torch -from torchtitan.models.common.attention import (ScaledDotProductAttentionWrapper) +from torchtitan.models.common.attention import ScaledDotProductAttention from alto.kernels.fp4.mxfp4.triton_flash_attention_mxfp4 import triton_attention_mxfp4 from .config import TrainingOpConfig -__all__ = ["LPScaledDotProductAttentionWrapper"] +__all__ = ["LPScaledDotProductAttention"] -class LPScaledDotProductAttentionWrapper(ScaledDotProductAttentionWrapper): +class LPScaledDotProductAttention(ScaledDotProductAttention): def __init__(self, config: TrainingOpConfig): super().__init__() diff --git a/alto/models/deepseek_v3/config_registry.py b/alto/models/deepseek_v3/config_registry.py index dc0f0ec..7735974 100644 --- a/alto/models/deepseek_v3/config_registry.py +++ b/alto/models/deepseek_v3/config_registry.py @@ -21,7 +21,7 @@ def deepseek_v3_debugmodel() -> Trainer.Config: config = deepseek_v3_debugmodel_orig() - config.profiling.enable_profiling = False + config.profiler.enable_profiling = False config.training.steps = 10 config.training.local_batch_size = 4 config.training.global_batch_size = 16 @@ -43,7 +43,7 @@ def deepseek_v3_16b() -> Trainer.Config: config = deepseek_v3_16b_orig() config.hf_assets_path = "/huggingface/hub/models--deepseek-ai--deepseek-moe-16b-base/snapshots/521d2bc4fb69a3f3ae565310fcc3b65f97af2580" config.dump_folder = "deepseek_v3_16b-outputs" - config.profiling.enable_profiling = False + config.profiler.enable_profiling = False config.training.steps = 0 config.training.local_batch_size = 1 config.training.seq_len = 4096 @@ -62,7 +62,6 @@ def deepseek_v3_16b() -> Trainer.Config: config.validator.freq = 10 config.validator.steps = 10 config.activation_checkpoint.mode = "none" - config.activation_checkpoint.selective_ac_option = "1" config.debug.seed = 1234 return config diff --git a/alto/models/gpt_oss/config_registry.py b/alto/models/gpt_oss/config_registry.py index ec1e5da..5f85a5e 100644 --- a/alto/models/gpt_oss/config_registry.py +++ b/alto/models/gpt_oss/config_registry.py @@ -22,7 +22,7 @@ def gpt_oss_debugmodel() -> Trainer.Config: config = gpt_oss_debugmodel_orig() - config.profiling.enable_profiling = False + config.profiler.enable_profiling = False config.training.steps = 10 config.training.local_batch_size = 4 config.training.global_batch_size = 16 @@ -44,7 +44,7 @@ def gpt_oss_20b() -> Trainer.Config: config = gpt_oss_20b_orig() config.hf_assets_path = "/huggingface/hub/models--openai--gpt-oss-20b/snapshots/6cee5e81ee83917806bbde320786a8fb61efebee/" config.dump_folder = "gpt_oss_20b-outputs" - config.profiling.enable_profiling = False + config.profiler.enable_profiling = False config.training.steps = 0 config.training.local_batch_size = 1 config.training.seq_len = 8192 @@ -64,7 +64,6 @@ def gpt_oss_20b() -> Trainer.Config: config.validator.freq = 10 config.validator.steps = 10 config.activation_checkpoint.mode = "none" - config.activation_checkpoint.selective_ac_option = "1" config.debug.seed = 1234 return config @@ -73,7 +72,7 @@ def gpt_oss_20b_pretrain() -> Trainer.Config: config = gpt_oss_20b_orig() config.hf_assets_path = "/huggingface/hub/models--openai--gpt-oss-20b/snapshots/6cee5e81ee83917806bbde320786a8fb61efebee/" config.dump_folder = "gpt_oss_20b-pretrain-subset-lr4e-4-outputs" - config.profiling.enable_profiling = False + config.profiler.enable_profiling = False config.training.steps = 1200000 config.training.local_batch_size = 1 config.training.global_batch_size = 16 @@ -103,7 +102,6 @@ def gpt_oss_20b_pretrain() -> Trainer.Config: config.validator.freq = 768 config.validator.steps = 64 config.activation_checkpoint.mode = "selective" - config.activation_checkpoint.selective_ac_option = "1" config.debug.seed = 1234 return config diff --git a/alto/models/llama3/config_registry.py b/alto/models/llama3/config_registry.py index 17dd4b7..cf4a126 100644 --- a/alto/models/llama3/config_registry.py +++ b/alto/models/llama3/config_registry.py @@ -51,7 +51,7 @@ def llama3_debugmodel() -> Trainer.Config: config = llama3_debugmodel_orig() - config.profiling.enable_profiling = False + config.profiler.enable_profiling = False config.training.steps = 0 config.training.local_batch_size = 4 config.training.global_batch_size = 16 @@ -82,7 +82,7 @@ def llama3_1b() -> Trainer.Config: config = llama3_1b_orig() config.hf_assets_path = "/group/archive_dataset_6_nobkup/archive_modelzoo/sequence_learning/weights/nlp-pretrained-model/meta-llama/Llama-3.2-1B" config.metrics.log_freq = 1 - config.profiling.enable_profiling = False + config.profiler.enable_profiling = False config.training.steps = 0 config.training.local_batch_size = 1 config.training.global_batch_size = 10 @@ -123,7 +123,7 @@ def llama3_8b() -> Trainer.Config: config = llama3_8b_orig() config.hf_assets_path = "/huggingface/hub/models--unsloth--Llama-3.1-8B/snapshots/3f0d51f8e5640f98f1a96ea9044a0e55c0a83814" config.metrics.log_freq = 1 - config.profiling.enable_profiling = False + config.profiler.enable_profiling = False config.training.steps = 0 config.training.local_batch_size = 1 config.training.seq_len = 8192 @@ -188,7 +188,7 @@ def llama3_8b() -> Trainer.Config: config = llama3_8b_orig() config.hf_assets_path = LLAMA3_8B_PATH config.metrics.log_freq = 1 - config.profiling.enable_profiling = False + config.profiler.enable_profiling = False config.training.steps = 0 config.training.local_batch_size = 1 config.training.global_batch_size = 8 @@ -352,7 +352,7 @@ def instella_3b() -> Trainer.Config: config = instella_3b_orig() config.hf_assets_path = "/group/ossmodelzoo/hanwang2/huggingface/hub/models--amd--Instella-3B-Stage1/snapshots/cb33253ab0a5b9f2ea0b98f3edd818d46454580e" config.metrics.log_freq = 1 - config.profiling.enable_profiling = False + config.profiler.enable_profiling = False config.training.steps = 0 config.training.local_batch_size = 1 config.training.global_batch_size = 10 diff --git a/alto/models/llama3/configs/recipe.yaml b/alto/models/llama3/configs/recipe.yaml index bc4da06..43d796e 100644 --- a/alto/models/llama3/configs/recipe.yaml +++ b/alto/models/llama3/configs/recipe.yaml @@ -1,10 +1,10 @@ -# sparsity_stage: -# sparsity_modifiers: -# WandaModifier: -# sparsity: 0.5 -# mask_structure: "2:4" -# targets: ["Linear"] -# ignore: ["output"] +sparsity_stage: + sparsity_modifiers: + WandaModifier: + sparsity: 0.5 + mask_structure: "2:4" + targets: ["Linear"] + ignore: ["output"] quantization_stage: quantization_modifiers: QuantizationModifier: @@ -23,13 +23,13 @@ quantization_stage: strategy: "tensor" observer: "minmax" targets: ["Linear"] - SelfDistillationModifier: - criterion: "LogitsDistillationLoss" - targets: ["__all__"] - steps: 10 - warmup_steps: 0 - lr: 1e-4 - min_lr_factor: 1.0 - decay_ratio: null - decay_type: "linear" - optimizer: "AdamW" + # SelfDistillationModifier: + # criterion: "LogitsDistillationLoss" + # targets: ["__all__"] + # steps: 10 + # warmup_steps: 0 + # lr: 1e-4 + # min_lr_factor: 1.0 + # decay_ratio: null + # decay_type: "linear" + # optimizer: "AdamW" diff --git a/alto/modifiers/distillation/utils/losses.py b/alto/modifiers/distillation/utils/losses.py index 7039d91..79c2257 100644 --- a/alto/modifiers/distillation/utils/losses.py +++ b/alto/modifiers/distillation/utils/losses.py @@ -1,263 +1,56 @@ -# copied from https://github.com/NVIDIA/Model-Optimizer/blob/452c5a09b03f8db34a39151aae4f7a4b01efd9cd/modelopt/torch/distill/losses.py -# licensed under the Apache License 2.0 +# Copyright (c) 2026 Advanced Micro Devices, Inc. # -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -"""Different types of distillation losses.""" +# SPDX-License-Identifier: MIT import torch -import torch.nn as nn import torch.nn.functional as F -from torch.nn.modules.loss import _Loss as Loss +from torch.nn.modules.loss import _Loss -__all__ = ["LogitsDistillationLoss", "MFTLoss", "MGDLoss"] +__all__ = ["LogitsDistillationLoss"] -class LogitsDistillationLoss(Loss): - """KL-Divergence loss on output logits. - - This function implements the distillation loss found in the paper: https://arxiv.org/abs/1503.02531. +class LogitsDistillationLoss(_Loss): """ + KL-divergence distillation loss over logits. - def __init__(self, temperature: float = 1.0, reduction: str = "batchmean"): - """Constructor. - - Args: - temperature: A value used to soften the logits_t and logits_s before computing loss on them. - reduction: How to reduce the final pointwise loss before returning. Pass ``"none"`` to - use your own reduction function afterwards, i.e. with loss masks. - """ - super().__init__() - self._temperature: float = temperature - self._reduction: str = reduction - - def forward(self, logits_s: torch.Tensor, logits_t: torch.Tensor) -> torch.Tensor: - """Compute KD loss on student and teacher logits. - - Args: - logits_s: Student's logits, treated as prediction. - logits_t: Teacher's logits, treated as label. - - .. note:: - - Assumes class logits dimension is last. - """ - soft_log_probs = F.log_softmax(logits_s / self._temperature, dim=-1) - soft_targets = F.softmax(logits_t / self._temperature, dim=-1) - - soft_log_probs = soft_log_probs.view(-1, soft_log_probs.size(-1)) - soft_targets = soft_targets.view(-1, soft_targets.size(-1)) + This implements the standard knowledge distillation objective using: + - student log-probabilities from softened student logits + - teacher probabilities from softened teacher logits - kd_loss = F.kl_div(soft_log_probs, soft_targets.detach(), reduction=self._reduction) - - # Since the magnitudes of the gradients produced by the soft logits scale as 1/(T^2), - # multiplying them by T^2 ensures that the relative contributions of the logits - # remain roughly unchanged while experimenting with meta-parameters. - kd_loss *= self._temperature**2 - - return kd_loss - - -class MFTLoss(Loss): - """KL-divergence loss with Minifinetuning threshold modification. - - This function implements the distillation loss found in the paper: https://arxiv.org/abs/2506.15702. + The result is multiplied by temperature^2, matching the common + distillation scaling convention. """ - def __init__(self, temperature: float = 1.0, threshold: float = 0.2, reduction: str = "batchmean"): - """Constructor. - - Args: - temperature: A value used to soften the logits_t and logits_s before computing the MFT loss on them. - reduction: How to reduce the final pointwise loss before returning. Pass ``"none"`` to - use your own reduction function afterwards, i.e. with loss masks. - threshold: A value used to correct the teacher's distribution. It is used to ensure that - the separation between the correct and incorrect argmax tokens is large enough. - The value should be in the range [0, 1]. Defaults to 0.2. - """ - super().__init__() - self._temperature: float = temperature - self._reduction: str = reduction - self._threshold: float = threshold - - def forward(self, logits_s: torch.Tensor, logits_t: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - """Compute KD loss on student and teacher logits. - - Args: - logits_s: Student's logits, treated as prediction. - logits_t: Teacher's logits, treated as training target. - labels: Labels for the ground truth, used to prepare the corrected teacher distributions. - - .. note:: - - Assumes class logits dimension is last. - """ - soft_log_probs = F.log_softmax(logits_s / self._temperature, dim=-1) # (B, ..., C) - soft_log_probs = soft_log_probs.view(-1, soft_log_probs.size(-1)) # (new B, C) - - target_logits: torch.Tensor = logits_t / self._temperature # (B, ..., C) - target_logits = target_logits.view(-1, target_logits.size(-1)) # (new B, C) - soft_targets = self._prepare_corrected_distributions(target_logits, - labels, - self._threshold, - apply_threshold_to_all=True) - - kd_loss = F.kl_div( - soft_log_probs, soft_targets.detach(), - reduction=self._reduction) # shape depends on reduction; "batchmean" would result in a scalar (1,) - - # Since the magnitudes of the gradients produced by the soft logits scale as 1/(T^2), - # multiplying them by T^2 ensures that the relative contributions of the logits - # remain roughly unchanged while experimenting with meta-parameters. - kd_loss *= self._temperature**2 - - return kd_loss - - def _prepare_corrected_distributions( - self, - logits: torch.Tensor, - labels: torch.Tensor, - threshold: float, - apply_threshold_to_all: bool = True, - ) -> torch.Tensor: - """Prepare the corrected distributions for MFT loss. - - Args: - logits: The logits from the teacher model, shape (batch, channels) # e.g. (batch_size * seq_len, vocab_size) - in case of LMs - labels: The ground truth labels, shape (batch) # e.g. (batch_size * seq_len) in case of LMs - threshold: The threshold value for the MFT correction. - apply_threshold_to_all: If True, apply the threshold correction to all tokens, - not just the incorrect argmax tokens. Defaults to True. - - Returns: - A tensor containing the corrected distributions, shape (batch_size * seq_len, vocab_size). - """ - # Ensure logits is a 2D tensor and labels is a 1D tensor - if logits.dim() != 2 or labels.dim() != 1: - raise ValueError("Logits must be a 2D tensor and labels must be a 1D tensor.") - # logits: (batch, channels) - # labels: (batch) - distribution = F.softmax(logits, dim=-1) # (batch, channels) - - argmax = distribution.argmax(dim=-1) # (batch,) - incorrect_argmax = argmax != labels # (batch,) - - p_argmax = torch.gather(distribution, 1, argmax.unsqueeze(1)).squeeze(1) # (batch,) - p_label = torch.gather(distribution, 1, labels.unsqueeze(1)).squeeze(1) # (batch,) - - # correction of the distribution at the tokens where the argmax is incorrect - mixin_factor = (p_argmax - p_label + threshold) / (1 + p_argmax - p_label + 1e-7) # (batch,) - adjusted_incorrect_distribution = distribution * (1 - mixin_factor.unsqueeze(1)) # (batch, channels) - _ = adjusted_incorrect_distribution.scatter_add_(1, labels.unsqueeze(1), - mixin_factor.unsqueeze(1)) # (batch, channels) - - if apply_threshold_to_all: - # correction of the distribution at the tokens where the argmax is correct but - # the separation may not be large enough - capped_targets = torch.where(p_label > 1 - threshold, 1, p_label + threshold) # (batch,) - mixin_factor = (capped_targets - p_argmax) / (1 - p_argmax + 1e-7) # (batch,) - adjusted_correct_distribution = distribution * (1 - mixin_factor.unsqueeze(1)) # (batch, channels) - _ = adjusted_correct_distribution.scatter_add_(1, labels.unsqueeze(1), mixin_factor.unsqueeze(1)) - else: - adjusted_correct_distribution = distribution - - return torch.where( - incorrect_argmax.unsqueeze(1), - adjusted_incorrect_distribution, - adjusted_correct_distribution, - ) # (batch, channels) - - -class MGDLoss(Loss): - """PyTorch version of Masked Generative Distillation. - - This function implements the distillation loss found in the paper: https://arxiv.org/abs/2205.01529. - """ - - def __init__( - self, - num_student_channels: int, - num_teacher_channels: int, - alpha_mgd: float = 1.0, - lambda_mgd: float = 0.65, - ): - """Constructor. - - Args: - num_student_channels: Number of channels in the student's feature map. - num_teacher_channels: Number of channels in the teacher's feature map. - alpha_mgd: Scalar final loss is multiplied by. Defaults to 1.0. - lambda_mgd: Masked ratio. Defaults to 0.65. - """ + def __init__(self, temperature: float = 1.0, reduction: str = "batchmean"): super().__init__() - self._alpha_mgd: float = alpha_mgd - self._lambda_mgd: float = lambda_mgd - - if num_student_channels != num_teacher_channels: - self.align = nn.Conv2d( - num_student_channels, - num_teacher_channels, - kernel_size=1, - stride=1, - padding=0, + if temperature <= 0: + raise ValueError("temperature must be > 0") + if reduction not in {"none", "batchmean", "sum", "mean"}: + raise ValueError("reduction must be one of: none, batchmean, sum, mean") + + self.temperature = float(temperature) + self.reduction = reduction + + def forward(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor) -> torch.Tensor: + if student_logits.shape != teacher_logits.shape: + raise ValueError( + f"student_logits and teacher_logits must have the same shape, " + f"got {student_logits.shape} and {teacher_logits.shape}" ) - else: - self.align = nn.Identity() - - self.generation = nn.Sequential( - nn.Conv2d( - num_teacher_channels, - num_teacher_channels, - kernel_size=3, - padding=1, - ), - nn.ReLU(inplace=True), - nn.Conv2d( - num_teacher_channels, - num_teacher_channels, - kernel_size=3, - padding=1, - ), - ) - - def _loss_fn(self, out_s: torch.Tensor, out_t: torch.Tensor): - n, _, h, w = out_t.shape - - mat = torch.rand((n, 1, h, w), device=out_s.device) - mat = torch.where(mat > 1 - self._lambda_mgd, 0, 1) - masked_feats = torch.mul(out_s, mat) - new_feats = self.generation(masked_feats) + t = self.temperature - kd_loss = F.mse_loss(new_feats, out_t) + student_log_probs = F.log_softmax(student_logits / t, dim=-1) + teacher_probs = F.softmax(teacher_logits / t, dim=-1) - return kd_loss + num_classes = student_log_probs.size(-1) + student_log_probs = student_log_probs.reshape(-1, num_classes) + teacher_probs = teacher_probs.reshape(-1, num_classes) - def forward(self, out_s: torch.Tensor, out_t: torch.Tensor): - """Forward function. - - Args: - out_s: Student's feature map (shape BxCxHxW). - out_t: Teacher's feature map (shape BxCxHxW). - """ - assert out_s.shape[-2:] == out_t.shape[-2:] - - out_s = self.align(out_s) - loss = self._loss_fn(out_s, out_t) * self._alpha_mgd + loss = F.kl_div( + student_log_probs, + teacher_probs.detach(), + reduction=self.reduction, + ) - return loss + return loss * (t * t) diff --git a/alto/modifiers/lpt/base.py b/alto/modifiers/lpt/base.py index 9849bf1..e5e9d2e 100644 --- a/alto/modifiers/lpt/base.py +++ b/alto/modifiers/lpt/base.py @@ -7,14 +7,13 @@ from compressed_tensors.utils import match_named_modules from pydantic import PrivateAttr, Field, field_validator from torchtitan.models.common.attention import BaseAttention -from torchtitan.models.common.moe.utils import set_token_group_alignment_size_m from torchtitan.tools.logging import logger from alto.modifiers import Modifier from alto.kernels.dispatch import ( swap_params, TrainingOpConfig, - LPScaledDotProductAttentionWrapper, + LPScaledDotProductAttention, ) from alto.kernels.fp4.mxfp4.mxfp_grouped_gemm.autotune import ALIGN_SIZE_M @@ -35,18 +34,6 @@ class LowPrecisionTrainingModifier(Modifier): _resolved_config: dict[TrainingOpConfig, list[str]] | None = PrivateAttr(default=None) - def __init__(self, **kwargs): - super().__init__(**kwargs) - - set_token_group_alignment_size_m(ALIGN_SIZE_M) - # self._config = MXFP4TrainingOpConfig( - # use_2dblock_x=self.use_2dblock_x, - # use_2dblock_w=self.use_2dblock_w, - # use_hadamard=self.use_hadamard, - # use_sr_grad=self.use_sr_grad, - # use_dge=self.use_dge, - # ) - @field_validator("targets", mode="before") def validate_targets(cls, value: str | list[str]) -> list[str]: if isinstance(value, str): @@ -96,7 +83,7 @@ def on_convert(self, model: Module, **kwargs) -> bool: for name, module in match_named_modules(model, targets, self.ignore): if isinstance(module, BaseAttention): assert module.attn_backend == "sdpa", "Only SDPA attention is supported for now." - module.inner_attention = LPScaledDotProductAttentionWrapper(config=scheme_obj) + module.inner_attention = LPScaledDotProductAttention(config=scheme_obj) elif isinstance(module, torch.nn.Linear) or module.__class__.__name__.endswith("GroupedExperts"): swap_params(module, config=scheme_obj, module_name=name) else: diff --git a/alto/modifiers/quantization/base.py b/alto/modifiers/quantization/base.py index 4aba465..7c1551d 100644 --- a/alto/modifiers/quantization/base.py +++ b/alto/modifiers/quantization/base.py @@ -18,7 +18,7 @@ from compressed_tensors.quantization import disable_quantization, enable_quantization from compressed_tensors.utils import getattr_chain, match_named_modules from pydantic import Field, PrivateAttr -from torch.nn import Module +from torch.nn import Module, Linear from torchtitan.tools.logging import logger from alto.modifiers import Modifier @@ -66,11 +66,6 @@ class QuantizationModifier(Modifier, QuantizationMixin): # ---- lifecycle ---------------------------------------------------- def on_initialize(self, model_parts: list[Module], **kwargs) -> bool: - if not QuantizationMixin.has_config(self): - raise ValueError("QuantizationModifier requires that quantization fields be specified") - for m in model_parts: - QuantizationMixin.initialize_quantization(self, m) - if self.sequential: self._build_sequential_blocks(model_parts) return True @@ -106,6 +101,22 @@ def on_finalize(self, model_parts: list[Module], **kwargs) -> bool: return True def on_convert(self, model: Module, **kwargs) -> bool: + if not QuantizationMixin.has_config(self): + raise ValueError("QuantizationModifier requires that quantization fields be specified") + # Note: qparams will be registered in the initialize_quantization method, + # so it has to be done before applying parallelism + QuantizationMixin.initialize_quantization(self, model) + + # patch param_init dict because the qparams are registered on meta device + # and need to be initialized after to_empty copy. + for mod_name, mod in model.named_modules(): + qscheme = getattr(mod, "quantization_scheme", None) + if qscheme is not None: + for pname, p in mod.named_parameters(): + if pname.endswith("scale"): + mod._param_init[pname] = torch.nn.init.ones_ + elif pname.endswith("zero_point"): + mod._param_init[pname] = torch.nn.init.zeros_ return True # ---- sequential loop (template) ----------------------------------- diff --git a/alto/modifiers/sparsification/wanda.py b/alto/modifiers/sparsification/wanda.py index 24c62df..bf51ac7 100644 --- a/alto/modifiers/sparsification/wanda.py +++ b/alto/modifiers/sparsification/wanda.py @@ -11,6 +11,7 @@ import torch from torch.nn import Module +from torch.distributed.tensor import DTensor, Replicate from torchtitan.tools.logging import logger from alto.modifiers.sparsification.base import SparsityModifierBase @@ -79,6 +80,11 @@ def compress_modules(self): prune_n=self._prune_n, prune_m=self._prune_m, ) + if isinstance(module.weight, DTensor): + sparsified_weight = sparsified_weight.redistribute( + module.weight.device_mesh, + placements=module.weight.placements, + ) module.weight.data.copy_(sparsified_weight) def _sparsify_weight( @@ -112,6 +118,12 @@ def _sparsify_weight( W = W.to(dtype=PRECISION) S = row_scalar + device_mesh = None + if isinstance(W, DTensor): + device_mesh = W.device_mesh + W = W.redistribute(device_mesh, placements=(Replicate(),)).to_local() + S = S.redistribute(device_mesh, placements=(Replicate(),)).to_local() + W_metric = torch.abs(W) * torch.sqrt(S.reshape((1, -1))) # initialize a mask to be all False @@ -138,4 +150,8 @@ def _sparsify_weight( W = W.reshape(final_shape).to(final_dtype) + if device_mesh is not None: + W = DTensor.from_local(W, device_mesh=device_mesh, placements=(Replicate(),)) + W_mask = DTensor.from_local(W_mask, device_mesh=device_mesh, placements=(Replicate(),)) + return W, W_mask.reshape(final_shape) diff --git a/alto/observers/per_channel_norm.py b/alto/observers/per_channel_norm.py index 6a1cda7..7c49c12 100644 --- a/alto/observers/per_channel_norm.py +++ b/alto/observers/per_channel_norm.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MIT import torch +from torch.distributed.tensor import DTensor, Partial from alto.utils.pytorch.module import TransformerConv1D from .base import Observer, register_observer @@ -25,7 +26,10 @@ def __init__(self, *args, **kwargs) -> None: def make_empty_row_scalars(self) -> torch.Tensor: weight = self.module().weight num_columns = weight.shape[1] - return torch.zeros(num_columns, device=self.device) + S = torch.zeros(num_columns, device=self.device) + if isinstance(weight, DTensor): + S = DTensor.from_local(S, device_mesh=weight.device_mesh, placements=(Partial("avg"),)) + return S def get_current_min_max(self, observed: torch.Tensor): pass @@ -38,6 +42,13 @@ def forward_inner(self, x_orig): return x_orig with torch.no_grad(): + S = self.stats + if isinstance(S, DTensor): + S = S.to_local() + + # TODO: support TP + assert not isinstance(x_orig, DTensor), "TP is not supported for per_channel_norm observer" + module = self.module() inp = x_orig.detach() if inp.dim() == 2: @@ -62,11 +73,11 @@ def forward_inner(self, x_orig): inp = inp.permute([1, 0, 2]) inp = inp.flatten(1) - self.stats *= self.num_samples / (self.num_samples + num_added) + S *= self.num_samples / (self.num_samples + num_added) self.num_samples += num_added inp = inp.type(PRECISION) - self.stats += torch.norm(inp, p=2, dim=1)**2 / self.num_samples + S += torch.norm(inp, p=2, dim=1)**2 / self.num_samples return x_orig diff --git a/alto/train.py b/alto/train.py index cec58ae..e5932cb 100644 --- a/alto/train.py +++ b/alto/train.py @@ -180,8 +180,7 @@ def forward_step( # Non-PP forward / backward with self.train_context(): assert len(model_parts) == 1 - with self.maybe_enable_amp: - result = model_parts[0](inputs, **extra_inputs, **extra_kwargs) + result = model_parts[0](inputs, **extra_inputs, **extra_kwargs) return result @@ -195,7 +194,6 @@ def train_step( # Keep these variables local to shorten the code as these are # the major variables that are used in the training loop. parallel_dims = self.parallel_dims - assert not parallel_dims.dp_cp_enabled, "DP_CP is not supported in post-training" # Collect all microbatches on CPU and count total valid tokens microbatches = [] diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh new file mode 100755 index 0000000..8f221b8 --- /dev/null +++ b/scripts/build_docker.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +IMAGE=wanghanthu/alto:rocm7.2.2-nightly-20260429 +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR/.. + +docker build -t $IMAGE . diff --git a/tests/integration/gpt_oss_debugmodel_lpt.sh b/tests/integration/gpt_oss_debugmodel_lpt.sh new file mode 100755 index 0000000..01294cc --- /dev/null +++ b/tests/integration/gpt_oss_debugmodel_lpt.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +SCRIPT_DIR=$(dirname "$0") +cd $SCRIPT_DIR/../.. + +NGPU=2 \ +MODULE=gpt_oss \ +CONFIG=gpt_oss_debugmodel_lpt \ +HSA_NO_SCRATCH_RECLAIM=1 \ +./examples/run.sh diff --git a/tests/integration/instella_3b_opt.sh b/tests/integration/instella_3b_opt.sh new file mode 100755 index 0000000..147aff1 --- /dev/null +++ b/tests/integration/instella_3b_opt.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +SCRIPT_DIR=$(dirname "$0") +cd $SCRIPT_DIR/../.. + +NGPU=2 \ +MODULE=llama3 \ +CONFIG=instella_3b_opt \ +HSA_NO_SCRATCH_RECLAIM=1 \ +./examples/run.sh \ + --hf_assets_path=/huggingface/hub/models--amd--Instella-3B-Stage1/snapshots/cb33253ab0a5b9f2ea0b98f3edd818d46454580e \ + --checkpoint.initial_load_path=/huggingface/hub/models--amd--Instella-3B-Stage1/snapshots/cb33253ab0a5b9f2ea0b98f3edd818d46454580e diff --git a/tests/integration/llama3_1b_opt.sh b/tests/integration/llama3_1b_opt.sh new file mode 100755 index 0000000..61c563c --- /dev/null +++ b/tests/integration/llama3_1b_opt.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +SCRIPT_DIR=$(dirname "$0") +cd $SCRIPT_DIR/../.. + +NGPU=1 \ +MODULE=llama3 \ +CONFIG=llama3_1b_opt \ +HSA_NO_SCRATCH_RECLAIM=1 \ +./examples/run.sh \ + --hf_assets_path=/huggingface/hub/models--meta-llama--Llama-3.2-1B/snapshots/4e20de362430cd3b72f300e6b0f18e50e7166e08 \ + --checkpoint.initial_load_path=/huggingface/hub/models--meta-llama--Llama-3.2-1B/snapshots/4e20de362430cd3b72f300e6b0f18e50e7166e08 diff --git a/tests/integration/llama3_debugmodel_lpt.sh b/tests/integration/llama3_debugmodel_lpt.sh new file mode 100755 index 0000000..ffb107b --- /dev/null +++ b/tests/integration/llama3_debugmodel_lpt.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +SCRIPT_DIR=$(dirname "$0") +cd $SCRIPT_DIR/../.. + +NGPU=2 \ +MODULE=llama3 \ +CONFIG=llama3_debugmodel_lpt \ +HSA_NO_SCRATCH_RECLAIM=1 \ +./examples/run.sh