From ebe03020e2b2b7b49fb3af3af4a6da980f805807 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 27 May 2026 08:12:15 +0000 Subject: [PATCH 1/9] wip: draft of DeOscillationOptimizer --- alto/components/__init__.py | 8 +- alto/components/optimizer.py | 444 +++++++++++++++++++++++++++++++++++ 2 files changed, 451 insertions(+), 1 deletion(-) create mode 100644 alto/components/optimizer.py diff --git a/alto/components/__init__.py b/alto/components/__init__.py index 5631d8a..19ba777 100644 --- a/alto/components/__init__.py +++ b/alto/components/__init__.py @@ -3,6 +3,12 @@ # SPDX-License-Identifier: MIT from .converter import ModelOptConverter +from .optimizer import DeOscillationConfig, DeOscillationOptimizersContainer from .state_dict_adapter_mixin import StateDictAdapterMixin -__all__ = ["ModelOptConverter", "StateDictAdapterMixin"] +__all__ = [ + "DeOscillationConfig", + "DeOscillationOptimizersContainer", + "ModelOptConverter", + "StateDictAdapterMixin", +] diff --git a/alto/components/optimizer.py b/alto/components/optimizer.py new file mode 100644 index 0000000..d1199bf --- /dev/null +++ b/alto/components/optimizer.py @@ -0,0 +1,444 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Weight de-oscillation hook for FP4-aware AdamW. + +This module provides :class:`DeOscillationOptimizersContainer`, an +:class:`OptimizersContainer` subclass that installs an optimizer +``step_post_hook`` implementing the weight de-oscillation strategy. + +Motivation +---------- + +When training with simulated FP4 GEMMs, the FP32 master weight is +re-quantized to FP4 just before every matmul. An element whose master +value sits near a quantization-bin boundary can *oscillate*: tiny FP +gradient steps flip the bin assignment on every iteration, so although +the master weight moves only slightly the *quantized* weight seen by the +GEMM hops between two adjacent bins every step. This hurts convergence. + +Detection criterion (DistRatio) +------------------------------- + +For each weight element, over a window of ``period`` optimizer steps, +accumulate two L1 distances of consecutive iterations: + +* ``dist_w = sum_t |w_t - w_{t-1}|`` (raw FP movement) +* ``dist_w_qdq = sum_t |Q(w_t) - Q(w_{t-1})|`` (quantized movement) + +where ``Q`` is the project's FP4 QDQ round trip. An element is treated +as oscillating when ``dist_w_qdq / dist_w >= ratio_threshold``: small +FP movement producing large quantized movement only happens at a bin +boundary. + +Reset +----- + +On the last step of each period, oscillating elements are *snapped* to +their current bin center, i.e. ``w <- Q(w)``. Counters and snapshots +are then re-initialized for the next period. + +Scope: which parameters are eligible +------------------------------------ + +The hook only acts on parameters whose storage is wrapped by either +:class:`~alto.kernels.dispatch.tensor.MXFP4TrainingWeightWrapperTensor` or +:class:`~alto.kernels.dispatch.tensor.NVFP4TrainingWeightWrapperTensor` +(directly, or as ``DTensor._local_tensor`` under FSDP2). Other +parameters (embeddings, norm scales, biases, MXFP8 weights, ...) are +silently skipped. The FP4 format and the per-block scaling layout used +for ``Q`` come from the wrapper's +:class:`~alto.kernels.dispatch.config.TrainingOpConfig`, and the +reduction axis along which QDQ is performed comes from +``wrapper.weight_reduction_axis`` -- the same field consumed by the FP4 +grouped-MM dispatch. Reading both from the wrapper guarantees that the +de-oscillation hook tracks oscillation in exactly the same FP4 grid the +forward / backward GEMMs see, with no axis inference duplicated here. + +""" + + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +import torch +import torch.nn as nn +from torch.distributed.tensor import DTensor +from torchtitan.components.optimizer import OptimizersContainer +from torchtitan.tools.logging import logger + +from alto.kernels.dispatch.config import TrainingOpConfig +from alto.kernels.dispatch.tensor import ( + MXFP4TrainingWeightWrapperTensor, + NVFP4TrainingWeightWrapperTensor, + TrainingWeightWrapperBaseTensor, +) + +__all__ = [ + "DeOscillationConfig", + "DeOscillationOptimizersContainer", +] + + +# Wrappers eligible for de-oscillation. Only the two FP4 variants are +# included on purpose: MXFP8 has a much wider grid where oscillation is +# rare, and the bin-center snap we use only really makes sense for FP4. +_FP4_WRAPPER_TYPES: tuple[type, ...] = ( + MXFP4TrainingWeightWrapperTensor, + NVFP4TrainingWeightWrapperTensor, +) + + +def _peel_to_fp4_wrapper( + p: torch.Tensor, +) -> TrainingWeightWrapperBaseTensor | None: + """Return the FP4 wrapper tensor backing *p*, or ``None``. + + Handles three layouts: + + * plain ``nn.Parameter(FP4WrapperTensor(...))`` + * FSDP2 ``nn.Parameter(DTensor(local=FP4WrapperTensor(...)))`` + * any of the above un-parameterised (e.g. inside checkpointing utils) + """ + t = p.data if isinstance(p, nn.Parameter) else p + if isinstance(t, DTensor): + t = t._local_tensor + if isinstance(t, _FP4_WRAPPER_TYPES): + return t # type: ignore[return-value] + return None + + +QdqFn = Callable[[torch.Tensor, int], torch.Tensor] + + +def _make_qdq_fn_for(cfg: TrainingOpConfig) -> QdqFn: + """Build a ``dequant(quant(., axis))`` callable matching the wrapper's + forward-path QDQ. + + The closure runs on the *plain* underlying FP32/BF16 tensor (i.e. + ``wrapper._data``), with the same ``is_2d_block`` choice the FP4 + GEMM applies to the weight operand. The reduction axis is passed in + by the caller (read from ``wrapper.weight_reduction_axis``) so the + same cache entry serves both ``nn.Linear`` weights (axis ``-1``) and + grouped-MM expert weights (axis ``-2``) that share a config. + """ + from alto.kernels.fp4 import ( + convert_from_mxfp4, + convert_from_nvfp4, + convert_to_mxfp4, + convert_to_nvfp4, + ) + + is_2d_block = cfg.use_2dblock_w + + if cfg.precision == "mxfp4": + + def qdq(w: torch.Tensor, axis: int) -> torch.Tensor: + data_lp, scales = convert_to_mxfp4( + w, axis=axis, is_2d_block=is_2d_block, + ) + return convert_from_mxfp4( + data_lp, + scales, + output_dtype=w.dtype, + axis=axis, + is_2d_block=is_2d_block, + ) + + elif cfg.precision == "nvfp4": + + def qdq(w: torch.Tensor, axis: int) -> torch.Tensor: + data_lp, scales = convert_to_nvfp4( + w, axis=axis, is_2d_block=is_2d_block, + ) + return convert_from_nvfp4( + data_lp, + scales, + output_dtype=w.dtype, + axis=axis, + is_2d_block=is_2d_block, + ) + + else: + raise ValueError( + f"de-oscillation only supports FP4 wrappers, " + f"got precision={cfg.precision!r}" + ) + + return qdq + + +@dataclass(kw_only=True, slots=True) +class DeOscillationConfig: + """Per-period weight de-oscillation knobs. + + The FP4 format / block layout used for the QDQ round trip is *not* + in this config -- it is read directly from each parameter's + :class:`TrainingOpConfig` so the hook always operates on the same + quantization grid that the forward pass uses. + + Attributes: + enabled: master switch. + period: number of ``optimizer.step()`` calls per de-oscillation + window. Reset decisions are taken on the last step of each + window. + ratio_threshold: an element is reset when + ``dist_w_qdq / dist_w >= ratio_threshold`` over the window. + Reasonable values are in [4, 32]; the TetraJet reference + uses ``16``. + log_freq: log reset statistics every ``log_freq`` triggered + periods. ``0`` disables logging. + """ + + enabled: bool = False + period: int = 25 + ratio_threshold: float = 16.0 + log_freq: int = 0 + + +class _DeOscillationHook: + """AdamW ``step_post_hook`` implementing the DistRatio reset. + + The hook itself is nearly stateless (only an internal counter used + for log throttling and a per-:class:`TrainingOpConfig` QDQ cache). + All per-parameter state lives in ``optimizer.state[p]`` under keys + prefixed with ``osci_``, so it is serialised through the standard + optimizer state dict path. + """ + + KEY_STEP = "osci_step_in_period" + KEY_DIST_FP = "osci_dist_w" + KEY_DIST_QDQ = "osci_dist_w_qdq" + KEY_PREV = "osci_prev_weight" + + # Floor for the per-element ratio denominator. Picks a value well + # above FP32 denormals so the divide stays in the normal range when + # the corresponding element happened to be frozen for the whole + # period. The explicit ``dist_w > 0`` mask below makes this a pure + # numeric safety net. + _EPS = 1.0e-30 + + def __init__(self, cfg: DeOscillationConfig) -> None: + self._cfg = cfg + # TrainingOpConfig is @dataclass(unsafe_hash=True), so the same + # config object (or two equal ones) maps to a single cached QDQ. + # The cached closure takes ``(tensor, axis)``; the axis is read + # per-call from ``wrapper.weight_reduction_axis`` and is not + # baked into the cache key. + self._qdq_cache: dict[TrainingOpConfig, QdqFn] = {} + # Counts how many periods have completed since the hook was + # installed; used only for log throttling. + self._periods_completed = 0 + + # ---- eligibility -------------------------------------------------- + + def _wrapper_if_eligible( + self, + p: nn.Parameter, + ) -> TrainingWeightWrapperBaseTensor | None: + """Return the FP4 wrapper if *p* is eligible for de-oscillation. + + Block-size divisibility is *not* re-checked here -- the FP4 + ``Linear`` / grouped-MM dispatch enforces it on every forward + already, so any param that survives one training step is + guaranteed to be compatible with the same QDQ kernels we are + about to call. + """ + if not p.requires_grad: + return None + wrapper = _peel_to_fp4_wrapper(p) + if wrapper is None: + return None + cfg = wrapper.config + if cfg is None or cfg.precision not in ("mxfp4", "nvfp4"): + return None + return wrapper + + # ---- per-config QDQ ---------------------------------------------- + + def _qdq_for(self, wrapper: TrainingWeightWrapperBaseTensor) -> QdqFn: + cfg = wrapper.config + fn = self._qdq_cache.get(cfg) + if fn is None: + fn = _make_qdq_fn_for(cfg) + self._qdq_cache[cfg] = fn + return fn + + # ---- hook entry point -------------------------------------------- + + def __call__(self, optimizer, args, kwargs) -> None: # noqa: ARG002 + triggered_params = 0 + n_reset = 0 + n_total = 0 + for group in optimizer.param_groups: + for p in group["params"]: + wrapper = self._wrapper_if_eligible(p) + if wrapper is None: + continue + state = optimizer.state[p] + reset_count, total_count, did_reset = self._step_param( + wrapper, state + ) + if did_reset: + triggered_params += 1 + n_reset += reset_count + n_total += total_count + + if triggered_params and self._cfg.log_freq: + self._periods_completed += 1 + if ( + self._periods_completed % self._cfg.log_freq == 0 + and n_total > 0 + ): + pct = n_reset / n_total * 100.0 + logger.info( + f"[de-osc] period#{self._periods_completed} " + f"triggered_params={triggered_params} " + f"reset_ratio={pct:.4f}% ({n_reset}/{n_total})" + ) + + # ---- per-parameter update ---------------------------------------- + + def _step_param( + self, + wrapper: TrainingWeightWrapperBaseTensor, + state: dict[str, Any], + ) -> tuple[int, int, bool]: + """Track movement, maybe reset, return ``(n_reset, n_total, did_reset)``.""" + cfg = self._cfg + # ``raw`` is the plain FP32/BF16 storage backing the wrapper. + # All math runs on it directly so we never accidentally go + # through __torch_function__ (which would re-enter the FP4 GEMM + # dispatch). + raw = wrapper._data + w_now = raw.detach() + + if self.KEY_PREV not in state or self.KEY_STEP not in state: + # First call (either ever, or after a state-dict mishap). + # Seed the period with the current weight and skip tracking + # for this iteration; the next post_step will see a real + # before/after pair. + self._init_period_state(state, w_now) + return 0, 0, False + + prev = state[self.KEY_PREV] + dist_w = state[self.KEY_DIST_FP] + dist_w_qdq = state[self.KEY_DIST_QDQ] + + qdq_fn = self._qdq_for(wrapper) + # Read the reduction axis from the wrapper so we QDQ along the + # exact same axis that the forward FP4 GEMM uses; no shape-based + # heuristic here. + axis = wrapper.weight_reduction_axis + w_qdq = qdq_fn(w_now, axis) + prev_qdq = qdq_fn(prev, axis) + dist_w.add_((w_now - prev).abs()) + dist_w_qdq.add_((w_qdq - prev_qdq).abs()) + prev.copy_(w_now) + state[self.KEY_STEP] += 1 + + if state[self.KEY_STEP] < cfg.period: + return 0, 0, False + + # End-of-period decision: snap oscillating elements to their + # current bin center, then start a fresh period. + ratio = dist_w_qdq / dist_w.clamp(min=self._EPS) + reset_mask = (dist_w > 0) & (ratio >= cfg.ratio_threshold) + + reset_count = 0 + total_count = dist_w.numel() + if reset_mask.any(): + reset_count = int(reset_mask.sum().item()) + # Snap into the wrapper's underlying storage in place; the + # FP4 GEMM dispatch on the next forward will read this new + # value. + raw.copy_( + torch.where(reset_mask, w_qdq.to(raw.dtype), raw) + ) + # Refresh the snapshot so the next period does not see the + # snap-to-bin-center as a large FP movement on its first + # step. + prev.copy_(raw) + + dist_w.zero_() + dist_w_qdq.zero_() + state[self.KEY_STEP] = 0 + return reset_count, total_count, True + + def _init_period_state( + self, + state: dict[str, Any], + w: torch.Tensor, + ) -> None: + """Allocate (or zero) per-parameter de-oscillation state. + + State tensors are *plain* tensors that share dtype, shape and + device with the wrapper's underlying storage. They are stored + in ``optimizer.state[p]`` directly so they ride along with the + AdamW moments through the standard optimizer state dict path. + """ + state[self.KEY_STEP] = 0 + if self.KEY_DIST_FP not in state: + state[self.KEY_DIST_FP] = torch.zeros_like(w) + state[self.KEY_DIST_QDQ] = torch.zeros_like(w) + state[self.KEY_PREV] = w.clone() + else: + state[self.KEY_DIST_FP].zero_() + state[self.KEY_DIST_QDQ].zero_() + state[self.KEY_PREV].copy_(w) + + +class DeOscillationOptimizersContainer(OptimizersContainer): + """:class:`OptimizersContainer` that installs the de-oscillation hook. + + All standard AdamW / Adam knobs from :class:`OptimizersContainer.Config` + keep their usual meaning; the de-oscillation behaviour is controlled by + the nested :class:`DeOscillationConfig`. + + The hook only fires on parameters wrapped by + :class:`MXFP4TrainingWeightWrapperTensor` or + :class:`NVFP4TrainingWeightWrapperTensor` (the FP4 format and + block-scaling layout are inferred from each wrapper's + :class:`TrainingOpConfig`). Everything else passes through untouched. + + Example:: + + config.optimizer = DeOscillationOptimizersContainer.Config( + name="AdamW", + lr=8e-4, + de_oscillation=DeOscillationConfig( + enabled=True, + period=25, + ratio_threshold=16.0, + log_freq=10, + ), + ) + """ + + @dataclass(kw_only=True, slots=True) + class Config(OptimizersContainer.Config): + de_oscillation: DeOscillationConfig = field( + default_factory=DeOscillationConfig + ) + + def __init__( + self, + config: Config, + *, + model_parts: list[nn.Module], + ) -> None: + super().__init__(config, model_parts=model_parts) + self._de_osci_cfg: DeOscillationConfig = config.de_oscillation + if self._de_osci_cfg.enabled: + hook = _DeOscillationHook(self._de_osci_cfg) + for opt in self.optimizers: + opt.register_step_post_hook(hook) + logger.info( + "[de-osc] enabled " + f"period={self._de_osci_cfg.period} " + f"ratio_threshold={self._de_osci_cfg.ratio_threshold} " + f"log_freq={self._de_osci_cfg.log_freq}; " + "scope=MXFP4/NVFP4 wrapped weights" + ) From 7ccfec2927db4bd5ffc31259d5b5b5964001449e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 28 May 2026 08:31:21 +0000 Subject: [PATCH 2/9] [PoC] de-oscillation optimizer --- alto/components/__init__.py | 4 +- alto/components/optimizer.py | 91 +++++++++-------------------- alto/kernels/dispatch/conversion.py | 6 +- alto/train.py | 9 +++ 4 files changed, 44 insertions(+), 66 deletions(-) diff --git a/alto/components/__init__.py b/alto/components/__init__.py index 19ba777..eb0ad18 100644 --- a/alto/components/__init__.py +++ b/alto/components/__init__.py @@ -3,12 +3,12 @@ # SPDX-License-Identifier: MIT from .converter import ModelOptConverter -from .optimizer import DeOscillationConfig, DeOscillationOptimizersContainer +from .optimizer import DeOscillationConfig, enable_de_oscillation from .state_dict_adapter_mixin import StateDictAdapterMixin __all__ = [ "DeOscillationConfig", - "DeOscillationOptimizersContainer", + "enable_de_oscillation", "ModelOptConverter", "StateDictAdapterMixin", ] diff --git a/alto/components/optimizer.py b/alto/components/optimizer.py index d1199bf..8da5f2b 100644 --- a/alto/components/optimizer.py +++ b/alto/components/optimizer.py @@ -51,8 +51,7 @@ for ``Q`` come from the wrapper's :class:`~alto.kernels.dispatch.config.TrainingOpConfig`, and the reduction axis along which QDQ is performed comes from -``wrapper.weight_reduction_axis`` -- the same field consumed by the FP4 -grouped-MM dispatch. Reading both from the wrapper guarantees that the +``_infer_reduction_axis``. Reading both from the wrapper guarantees that the de-oscillation hook tracks oscillation in exactly the same FP4 grid the forward / backward GEMMs see, with no axis inference duplicated here. @@ -91,6 +90,15 @@ ) +def _infer_reduction_axis(w: torch.Tensor): + if w.dim() == 2: + return -1 + elif w.dim() == 3: + return -2 + else: + raise ValueError(f"Unexpected weight shape: {w.shape}") + + def _peel_to_fp4_wrapper( p: torch.Tensor, ) -> TrainingWeightWrapperBaseTensor | None: @@ -120,7 +128,7 @@ def _make_qdq_fn_for(cfg: TrainingOpConfig) -> QdqFn: The closure runs on the *plain* underlying FP32/BF16 tensor (i.e. ``wrapper._data``), with the same ``is_2d_block`` choice the FP4 GEMM applies to the weight operand. The reduction axis is passed in - by the caller (read from ``wrapper.weight_reduction_axis``) so the + by the caller (read from ``_infer_reduction_axis``) so the same cache entry serves both ``nn.Linear`` weights (axis ``-1``) and grouped-MM expert weights (axis ``-2``) that share a config. """ @@ -160,7 +168,6 @@ def qdq(w: torch.Tensor, axis: int) -> torch.Tensor: axis=axis, is_2d_block=is_2d_block, ) - else: raise ValueError( f"de-oscillation only supports FP4 wrappers, " @@ -180,7 +187,7 @@ class DeOscillationConfig: quantization grid that the forward pass uses. Attributes: - enabled: master switch. + enable: master switch. period: number of ``optimizer.step()`` calls per de-oscillation window. Reset decisions are taken on the last step of each window. @@ -192,7 +199,7 @@ class DeOscillationConfig: periods. ``0`` disables logging. """ - enabled: bool = False + enable: bool = False period: int = 25 ratio_threshold: float = 16.0 log_freq: int = 0 @@ -225,7 +232,7 @@ def __init__(self, cfg: DeOscillationConfig) -> None: # TrainingOpConfig is @dataclass(unsafe_hash=True), so the same # config object (or two equal ones) maps to a single cached QDQ. # The cached closure takes ``(tensor, axis)``; the axis is read - # per-call from ``wrapper.weight_reduction_axis`` and is not + # per-call from ``_infer_reduction_axis`` and is not # baked into the cache key. self._qdq_cache: dict[TrainingOpConfig, QdqFn] = {} # Counts how many periods have completed since the hook was @@ -328,10 +335,7 @@ def _step_param( dist_w_qdq = state[self.KEY_DIST_QDQ] qdq_fn = self._qdq_for(wrapper) - # Read the reduction axis from the wrapper so we QDQ along the - # exact same axis that the forward FP4 GEMM uses; no shape-based - # heuristic here. - axis = wrapper.weight_reduction_axis + axis = _infer_reduction_axis(w_now) w_qdq = qdq_fn(w_now, axis) prev_qdq = qdq_fn(prev, axis) dist_w.add_((w_now - prev).abs()) @@ -390,55 +394,16 @@ def _init_period_state( state[self.KEY_PREV].copy_(w) -class DeOscillationOptimizersContainer(OptimizersContainer): - """:class:`OptimizersContainer` that installs the de-oscillation hook. - - All standard AdamW / Adam knobs from :class:`OptimizersContainer.Config` - keep their usual meaning; the de-oscillation behaviour is controlled by - the nested :class:`DeOscillationConfig`. - - The hook only fires on parameters wrapped by - :class:`MXFP4TrainingWeightWrapperTensor` or - :class:`NVFP4TrainingWeightWrapperTensor` (the FP4 format and - block-scaling layout are inferred from each wrapper's - :class:`TrainingOpConfig`). Everything else passes through untouched. - - Example:: - - config.optimizer = DeOscillationOptimizersContainer.Config( - name="AdamW", - lr=8e-4, - de_oscillation=DeOscillationConfig( - enabled=True, - period=25, - ratio_threshold=16.0, - log_freq=10, - ), - ) - """ - - @dataclass(kw_only=True, slots=True) - class Config(OptimizersContainer.Config): - de_oscillation: DeOscillationConfig = field( - default_factory=DeOscillationConfig - ) - - def __init__( - self, - config: Config, - *, - model_parts: list[nn.Module], - ) -> None: - super().__init__(config, model_parts=model_parts) - self._de_osci_cfg: DeOscillationConfig = config.de_oscillation - if self._de_osci_cfg.enabled: - hook = _DeOscillationHook(self._de_osci_cfg) - for opt in self.optimizers: - opt.register_step_post_hook(hook) - logger.info( - "[de-osc] enabled " - f"period={self._de_osci_cfg.period} " - f"ratio_threshold={self._de_osci_cfg.ratio_threshold} " - f"log_freq={self._de_osci_cfg.log_freq}; " - "scope=MXFP4/NVFP4 wrapped weights" - ) +def enable_de_oscillation(optimizers: OptimizersContainer, config: DeOscillationConfig) -> None: + if not config.enable: + return + hook = _DeOscillationHook(config) + for opt in optimizers.optimizers: + opt.register_step_post_hook(hook) + logger.info( + "[de-osc] enabled " + f"period={config.period} " + f"ratio_threshold={config.ratio_threshold} " + f"log_freq={config.log_freq}; " + "scope=MXFP4/NVFP4 wrapped weights" + ) diff --git a/alto/kernels/dispatch/conversion.py b/alto/kernels/dispatch/conversion.py index 6127ce8..c10d64a 100644 --- a/alto/kernels/dispatch/conversion.py +++ b/alto/kernels/dispatch/conversion.py @@ -97,6 +97,10 @@ def post_order_traversal( for param_name, param in module.named_parameters(recurse=False): if (target_parameter_name is not None and param_name != target_parameter_name): continue + full_param_name = f"{module_name}{'.' if module_name else ''}{cur_fqn}{'.' if cur_fqn else ''}{param_name}" + if (target_parameter_name is None and "bias" in param_name): + logger.warn(f"Skipped {full_param_name} because it is a bias parameter") + continue if not isinstance(param.data, TrainingWeightWrapperBaseTensor): new_param = nn.Parameter( tensor_cls(param.data, config), @@ -104,7 +108,7 @@ def post_order_traversal( ) setattr(module, param_name, new_param) logger.info( - f"Swapped {module_name}{'.' if module_name else ''}{cur_fqn}{'.' if cur_fqn else ''}{param_name} to {tensor_cls.__name__}" + f"Swapped {full_param_name} to {tensor_cls.__name__}" ) def get_name_func_new(): diff --git a/alto/train.py b/alto/train.py index cec58ae..4dbab96 100644 --- a/alto/train.py +++ b/alto/train.py @@ -12,6 +12,7 @@ from torchtitan.experiments.forge.example_train import Trainer as ForgeTrainer, main as forge_main from torchtitan.components.metrics import MetricsProcessor from alto.components.converter import ModelOptConverter +from alto.components.optimizer import DeOscillationConfig, enable_de_oscillation from torchtitan.tools.logging import logger @@ -109,6 +110,14 @@ def __init__(self, config: TitanTrainer.Config): logger.info("data replay buffer disabled") self.enable_data_cache = False + deosc_config = DeOscillationConfig( + enable=True, + period=4, + ratio_threshold=16.0, + log_freq=1, + ) + enable_de_oscillation(self.optimizers, deosc_config) + def cache_input(self, microbatches: list[tuple[dict[str, torch.Tensor], torch.Tensor]]): if self.enable_data_cache: self._input_cache = microbatches From 579e32bf4b7bae4a502c79aa0fc13cd547ed2fda Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 28 May 2026 09:10:55 +0000 Subject: [PATCH 3/9] feat: enable de-osc by env var --- alto/train.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/alto/train.py b/alto/train.py index 4dbab96..df8ff66 100644 --- a/alto/train.py +++ b/alto/train.py @@ -4,6 +4,7 @@ from typing import Iterable, Any from contextlib import contextmanager +import os import time import torch from torchtitan.components.loss import IGNORE_INDEX @@ -110,14 +111,6 @@ def __init__(self, config: TitanTrainer.Config): logger.info("data replay buffer disabled") self.enable_data_cache = False - deosc_config = DeOscillationConfig( - enable=True, - period=4, - ratio_threshold=16.0, - log_freq=1, - ) - enable_de_oscillation(self.optimizers, deosc_config) - def cache_input(self, microbatches: list[tuple[dict[str, torch.Tensor], torch.Tensor]]): if self.enable_data_cache: self._input_cache = microbatches @@ -199,6 +192,17 @@ def train_step( data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]], ): if self.training_mode: + # FIXME: This is a hack to enable de-oscillation at a specific step. + deosc_step = int(os.environ.get("DEOSC_STEP", "0")) + if deosc_step > 0 and self.step == deosc_step: + deosc_config = DeOscillationConfig( + enable=True, + period=200, + ratio_threshold=8.0, + log_freq=1, + ) + enable_de_oscillation(self.optimizers, deosc_config) + return super().train_step(data_iterator) # Keep these variables local to shorten the code as these are From acf9b8e471bacec427336c2e1c765da61516d061 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 28 May 2026 10:02:21 +0000 Subject: [PATCH 4/9] hotfix: de-osc with FSDP for GPT-OSS 20B model --- alto/components/optimizer.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/alto/components/optimizer.py b/alto/components/optimizer.py index 8da5f2b..c575818 100644 --- a/alto/components/optimizer.py +++ b/alto/components/optimizer.py @@ -144,16 +144,27 @@ def _make_qdq_fn_for(cfg: TrainingOpConfig) -> QdqFn: if cfg.precision == "mxfp4": def qdq(w: torch.Tensor, axis: int) -> torch.Tensor: + # hotfix for GPT-OSS 20B model + original_rows = -1 + if w.dim() == 2 and w.shape[0] % 32 != 0: + # for 2-D w, if the first dim is not divisible by 32, pad it to be divisible by 32 + original_rows = w.shape[0] + w = torch.nn.functional.pad(w, (0, 0, 0, 32 - original_rows % 32)) + data_lp, scales = convert_to_mxfp4( w, axis=axis, is_2d_block=is_2d_block, ) - return convert_from_mxfp4( + dequantized = convert_from_mxfp4( data_lp, scales, output_dtype=w.dtype, axis=axis, is_2d_block=is_2d_block, ) + # hotfix for GPT-OSS 20B model + if original_rows != -1: + dequantized = dequantized[:original_rows, :] + return dequantized elif cfg.precision == "nvfp4": From fa6c29db7ab752e08c401cc156857822a26d097e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 2 Jun 2026 02:27:06 +0000 Subject: [PATCH 5/9] feat: add DEOSC_RATIO env --- alto/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/alto/train.py b/alto/train.py index df8ff66..223a755 100644 --- a/alto/train.py +++ b/alto/train.py @@ -194,11 +194,12 @@ def train_step( if self.training_mode: # FIXME: This is a hack to enable de-oscillation at a specific step. deosc_step = int(os.environ.get("DEOSC_STEP", "0")) + ratio_threshold = float(os.environ.get("DEOSC_RATIO", "8.0")) if deosc_step > 0 and self.step == deosc_step: deosc_config = DeOscillationConfig( enable=True, period=200, - ratio_threshold=8.0, + ratio_threshold=ratio_threshold, log_freq=1, ) enable_de_oscillation(self.optimizers, deosc_config) From e3a551659f470206875dcd976497a2f67564afa9 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 2 Jun 2026 03:06:25 +0000 Subject: [PATCH 6/9] expr: llama3 mxfp4 --- alto/models/llama3/config_registry.py | 28 +++++++++++++++------- alto/models/llama3/configs/lpt_recipe.yaml | 8 +++++++ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/alto/models/llama3/config_registry.py b/alto/models/llama3/config_registry.py index e499b0c..89b4dfb 100644 --- a/alto/models/llama3/config_registry.py +++ b/alto/models/llama3/config_registry.py @@ -123,24 +123,34 @@ def llama3_1b_lpt() -> Trainer.Config: def llama3_8b_pretrain() -> Trainer.Config: config = llama3_8b_orig() config.hf_assets_path = "/huggingface/hub/models--unsloth--Llama-3.1-8B/snapshots/3f0d51f8e5640f98f1a96ea9044a0e55c0a83814" + config.dump_folder = "llama3_8b-mi308-pretrain-subset-gbs384-lr1e-4-outputs" config.metrics.log_freq = 1 + config.metrics.enable_tensorboard = True config.profiling.enable_profiling = False - config.training.steps = 0 - config.training.local_batch_size = 1 + config.training.steps = 5000 + config.training.local_batch_size = 3 + config.training.global_batch_size = 384 config.training.seq_len = 8192 - config.dataloader.dataset = "c4_test" + config.optimizer.lr = 1e-4 + config.lr_scheduler.min_lr_factor = 0.0 + config.lr_scheduler.warmup_steps = 500 + config.lr_scheduler.decay_ratio = 0.9 + config.lr_scheduler.decay_type = "cosine" + config.dataloader.dataset = "megatron" + config.dataloader.dataset_path = "/workspace/workspace/megatron_dataset/data/c4-train.en_6_text_document.idx" config.parallelism.expert_parallel_degree = 1 config.parallelism.expert_tensor_parallel_degree = 1 - config.parallelism.tensor_parallel_degree = 8 + config.parallelism.tensor_parallel_degree = 2 config.activation_checkpoint.mode = "none" config.checkpoint.enable = False config.checkpoint.interval = 10 config.checkpoint.initial_load_path = "/huggingface/hub/models--unsloth--Llama-3.1-8B/snapshots/3f0d51f8e5640f98f1a96ea9044a0e55c0a83814" config.checkpoint.initial_load_in_hf = False - config.validator.enable = True - config.validator.dataloader.dataset = "wikitext_test" - config.validator.freq = 10 - config.validator.steps = 10 + config.validator.enable = False + config.validator.dataloader.dataset = "megatron" + config.validator.dataloader.dataset_path = "/workspace/workspace/megatron_dataset/data/c4-validation-91205-samples.en_text_document.idx" + config.validator.freq = 768 + config.validator.steps = 64 config.debug.seed = 1234 return config @@ -156,7 +166,7 @@ def llama3_8b_opt() -> Trainer.Config: def llama3_8b_lpt() -> Trainer.Config: config = llama3_8b_pretrain() - config.training.steps = 1000 + config.dump_folder = "llama3_8b-mi308-pretrain-subset-mxfp4gemm_1d2d-hadamard-sr-gbs384-lr1e-4-outputs" config.model_converters = ModelConvertersContainer.Config( converters=[ModelOptConverter.Config(recipe="./alto/models/llama3/configs/lpt_recipe.yaml",)],) return config diff --git a/alto/models/llama3/configs/lpt_recipe.yaml b/alto/models/llama3/configs/lpt_recipe.yaml index b1cfbdc..2a12a1d 100644 --- a/alto/models/llama3/configs/lpt_recipe.yaml +++ b/alto/models/llama3/configs/lpt_recipe.yaml @@ -4,3 +4,11 @@ training_stage: scheme: "mxfp4" targets: ["Linear"] ignore: ["output"] + use_2dblock_x: false + use_2dblock_w: true + use_hadamard: true + use_sr_grad: true + use_dge: false + clip_mode: none + two_level_scaling: none + lora_rank: 0 From a1961371039dfaf0194fd73d12fb38e61a2b61c3 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 3 Jun 2026 03:42:18 +0000 Subject: [PATCH 7/9] expr: decrease lbs of llama3-8b --- alto/models/llama3/config_registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/alto/models/llama3/config_registry.py b/alto/models/llama3/config_registry.py index 89b4dfb..d2df171 100644 --- a/alto/models/llama3/config_registry.py +++ b/alto/models/llama3/config_registry.py @@ -128,7 +128,7 @@ def llama3_8b_pretrain() -> Trainer.Config: config.metrics.enable_tensorboard = True config.profiling.enable_profiling = False config.training.steps = 5000 - config.training.local_batch_size = 3 + config.training.local_batch_size = 2 config.training.global_batch_size = 384 config.training.seq_len = 8192 config.optimizer.lr = 1e-4 @@ -140,7 +140,7 @@ def llama3_8b_pretrain() -> Trainer.Config: config.dataloader.dataset_path = "/workspace/workspace/megatron_dataset/data/c4-train.en_6_text_document.idx" config.parallelism.expert_parallel_degree = 1 config.parallelism.expert_tensor_parallel_degree = 1 - config.parallelism.tensor_parallel_degree = 2 + config.parallelism.tensor_parallel_degree = 1 config.activation_checkpoint.mode = "none" config.checkpoint.enable = False config.checkpoint.interval = 10 From 6ec0b63a6a50d3eabfb7fe3f0fadf4a369d5b50c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 8 Jun 2026 06:25:50 +0000 Subject: [PATCH 8/9] fix: nvfp4 de-osc padding --- alto/components/optimizer.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/alto/components/optimizer.py b/alto/components/optimizer.py index c575818..3b602cb 100644 --- a/alto/components/optimizer.py +++ b/alto/components/optimizer.py @@ -169,16 +169,27 @@ def qdq(w: torch.Tensor, axis: int) -> torch.Tensor: elif cfg.precision == "nvfp4": def qdq(w: torch.Tensor, axis: int) -> torch.Tensor: + # hotfix for GPT-OSS 20B model + original_rows = -1 + if w.dim() == 2 and w.shape[0] % 16 != 0: + # for 2-D w, if the first dim is not divisible by 16, pad it to be divisible by 16 + original_rows = w.shape[0] + w = torch.nn.functional.pad(w, (0, 0, 0, 16 - original_rows % 16)) + data_lp, scales = convert_to_nvfp4( w, axis=axis, is_2d_block=is_2d_block, ) - return convert_from_nvfp4( + dequantized = convert_from_nvfp4( data_lp, scales, output_dtype=w.dtype, axis=axis, is_2d_block=is_2d_block, ) + # hotfix for GPT-OSS 20B model + if original_rows != -1: + dequantized = dequantized[:original_rows, :] + return dequantized else: raise ValueError( f"de-oscillation only supports FP4 wrappers, " From 614be7d5a103c8d270e459ce66b2b6a5b37841df Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 9 Jun 2026 05:02:53 +0000 Subject: [PATCH 9/9] feat: enable nvfp4 tensor-wise scaling in de-osc --- alto/components/optimizer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/alto/components/optimizer.py b/alto/components/optimizer.py index 3b602cb..3cbe240 100644 --- a/alto/components/optimizer.py +++ b/alto/components/optimizer.py @@ -167,6 +167,7 @@ def qdq(w: torch.Tensor, axis: int) -> torch.Tensor: return dequantized elif cfg.precision == "nvfp4": + use_outer_scale = cfg.two_level_scaling == "tensorwise" def qdq(w: torch.Tensor, axis: int) -> torch.Tensor: # hotfix for GPT-OSS 20B model @@ -176,8 +177,16 @@ def qdq(w: torch.Tensor, axis: int) -> torch.Tensor: original_rows = w.shape[0] w = torch.nn.functional.pad(w, (0, 0, 0, 16 - original_rows % 16)) + outer_scale = ( + torch.empty(1, dtype=torch.float32, device=w.device) + if use_outer_scale + else None + ) + data_lp, scales = convert_to_nvfp4( w, axis=axis, is_2d_block=is_2d_block, + outer_scale=outer_scale, + update_outer_scale=use_outer_scale, ) dequantized = convert_from_nvfp4( data_lp, @@ -185,6 +194,7 @@ def qdq(w: torch.Tensor, axis: int) -> torch.Tensor: output_dtype=w.dtype, axis=axis, is_2d_block=is_2d_block, + outer_scale=outer_scale, ) # hotfix for GPT-OSS 20B model if original_rows != -1: