diff --git a/README.md b/README.md index 2445947..7752931 100644 --- a/README.md +++ b/README.md @@ -1 +1,49 @@ # Sparse-Activations + +## Modifiers + +The `modifiers` package is a collection of drop-in replacements for standard PyTorch activation functions and normalization layers, with a focus on sparsity-inducing variants. + +### Activations + +Most of these are decorated with `@analytical_module`, which optionally stores input/output tensors on the forward pass. + +- **ReLUSquared** - Just ReLU followed by squaring: $f(x) = (\max(0, x))^2$. +- **BSiLU** - A shifted SiLU variant: $f(x) = (x + \alpha) \cdot \sigma(x) - \alpha/2$. Comes from [this paper](https://arxiv.org/html/2505.22074v1). Smoother gradients than ReLU, and `alpha` is configurable. +- **SUGARBSiLU** - Uses ReLU in the forward pass but BSiLU's gradient in the backward pass (surrogate gradient trick). Same paper as above. +- **NoisyReLU** - Adds learnable noise during training based on the negative part of the input. The noise scale is controlled by a parameter `p` and a constant `c`. Based on [this paper](https://arxiv.org/pdf/1603.00391). +- **QuantileReLU** - Zeros out activations below a given quantile threshold instead of just below zero. Supports several modes: shifted sparsity, unsigned, continuous, etc. +- **TopKSparseGELU** - GELU but only the top-k% of activations survive (the rest get zeroed). Uses the `@topk_sparse_module` decorator under the hood. + +All activations are accessible by string name (e.g. `'ReLUSquared'`, `'TopKSparseGELU-50'`) through a built-in name map, so you don't have to import classes manually if you don't want to. Presets with common sparsity levels (10%, 25%, 50%, 75%, 90%) are included. + +### Normalizations + +Custom normalization layers that shift the centering point from the mean to a quantile - this effectively biases the normalization toward sparser outputs. + +- **QuantileBatchNorm2d** - Like BatchNorm2d, but uses a quantile of the activations as the mean estimate. Supports global, batchwise, or channelwise quantile computation. +- **QuantileMeanBatchNorm2d** - Same quantile-based mean, but keeps the standard variance calculation. A middle ground. +- **QuantileLayerNorm** - LayerNorm variant with quantile-based centering. Also supports running stats tracking and configurable quantile search modes. + +Same deal as activations - string names with sparsity presets are available (e.g. `'QuantileBatchNorm2d-50'`). + +### Decorators + +Two module decorators that can wrap any `nn.Module`: + +- `@analytical_module` - Adds `in_activation` / `out_activation` attributes that capture tensors during forward. Toggle with `debug_info=True/False`. +- `@topk_sparse_module` - Adds top-k sparsity to any activation. Set `sparsity_level` (0 to 1) and choose whether sparsity is applied before or after the base activation with `post_sparsity`. + +### Replacing layers in an existing model + +```python +from modifiers import replace_activation, replace_normalization + +# Swap all GELU activations for ReLUSquared +replace_activation(model, original_activation='GELU', replaced_activation='ReLUSquared') + +# Swap BatchNorm2d for quantile-based variant at 50% sparsity +replace_normalization(model, original_normalization='BatchNorm2d', replaced_normalization='QuantileBatchNorm2d-50') +``` + +Both functions walk the module tree recursively and return a list of the newly created layers, in case you need to track them. diff --git a/activations.py b/activations.py deleted file mode 100644 index fd14202..0000000 --- a/activations.py +++ /dev/null @@ -1,75 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import Module, ReLU -from torch.autograd import Function -# relu squared with optional signing - -class ReLUSquared(nn.Module): - def forward(self, x): - return F.relu(x).square() - - -class NoisyReLU(nn.Module): - def __init__(self, alpha=1.0, c=1.0, noise_type='half-normal'): - super().__init__() - self.alpha = alpha - self.c = c - self.noise_type = noise_type - self.p = nn.Parameter(torch.randn(1)) - - def forward(self, x): - if not self.training: - # Pure ReLU at test time - maintains sparsity - return F.relu(x) - - # Training time with noise - h_x = F.relu(x) - u_x = x - delta = torch.where(x < 0, -x, torch.zeros_like(x)) - - sigma = self.c * torch.square(torch.sigmoid(self.p * delta) - 0.5) - direction = torch.where(x < 0, torch.ones_like(x), torch.zeros_like(x)) - - if self.noise_type == 'half-normal': - epsilon = torch.abs(torch.randn_like(x)) - noise = direction * sigma * epsilon - else: - epsilon = torch.randn_like(x) - noise = direction * sigma * epsilon - - return self.alpha * h_x + (1 - self.alpha) * u_x + noise - -ALPHA = 1.67 - -def b_silu_forward(x): - sigma_x = torch.sigmoid(x) - return (x + ALPHA) * sigma_x - ALPHA / 2.0 - -def b_silu_backward(x): - # Derivative of B-SiLU: sigma(x) + (x + alpha) * sigma(x) * (1 - sigma(x)) - sigma_x = torch.sigmoid(x) - return sigma_x + (x + ALPHA) * sigma_x * (1.0 - sigma_x) - -# Define the custom autograd function for SUGAR (Surrogate Gradient for ReLU) -class SUGARBSiLUFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - # Save input 'x' for the backward pass - ctx.save_for_backward(x) - # Forward pass is simply ReLU(x) - return F.relu(x) - - @staticmethod - def backward(ctx, grad_output): - # Backward pass uses the B-SiLU derivative as the surrogate gradient - x, = ctx.saved_tensors - # Calculate the B-SiLU derivative at input 'x' - b_silu_grad = b_silu_backward(x) - # Multiply with the incoming gradient - return grad_output * b_silu_grad - -# Define the nn.Module wrapper -class SUGARBSiLU(nn.Module): - def forward(self, x): - return SUGARBSiLUFunction.apply(x) diff --git a/modifiers/__init__.py b/modifiers/__init__.py new file mode 100644 index 0000000..80d04fa --- /dev/null +++ b/modifiers/__init__.py @@ -0,0 +1,90 @@ +from .decorators import analytical_activation_module, analytical_linear_module, topk_sparse_module +from .activations import ( + ReLUSquared, + ReLUSquaredClipped, + GELUSquared, + GELUSquaredClipped, + + QuantileReLU, + NoisyReLU, + + BSiLU, + SUGARBSiLU, + + TopKSparseGELU, + + ActivationClass, +) +from .normalizations import ( + BatchNorm2dPreStop, + LayerNormPreStop, + + QuantileBatchNorm2d, + QuantileLayerNorm, + QuantileMeanBatchNorm2d, + + NormalizationClass, +) +from .linears import ( + TopKSparseLinear, + TopKSparseConv2d, + TopKSparseConv1d, + + LinearClass, +) +from .modify import ( + replace_activation, + replace_normalization, + replace_linear, + + make_analytical_activation, + make_analytical_linear +) + +__all__ = [ + # Decorators + 'topk_sparse_module', + 'analytical_activation_module', + 'analytical_linear_module', + + # Activations + 'ReLUSquared', + "ReLUSquaredClipped", + 'GELUSquared', + 'GELUSquaredClipped', + + 'QuantileReLU', + 'NoisyReLU', + + 'BSiLU', + 'SUGARBSiLU', + + 'TopKSparseGELU', + + 'ActivationClass', + + # Normalizations + 'BatchNorm2dPreStop', + 'LayerNormPreStop', + + 'QuantileBatchNorm2d', + 'QuantileLayerNorm', + 'QuantileMeanBatchNorm2d', + + 'NormalizationClass', + + # Linears + 'TopKSparseLinear', + 'TopKSparseConv2d', + 'TopKSparseConv1d', + + 'LinearClass', + + # Modifiers + 'replace_activation', + 'replace_normalization', + 'replace_linear', + + 'make_analytical_activation', + 'make_analytical_linear', +] diff --git a/modifiers/activations.py b/modifiers/activations.py new file mode 100644 index 0000000..a0adc2d --- /dev/null +++ b/modifiers/activations.py @@ -0,0 +1,300 @@ +from typing import Optional, Literal +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .decorators import analytical_activation_module, topk_sparse_module + + +########################################################################## +# Extra activation classes # +########################################################################## + +# ReLU^2 + +@analytical_activation_module +class ReLUSquared(nn.ReLU): + """ + ReLUSquared is an activation function that applies the ReLU operation followed by squaring the output (i.e., f(x) = (max(0, x))^2). + """ + def forward(self, input): + output = super().forward(input) + output = output.square() # FIXME: Use in-place operation for better performance. For now square_ causes issues with autograd in some cases + return output + + +@analytical_activation_module +class ReLUSquaredClipped(ReLUSquared): + """ + ReLUSquaredClipped is an activation function that applies the ReLU operation followed by squaring the output and clipping it to a maximum value (i.e., f(x) = min((max(0, x))^2, clip_value)). + """ + def __init__(self, *args, clip_value: float = 15.0, **kwargs): + super().__init__(*args, **kwargs) + self.clip_value = clip_value + + def forward(self, input): + output = super().forward(input) + output.clamp_(max=self.clip_value) + return output + + +# GELU^2 + +@analytical_activation_module +class GELUSquared(nn.GELU): + """ + GELUSquared is an activation function that applies the GELU operation followed by squaring the output (i.e., f(x) = (GELU(x))^2). + """ + def forward(self, input): + output = super().forward(input) + output.square_() + return output + + +@analytical_activation_module +class GELUSquaredClipped(GELUSquared): + """ + GELUSquaredClipped is an activation function that applies the GELU operation followed by squaring the output and clipping it to a maximum value (i.e., f(x) = min((GELU(x))^2, clip_value)). + """ + def __init__(self, *args, clip_value: float = 15.0, **kwargs): + super().__init__(*args, **kwargs) + self.clip_value = clip_value + + def forward(self, input): + output = super().forward(input) + output.clamp_(max=self.clip_value) + return output + + +# B-SiLU + +@analytical_activation_module +class BSiLU(nn.SiLU): + """ + BSiLU is a modified version of the SiLU (Sigmoid Linear Unit) activation function, defined as: + f(x) = (x + alpha) * sigmoid(x) - alpha / 2 + where alpha is a hyperparameter that controls the shape of the function. The BSiLU activation can provide smoother gradients compared to ReLU and may help with training stability in certain neural network architectures. + + see: https://arxiv.org/html/2505.22074v1 for more details on B-SiLU and its properties. + """ + + def __init__(self, *args, alpha=1.67, **kwargs): + super().__init__(*args, **kwargs) + + assert getattr(self, 'inplace', False) == False, "Without triton kernel it is impossible to make inplace B-SiLU" + + self.alpha = alpha + + def forward(self, input): + sigma_x = torch.sigmoid(input) + return (input + self.alpha) * sigma_x - self.alpha / 2.0 + + def backward(self, grad_output): + # Derivative of B-SiLU: sigma(x) + (x + alpha) * sigma(x) * (1 - sigma(x)) + x = self.in_activation + sigma_x = torch.sigmoid(x) + b_silu_grad = sigma_x + (x + self.alpha) * sigma_x * (1.0 - sigma_x) + return grad_output * b_silu_grad + + +# Sugar B-SiLU + +@analytical_activation_module +class SUGARBSiLU(nn.ReLU): + """ + SUGAR-BSiLU is a variant of the surrogate gradient activation function that combines the properties of ReLU and B-SiLU. It is defined as: + + see: https://arxiv.org/html/2505.22074v1 for more details on SUGAR-BSiLU and its properties. + """ + + def __init__(self, *args, alpha=1.67, **kwargs): + super().__init__(*args, **kwargs) + + assert getattr(self, 'inplace', False) == False, "Without triton kernel it is impossible to make inplace SUGAR-BSiLU" + + self.alpha = alpha + + def backward(self, grad_output): + return BSiLU.backward(self, grad_output) + + +# Noisy ReLU + +@analytical_activation_module +class NoisyReLU(nn.ReLU): + """ + NoisyReLU is a variant of the ReLU activation function that adds noise to the output during training. The noise is generated based on the negative part of the input, and its scale is controlled by a learnable parameter p and a hyperparameter c. The noise can help regularize the model and improve generalization by preventing overfitting to the training data. + + see: https://arxiv.org/pdf/1603.00391 for more details on NoisyReLU and its properties. + """ + + def __init__( + self, + *args, + alpha=1.0, + c=1.0, + noise_type='half-normal', + **kwargs + ): + super().__init__(*args, **kwargs) + + assert getattr(self, 'inplace', False) == False, "Without triton kernel it is impossible to make inplace NoisyReLU" + + self.alpha = alpha + self.c = c + self.noise_type = noise_type + self.p = nn.Parameter(torch.randn(1)) + + def forward(self, x): + if not self.training: + return F.relu(x) + + # Training time with noise + mask = x < 0 + delta = torch.where(mask, x, 0.0) + + sigma = delta.mul(-self.p).sigmoid().sub(0.5).square() + + epsilon = torch.randn_like(x) + if self.noise_type == 'half-normal': + epsilon.abs_() + noise = sigma.mul(epsilon.mul_(self.c)) + + if 1 - self.alpha < 0: + noise = noise.neg_() + + x = F.leaky_relu(x, (1 - self.alpha)) + noise + + return x + + +# Quantile-based ReLU + +@analytical_activation_module +class QuantileReLU(nn.ReLU): + def __init__( + self, + sparsity_level: Optional[float] = None, + shifted_sparsity: bool = False, + signed = True, + continuous = False, + **kwargs + ): + super().__init__(**kwargs) + + self.sparsity_level = sparsity_level + self.shifted_sparsity = shifted_sparsity + self.signed = signed + self.continuous = continuous + + def forward(self, input): + if not self.signed: + sign_mask = torch.ones_like(input) + sign_mask.masked_fill_(input < 0, -1) + input = input.abs() + + if self.sparsity_level is None: + output = super().forward(input) + + elif self.shifted_sparsity or self.continuous: + # Find k-th value for each batch element + n_remove = int(self.sparsity_level * input.size(dim=0)) + 1 + kth_values = torch.kthvalue(input, n_remove, dim=0, keepdim=True).values + output = super().forward(input - kth_values) + if not self.continuous: + output = output + kth_values + + else: + n_remove = int(self.sparsity_level * input.size(dim=0)) + 1 + mask = input >= torch.kthvalue(input, n_remove, dim=0, keepdim=True).values + output = input * mask + + if not self.signed: + input = input * sign_mask + + return output + + def extra_repr(self) -> str: + return (f'sparsity_level={self.sparsity_level}, shifted_sparsity={self.shifted_sparsity}, ' + f'signed={self.signed}, continuous={self.continuous}, {super().extra_repr()}') + + +########################################################################## +# Sparse activations # +########################################################################## + +@analytical_activation_module +@topk_sparse_module +class TopKSparseGELU(nn.GELU): + """ + TopKSparseGELU is a variant of the GELU activation function that applies sparsity to the activations by zeroing out the smallest activations based on a specified sparsity level. The sparsity is applied by keeping only the top k% of the activations, where k is determined by the sparsity_level parameter. + """ + pass + + +########################################################################## +# Mapping from string names to activation classes # +########################################################################## + +ACTIVATION_NAMES_MAP = { + 'ReLU': nn.ReLU, + 'PReLU': nn.PReLU, + 'GELU': nn.GELU, + 'SiLU': nn.SiLU, + + 'AReLU': analytical_activation_module(nn.ReLU), + 'APReLU': analytical_activation_module(nn.PReLU), + 'AGELU': analytical_activation_module(nn.GELU), + 'ASiLU': analytical_activation_module(nn.SiLU), + + 'ReLUSquared': ReLUSquared, + 'ReLUSquaredClipped': ReLUSquaredClipped, + 'GELUSquared': GELUSquared, + 'GELUSquaredClipped': GELUSquaredClipped, + + 'BSiLU': BSiLU, + 'SUGARBSiLU': SUGARBSiLU, + 'NoisyReLU': NoisyReLU, + + 'QuantileReLU': QuantileReLU, + 'QuantileReLU-10': partial(QuantileReLU, sparsity_level=0.10), + 'QuantileReLU-25': partial(QuantileReLU, sparsity_level=0.25), + 'QuantileReLU-50': partial(QuantileReLU, sparsity_level=0.50), + 'QuantileReLU-75': partial(QuantileReLU, sparsity_level=0.75), + 'QuantileReLU-90': partial(QuantileReLU, sparsity_level=0.90), + + 'TopKSparseGELU': TopKSparseGELU, + 'TopKSparseGELU-10': partial(TopKSparseGELU, sparsity_level=0.10), + 'TopKSparseGELU-25': partial(TopKSparseGELU, sparsity_level=0.25), + 'TopKSparseGELU-50': partial(TopKSparseGELU, sparsity_level=0.50), + 'TopKSparseGELU-75': partial(TopKSparseGELU, sparsity_level=0.75), + 'TopKSparseGELU-80': partial(TopKSparseGELU, sparsity_level=0.80), + 'TopKSparseGELU-85': partial(TopKSparseGELU, sparsity_level=0.85), + 'TopKSparseGELU-90': partial(TopKSparseGELU, sparsity_level=0.90), + 'TopKSparseGELU-95': partial(TopKSparseGELU, sparsity_level=0.95), + 'TopKSparseGELU-99': partial(TopKSparseGELU, sparsity_level=0.99), + + 'TopKSparseGELU-AS': partial(TopKSparseGELU, max_tracked_cnt=50_000), + 'TopKSparseGELU-10-AS': partial(TopKSparseGELU, max_tracked_cnt=50_000, sparsity_level=0.10), + 'TopKSparseGELU-25-AS': partial(TopKSparseGELU, max_tracked_cnt=50_000, sparsity_level=0.25), + 'TopKSparseGELU-50-AS': partial(TopKSparseGELU, max_tracked_cnt=50_000, sparsity_level=0.50), + 'TopKSparseGELU-75-AS': partial(TopKSparseGELU, max_tracked_cnt=50_000, sparsity_level=0.75), + 'TopKSparseGELU-80-AS': partial(TopKSparseGELU, max_tracked_cnt=50_000, sparsity_level=0.80), + 'TopKSparseGELU-85-AS': partial(TopKSparseGELU, max_tracked_cnt=50_000, sparsity_level=0.85), + 'TopKSparseGELU-90-AS': partial(TopKSparseGELU, max_tracked_cnt=50_000, sparsity_level=0.90), + 'TopKSparseGELU-95-AS': partial(TopKSparseGELU, max_tracked_cnt=50_000, sparsity_level=0.95), + 'TopKSparseGELU-99-AS': partial(TopKSparseGELU, max_tracked_cnt=50_000, sparsity_level=0.99), +} + +ActivationClass = Literal[ + 'ReLU', 'PReLU', 'GELU', 'SiLU', + 'AReLU', 'APReLU', 'AGELU', 'ASiLU', + 'ReLUSquared', 'ReLUSquaredClipped', + 'GELUSquared', 'GELUSquaredClipped', + 'BSiLU', 'SUGARBSiLU', 'NoisyReLU', + 'QuantileReLU', 'QuantileReLU-10', 'QuantileReLU-25', 'QuantileReLU-50', 'QuantileReLU-75', 'QuantileReLU-90', + 'TopKSparseGELU', 'TopKSparseGELU-10', 'TopKSparseGELU-25', 'TopKSparseGELU-50', 'TopKSparseGELU-75', 'TopKSparseGELU-80', 'TopKSparseGELU-85', 'TopKSparseGELU-90', 'TopKSparseGELU-95', 'TopKSparseGELU-99', + 'TopKSparseGELU-AS', 'TopKSparseGELU-10-AS', 'TopKSparseGELU-25-AS', 'TopKSparseGELU-50-AS', 'TopKSparseGELU-75-AS', 'TopKSparseGELU-80-AS', 'TopKSparseGELU-85-AS', 'TopKSparseGELU-90-AS', 'TopKSparseGELU-95-AS', 'TopKSparseGELU-99-AS', +] \ No newline at end of file diff --git a/modifiers/decorators.py b/modifiers/decorators.py new file mode 100644 index 0000000..87a32ae --- /dev/null +++ b/modifiers/decorators.py @@ -0,0 +1,209 @@ +from typing import Literal, Optional, Type + +import torch +import torch.nn as nn + +from .utils import _review_as_with_batch + + +########################################################################## +# Decorator for analizing modules # +########################################################################## + +def analytical_activation_module(cls: Type[nn.Module]) -> Type[nn.Module]: + """ + Decorator to create an analytical version of a given nn.Module class. The resulting class will have additional attributes to store the input and output activations, as well as a debug_info flag to control whether these activations are stored during the forward pass. + """ + + class AnalyticalModule(cls): + def __init__( + self, + *args, + debug_info: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.debug_info = debug_info + self.in_activation = None + self.out_activation = None + + if self.debug_info: + def forward_hook(module, input, output): + self.in_activation = input[0].clone().detach() + self.out_activation = output.clone().detach() + + self.register_forward_hook(forward_hook) + + def extra_repr(self) -> str: + return f'debug_info={self.debug_info}, {super().extra_repr()}' + + AnalyticalModule.__name__ = f"Analytical{cls.__name__}" + + return AnalyticalModule + + +def analytical_linear_module(cls: Type[nn.Module]) -> Type[nn.Module]: + """ + Decorator to create an analytical version of a given nn.Module class. The resulting class will have additional attributes to store the input and output activations, as well as a debug_info flag to control whether these activations are stored during the forward pass. + """ + + class AnalyticalModule(cls): + def __init__( + self, + *args, + debug_info: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.debug_info = debug_info + + self.in_activation = None + self.out_activation = None + + self.grad_in_activation = None + self.grad_out_activation = None + + self.initial_weight_copy = None + + if self.debug_info: + def forward_hook(module, input, output): + self.in_activation = input[0].clone().detach() + self.out_activation = output.clone().detach() + + def backward_hook(module, grad_input, grad_output): + if grad_input[0] is not None: + self.grad_in_activation = grad_input[0].clone().detach() + if grad_output[0] is not None: + self.grad_out_activation = grad_output[0].clone().detach() + + self.register_forward_hook(forward_hook) + self.register_full_backward_hook(backward_hook) + + def extra_repr(self) -> str: + return f'debug_info={self.debug_info}, {super().extra_repr()}' + + def record_initial_weights(self): + self.initial_weight_copy = self.weight.detach().clone() + + @property + def z_score(self): + if self.initial_weight_copy is None: + raise ValueError("Initial weights not recorded. Call record_initial_weights() before accessing z_score.") + return (self.weight - self.initial_weight_copy).mean() / (self.initial_weight_copy.std() + self.weight.std() + 1e-8) * 2.0 + + @property + def weight_grad_mean(self): + return self.weight.grad.abs().mean() + + @property + def weight_grad_norm(self): + return self.weight.grad.norm() + + @property + def weight_mean(self): + return self.weight.mean() + + @property + def weight_norm(self): + return self.weight.norm() + + AnalyticalModule.__name__ = f"Analytical{cls.__name__}" + + return AnalyticalModule + + +########################################################################## +# Decorator for sparse activations # +########################################################################## + +def topk_sparse_module(cls: Type[nn.Module]) -> Type[nn.Module]: + """ + Decorator to create a sparse version of a given nn.Module class. The resulting class will have an additional attribute sparsity_level to control the level of sparsity applied to the activations during the forward pass. The sparsity is applied by zeroing out the smallest activations based on the specified sparsity level. + """ + + class SparseModule(cls): + def __init__( + self, + *args, + sparsity_level: Optional[float] = None, + post_sparsity: bool = True, + quantile_search_mode: Literal['global', 'batchwise', 'channelwise'] = 'channelwise', + + running_stats: bool = False, + running_shape: Optional[torch.Size] = None, + momentum: float = 0.1, + max_tracked_cnt: Optional[int] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + + assert sparsity_level is None or (0.0 < sparsity_level < 1.0), "sparsity_level must be in (0, 1)" + + self.sparsity_level = sparsity_level + self.post_sparsity = post_sparsity + self.running_stats = running_stats + self.momentum = momentum + + self.quantile_search_mode = quantile_search_mode + self.quantile_view_fn = { + 'global': lambda x: x.view(-1), + 'batchwise': lambda x: x.view(x.size(0), -1), + 'channelwise': lambda x: x.view(x.size(0), x.size(1), -1), + }[self.quantile_search_mode] + + if self.running_stats: + self.register_buffer('running_treshold', torch.zeros(running_shape)) + self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) + + self.running_treshold: torch.Tensor | None + self.num_batches_tracked: torch.Tensor | None + else: + self.register_buffer('running_treshold', None) + self.register_buffer('num_batches_tracked', None) + + self.max_tracked_cnt = max_tracked_cnt + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.post_sparsity: + x = super().forward(x) + + if self.sparsity_level is not None: + if self.running_stats and self.max_tracked_cnt is not None and self.max_tracked_cnt <= self.num_batches_tracked: + treshold = self.running_treshold + elif self.training or self.running_treshold is None: + # Compute quantile threshold + x_viewed = self.quantile_view_fn(x) + total_elements = x_viewed.size(dim=-1) # per-sample element count + n_remove = int(self.sparsity_level * total_elements) + 1 + + treshold = torch.kthvalue(x_viewed, n_remove, dim=-1).values + treshold = treshold.mean(dim=0) # Average over batch + + if self.running_stats: + with torch.no_grad(): + if self.num_batches_tracked != 0: + self.running_treshold = (1 - self.momentum) * self.running_treshold + self.momentum * treshold + else: + self.running_treshold = treshold + self.num_batches_tracked += 1 + # Compute quantile threshold + else: + treshold = self.running_treshold + + treshold = _review_as_with_batch(treshold, x.shape) + mask = x < treshold + x.masked_fill_(mask, 0.0) + + if not self.post_sparsity: + x = super().forward(x) + + return x + + def extra_repr(self) -> str: + return f'sparsity_level={self.sparsity_level}, {super().extra_repr()}' + + SparseModule.__name__ = f"TopK{cls.__name__}" + + return SparseModule diff --git a/modifiers/linears.py b/modifiers/linears.py new file mode 100644 index 0000000..103ad8c --- /dev/null +++ b/modifiers/linears.py @@ -0,0 +1,82 @@ +from typing import Literal +from functools import partial + +import torch.nn as nn + +from .decorators import analytical_linear_module, topk_sparse_module + + +########################################################################## +# Sparse activations # +########################################################################## + +@analytical_linear_module +@topk_sparse_module +class TopKSparseLinear(nn.Linear): + """ + TopKSparseLinear is a variant of the linear layer that applies sparsity to the activations by zeroing out the smallest activations based on a specified sparsity level. The sparsity is applied by keeping only the top k% of the activations, where k is determined by the sparsity_level parameter. + """ + pass + + +@analytical_linear_module +@topk_sparse_module +class TopKSparseConv2d(nn.Conv2d): + """ + TopKSparseConv2d is a variant of the 2D convolutional layer that applies sparsity to the activations by zeroing out the smallest activations based on a specified sparsity level. The sparsity is applied by keeping only the top k% of the activations, where k is determined by the sparsity_level parameter. + """ + pass + + +@analytical_linear_module +@topk_sparse_module +class TopKSparseConv1d(nn.Conv1d): + """ + TopKSparseConv1d is a variant of the 1D convolutional layer that applies sparsity to the activations by zeroing out the smallest activations based on a specified sparsity level. The sparsity is applied by keeping only the top k% of the activations, where k is determined by the sparsity_level parameter. + """ + pass + + +########################################################################## +# Mapping from string names to linear classes # +########################################################################## + +LINEAR_NAMES_MAP = { + 'Linear': nn.Linear, + 'Conv2d': nn.Conv2d, + 'Conv1d': nn.Conv1d, + + 'ALinear': analytical_linear_module(nn.Linear), + 'AConv2d': analytical_linear_module(nn.Conv2d), + 'AConv1d': analytical_linear_module(nn.Conv1d), + + 'TopKSparseLinear': TopKSparseLinear, + 'TopKSparseLinear-10': partial(TopKSparseLinear, sparsity_level=0.10), + 'TopKSparseLinear-25': partial(TopKSparseLinear, sparsity_level=0.25), + 'TopKSparseLinear-50': partial(TopKSparseLinear, sparsity_level=0.50), + 'TopKSparseLinear-75': partial(TopKSparseLinear, sparsity_level=0.75), + 'TopKSparseLinear-90': partial(TopKSparseLinear, sparsity_level=0.90), + + 'TopKSparseConv2d': TopKSparseConv2d, + 'TopKSparseConv2d-10': partial(TopKSparseConv2d, sparsity_level=0.10), + 'TopKSparseConv2d-25': partial(TopKSparseConv2d, sparsity_level=0.25), + 'TopKSparseConv2d-50': partial(TopKSparseConv2d, sparsity_level=0.50), + 'TopKSparseConv2d-75': partial(TopKSparseConv2d, sparsity_level=0.75), + 'TopKSparseConv2d-90': partial(TopKSparseConv2d, sparsity_level=0.90), + + 'TopKSparseConv1d': TopKSparseConv1d, + 'TopKSparseConv1d-10': partial(TopKSparseConv1d, sparsity_level=0.10), + 'TopKSparseConv1d-25': partial(TopKSparseConv1d, sparsity_level=0.25), + 'TopKSparseConv1d-50': partial(TopKSparseConv1d, sparsity_level=0.50), + 'TopKSparseConv1d-75': partial(TopKSparseConv1d, sparsity_level=0.75), + 'TopKSparseConv1d-90': partial(TopKSparseConv1d, sparsity_level=0.90), +} + +LinearClass = Literal[ + 'Linear', 'Conv2d', 'Conv1d', + 'ALinear', 'AConv2d', 'AConv1d', + + 'TopKSparseLinear', 'TopKSparseLinear-10', 'TopKSparseLinear-25', 'TopKSparseLinear-50', 'TopKSparseLinear-75', 'TopKSparseLinear-90', + 'TopKSparseConv2d', 'TopKSparseConv2d-10', 'TopKSparseConv2d-25', 'TopKSparseConv2d-50', 'TopKSparseConv2d-75', 'TopKSparseConv2d-90', + 'TopKSparseConv1d', 'TopKSparseConv1d-10', 'TopKSparseConv1d-25', 'TopKSparseConv1d-50', 'TopKSparseConv1d-75', 'TopKSparseConv1d-90', +] diff --git a/modifiers/modify.py b/modifiers/modify.py new file mode 100644 index 0000000..3d0b376 --- /dev/null +++ b/modifiers/modify.py @@ -0,0 +1,156 @@ +import inspect +from typing import Dict, List +from functools import partial + +import torch.nn as nn + +from .activations import ACTIVATION_NAMES_MAP, ActivationClass +from .normalizations import NORMALIZATION_NAMES_MAP, NormalizationClass +from .linears import LINEAR_NAMES_MAP, LinearClass + + +########################################################################## +# Function to replace layers in a module # +########################################################################## + +def create_layer_with_parameters(new_cls: nn.Module, old_layer: nn.Module, params_to_copy: List[str] = [], params_to_set: Dict[str, object] = {}) -> nn.Module: + new_layer_args = inspect.signature(new_cls.__init__).parameters + has_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in new_layer_args.values()) + layer_params = { + k: v for k, v in old_layer.__dict__.items() + if ( + k not in params_to_copy + and not k.startswith('_') + and (k in new_layer_args or has_kwargs) + ) + } + + for p, v in params_to_set.items(): + if p in new_layer_args or has_kwargs: + layer_params[p] = v + + new_layer = new_cls(**layer_params) + + for p in params_to_copy: + if (p not in layer_params) and hasattr(old_layer, p) and hasattr(new_layer, p): + setattr(new_layer, p, getattr(old_layer, p)) + + for p, v in params_to_set.items(): + if (p not in layer_params) and hasattr(new_layer, p): + setattr(new_layer, p, v) + + return new_layer + +def replace_activation( + module: nn.Module, + original_activation: ActivationClass = 'GELU', + replaced_activation: ActivationClass | nn.Module = 'ReLU', + **params_to_set, +) -> List[nn.Module]: + assert original_activation in ACTIVATION_NAMES_MAP, f"Original activation '{original_activation}' is not supported." + assert replaced_activation in ACTIVATION_NAMES_MAP or isinstance(replaced_activation, nn.Module), f"Replaced activation '{replaced_activation}' is not supported." + + original_cls = ACTIVATION_NAMES_MAP.get(original_activation) + replaced_cls = ACTIVATION_NAMES_MAP.get(replaced_activation) if isinstance(replaced_activation, str) else replaced_activation + + resulting_layers: List[nn.Module] = [] + + params_to_copy = ['training', 'num_batches_tracked', 'running_treshold'] + + for layer in module.modules(): + for child_name, child in layer.named_children(): + if type(child) is original_cls: + new_activation = create_layer_with_parameters(replaced_cls, child, params_to_copy=params_to_copy, params_to_set=params_to_set) + + setattr(layer, child_name, new_activation) + resulting_layers.append(new_activation) + + return resulting_layers + + +def replace_normalization( + module: nn.Module, + original_normalization: NormalizationClass = 'BatchNorm2d', + replaced_normalization: NormalizationClass | nn.Module = 'QuantileBatchNorm2d-50', + **params_to_set, +) -> List[nn.Module]: + assert original_normalization in NORMALIZATION_NAMES_MAP, f"Original normalization '{original_normalization}' is not supported." + assert replaced_normalization in NORMALIZATION_NAMES_MAP or isinstance(replaced_normalization, nn.Module), f"Replaced normalization '{replaced_normalization}' is not supported." + + original_cls = NORMALIZATION_NAMES_MAP.get(original_normalization) + replaced_cls = NORMALIZATION_NAMES_MAP.get(replaced_normalization) if isinstance(replaced_normalization, str) else replaced_normalization + + resulting_layers: List[nn.Module] = [] + + params_to_copy = ['training', 'running_mean', 'running_var'] + + for layer in module.modules(): + for child_name, child in layer.named_children(): + if type(child) is original_cls: + new_normalization = create_layer_with_parameters(replaced_cls, child, params_to_copy=params_to_copy, params_to_set=params_to_set) + + setattr(layer, child_name, new_normalization) + resulting_layers.append(new_normalization) + + return resulting_layers + + +def replace_linear( + module: nn.Module, + original_linear: LinearClass = 'Linear', + replaced_linear: LinearClass | nn.Module = 'TopKSparseLinear-50', + **params_to_set, +) -> List[nn.Module]: + assert original_linear in LINEAR_NAMES_MAP, f"Original linear '{original_linear}' is not supported." + assert replaced_linear in LINEAR_NAMES_MAP or isinstance(replaced_linear, nn.Module), f"Replaced linear '{replaced_linear}' is not supported." + + original_cls = LINEAR_NAMES_MAP.get(original_linear) + replaced_cls = LINEAR_NAMES_MAP.get(replaced_linear) if isinstance(replaced_linear, str) else replaced_linear + + resulting_layers: List[nn.Module] = [] + + params_to_copy = ['training', 'transposed', 'output_padding', 'bias'] + + for layer in module.modules(): + for child_name, child in layer.named_children(): + if type(child) is original_cls: + new_activation = create_layer_with_parameters(replaced_cls, child, params_to_copy=params_to_copy, params_to_set=params_to_set) + + setattr(layer, child_name, new_activation) + resulting_layers.append(new_activation) + + return resulting_layers + + +########################################################################## +# Wrapper functions for common modifications # +########################################################################## + + +relufiaction = partial(replace_activation, replaced_activation='ReLU') + + +def make_analytical_activation(module: nn.Module) -> List[nn.Module]: + return sum( + [ + replace_activation(module, original_activation=original_activation, replaced_activation=analytical_cls) + for original_activation, analytical_cls in [ + ('GELU', 'AGELU'), + ('ReLU', 'AReLU'), + ('SiLU', 'ASiLU'), + ] + ], start=[] + ) + + +def make_analytical_linear(module: nn.Module) -> List[nn.Module]: + return sum( + [ + replace_linear(module, original_linear=original_linear, replaced_linear=analytical_cls) + for original_linear, analytical_cls in [ + ('Linear', 'ALinear'), + ('Conv2d', 'AConv2d'), + ('Conv1d', 'AConv1d'), + ] + ], start=[] + ) diff --git a/modifiers/normalizations.py b/modifiers/normalizations.py new file mode 100644 index 0000000..96758b9 --- /dev/null +++ b/modifiers/normalizations.py @@ -0,0 +1,530 @@ +from typing import Optional, Literal +from functools import partial + +import torch +import torch.nn as nn + +from .utils import _review_as_with_batch + + +########################################################################## +# Hand-written implementations of normalization layers # +########################################################################## + +class BatchNorm2d(nn.Module): + """ + Hand-crafted implementation of 2D Batch Normalization. + This module normalizes the input across the batch dimension + + for each channel independently. + """ + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True + ): + """ + Args: + num_features (int): Number of feature channels in the input. + eps (float): A small value to avoid division by zero. + momentum (float): Momentum for running mean and variance. + affine (bool): If True, learnable scale and shift parameters are used. + track_running_stats (bool): If True, running mean and variance are tracked. + """ + super(BatchNorm2d, self).__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + + if self.affine: + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + + if self.track_running_stats: + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) + + self.running_mean: torch.Tensor | None + self.running_var: torch.Tensor | None + self.num_batches_tracked: torch.Tensor | None + else: + self.register_buffer('running_mean', None) + self.register_buffer('running_var', None) + self.register_buffer('num_batches_tracked', None) + + def forward( + self, + x: torch.Tensor, + batch_mean: Optional[torch.Tensor] = None, + batch_var: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Forward pass for batch normalization. + + Args: + x (torch.Tensor): Input tensor of shape (N, C, H, W). + batch_mean (Optional[torch.Tensor]): Optional precomputed batch mean. + batch_var (Optional[torch.Tensor]): Optional precomputed batch variance. + Returns: + torch.Tensor: Normalized tensor of the same shape as input. + """ + if x.dim() != 4: + raise ValueError("Expected input tensor to be 4D (N, C, H, W)") + N, C, H, W = x.shape + if C != self.num_features: + raise ValueError(f"Expected input with {self.num_features} channels, got {C}") + + if self.training or (self.running_mean is None) or (self.running_var is None): + if batch_mean is None: + batch_mean = x.mean(dim=(0, 2, 3)) + if batch_var is None: + batch_var = (x - _review_as_with_batch(batch_mean, x.shape)).square().mean(dim=(0, 2, 3)) + + if self.track_running_stats: + if self.momentum is None or self.momentum == 1.0: + self.running_mean = batch_mean + self.running_var = batch_var + else: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var + self.num_batches_tracked += 1 + else: + batch_mean = self.running_mean + batch_var = self.running_var + + x_normalized = (x - _review_as_with_batch(batch_mean, x.shape)) / (_review_as_with_batch(batch_var, x.shape) + self.eps).sqrt() + + if self.affine: + x_normalized = x_normalized * _review_as_with_batch(self.weight, x.shape) + _review_as_with_batch(self.bias, x.shape) + + return x_normalized + + def extra_repr(self) -> str: + return (f'num_features={self.num_features}, eps={self.eps}, ' + f'momentum={self.momentum}, affine={self.affine}, ' + f'track_running_stats={self.track_running_stats}') + + +class LayerNorm(nn.Module): + """ + Hand-crafted implementation of Layer Normalization. + """ + def __init__( + self, + normalized_shape, + eps: float = 0.00001, + elementwise_affine: bool = True, + bias: bool = True + ): + """ + Args: + normalized_shape (int or list or torch.Size): Input shape from an expected input. + eps (float): A small value to avoid division by zero. + elementwise_affine (bool): If True, learnable scale and shift parameters are used. + bias (bool): If True, adds a learnable bias to the output. + """ + super(LayerNorm, self).__init__() + + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + + self.normalized_shape = normalized_shape + self.eps = eps + self.elementwise_affine = elementwise_affine + self.use_bias = bias + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(normalized_shape)) + if self.use_bias: + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + else: + self.register_parameter('bias', None) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + + def forward( + self, + x: torch.Tensor, + layer_mean: Optional[torch.Tensor] = None, + layer_var: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Forward pass for layer normalization. + + Args: + x (torch.Tensor): Input tensor. + layer_mean (Optional[torch.Tensor]): Optional precomputed layer mean. + layer_var (Optional[torch.Tensor]): Optional precomputed layer variance. + Returns: + torch.Tensor: Normalized tensor of the same shape as input. + """ + + # Determine the dimensions to normalize over (last len(normalized_shape) dims) + dims = tuple(range(-len(self.normalized_shape), 0)) + + if layer_mean is None: + layer_mean = x.mean(dim=dims, keepdim=True) + if layer_var is None: + layer_var = (x - layer_mean).square().mean(dim=dims, keepdim=True) + + x_normalized = (x - layer_mean) / (layer_var + self.eps).sqrt() + + if self.elementwise_affine: + x_normalized = x_normalized * self.weight + if self.bias is not None: + x_normalized = x_normalized + self.bias + + return x_normalized + + def extra_repr(self) -> str: + return (f'normalized_shape={self.normalized_shape}, eps={self.eps}, ' + f'elementwise_affine={self.elementwise_affine}, bias={self.bias is not None}') + + +########################################################################## +# PreStop Normalization layers # +########################################################################## + + +class BatchNorm2dPreStop(BatchNorm2d): + """ + BatchNorm2d variant that stops updating running statistics after a certain number of batches. + After num_batches_tracked reaches max_tracked_cnt, the module will use the current running_mean and running_var for normalization without updating them further. + """ + + def __init__(self, *args, max_tracked_cnt: Optional[int] = None, **kwargs): + super().__init__(*args, **kwargs) + + self.max_tracked_cnt = max_tracked_cnt + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_mean = None + batch_var = None + + if self.track_running_stats and self.max_tracked_cnt is not None and self.max_tracked_cnt <= self.num_batches_tracked: + batch_mean = self.running_mean + batch_var = self.running_var + + return super().forward(x, batch_mean=batch_mean, batch_var=batch_var) + + def extra_repr(self) -> str: + return f'(pre-stop) max_tracked_cnt={self.max_tracked_cnt}, {super().extra_repr()}' + + +class LayerNormPreStop(LayerNorm): + """ + LayerNorm variant that stops updating running statistics after a certain number of batches. + After num_batches_tracked reaches max_tracked_cnt, the module will use the current running_layer_mean for normalization without updating it further. + """ + + def __init__( + self, + *args, + track_running_stats: bool = True, + running_shape: Optional[torch.Size] = None, + momentum: float = 0.1, + max_tracked_cnt: Optional[int] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.track_running_stats = track_running_stats + self.momentum = momentum + self.max_tracked_cnt = max_tracked_cnt + + if self.track_running_stats: + self.register_buffer('running_layer_mean', torch.zeros(running_shape)) + self.register_buffer('running_layer_var', torch.ones(running_shape)) + self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) + + self.running_layer_mean: torch.Tensor | None + self.running_layer_var: torch.Tensor | None + self.num_batches_tracked: torch.Tensor | None + else: + self.register_buffer('running_layer_mean', None) + self.register_buffer('running_layer_var', None) + self.register_buffer('num_batches_tracked', None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + layer_mean = None + layer_var = None + + if self.track_running_stats and self.max_tracked_cnt is not None and self.max_tracked_cnt <= self.num_batches_tracked: + layer_mean = self.running_layer_mean + layer_var = self.running_layer_var + elif self.training or self.running_layer_mean is None: + dims = tuple(range(-len(self.normalized_shape), 0)) + + layer_mean = x.mean(dim=dims, keepdim=True) + layer_var = (x - layer_mean).square().mean(dim=dims, keepdim=True) + + if self.track_running_stats: + with torch.no_grad(): + if self.num_batches_tracked != 0: + self.running_layer_mean = (1 - self.momentum) * self.running_layer_mean + self.momentum * layer_mean + self.running_layer_var = (1 - self.momentum) * self.running_layer_var + self.momentum * layer_var + else: + self.running_layer_mean = layer_mean + self.running_layer_var = layer_var + self.num_batches_tracked += 1 + else: + layer_mean = self.running_layer_mean + layer_var = self.running_layer_var + + output= super().forward(x, layer_mean=layer_mean, layer_var=layer_var) + return output + + def extra_repr(self) -> str: + return f'(pre-stop) max_tracked_cnt={self.max_tracked_cnt}, {super().extra_repr()}' + + +########################################################################## +# Quantile-based normalization layers # +########################################################################## + +class QuantileBatchNorm2d(BatchNorm2d): + def __init__( + self, + *args, + sparsity_level: Optional[float] = None, + quantile_search_mode: Literal['global', 'batchwise', 'channelwise'] = 'channelwise', + max_tracked_cnt: Optional[int] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.sparsity_level = sparsity_level + self.quantile_search_mode = quantile_search_mode + self.max_tracked_cnt = max_tracked_cnt + + self.quantile_view_fn = { + 'global': lambda x: x.view(-1), + 'batchwise': lambda x: x.view(x.size(0), -1), + 'channelwise': lambda x: x.view(x.size(0), x.size(1), -1), + }[self.quantile_search_mode] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_mean = None + + if self.sparsity_level is None: + return super().forward(x) + + if self.track_running_stats and self.max_tracked_cnt is not None and self.max_tracked_cnt <= self.num_batches_tracked: + batch_mean = self.running_mean + elif self.training and self.sparsity_level is not None: + x_viewed = self.quantile_view_fn(x) + + kth_element = int(self.sparsity_level * x_viewed.size(dim=-1)) + 1 + batch_mean = torch.kthvalue(x_viewed, kth_element, dim=-1).values + batch_mean = batch_mean.mean(dim=0) # Average over batch + + return super().forward(x, batch_mean=batch_mean) + + def extra_repr(self) -> str: + return f'quantile={self.sparsity_level}, {super().extra_repr()}' + + +class QuantileMeanBatchNorm2d(BatchNorm2d): + def __init__( + self, + *args, + sparsity_level: Optional[float] = None, + quantile_search_mode: Literal['global', 'batchwise', 'channelwise'] = 'channelwise', + max_tracked_cnt: Optional[int] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.sparsity_level = sparsity_level + self.quantile_search_mode = quantile_search_mode + self.max_tracked_cnt = max_tracked_cnt + + self.quantile_view_fn = { + 'global': lambda x: x.view(-1), + 'batchwise': lambda x: x.view(x.size(0), -1), + 'channelwise': lambda x: x.view(x.size(0), x.size(1), -1), + }[self.quantile_search_mode] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_mean, batch_var = None, None + + if self.sparsity_level is None: + return super().forward(x) + + if self.track_running_stats and self.max_tracked_cnt is not None and self.max_tracked_cnt <= self.num_batches_tracked: + batch_mean = self.running_mean + + batch_var = x.var(dim=(0, 2, 3), correction=0) + elif self.training and self.sparsity_level is not None: + x_viewed = self.quantile_view_fn(x) + + kth_element = int(self.sparsity_level * x_viewed.size(dim=-1)) + 1 + batch_mean = torch.kthvalue(x_viewed, kth_element, dim=-1).values + batch_mean = batch_mean.mean(dim=0) # Average over batch + + batch_var = x.var(dim=(0, 2, 3), correction=0) + + return super().forward(x, batch_mean=batch_mean, batch_var=batch_var) + + def extra_repr(self) -> str: + return f'(standart var) quantile={self.sparsity_level}, {super().extra_repr()}' + + +class QuantileLayerNorm(LayerNorm): + """ + Quantile version of the LayerNorm module. + This module normalizes only the non-zero elements in the input tensor. + """ + def __init__( + self, + *args, + sparsity_level: Optional[float] = None, + quantile_search_mode: Literal['global', 'batchwise', 'channelwise'] = 'channelwise', + + track_running_stats: bool = False, + running_shape: Optional[torch.Size] = None, + momentum: float = 0.1, + max_tracked_cnt: Optional[int] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + + assert sparsity_level is None or (0.0 < sparsity_level < 1.0), \ + "sparsity_level must be in the range (0.0, 1.0)" + + self.sparsity_level = sparsity_level + self.quantile_search_mode = quantile_search_mode + self.track_running_stats = track_running_stats + self.momentum = momentum + self.max_tracked_cnt = max_tracked_cnt + + self.quantile_view_fn = { + 'global': lambda x: x.view(-1), + 'batchwise': lambda x: x.view(x.size(0), -1), + 'channelwise': lambda x: x.view(x.size(0), x.size(1), -1), + }[self.quantile_search_mode] + + if self.track_running_stats: + self.register_buffer('running_layer_mean', torch.ones(running_shape)) + self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) + + self.running_layer_mean: torch.Tensor | None + self.num_batches_tracked: torch.Tensor | None + else: + self.register_buffer('running_layer_mean', None) + self.register_buffer('num_batches_tracked', None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for sparse batch normalization. + + Args: + x (torch.Tensor): Input tensor of shape (N, C, H, W). + Returns: + torch.Tensor: Normalized tensor of the same shape as input. + """ + layer_mean = None + + if self.sparsity_level is None: + return super().forward(x) + + if self.track_running_stats and self.max_tracked_cnt is not None and self.max_tracked_cnt <= self.num_batches_tracked: + layer_mean = self.running_layer_mean + elif self.training or self.running_layer_mean is None: + x_viewed = self.quantile_view_fn(x) + + # Compute quantile threshold + kth_element = int(self.sparsity_level * x_viewed.size(dim=-1)) + 1 + layer_mean = torch.kthvalue(x_viewed, kth_element, dim=-1).values + layer_mean = layer_mean.mean(dim=0) # Average over batch + + if self.track_running_stats: + with torch.no_grad(): + if self.num_batches_tracked != 0: + self.running_layer_mean = (1 - self.momentum) * self.running_layer_mean + self.momentum * layer_mean + else: + self.running_layer_mean = layer_mean + self.num_batches_tracked += 1 + else: + layer_mean = self.running_layer_mean + + layer_mean = _review_as_with_batch(layer_mean, x.shape) + + return super().forward(x, layer_mean=layer_mean) + + def extra_repr(self): + return f'quantile={self.sparsity_level}, {super().extra_repr()}' + +########################################################################## +# Mapping from string names to nomalization classes # +########################################################################## + +NORMALIZATION_NAMES_MAP = { + 'BatchNorm2d': nn.BatchNorm2d, + 'LayerNorm': nn.LayerNorm, + + 'BatchNorm2dPreStop': BatchNorm2dPreStop, + 'LayerNormPreStop': LayerNormPreStop, + + 'QuantileBatchNorm2d': QuantileBatchNorm2d, + 'QuantileBatchNorm2d-10': partial(QuantileBatchNorm2d, sparsity_level=0.1), + 'QuantileBatchNorm2d-25': partial(QuantileBatchNorm2d, sparsity_level=0.25), + 'QuantileBatchNorm2d-50': partial(QuantileBatchNorm2d, sparsity_level=0.50), + 'QuantileBatchNorm2d-75': partial(QuantileBatchNorm2d, sparsity_level=0.75), + 'QuantileBatchNorm2d-90': partial(QuantileBatchNorm2d, sparsity_level=0.90), + + 'QuantileBatchNorm2d-AS': partial(QuantileBatchNorm2d, max_tracked_cnt=50_000), + 'QuantileBatchNorm2d-10-AS': partial(QuantileBatchNorm2d, max_tracked_cnt=50_000, sparsity_level=0.1), + 'QuantileBatchNorm2d-25-AS': partial(QuantileBatchNorm2d, max_tracked_cnt=50_000, sparsity_level=0.25), + 'QuantileBatchNorm2d-50-AS': partial(QuantileBatchNorm2d, max_tracked_cnt=50_000, sparsity_level=0.50), + 'QuantileBatchNorm2d-75-AS': partial(QuantileBatchNorm2d, max_tracked_cnt=50_000, sparsity_level=0.75), + 'QuantileBatchNorm2d-90-AS': partial(QuantileBatchNorm2d, max_tracked_cnt=50_000, sparsity_level=0.90), + + 'QuantileMeanBatchNorm2d': QuantileMeanBatchNorm2d, + 'QuantileMeanBatchNorm2d-10': partial(QuantileMeanBatchNorm2d, sparsity_level=0.1), + 'QuantileMeanBatchNorm2d-25': partial(QuantileMeanBatchNorm2d, sparsity_level=0.25), + 'QuantileMeanBatchNorm2d-50': partial(QuantileMeanBatchNorm2d, sparsity_level=0.50), + 'QuantileMeanBatchNorm2d-75': partial(QuantileMeanBatchNorm2d, sparsity_level=0.75), + 'QuantileMeanBatchNorm2d-90': partial(QuantileMeanBatchNorm2d, sparsity_level=0.90), + + 'QuantileLayerNorm': QuantileLayerNorm, + 'QuantileLayerNorm-10': partial(QuantileLayerNorm, sparsity_level=0.1), + 'QuantileLayerNorm-25': partial(QuantileLayerNorm, sparsity_level=0.25), + 'QuantileLayerNorm-50': partial(QuantileLayerNorm, sparsity_level=0.50), + 'QuantileLayerNorm-75': partial(QuantileLayerNorm, sparsity_level=0.75), + 'QuantileLayerNorm-90': partial(QuantileLayerNorm, sparsity_level=0.90), + + 'QuantileLayerNorm-AS': partial(QuantileLayerNorm, max_tracked_cnt=50_000), + 'QuantileLayerNorm-10-AS': partial(QuantileLayerNorm, max_tracked_cnt=50_000, sparsity_level=0.1), + 'QuantileLayerNorm-25-AS': partial(QuantileLayerNorm, max_tracked_cnt=50_000, sparsity_level=0.25), + 'QuantileLayerNorm-50-AS': partial(QuantileLayerNorm, max_tracked_cnt=50_000, sparsity_level=0.50), + 'QuantileLayerNorm-75-AS': partial(QuantileLayerNorm, max_tracked_cnt=50_000, sparsity_level=0.75), + 'QuantileLayerNorm-90-AS': partial(QuantileLayerNorm, max_tracked_cnt=50_000, sparsity_level=0.90), +} + +NormalizationClass = Literal[ + 'BatchNorm2d', + 'LayerNorm', + + 'BatchNorm2dPreStop', + 'LayerNormPreStop', + + 'QuantileBatchNorm2d', 'QuantileBatchNorm2d-10', 'QuantileBatchNorm2d-25', 'QuantileBatchNorm2d-50', 'QuantileBatchNorm2d-75', 'QuantileBatchNorm2d-90', + 'QuantileBatchNorm2d-AS', 'QuantileBatchNorm2d-10-AS', 'QuantileBatchNorm2d-25-AS', 'QuantileBatchNorm2d-50-AS', 'QuantileBatchNorm2d-75-AS', 'QuantileBatchNorm2d-90-AS', + + 'QuantileMeanBatchNorm2d', 'QuantileMeanBatchNorm2d-10', 'QuantileMeanBatchNorm2d-25', 'QuantileMeanBatchNorm2d-50', 'QuantileMeanBatchNorm2d-75', 'QuantileMeanBatchNorm2d-90', + + 'QuantileLayerNorm', 'QuantileLayerNorm-10', 'QuantileLayerNorm-25', 'QuantileLayerNorm-50', 'QuantileLayerNorm-75', 'QuantileLayerNorm-90', + 'QuantileLayerNorm-AS', 'QuantileLayerNorm-10-AS', 'QuantileLayerNorm-25-AS', 'QuantileLayerNorm-50-AS', 'QuantileLayerNorm-75-AS', 'QuantileLayerNorm-90-AS', +] diff --git a/modifiers/utils.py b/modifiers/utils.py new file mode 100644 index 0000000..0ba6ef5 --- /dev/null +++ b/modifiers/utils.py @@ -0,0 +1,10 @@ +import torch + + +########################################################################## +# Auxiliary functions for normalization layers # +########################################################################## + +def _review_as_with_batch(x: torch.Tensor, target_shape: torch.Size) -> torch.Tensor: + extra_dims = len(target_shape) - len(x.shape) - 1 + return x.view(1, *x.shape, *((1,) * extra_dims)) diff --git a/tests/fixtures.py b/tests/fixtures.py new file mode 100644 index 0000000..72976e7 --- /dev/null +++ b/tests/fixtures.py @@ -0,0 +1,90 @@ +import pytest +import torch +import torch.nn as nn + + +# ============================================================================ +# FIXTURES - Setup and teardown helpers +# ============================================================================ + +@pytest.fixture +def sample_input(): + """Standard input tensor (2, 3, 4, 4).""" + return torch.randn(2, 3, 4, 4) + + +@pytest.fixture +def small_input(): + """Small input tensor (1, 1, 4, 4).""" + return torch.randn(1, 1, 4, 4) + + +@pytest.fixture +def batch_input(): + """Large batch input tensor (8, 64, 16, 16).""" + return torch.randn(8, 64, 16, 16) + + +@pytest.fixture +def layer_norm_input(): + """Input suitable for LayerNorm (4, 100).""" + return torch.randn(4, 100) + + +@pytest.fixture +def batch_norm_input(): + """Input suitable for BatchNorm2d (8, 32, 16, 16).""" + return torch.randn(8, 32, 16, 16) + + +@pytest.fixture +def grad_input(sample_input): + """Input tensor that requires gradients.""" + sample_input.requires_grad_(True) + return sample_input + + +@pytest.fixture +def zero_input(): + """Zero-valued input tensor.""" + return torch.zeros(2, 3, 4, 4) + + +@pytest.fixture +def large_value_input(): + """Input with very large values (1e6).""" + return torch.ones(2, 3, 4, 4) * 1e6 + + +@pytest.fixture +def negative_input(): + """Input with negative values (-5).""" + return torch.ones(2, 3, 4, 4) * -5.0 + + +@pytest.fixture +def model_with_gelu(): + """Model with GELU activations.""" + return nn.Sequential( + nn.Linear(10, 10), + nn.GELU(approximate='tanh'), + nn.Linear(10, 10), + nn.GELU(approximate='none') + ) + + +@pytest.fixture +def model_with_batchnorm(): + """Model with BatchNorm2d layers.""" + return nn.Sequential( + nn.Conv2d(3, 64, 3), + nn.BatchNorm2d(64), + nn.Conv2d(64, 128, 3), + nn.BatchNorm2d(128), + ) + + +@pytest.fixture +def device(): + """Detect and return available device (cuda or cpu).""" + return torch.device('cuda' if torch.cuda.is_available() else 'cpu') diff --git a/tests/test_activations.py b/tests/test_activations.py new file mode 100644 index 0000000..ba6319b --- /dev/null +++ b/tests/test_activations.py @@ -0,0 +1,389 @@ +""" +Comprehensive pytest test suite for Sparse-Activations library. + +This test suite provides comprehensive coverage with: +- Clear, readable test organization using fixtures +- Parametrized tests for testing multiple configurations +- Descriptive test names and docstrings +- Proper setup/teardown with fixtures +- Grouping of related tests in classes +""" + +import os +import sys + +import pytest +import torch +import torch.nn as nn + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from modifiers.activations import ( + ReLUSquared, + ReLUSquaredClipped, + GELUSquared, + GELUSquaredClipped, + BSiLU, + SUGARBSiLU, + NoisyReLU, + QuantileReLU, + ACTIVATION_NAMES_MAP, +) + + +from tests.fixtures import * + +# ============================================================================ +# TESTS - Activation functions +# ============================================================================ + +class TestReLUSquared: + """Tests for ReLUSquared activation.""" + + @pytest.mark.parametrize('activation_cls', [ReLUSquared, ReLUSquaredClipped]) + def test_applies_relu_then_square(self, activation_cls): + """Output should be ReLU(x)^2.""" + module = activation_cls() + x = torch.tensor([[-2.0, -1.0, 0.0, 1.0, 2.0]]) + y = module(x) + + expected = torch.tensor([[0.0, 0.0, 0.0, 1.0, 4.0]]) + assert torch.allclose(y, expected) + + @pytest.mark.parametrize('activation_cls', [ReLUSquared, ReLUSquaredClipped]) + def test_output_is_non_negative(self, activation_cls, sample_input): + """All outputs should be non-negative.""" + module = activation_cls() + y = module(sample_input) + + assert (y >= 0).all() + + + def test_clips_output_to_max_value(self): + """ReLUSquaredClipped should clip output to clip_value.""" + module = ReLUSquaredClipped(clip_value=2.0) + x = torch.tensor([[0.0, 1.0, 2.0, 3.0]]) + y = module(x) + + # ReLU: [0, 1, 2, 3], squared: [0, 1, 4, 9], clipped: [0, 1, 2, 2] + expected = torch.tensor([[0.0, 1.0, 2.0, 2.0]]) + assert torch.allclose(y, expected) + + +class TestGELUSquared: + """Tests for GELUSquared activation.""" + + @pytest.mark.parametrize('activation_cls', [GELUSquared, GELUSquaredClipped]) + def test_applies_gelu_then_square(self, activation_cls): + """Output should be GELU(x)^2.""" + module = activation_cls() + x = torch.tensor([[-2.0, -1.0, 0.0, 1.0, 2.0]]) + y = module(x) + + gelu = nn.GELU() + expected = gelu(x) ** 2 + assert torch.allclose(y, expected) + + @pytest.mark.parametrize('activation_cls', [GELUSquared, GELUSquaredClipped]) + def test_output_is_non_negative(self, activation_cls, sample_input): + """All outputs should be non-negative.""" + module = activation_cls() + y = module(sample_input) + + assert (y >= 0).all() + + def test_clips_output_to_max_value(self): + """GELUSquaredClipped should clip output to clip_value.""" + module = GELUSquaredClipped(clip_value=2.0) + x = torch.tensor([[0.0, 1.0, 2.0, 3.0]]) + gelu = nn.GELU() + expected = gelu(x) ** 2 + expected = torch.clamp(expected, max=2.0) + + y = module(x) + + assert torch.allclose(y, expected) + + +class TestBSiLU: + """Tests for BSiLU activation.""" + + def test_formula_at_zero(self): + """At x=0, BSiLU(0) = (0 + alpha) * 0.5 - alpha/2 = 0.""" + module = BSiLU(alpha=1.0) + x = torch.tensor([0.0]) + y = module(x) + + assert torch.allclose(y, torch.tensor([0.0]), atol=1e-5) + + def test_default_alpha_is_1_67(self): + """Default alpha should be approximately 1.67.""" + module = BSiLU() + assert abs(module.alpha - 1.67) < 0.01 + + @pytest.mark.parametrize('alpha', [0.5, 1.0, 1.67, 2.0, 3.0]) + def test_accepts_custom_alpha(self, alpha): + """Should accept custom alpha parameter.""" + module = BSiLU(alpha=alpha) + assert module.alpha == alpha + + def test_rejects_inplace_true(self): + """Should reject inplace=True since it breaks gradients.""" + with pytest.raises(AssertionError): + BSiLU(inplace=True) + + +class TestSUGARBSiLU: + """Tests for SUGARBSiLU activation.""" + + def test_forward_pass_produces_expected_output(self): + """Should produce expected output based on BSiLU formula.""" + module = SUGARBSiLU(alpha=1.0) + relu_module = nn.ReLU() + + x = torch.tensor([[-1.0, 0.0, 1.0]]) + y = module(x) + + expected = relu_module(y) + assert torch.allclose(y, expected, atol=1e-5) + + +class TestNoisyReLU: + """Tests for NoisyReLU activation.""" + + @pytest.mark.parametrize('alpha,c,noise_type', [ + (0.5, 1.0, 'normal'), + (1.0, 2.0, 'half-normal'), + (2.0, 0.5, 'normal'), + ]) + def test_accepts_custom_parameters(self, alpha, c, noise_type): + """Should accept custom parameters.""" + module = NoisyReLU(alpha=alpha, c=c, noise_type=noise_type) + + assert module.alpha == alpha + assert module.c == c + assert module.noise_type == noise_type + + def test_learnable_parameter_p(self): + """Parameter p should be learnable.""" + module = NoisyReLU() + + assert isinstance(module.p, nn.Parameter) + assert module.p.requires_grad + + with torch.no_grad(): + x = torch.randn(2, 3, 4, 4) + _ = module(x) + + assert module.p.grad is None + + def test_no_noise_in_eval_mode(self, sample_input): + """Eval mode should produce deterministic output (no noise).""" + module = NoisyReLU() + module.eval() + + relu_module = nn.ReLU() + + y1 = module(sample_input) + y2 = module(sample_input) + y3 = relu_module(sample_input) + + assert torch.equal(y1, y2) + assert torch.equal(y1, y3) + + def test_adds_noise_in_training_mode(self, sample_input): + """Training mode should apply noise transformation.""" + module = NoisyReLU(c=1.0) + module.train() + + relu_module = nn.ReLU() + + y1 = module(sample_input) + y2 = module(sample_input) + y3 = relu_module(sample_input) + + assert ((sample_input >= 0) | (y1 != 0.0)).all() # There are should be noise at the negative side + assert ((sample_input >= 0) | (y1 != y2)).all() + assert ((sample_input < 0) | (y1 == y3)).all() + + def test_rejects_inplace_true(self): + """Should reject inplace=True since it breaks gradients.""" + with pytest.raises(AssertionError): + NoisyReLU(inplace=True) + + +class TestQuantileReLU: + """Tests for QuantileReLU activation.""" + + def test_without_sparsity_level_behaves_like_relu(self, sample_input): + """When sparsity_level=None, should act like standard ReLU.""" + module = QuantileReLU(sparsity_level=None) + expected = nn.ReLU()(sample_input) + + y = module(sample_input) + + assert torch.allclose(y, expected) + + @pytest.mark.parametrize('sparsity_level', [0.1, 0.25, 0.5, 0.75, 0.9]) + def test_accepts_valid_sparsity_levels(self, sparsity_level): + """Should accept sparsity levels in (0, 1) range.""" + module = QuantileReLU(sparsity_level=sparsity_level) + assert module.sparsity_level == sparsity_level + + +# ============================================================================ +# TESTS - Property preservation and edge cases +# ============================================================================ + +activations_parametrization = pytest.mark.parametrize('activation_cls', [ReLUSquared, ReLUSquaredClipped, GELUSquared, GELUSquaredClipped, BSiLU, SUGARBSiLU, NoisyReLU, QuantileReLU]) + +class TestPropertyPreservation: + """Tests for property and state preservation.""" + + @activations_parametrization + def test_preserves_input_dtype(self, sample_input, activation_cls): + """Output dtype should match input dtype.""" + module = activation_cls() + + for dtype in [torch.float32, torch.float64]: + x = sample_input.to(dtype) + y = module(x) + + assert y.dtype == dtype + + @activations_parametrization + def test_device_transfer_works(self, sample_input, device, activation_cls): + """Module should work after device transfer.""" + module = activation_cls() + module = module.to(device) + + x = sample_input.to(device) + y = module(x) + + assert y.device.type == device.type + + @activations_parametrization + def test_shape_preservation(self, sample_input, activation_cls): + """Output shape should match input shape.""" + module = activation_cls() + y = module(sample_input) + + assert y.shape == sample_input.shape + + +class TestAutogradVersioning: + """Tests to ensure autograd versioning is not broken.""" + + @activations_parametrization + def test_no_inplace_breaks_in_relu_squared(self, grad_input, activation_cls): + """ReLUSquared should not break autograd.""" + module = activation_cls() + y = module(grad_input) + loss = y.sum() + + loss.backward() + + assert grad_input.grad is not None + + @activations_parametrization + def test_gradient_accumulation_works(self, sample_input, activation_cls): + """Should support gradient accumulation.""" + module = activation_cls() + + for _ in range(3): + x = sample_input.clone().detach().requires_grad_(True) + y = module(x) + loss = y.sum() + loss.backward() + + assert x.grad is not None + + @activations_parametrization + def test_sequential_forward_backward_passes(self, batch_input, activation_cls): + """Should support multiple sequential passes.""" + model = nn.Sequential( + activation_cls(), + nn.Linear(16, 10), + ) + + for _ in range(3): + x = batch_input.clone().detach().requires_grad_(True) + y = model(x) + loss = y.sum() + loss.backward() + + assert x.grad is not None + + +class TestEdgeCases: + """Tests for edge cases and boundary conditions.""" + + @activations_parametrization + def test_single_element_batch(self, activation_cls): + """Should work with batch size 1.""" + module = activation_cls() + x = torch.randn(1, 3, 4, 4) + y = module(x) + + assert y.shape == x.shape + + @activations_parametrization + def test_small_spatial_dimensions(self, activation_cls): + """Should work with 1×1 spatial dimensions.""" + module = activation_cls() + x = torch.randn(2, 3, 1, 1) + y = module(x) + + assert y.shape == x.shape + + @activations_parametrization + def test_zero_input_produces_zero_output(self, zero_input, activation_cls): + """ReLUSquared(0) should be all zeros.""" + module = activation_cls() + y = module(zero_input) + + assert torch.allclose(y, torch.zeros_like(y)) + + @activations_parametrization + def test_large_values_are_handled(self, large_value_input, activation_cls): + """Should handle very large input values.""" + module = activation_cls() + y = module(large_value_input) + + assert not torch.isnan(y).any() + + @pytest.mark.parametrize('activation_cls', [ReLUSquared, ReLUSquaredClipped]) + def test_negative_values_produce_zeros(self, negative_input, activation_cls): + """ReLU should zero out all negative values.""" + module = activation_cls() + y = module(negative_input) + + assert torch.allclose(y, torch.zeros_like(y)) + + +# ============================================================================ +# TESTS - Registry coverage +# ============================================================================ + +class TestActivationRegistry: + """Tests for all activations in ACTIVATION_NAMES_MAP.""" + + @pytest.mark.parametrize('name', list(ACTIVATION_NAMES_MAP.keys())) + def test_all_activations_are_instantiable(self, name): + """All registered activations should be instantiable.""" + cls_or_partial = ACTIVATION_NAMES_MAP[name] + + if callable(cls_or_partial): + instance = cls_or_partial() + assert instance is not None + + @pytest.mark.parametrize('name', list(ACTIVATION_NAMES_MAP.keys())) + def test_all_activations_have_forward_pass(self, name, sample_input): + """All activations should support forward pass.""" + cls_or_partial = ACTIVATION_NAMES_MAP[name] + + if callable(cls_or_partial): + instance = cls_or_partial() + y = instance(sample_input) + + assert y.shape == sample_input.shape diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 0000000..3f38543 --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,255 @@ + + +import os +import sys + +import pytest +import torch +import torch.nn as nn + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from modifiers.decorators import analytical_module, topk_sparse_module + +from tests.fixtures import * + + +activation_parametrization = pytest.mark.parametrize('activation_cls', [nn.ReLU, nn.GELU]) +sparsity_parametrization = pytest.mark.parametrize('sparsity', [0.1, 0.25, 0.5, 0.75, 0.9]) +debug_info_parametrization = pytest.mark.parametrize('debug_info', [True, False]) + + +# ============================================================================ +# TESTS - Decorator functionality +# ============================================================================ + +class TestAnalyticalModuleDecorator: + """Tests for analytical_module decorator.""" + + @activation_parametrization + @debug_info_parametrization + def test_creates_new_class_with_debug_attributes(self, activation_cls, debug_info): + """Should create new class with debug_info, in_activation, out_activation.""" + AnalyticalReLU = analytical_module(activation_cls) + module = AnalyticalReLU(debug_info=debug_info) + + assert hasattr(module, 'debug_info') + assert hasattr(module, 'in_activation') + assert hasattr(module, 'out_activation') + assert module.debug_info is debug_info + + def test_debug_info_disabled_does_not_store_activations(self, sample_input): + """When debug_info=False, activations should not be stored.""" + AnalyticalReLU = analytical_module(nn.ReLU) + module = AnalyticalReLU(debug_info=False) + + _ = module(sample_input) + + assert module.in_activation is None + assert module.out_activation is None + + @activation_parametrization + def test_debug_info_enabled_stores_activations(self, sample_input, activation_cls): + """When debug_info=True, activations should be stored.""" + AnalyticalReLU = analytical_module(activation_cls) + module = AnalyticalReLU(debug_info=True) + + y = module(sample_input) + + assert module.in_activation is not None + assert module.out_activation is not None + assert torch.equal(module.in_activation, sample_input) + assert torch.equal(module.out_activation, y) + + @activation_parametrization + def test_preserves_class_name_with_analytical_prefix(self, activation_cls): + """Class name should include 'Analytical' prefix.""" + AnalyticalReLU = analytical_module(activation_cls) + + assert AnalyticalReLU.__name__.startswith('Analytical') + assert AnalyticalReLU.__name__.endswith(activation_cls.__name__) + + +class TestTopKSparseModuleDecorator: + """Tests for topk_sparse_module decorator.""" + + @activation_parametrization + @sparsity_parametrization + def test_creates_new_class_with_sparsity_attributes(self, activation_cls, sparsity): + """Should create new class with sparsity_level and post_sparsity.""" + SparseReLU = topk_sparse_module(activation_cls) + module = SparseReLU(sparsity_level=sparsity, post_sparsity=True) + + assert hasattr(module, 'sparsity_level') + assert hasattr(module, 'post_sparsity') + assert module.sparsity_level == sparsity + assert module.post_sparsity is True + + @pytest.mark.parametrize('invalid_sparsity', [-0.5, 1.5, 2.0]) + def test_rejects_invalid_sparsity_levels(self, invalid_sparsity): + """Should reject sparsity levels outside [0, 1] range.""" + SparseReLU = topk_sparse_module(nn.ReLU) + + with pytest.raises(AssertionError): + SparseReLU(sparsity_level=invalid_sparsity) + + @activation_parametrization + def test_no_sparsity_when_level_is_none(self, sample_input, activation_cls): + """When sparsity_level=None, output should equal standard module.""" + SparseReLU = topk_sparse_module(activation_cls) + sparse_module = SparseReLU(sparsity_level=None) + normal_module = activation_cls() + + y_sparse = sparse_module(sample_input) + y_normal = normal_module(sample_input) + + assert torch.allclose(y_sparse, y_normal) + + @activation_parametrization + def test_full_sparsity_zeros_out_all_activations(self, sample_input, activation_cls): + """When sparsity_level=1.0, all activations should be zero.""" + SparseReLU = topk_sparse_module(activation_cls) + sparse_module = SparseReLU(sparsity_level=1.0, post_sparsity=True) + + y_sparse = sparse_module(sample_input) + y_normal = torch.zeros_like(sample_input) + + assert torch.allclose(y_sparse, y_normal) + + def test_full_sparsity_zeros_out_all_activations_sigmoid(self, sample_input): + """When sparsity_level=1.0, all activations should be zero.""" + SparseReLU = topk_sparse_module(nn.Sigmoid) + sparse_module = SparseReLU(sparsity_level=1.0, post_sparsity=True) + + y_sparse = sparse_module(sample_input) + y_normal = torch.zeros_like(sample_input) + + assert torch.allclose(y_sparse, y_normal) + + SparseReLU = topk_sparse_module(nn.Sigmoid) + sparse_module = SparseReLU(sparsity_level=1.0, post_sparsity=False) + + y_sparse = sparse_module(sample_input) + y_normal = torch.full_like(sample_input, 0.5) + + assert torch.allclose(y_sparse, y_normal) + + @activation_parametrization + @sparsity_parametrization + def test_sparsity_zeros_out_small_activations(self, batch_input, activation_cls, sparsity): + """Should zero out approximately correct percentage of small values.""" + SparseReLU = topk_sparse_module(activation_cls) + module = SparseReLU(sparsity_level=sparsity, post_sparsity=True) + + y = module(batch_input) + + zero_count = (y == 0).sum().item() + total_count = y.numel() + sparsity_ratio = zero_count / total_count + + # Allow 10% tolerance due to discrete top-k behavior + if activation_cls != nn.ReLU: + assert abs(sparsity_ratio - sparsity) < 0.1 + else: + # For ReLU, we expect the sparsity to be higher due to zeroing out negative values + assert sparsity_ratio >= sparsity - 0.1 + + @activation_parametrization + def test_preserves_class_name_with_topk_prefix(self, activation_cls): + """Class name should include 'TopK' prefix.""" + SparseReLU = topk_sparse_module(activation_cls) + + assert SparseReLU.__name__.startswith('TopK') + assert SparseReLU.__name__.endswith(activation_cls.__name__) + +class TestCombinedDecorators: + """Tests for using both analytical_module and topk_sparse_module together.""" + + @activation_parametrization + @sparsity_parametrization + @debug_info_parametrization + def test_combined_decorators_work_together(self, batch_input, activation_cls, sparsity, debug_info): + """Should apply both analytical and sparsity modifications correctly.""" + AnalyticalSparseReLU = analytical_module(topk_sparse_module(activation_cls)) + module = AnalyticalSparseReLU(sparsity_level=sparsity, post_sparsity=True, debug_info=debug_info) + + y = module(batch_input.clone()) + + # Check debug info attributes + if debug_info: + assert module.in_activation is not None + assert module.out_activation is not None + assert torch.equal(module.in_activation, batch_input) + assert torch.equal(module.out_activation, y) + else: + assert module.in_activation is None + assert module.out_activation is None + + # Check sparsity level approximately correct + zero_count = (y == 0).sum().item() + total_count = y.numel() + sparsity_ratio = zero_count / total_count + + # For ReLU, we expect the sparsity to be higher due to zeroing out negative values + if activation_cls != nn.ReLU: + assert abs(sparsity_ratio - sparsity) < 0.1 + else: + assert sparsity_ratio >= sparsity - 0.1 + + + @activation_parametrization + @sparsity_parametrization + @debug_info_parametrization + def test_inner_application_order(self, batch_input, activation_cls, sparsity, debug_info): + """Should apply decorators in correct order (analytical outside topk).""" + AnalyticalSparseReLU = topk_sparse_module(analytical_module(activation_cls)) + module = AnalyticalSparseReLU(sparsity_level=sparsity, post_sparsity=True, debug_info=debug_info) + + y = module(batch_input.clone()) + + # Check debug info attributes + if debug_info: + assert module.in_activation is not None + assert module.out_activation is not None + assert torch.equal(module.in_activation, batch_input) + # assert not torch.equal(module.out_activation, y) # TODO: Fix this check to work with low sparsity levels + else: + assert module.in_activation is None + assert module.out_activation is None + + # Check sparsity level approximately correct + zero_count = (y == 0).sum().item() + total_count = y.numel() + sparsity_ratio = zero_count / total_count + + # For ReLU, we expect the sparsity to be higher due to zeroing out negative values + if activation_cls != nn.ReLU: + assert abs(sparsity_ratio - sparsity) < 0.1 + else: + assert sparsity_ratio >= sparsity - 0.1 + + AnalyticalSparseReLU = topk_sparse_module(analytical_module(activation_cls)) + module = AnalyticalSparseReLU(sparsity_level=sparsity, post_sparsity=False, debug_info=debug_info) + + y = module(batch_input.clone()) + + # Check debug info attributes + if debug_info: + assert module.in_activation is not None + assert module.out_activation is not None + assert not torch.equal(module.in_activation, batch_input) + assert torch.equal(module.out_activation, y) + else: + assert module.in_activation is None + assert module.out_activation is None + + # Check sparsity level approximately correct + zero_count = (y == 0).sum().item() + total_count = y.numel() + sparsity_ratio = zero_count / total_count + + # For ReLU, we expect the sparsity to be higher due to zeroing out negative values + if activation_cls != nn.ReLU: + assert abs(sparsity_ratio - sparsity) < 0.1 + else: + assert sparsity_ratio >= sparsity - 0.1 diff --git a/tests/test_modifiers.py b/tests/test_modifiers.py new file mode 100644 index 0000000..8dc0126 --- /dev/null +++ b/tests/test_modifiers.py @@ -0,0 +1,124 @@ + +import os +import sys + +import pytest +import torch +import torch.nn as nn + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from modifiers.normalizations import ( + QuantileBatchNorm2d, +) +from modifiers.activations import ( + TopKSparseGELU, +) +from modifiers.modify import ( + replace_activation, + replace_normalization, + relufiaction, +) + +from tests.fixtures import * + + +# ============================================================================ +# TESTS - Modifier functions +# ============================================================================ + +class TestReplaceActivation: + """Tests for replace_activation function.""" + + def test_replaces_gelu_with_relu(self, model_with_gelu): + """Should replace GELU with ReLU.""" + layers = replace_activation(model_with_gelu, 'GELU', 'ReLU') + + assert isinstance(layers, list) + assert len(layers) == 2 + assert all(isinstance(layer, nn.ReLU) for layer in layers) + + for layer in model_with_gelu.modules(): + assert not isinstance(layer, nn.GELU) + + def test_dont_replaces_silu_with_relu(self, model_with_gelu): + """Should not replace SiLU with ReLU.""" + layers = replace_activation(model_with_gelu, 'SiLU', 'ReLU') + + assert len(layers) == 0 + + for layer in model_with_gelu.modules(): + assert not isinstance(layer, nn.ReLU) + + def test_preserves_model_functionality(self, model_with_gelu): + """Model should still be functional after replacement.""" + replace_activation(model_with_gelu, 'GELU', 'ReLU') + + x = torch.randn(2, 10) + y = model_with_gelu(x) + + assert y.shape == (2, 10) + + def test_rejects_invalid_original_activation(self, model_with_gelu): + """Should raise AssertionError for invalid original activation.""" + with pytest.raises(AssertionError): + replace_activation(model_with_gelu, 'InvalidActivation', 'ReLU') + + def test_rejects_invalid_replaced_activation(self, model_with_gelu): + """Should raise AssertionError for invalid replacement activation.""" + with pytest.raises(AssertionError): + replace_activation(model_with_gelu, 'GELU', 'InvalidActivation') + + def test_preserves_class_parameters(self, model_with_gelu): + """Should preserve parameters of replaced activations.""" + initial_approximate_values = [layer.approximate for layer in model_with_gelu if isinstance(layer, nn.GELU)] + + layers = replace_activation(model_with_gelu, 'GELU', 'TopKSparseGELU-10') + + assert len(layers) == 2 + assert all(isinstance(layer, TopKSparseGELU) for layer in layers) + assert all(layer.approximate == approximate for layer, approximate in zip(layers, initial_approximate_values)) + + +class TestReplaceNormalization: + """Tests for replace_normalization function.""" + + def test_replaces_batchnorm_with_quantile(self, model_with_batchnorm): + """Should replace BatchNorm2d with QuantileBatchNorm2d.""" + layers = replace_normalization( + model_with_batchnorm, 'BatchNorm2d', 'QuantileBatchNorm2d' + ) + + assert len(layers) == 2 + assert all(isinstance(layer, QuantileBatchNorm2d) for layer in layers) + + def test_infers_num_features_correctly(self, model_with_batchnorm): + """Should preserve num_features during replacement.""" + params = [] + for layer in model_with_batchnorm.modules(): + if isinstance(layer, nn.BatchNorm2d): + layer.running_mean = torch.randn(layer.num_features) + layer.running_var = torch.rand(layer.num_features) + 0.1 # Avoid zero variance + params.append((layer.num_features, layer.running_mean.clone(), layer.running_var.clone())) + + layers = replace_normalization( + model_with_batchnorm, 'BatchNorm2d', 'QuantileBatchNorm2d' + ) + + assert len(layers) == 2 + for layer, (num_features, running_mean, running_var) in zip(layers, params): + assert isinstance(layer, QuantileBatchNorm2d) + assert layer.num_features == num_features + assert torch.allclose(layer.running_mean, running_mean) + assert torch.allclose(layer.running_var, running_var) + + +class TestRelufiaction: + """Tests for relufiaction helper function.""" + + def test_replaces_with_relu_by_default(self, model_with_gelu): + """Should replace activations with ReLU.""" + layers = relufiaction(model_with_gelu) + + assert len(layers) == 2 + assert all(isinstance(layer, nn.ReLU) for layer in layers) diff --git a/tests/test_normalizations.py b/tests/test_normalizations.py new file mode 100644 index 0000000..08d15a9 --- /dev/null +++ b/tests/test_normalizations.py @@ -0,0 +1,240 @@ + +import os +import sys + +import pytest +import torch +import torch.nn as nn + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from modifiers.normalizations import ( + BatchNorm2d, + LayerNorm, + QuantileBatchNorm2d, + QuantileLayerNorm, + QuantileMeanBatchNorm2d, + BatchNorm2dPreStop, + NORMALIZATION_NAMES_MAP, +) + + +from tests.fixtures import * + +# ============================================================================ +# TESTS - Normalization layers +# ============================================================================ + +class TestBatchNorm2d: + """Tests for custom BatchNorm2d implementation.""" + + def test_output_shape_matches_input(self, batch_norm_input): + """Output shape should match input shape.""" + module = BatchNorm2d(32) + original_module = nn.BatchNorm2d(32) + + y = module(batch_norm_input) + expected_y = original_module(batch_norm_input) + + assert y.shape == batch_norm_input.shape + assert torch.allclose(y, expected_y, atol=1e-5) + + assert torch.allclose(module.running_mean, original_module.running_mean, atol=1e-5) + assert torch.allclose(module.running_var, original_module.running_var, atol=1e-3) + + def test_running_stats_tracked_in_training(self, batch_norm_input): + """Running mean/var should be updated during training.""" + module = BatchNorm2d(32) + module.train() + + # Get initial running mean + initial_mean = module.running_mean.clone() + + # Forward pass + module(batch_norm_input) + + # Running mean should change + assert not torch.allclose(module.running_mean, initial_mean) + + def test_affine_parameters_are_trainable(self): + """Weight and bias parameters should be trainable.""" + module = BatchNorm2d(32, affine=True) + + assert module.weight is not None + assert module.bias is not None + assert module.weight.requires_grad + assert module.bias.requires_grad + + def test_no_affine_parameters_when_disabled(self): + """When affine=False, weight and bias should be None.""" + module = BatchNorm2d(32, affine=False) + + assert module.weight is None + assert module.bias is None + + +class TestLayerNorm: + """Tests for custom LayerNorm implementation.""" + + def test_output_shape_matches_input(self, layer_norm_input): + """Output shape should match input shape.""" + module = LayerNorm(100) + original_module = nn.LayerNorm(100) + + y = module(layer_norm_input) + expected_y = original_module(layer_norm_input) + + assert y.shape == layer_norm_input.shape + assert torch.allclose(y, expected_y, atol=1e-5) + + def test_normalized_mean_is_near_zero(self, layer_norm_input): + """Normalized output mean should be close to 0.""" + module = LayerNorm(100) + y = module(layer_norm_input) + + mean = y.mean(dim=-1) + assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-5) + + def test_normalized_std_is_near_one(self, layer_norm_input): + """Normalized output std should be close to 1.""" + module = LayerNorm(100) + y = module(layer_norm_input) + + std = y.std(dim=-1) + assert torch.allclose(std, torch.ones_like(std), atol=0.1) + + +class TestQuantileBatchNorm2d: + """Tests for QuantileBatchNorm2d normalization.""" + + def test_behaves_like_batchnorm_without_sparsity(self, batch_norm_input): + """Without sparsity_level, should behave normally.""" + module = QuantileBatchNorm2d(32, sparsity_level=None) + y = module(batch_norm_input) + + assert y.shape == batch_norm_input.shape + + @pytest.mark.parametrize('mode', ['global', 'batchwise', 'channelwise']) + def test_different_quantile_search_modes(self, batch_norm_input, mode): + """Should work with different search modes.""" + module = QuantileBatchNorm2d( + 32, sparsity_level=0.5, quantile_search_mode=mode + ) + module.train() + y = module(batch_norm_input) + + assert y.shape == batch_norm_input.shape + + def test_max_tracked_cnt_stops_updates(self, batch_norm_input): + """Updates should stop after max_tracked_cnt is reached.""" + module = QuantileBatchNorm2d(32, sparsity_level=0.5, max_tracked_cnt=2) + module.train() + + # Process first 2 batches + module(batch_norm_input) + module(batch_norm_input) + mean_after_2 = module.running_mean.clone() + + # Process 3rd batch + module(batch_norm_input) + mean_after_3 = module.running_mean.clone() + + assert torch.allclose(mean_after_2, mean_after_3) + + +class TestQuantileLayerNorm: + """Tests for QuantileLayerNorm normalization.""" + + def test_behaves_like_layernorm_without_sparsity(self, layer_norm_input): + """Without sparsity_level, should behave like LayerNorm.""" + module = QuantileLayerNorm(100, sparsity_level=None) + y = module(layer_norm_input) + + assert y.shape == layer_norm_input.shape + + @pytest.mark.parametrize('invalid_sparsity', [0.0, 1.0]) + def test_rejects_invalid_sparsity_levels(self, invalid_sparsity): + """Should reject sparsity_level=0.0 or 1.0.""" + with pytest.raises(AssertionError): + QuantileLayerNorm(100, sparsity_level=invalid_sparsity) + + @pytest.mark.parametrize('mode', ['global', 'batchwise', 'channelwise']) + def test_different_quantile_search_modes(self, layer_norm_input, mode): + """Should work with different search modes.""" + module = QuantileLayerNorm( + 100, sparsity_level=0.5, quantile_search_mode=mode + ) + module.train() + y = module(layer_norm_input) + + assert y.shape == layer_norm_input.shape + + +class TestQuantileMeanBatchNorm2d: + """Tests for QuantileMeanBatchNorm2d normalization.""" + + def test_output_shape_matches_input(self, batch_norm_input): + """Output shape should match input shape.""" + module = QuantileMeanBatchNorm2d(32, sparsity_level=0.5) + module.train() + y = module(batch_norm_input) + + assert y.shape == batch_norm_input.shape + + +class TestBatchNorm2dPreStop: + """Tests for BatchNorm2dPreStop normalization.""" + + def test_stops_updating_after_max_tracked_cnt(self, batch_norm_input): + """Updates should stop after max_tracked_cnt is reached.""" + module = BatchNorm2dPreStop(32, max_tracked_cnt=2) + module.train() + + # Process first 2 batches + module(batch_norm_input) + module(batch_norm_input) + mean_after_2 = module.running_mean.clone() + + # Process 3rd batch + module(batch_norm_input) + mean_after_3 = module.running_mean.clone() + + assert torch.allclose(mean_after_2, mean_after_3) + + +# ============================================================================ +# TESTS - Property preservation and edge cases +# ============================================================================ + + +class TestNormalizationRegistry: + """Tests for all normalizations in NORMALIZATION_NAMES_MAP.""" + + @pytest.mark.parametrize('name', list(NORMALIZATION_NAMES_MAP.keys())) + def test_all_normalizations_are_instantiable(self, name): + """All registered normalizations should be instantiable.""" + cls_or_partial = NORMALIZATION_NAMES_MAP[name] + + if callable(cls_or_partial): + if 'BatchNorm' in name or 'QuantileBatchNorm' in name: + instance = cls_or_partial(3) + else: + instance = cls_or_partial(100) + + assert instance is not None + + @pytest.mark.parametrize('name', list(NORMALIZATION_NAMES_MAP.keys())) + def test_all_normalizations_have_forward_pass(self, name): + """All normalizations should support forward pass.""" + cls_or_partial = NORMALIZATION_NAMES_MAP[name] + + if callable(cls_or_partial): + if 'BatchNorm' in name: + instance = cls_or_partial(3) + x = torch.randn(2, 3, 4, 4) + else: + instance = cls_or_partial(100) + x = torch.randn(4, 100) + + y = instance(x) + assert y.shape == x.shape