From 55a398fb26fa6462e3ff4e354f31b2af02b54b03 Mon Sep 17 00:00:00 2001 From: Sheila Seidel Date: Thu, 24 Jul 2025 19:25:43 -0500 Subject: [PATCH 01/12] draft gru example --- examples/GRUs/gru_pruning_example.py | 170 +++++++++++++ examples/GRUs/gru_utils.py | 352 +++++++++++++++++++++++++++ examples/GRUs/readme.md | 43 ++++ examples/GRUs/test_gru.py | 308 +++++++++++++++++++++++ 4 files changed, 873 insertions(+) create mode 100644 examples/GRUs/gru_pruning_example.py create mode 100644 examples/GRUs/gru_utils.py create mode 100644 examples/GRUs/readme.md create mode 100644 examples/GRUs/test_gru.py diff --git a/examples/GRUs/gru_pruning_example.py b/examples/GRUs/gru_pruning_example.py new file mode 100644 index 0000000..02012ba --- /dev/null +++ b/examples/GRUs/gru_pruning_example.py @@ -0,0 +1,170 @@ +""" +GRU Pruning Example for Torch-Pruning + +This example demonstrates how to prune GRU layers in PyTorch models using torch-pruning. +The key challenges addressed here are the opaque implementation of standard GRU layers +and the circular dependency problem inherent to recurrent layers. + +Key Innovation: Custom GRU Implementation with Identity Layer Solution +===================================================================== + +Problem 1: PyTorch's torch.nn.GRU uses optimized C/CUDA implementations under the hood. +These low-level implementations are opaque to torch-pruning's dependency graph analysis, +making it impossible for the pruning framework to understand the internal structure +and dependencies of the GRU operations. + +Problem 2: GRU hidden states create circular dependencies that prevent torch-pruning +from modifying hidden dimensions: + + Hidden State (t-1) → GRU → Hidden State (t) → GRU → Hidden State (t+1) + ↑ ↓ + └──────────────── Circular dependency ─────────┘ + +torch-pruning sees the hidden state as both input AND output, so it refuses to +change the hidden dimension to avoid breaking this cycle. + +Solution: Create a custom PrunableGRUEqualHiddenSize that: +1. Implements GRU operations in pure Python/PyTorch (transparent to torch-pruning) +2. Inserts identity layers (hidden_map) to break circular dependencies: + + hidden_state → GRU → hidden_map (identity) → pruned_hidden_state + +This provides "safe" pruning points where torch-pruning can modify dimensions +without worrying about the circular constraint, while using a transparent +implementation that the pruning framework can analyze. + +Workflow: +1. Replace torch.nn.GRU with PrunableGRU (includes identity layers) +2. Run torch-pruning (can now safely modify hidden dimensions) +3. Convert back to torch.nn.GRU (removes identity layers, keeps pruned structure) +""" + +import torch_pruning as tp +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Import your utility functions +from gru_utils import ( + replace_prunablegru_with_torchgru, + replace_torchgru_with_prunablegru, +) + +class GRUTestNet(torch.nn.Module): + """ + Simple test network demonstrating GRU pruning workflow. + + Architecture: Conv layers → FC layers → Multi-layer GRU → Output FC + This mimics common architectures where GRU processes encoded features. + """ + def __init__(self, input_size=80, hidden_size=164): + super(GRUTestNet, self).__init__() + # Feature extraction layers + self.conv1 = nn.Conv2d(1, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(256, 196) + self.fc2 = nn.Linear(196, 80) + + # Multi-layer GRU (this is what we want to prune) + self.gru = nn.GRU(input_size, hidden_size, num_layers=2) + + # Output layer + self.fc3 = nn.Linear(164, 10) + + def forward(self, x, hx=None): + # Feature extraction + x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) + x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = x.view(-1, int(x.nelement() / x.shape[0])) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + + # GRU processing (sequence length = 1 for this example) + x = self.gru(x, hx=hx)[0] + + # Final classification + x = self.fc3(x) + return x + + +def demonstrate_gru_pruning_workflow(): + """ + Complete workflow showing GRU pruning with the identity layer solution. + + This function demonstrates: + 1. The circular dependency problem + 2. How our solution works + 3. End-to-end pruning workflow + """ + print("=" * 60) + print("GRU Pruning Workflow Demonstration") + print("=" * 60) + + # Step 1: Create original model with standard torch.nn.GRU + print("\n1. Creating original model with torch.nn.GRU...") + model = GRUTestNet() + print(f" Original GRU hidden size: {model.gru.hidden_size}") + + # Step 2: Prepare inputs for torch-pruning dependency analysis + print("\n2. Preparing example inputs for dependency graph...") + example_inputs = torch.randn(1, 1, 28, 28) + input_data = {"x": example_inputs, "hx": None} + + # Verify model runs before pruning + original_output = model(**input_data) + print(f" Original model output shape: {original_output.shape}") + + # Step 3: Replace torch.nn.GRU with PrunableGRU (adds identity layers) + print("\n3. Converting to PrunableGRU (adds identity layers to break circular deps)...") + model = replace_torchgru_with_prunablegru(model) + print(" ✓ Identity layers inserted - torch-pruning can now modify hidden dims") + + # Step 4: Build dependency graph and create pruner + print("\n4. Building dependency graph and setting up pruner...") + DG = tp.DependencyGraph().build_dependency(model, example_inputs=input_data) + + imp = tp.importance.GroupMagnitudeImportance(p=2) + pruner = tp.pruner.MetaPruner( + model, + input_data, + importance=imp, + pruning_ratio=0.5, # Remove 50% of parameters + isomorphic=True, + global_pruning=True, + root_module_types=(nn.Linear, nn.LayerNorm, nn.Conv2d), + ) + + # Step 5: Execute pruning + print("\n5. Executing pruning...") + pruner.step() + + # Verify model still works after pruning + pruned_output = model(example_inputs) + print(f" Pruned model output shape: {pruned_output.shape}") + print(" ✓ Model still functional after pruning") + + # Step 6: Convert back to torch.nn.GRU (removes identity layers) + print("\n6. Converting back to torch.nn.GRU (removes identity layers)...") + final_model = replace_prunablegru_with_torchgru(model) + + # Show the results + print(f" Final GRU hidden size: {final_model.gru.hidden_size}") + print(f" Hidden size reduction: {model.gru.hidden_size} → {final_model.gru.hidden_size}") + + # Final verification + final_output = final_model(example_inputs) + print(f" Final model output shape: {final_output.shape}") + print(" ✓ Successfully pruned GRU while maintaining functionality!") + + print("\n" + "=" * 60) + print("Workflow Complete!") + print("=" * 60) + print("\nKey Insights:") + print("• Identity layers broke circular dependencies in hidden states") + print("• torch-pruning could safely modify GRU hidden dimensions") + print("• Final model uses standard torch.nn.GRU with reduced hidden size") + print("• All functionality preserved throughout the process") + + +if __name__ == "__main__": + demonstrate_gru_pruning_workflow() diff --git a/examples/GRUs/gru_utils.py b/examples/GRUs/gru_utils.py new file mode 100644 index 0000000..1418f70 --- /dev/null +++ b/examples/GRUs/gru_utils.py @@ -0,0 +1,352 @@ +""" +Utilities for GRU pruning with torch-pruning library. + +This module provides functionality to convert between PyTorch's built-in nn.GRU +and custom PrunableGRU modules that are compatible with torch-pruning. The key +innovation is breaking circular dependencies in recurrent layers by introducing +identity layers that provide safe pruning points. + +Key Components: +- PrunableGRU: Custom GRU implementation with identity layers for pruning +- Conversion functions between nn.GRU and PrunableGRU +- Model-wide replacement utilities for seamless integration +""" + +import torch +import torch.nn as nn +import copy + +class PrunableGRU(nn.Module): + """ + Custom GRU module designed for compatibility with torch-pruning. + + This implementation replaces PyTorch's built-in nn.GRU with a custom version + that includes identity layers (hidden_map) to break circular dependencies + inherent in recurrent networks. This allows torch-pruning to safely modify + hidden dimensions without encountering circular dependency constraints. + + Architecture: + - Each layer contains: linear_ih, linear_hh, and hidden_map (identity layer) + - The hidden_map provides a pruning point where dimensions can be safely modified + - Maintains same hidden size across all layers (like torch.nn.GRU) + + Args: + input_size (int): Number of expected input features + hidden_size (int): Number of features in the hidden state + num_layers (int, optional): Number of recurrent layers. Default: 1 + batch_first (bool, optional): If True, input/output tensors are provided + as (batch, seq, feature). Default: False + + Input Shape: + - If batch_first=False: (seq_len, batch, input_size) or (seq_len, input_size) + - If batch_first=True: (batch, seq_len, input_size) + - hx: (num_layers, batch, hidden_size) or (num_layers, hidden_size) + + Output Shape: + - output: Same shape as input but with input_size replaced by hidden_size + - h_n: (num_layers, batch, hidden_size) or (num_layers, hidden_size) + + Note: + This version maintains equal hidden sizes across layers for compatibility + with torch.nn.GRU. For different hidden sizes per layer, use + PrunableGRUDifferentHiddenSize. + """ + + def __init__(self, input_size, hidden_size, num_layers=1, batch_first=False): + super().__init__() + self.num_layers = num_layers + self.hidden_size = hidden_size + self.batch_first = batch_first + + # Create layers: each layer has linear_ih, linear_hh, hidden_map + self.layers = nn.ModuleList() + + for layer_idx in range(num_layers): + layer_input_size = input_size if layer_idx == 0 else hidden_size + gru_layer = nn.ModuleDict( + { + "linear_ih": nn.Linear(layer_input_size, 3 * hidden_size), + "linear_hh": nn.Linear(hidden_size, 3 * hidden_size), + "hidden_map": nn.Linear(hidden_size, hidden_size), + } + ) + self.layers.append(gru_layer) + + def forward(self, x, hx=None): + """ + Forward pass through the PrunableGRU. + + Implements the GRU equations with an additional identity mapping (hidden_map) + that provides a safe pruning point for torch-pruning to modify hidden dimensions. + + GRU Equations: + - r_t = σ(W_ir @ x_t + b_ir + W_hr @ h_(t-1) + b_hr) # reset gate + - z_t = σ(W_iz @ x_t + b_iz + W_hz @ h_(t-1) + b_hz) # update gate + - n_t = tanh(W_in @ x_t + b_in + r_t * (W_hn @ h_(t-1) + b_hn)) # new gate + - h_t = (1 - z_t) * n_t + z_t * h_(t-1) # new hidden state + + Args: + x (torch.Tensor): Input tensor of shape: + - (seq_len, input_size) for single batch, batch_first=False + - (seq_len, batch_size, input_size) for batch, batch_first=False + - (batch_size, seq_len, input_size) for batch_first=True + hx (torch.Tensor, optional): Initial hidden state tensor of shape: + - (num_layers, hidden_size) for single batch + - (num_layers, batch_size, hidden_size) for batch + If None, defaults to zeros. + + Returns: + tuple: (output, h_n) where: + - output: Tensor containing output features for each timestep + - h_n: Tensor containing final hidden state for each layer + """ + batch_input = True + if not self.batch_first: + seq_len = x.shape[0] + if len(x.shape) == 2: + num_batch = 1 + x = x.unsqueeze(1) # make it (seq_len, batch, input_size) + if hx is not None: + hx = hx.unsqueeze(1) + batch_input = False + else: + num_batch = x.shape[1] + else: + seq_len = x.shape[1] + num_batch = x.shape[0] + x = x.permute(1, 0, 2) + + # hidden state initialization + if hx is None: + hx = torch.zeros(self.num_layers, num_batch, self.hidden_size, device=x.device) + + # hidden state output + h_n = [] + + # out tracks the output of last layer, which is input to the next layer + out = x + + for layer_idx, layer in enumerate(self.layers): + h_prev = hx[layer_idx, 0, :].unsqueeze(0) # (batch, hidden_size) + outputs = [] # to contain the outputs for each time step for this layer + for t in range(seq_len): + + h = layer["hidden_map"](h_prev) + gates_hh = layer["linear_hh"](h) + gates_ih = layer["linear_ih"](out[t, 0, :].unsqueeze(0)) + + r_hh_lin_out, z_hh_lin_out, n_hh_lin_out = gates_hh.chunk(3, dim=1) + r_ih_lin_out, z_ih_lin_out, n_ih_lin_out = gates_ih.chunk(3, dim=1) + + r = torch.sigmoid(r_hh_lin_out + r_ih_lin_out) + z = torch.sigmoid(z_hh_lin_out + z_ih_lin_out) + n = torch.tanh(n_ih_lin_out + r * n_hh_lin_out) + + h_new = (1 - z) * n + z * h + + if layer_idx > 0: + outputs.append(h_new + 0 * out[t, 0, :].unsqueeze(0)) + else: + outputs.append(h_new) + + h_prev = h_new + + out = torch.stack(outputs, dim=0) # (seq_len, batch, hidden_size) + h_n.append(h_prev) # keep track of final hidden state output for each layer + + # Stack hidden states for all layers: (num_layers, batch, hidden_size) + h_n = torch.stack(h_n, dim=0) + if not batch_input: + h_n = h_n.squeeze(1) + out = out.squeeze(1) # (seq_len, hidden_size) + + if self.batch_first: + out = out.permute(1, 0, 2) + + return out, h_n + + +def torchgru_to_prunablegru(gru): + """ + Converts a standard nn.GRU layer to a PrunableGRU with identical behavior. + + This function creates a PrunableGRU instance and copies all weights and biases + from the original nn.GRU. The hidden_map layers are initialized as identity + transformations to ensure mathematical equivalence before pruning. + + Args: + gru (nn.GRU): A PyTorch GRU layer to convert + + Returns: + PrunableGRU: Equivalent PrunableGRU with copied weights and identity mappings + + Example: + >>> original_gru = nn.GRU(input_size=10, hidden_size=20, num_layers=2) + >>> prunable_gru = torchgru_to_prunablegru(original_gru) + >>> # prunable_gru now behaves identically to original_gru + """ + input_size = gru.input_size + hidden_size = gru.hidden_size + num_layers = gru.num_layers + batch_first = gru.batch_first + + prunable_gru = PrunableGRU( + input_size, hidden_size, num_layers=num_layers, batch_first=batch_first + ) + + with torch.no_grad(): + for i in range(num_layers): + layer = prunable_gru.layers[i] + layer["linear_ih"].weight.copy_(getattr(gru, f"weight_ih_l{i}")) + layer["linear_ih"].bias.copy_(getattr(gru, f"bias_ih_l{i}")) + layer["linear_hh"].weight.copy_(getattr(gru, f"weight_hh_l{i}")) + layer["linear_hh"].bias.copy_(getattr(gru, f"bias_hh_l{i}")) + + # hidden_map init as identity matrix + zero bias + layer["hidden_map"].weight.copy_(torch.eye(hidden_size)) + layer["hidden_map"].bias.zero_() + + return prunable_gru + + +def replace_torchgru_with_prunablegru(original_model): + """ + Creates a deep copy of a model with all nn.GRU modules replaced by PrunableGRU. + + This function recursively traverses the model architecture and replaces every + nn.GRU instance with an equivalent PrunableGRU. This is the recommended way + to prepare a model for GRU pruning with torch-pruning. + + The conversion process: + 1. Creates a deep copy of the original model + 2. Recursively finds all nn.GRU modules + 3. Replaces each with a PrunableGRU using torchgru_to_prunablegru() + 4. Preserves the original model's device placement + + Args: + original_model (nn.Module): PyTorch model containing nn.GRU layers + + Returns: + nn.Module: Deep copy of the model with PrunableGRU layers, preserving + all other components and maintaining device placement + + Example: + >>> model = MyModel() # Contains nn.GRU layers + >>> prunable_model = replace_torchgru_with_prunablegru(model) + >>> # Now ready for pruning with torch-pruning + """ + model_copy = copy.deepcopy(original_model) + + # Get the device of the original model + device = next(original_model.parameters()).device + + def _replace_gru(module): + for name, child in module.named_children(): + if isinstance(child, nn.GRU): + prunable_gru = torchgru_to_prunablegru(child) + setattr(module, name, prunable_gru) + else: + _replace_gru(child) + + _replace_gru(model_copy) + + # Move the copied model to the same device as the original model + model_copy.to(device) + return model_copy + + +def prunablegru_to_torchgru(prunable_gru): + """ + Converts a PrunableGRU back to a standard nn.GRU layer. + + This function is typically used after pruning to convert the pruned PrunableGRU + back to a standard nn.GRU for deployment. The conversion discards the identity + hidden_map layers while preserving the pruned structure and learned weights. + + Important: The resulting nn.GRU will have dimensions matching the pruned + PrunableGRU, not the original pre-pruning dimensions. This allows you to + deploy pruned models using standard PyTorch components. + + Args: + prunable_gru (PrunableGRU): A PrunableGRU instance (potentially pruned) + + Returns: + nn.GRU: Standard PyTorch GRU with copied weights and biases. + Dimensions match the (potentially pruned) PrunableGRU. + + Example: + >>> # After pruning + >>> pruned_prunable_gru = prune_model(prunable_gru) + >>> standard_gru = prunablegru_to_torchgru(pruned_prunable_gru) + >>> # standard_gru now has pruned dimensions but uses nn.GRU + """ + # Get parameters + num_layers = prunable_gru.num_layers + hidden_size = prunable_gru.layers[0]["linear_hh"].weight.shape[1] + # The input size is from first layer's linear_ih input features + input_size = prunable_gru.layers[0]["linear_ih"].weight.shape[1] + + # Create torch GRU with matching parameters + torch_gru = nn.GRU(input_size, hidden_size, num_layers=num_layers, batch_first=True) + + with torch.no_grad(): + for i in range(num_layers): + layer = prunable_gru.layers[i] + # Copy weights and biases properly: + getattr(torch_gru, f"weight_ih_l{i}").data.copy_(layer["linear_ih"].weight.data) + getattr(torch_gru, f"bias_ih_l{i}").data.copy_(layer["linear_ih"].bias.data) + getattr(torch_gru, f"weight_hh_l{i}").data.copy_(layer["linear_hh"].weight.data) + getattr(torch_gru, f"bias_hh_l{i}").data.copy_(layer["linear_hh"].bias.data) + + return torch_gru + + +def replace_prunablegru_with_torchgru(original_model): + """ + Creates a deep copy of a model with all PrunableGRU modules replaced by nn.GRU. + + This function is the inverse of replace_torchgru_with_prunablegru() and is + typically used after pruning to convert back to standard PyTorch components + for deployment. The resulting model maintains the pruned structure but uses + standard nn.GRU layers for better compatibility and potentially improved performance. + + The conversion process: + 1. Creates a deep copy of the model containing PrunableGRU layers + 2. Recursively finds all PrunableGRU modules + 3. Replaces each with an nn.GRU using prunablegru_to_torchgru() + 4. Preserves device placement and all other model components + + Args: + original_model (nn.Module): PyTorch model containing PrunableGRU layers + + Returns: + nn.Module: Deep copy of the model with standard nn.GRU layers. + Dimensions reflect any pruning that was applied to the + PrunableGRU layers. + + Example: + >>> # After pruning a model with PrunableGRU layers + >>> pruned_model = prune_with_torch_pruning(model_with_prunable_gru) + >>> final_model = replace_prunablegru_with_torchgru(pruned_model) + >>> # final_model now uses standard nn.GRU with pruned dimensions + """ + model_copy = copy.deepcopy(original_model) + + # Get the device of the original model + device = next(original_model.parameters()).device + + def _replace_prunablegru(module): + for name, child in module.named_children(): + # Check for multilayer prunable GRU + if isinstance(child, PrunableGRU): + + torch_gru = prunablegru_to_torchgru(child) + + setattr(module, name, torch_gru) + else: + _replace_prunablegru(child) + + _replace_prunablegru(model_copy) + # Move the copied model to the same device as the original model + model_copy.to(device) + return model_copy diff --git a/examples/GRUs/readme.md b/examples/GRUs/readme.md new file mode 100644 index 0000000..046d2fb --- /dev/null +++ b/examples/GRUs/readme.md @@ -0,0 +1,43 @@ +# GRU Pruning with Torch-Pruning + +This example demonstrates how to prune GRU (Gated Recurrent Unit) layers in PyTorch models using the torch-pruning library. The key challenge addressed here is making GRU layers compatible with torch-pruning through custom implementations that resolve fundamental architectural constraints. + +## Key Innovations + +### Problem 1: Opaque C++ Implementation +The standard `torch.nn.GRU` module uses an optimized C++ implementation under the hood that torch-pruning cannot analyze or modify. This black-box nature prevents the pruning library from understanding the internal structure needed for safe pruning operations. + +### Problem 2: Circular Dependency in Hidden States +GRU layers create circular dependencies that prevent torch-pruning from modifying hidden dimensions: + +torch-pruning sees the hidden state as both input AND output, so it refuses to change the hidden dimension to avoid breaking this cycle. + +## Solution: Custom PrunableGRU with Identity Layers + +Our approach addresses both problems: + +1. **Replace opaque torch.nn.GRU** with a custom `PrunableGRU` implementation that exposes all internal operations as standard PyTorch layers +2. **Insert identity layers (`hidden_map`)** to break the circular dependency: + +hidden_state → GRU → hidden_map (identity) → pruned_hidden_state + + +This provides "safe" pruning points where torch-pruning can modify dimensions without worrying about the circular constraint. + +## Workflow + +1. **Convert**: Replace `torch.nn.GRU` with `PrunableGRU` (includes identity layers) +2. **Prune**: Run torch-pruning (can now safely modify hidden dimensions) +3. **Convert Back**: Convert back to `torch.nn.GRU` (removes identity layers, keeps pruned structure) + +## Usage + +### Basic Example + +```python +import torch +import torch.nn as nn +from gru_utils import replace_torchgru_with_prunablegru, replace_prunablegru_with_torchgru + + +#### Original model with torch.nn.GRU diff --git a/examples/GRUs/test_gru.py b/examples/GRUs/test_gru.py new file mode 100644 index 0000000..2d717c1 --- /dev/null +++ b/examples/GRUs/test_gru.py @@ -0,0 +1,308 @@ +""" +Comprehensive tests for GRU pruning functionality. +Tests both the prunable GRU implementation and the pruning workflow. +""" + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_pruning as tp + +from df.torch_prune_utils import ( + replace_prunablegru_with_torchgru_equalhidden, + replace_torchgru_with_prunablegru_equalhidden, + torchgru_to_prunablegru_equalhidden, +) + + +class GRUTestNet(torch.nn.Module): + """Simple test network for GRU pruning (seq_first=True).""" + def __init__(self, input_size=80, hidden_size=164): + super(GRUTestNet, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(256, 196) + self.fc2 = nn.Linear(196, 80) + self.gru = nn.GRU(input_size, hidden_size, num_layers=2) + self.fc3 = nn.Linear(164, 10) + + def forward(self, x, hx=None): + x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) + x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = x.view(-1, int(x.nelement() / x.shape[0])) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.gru(x, hx=hx)[0] + x = self.fc3(x) + return x + + +class GRUTestNetBatchFirst(torch.nn.Module): + """Simple test network for GRU pruning (batch_first=True).""" + def __init__(self, input_size=80, hidden_size=164): + super(GRUTestNetBatchFirst, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(256, 196) + self.fc2 = nn.Linear(196, 80) + self.gru = nn.GRU(input_size, hidden_size, num_layers=2, batch_first=True) + self.fc3 = nn.Linear(164, 10) + + def forward(self, x, hx=None): + x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) + x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = x.view(-1, int(x.nelement() / x.shape[0])) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = x.unsqueeze(1) # Add sequence dimension + x = self.gru(x, hx=hx)[0] + x = x.squeeze(1) # Remove sequence dimension + x = self.fc3(x) + return x + + +class TestPrunableGRUBehavior: + """Test that prunable GRU modules behave identically to torch.nn.GRU.""" + + @pytest.fixture + def gru_params(self): + return { + "input_size": 80, + "hidden_size": 164, + "num_layers": 2 + } + + def test_single_layer_gru_equivalence(self, gru_params): + """Test single layer GRU equivalence.""" + params = gru_params.copy() + params["num_layers"] = 1 + + torch_gru = nn.GRU(**params) + prunable_gru = torchgru_to_prunablegru_equalhidden(torch_gru) + + # Test with sequence-first input (default) + gru_input = torch.randn([3, params["input_size"]]) # seq_len=3 + gru_state = torch.randn([1, params["hidden_size"]]) + + torch_gru_out = torch_gru(gru_input, hx=gru_state)[0] + prunable_gru_out = prunable_gru(gru_input, hx=gru_state)[0] + + assert torch.allclose(torch_gru_out, prunable_gru_out, atol=1e-6), \ + "Single-layer GRU outputs do not match!" + + def test_multi_layer_gru_equivalence(self, gru_params): + """Test multi-layer GRU equivalence.""" + torch_gru = nn.GRU(**gru_params) + prunable_gru = torchgru_to_prunablegru_equalhidden(torch_gru) + + # Test with sequence-first input (default) + gru_input = torch.randn([2, gru_params["input_size"]]) # seq_len=2 + gru_state = torch.randn([gru_params["num_layers"], gru_params["hidden_size"]]) + + torch_gru_out = torch_gru(gru_input, hx=gru_state)[0] + prunable_gru_out = prunable_gru(gru_input, hx=gru_state)[0] + + assert torch.allclose(torch_gru_out, prunable_gru_out, atol=1e-6), \ + "Multi-layer GRU outputs do not match!" + + def test_batch_first_equivalence(self, gru_params): + """Test batch_first=True equivalence.""" + batch_size = 1 + seq_len = 2 + + torch_gru = nn.GRU(**gru_params, batch_first=True) + prunable_gru = torchgru_to_prunablegru_equalhidden(torch_gru) + + # Test with batch-first input + gru_input = torch.randn([batch_size, seq_len, gru_params["input_size"]]) + gru_state = torch.randn([gru_params["num_layers"], batch_size, gru_params["hidden_size"]]) + + torch_gru_out = torch_gru(gru_input, hx=gru_state)[0] + prunable_gru_out = prunable_gru(gru_input, hx=gru_state)[0] + + assert torch.allclose(torch_gru_out, prunable_gru_out, atol=1e-6), \ + "Batch-first GRU outputs do not match!" + + def test_no_hidden_state_equivalence(self, gru_params): + """Test equivalence when no hidden state is provided.""" + torch_gru = nn.GRU(**gru_params) + prunable_gru = torchgru_to_prunablegru_equalhidden(torch_gru) + + gru_input = torch.randn([2, gru_params["input_size"]]) + + torch_gru_out = torch_gru(gru_input)[0] + prunable_gru_out = prunable_gru(gru_input)[0] + + assert torch.allclose(torch_gru_out, prunable_gru_out, atol=1e-6), \ + "GRU outputs without hidden state do not match!" + + +class TestModelConversion: + """Test model-level conversion between torch and prunable GRU implementations.""" + + def test_model_conversion_seq_first(self): + """Test model conversion with seq_first=True (default).""" + original_model = GRUTestNet() + example_input = torch.randn(1, 1, 28, 28) + + # Get original output + original_output = original_model(example_input) + + # Convert to prunable GRU + prunable_model = replace_torchgru_with_prunablegru_equalhidden(original_model) + prunable_output = prunable_model(example_input) + + assert torch.allclose(original_output, prunable_output, atol=1e-6), \ + "Model outputs do not match after conversion to prunable GRU!" + + # Convert back to torch GRU + converted_back_model = replace_prunablegru_with_torchgru_equalhidden(prunable_model) + converted_back_output = converted_back_model(example_input) + + assert torch.allclose(prunable_output, converted_back_output, atol=1e-6), \ + "Model outputs do not match after conversion back to torch GRU!" + + def test_model_conversion_batch_first(self): + """Test model conversion with batch_first=True.""" + original_model = GRUTestNetBatchFirst() + example_input = torch.randn(1, 1, 28, 28) + + # Get original output + original_output = original_model(example_input) + + # Convert to prunable GRU + prunable_model = replace_torchgru_with_prunablegru_equalhidden(original_model) + prunable_output = prunable_model(example_input) + + assert torch.allclose(original_output, prunable_output, atol=1e-6), \ + "Batch-first model outputs do not match after conversion to prunable GRU!" + + # Convert back to torch GRU + converted_back_model = replace_prunablegru_with_torchgru_equalhidden(prunable_model) + converted_back_output = converted_back_model(example_input) + + assert torch.allclose(prunable_output, converted_back_output, atol=1e-6), \ + "Batch-first model outputs do not match after conversion back to torch GRU!" + + +class TestPruningWorkflow: + """Test the complete pruning workflow.""" + + @pytest.mark.parametrize("batch_first", [True, False]) + def test_pruning_workflow(self, batch_first): + """Test complete pruning workflow.""" + # Create model + if batch_first: + model = GRUTestNetBatchFirst() + else: + model = GRUTestNet() + + example_input = torch.randn(1, 1, 28, 28) + input_data = {"x": example_input, "hx": None} + + # Get original output size for comparison + original_output = model(**input_data) + original_gru_hidden_size = model.gru.hidden_size + + # Convert to prunable GRU + model = replace_torchgru_with_prunablegru_equalhidden(model) + + # Build dependency graph + DG = tp.DependencyGraph().build_dependency(model, example_inputs=input_data) + + # Set up pruning + imp = tp.importance.GroupMagnitudeImportance(p=2) + pruning_ratio = 0.5 + pruner = tp.pruner.MetaPruner( + model, + input_data, + importance=imp, + pruning_ratio=pruning_ratio, + isomorphic=True, + global_pruning=True, + root_module_types=(nn.Linear, nn.LayerNorm, nn.Conv2d), + ) + + # Execute pruning + pruner.step() + + # Test inference after pruning + pruned_output = model(**input_data) + assert pruned_output.shape == original_output.shape, \ + "Output shape changed after pruning!" + + # Convert back to torch GRU + final_model = replace_prunablegru_with_torchgru_equalhidden(model) + + # Test final inference + final_output = final_model(**input_data) + assert final_output.shape == original_output.shape, \ + "Final output shape changed after conversion!" + + # Verify that hidden size was actually reduced + final_gru_hidden_size = final_model.gru.hidden_size + assert final_gru_hidden_size < original_gru_hidden_size, \ + f"Hidden size was not reduced: {final_gru_hidden_size} >= {original_gru_hidden_size}" + + print(f"Original hidden size: {original_gru_hidden_size}") + print(f"Pruned hidden size: {final_gru_hidden_size}") + print(f"Reduction ratio: {final_gru_hidden_size / original_gru_hidden_size:.2f}") + + def test_pruning_preserves_functionality(self): + """Test that pruned model maintains basic functionality.""" + model = GRUTestNet() + example_input = torch.randn(1, 1, 28, 28) + input_data = {"x": example_input, "hx": None} + + # Convert and prune + model = replace_torchgru_with_prunablegru_equalhidden(model) + + imp = tp.importance.GroupMagnitudeImportance(p=2) + pruner = tp.pruner.MetaPruner( + model, + input_data, + importance=imp, + pruning_ratio=0.3, # Moderate pruning + global_pruning=True, + ) + pruner.step() + + # Convert back + final_model = replace_prunablegru_with_torchgru_equalhidden(model) + + # Test with different inputs + test_inputs = [ + torch.randn(1, 1, 28, 28), + torch.randn(2, 1, 28, 28), # Different batch size + ] + + for test_input in test_inputs: + try: + output = final_model(test_input, hx=None) + assert output.shape[0] == test_input.shape[0], \ + f"Batch dimension mismatch: {output.shape[0]} != {test_input.shape[0]}" + assert torch.isfinite(output).all(), \ + "Model output contains non-finite values!" + except Exception as e: + pytest.fail(f"Model failed on input shape {test_input.shape}: {e}") + + +if __name__ == "__main__": + # Run basic functionality test + test_behavior = TestPrunableGRUBehavior() + test_behavior.test_multi_layer_gru_equivalence({ + "input_size": 80, + "hidden_size": 164, + "num_layers": 2 + }) + + test_conversion = TestModelConversion() + test_conversion.test_model_conversion_seq_first() + test_conversion.test_model_conversion_batch_first() + + test_pruning = TestPruningWorkflow() + test_pruning.test_pruning_workflow(batch_first=True) + test_pruning.test_pruning_workflow(batch_first=False) + + print("All tests passed!") \ No newline at end of file From bf687cf52912ad6e1cb11e03eab98ab472ede92d Mon Sep 17 00:00:00 2001 From: Sheila Seidel Date: Mon, 28 Jul 2025 12:49:49 -0500 Subject: [PATCH 02/12] working tests --- examples/GRUs/test_gru.py | 361 +++++++++++--------------------------- 1 file changed, 103 insertions(+), 258 deletions(-) diff --git a/examples/GRUs/test_gru.py b/examples/GRUs/test_gru.py index 2d717c1..ad5e539 100644 --- a/examples/GRUs/test_gru.py +++ b/examples/GRUs/test_gru.py @@ -1,308 +1,153 @@ """ -Comprehensive tests for GRU pruning functionality. -Tests both the prunable GRU implementation and the pruning workflow. +This script demonstrates the pruning of GRU modules in a PyTorch model and tests out some of the different building blocks. +This work was a precurser to testing GRU pruning in our actual DeepFilterNet model. """ - -import pytest import torch import torch.nn as nn -import torch.nn.functional as F import torch_pruning as tp - -from df.torch_prune_utils import ( - replace_prunablegru_with_torchgru_equalhidden, - replace_torchgru_with_prunablegru_equalhidden, - torchgru_to_prunablegru_equalhidden, +from gru_utils import ( + replace_prunablegru_with_torchgru, + replace_torchgru_with_prunablegru, + torchgru_to_prunablegru, + GRUTestNet, ) -class GRUTestNet(torch.nn.Module): - """Simple test network for GRU pruning (seq_first=True).""" - def __init__(self, input_size=80, hidden_size=164): - super(GRUTestNet, self).__init__() - self.conv1 = nn.Conv2d(1, 6, 5) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(256, 196) - self.fc2 = nn.Linear(196, 80) - self.gru = nn.GRU(input_size, hidden_size, num_layers=2) - self.fc3 = nn.Linear(164, 10) - - def forward(self, x, hx=None): - x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) - x = F.max_pool2d(F.relu(self.conv2(x)), 2) - x = x.view(-1, int(x.nelement() / x.shape[0])) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.gru(x, hx=hx)[0] - x = self.fc3(x) - return x - - -class GRUTestNetBatchFirst(torch.nn.Module): - """Simple test network for GRU pruning (batch_first=True).""" - def __init__(self, input_size=80, hidden_size=164): - super(GRUTestNetBatchFirst, self).__init__() - self.conv1 = nn.Conv2d(1, 6, 5) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(256, 196) - self.fc2 = nn.Linear(196, 80) - self.gru = nn.GRU(input_size, hidden_size, num_layers=2, batch_first=True) - self.fc3 = nn.Linear(164, 10) - - def forward(self, x, hx=None): - x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) - x = F.max_pool2d(F.relu(self.conv2(x)), 2) - x = x.view(-1, int(x.nelement() / x.shape[0])) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = x.unsqueeze(1) # Add sequence dimension - x = self.gru(x, hx=hx)[0] - x = x.squeeze(1) # Remove sequence dimension - x = self.fc3(x) - return x - - class TestPrunableGRUBehavior: """Test that prunable GRU modules behave identically to torch.nn.GRU.""" - - @pytest.fixture def gru_params(self): return { "input_size": 80, "hidden_size": 164, - "num_layers": 2 + "num_layers": 1 } - - def test_single_layer_gru_equivalence(self, gru_params): + + def test_single_layer_gru_equivalence(self): """Test single layer GRU equivalence.""" - params = gru_params.copy() - params["num_layers"] = 1 + # get params for single layer GRU + params = self.gru_params() + # create torch GRU and prunable GRU torch_gru = nn.GRU(**params) - prunable_gru = torchgru_to_prunablegru_equalhidden(torch_gru) - # Test with sequence-first input (default) - gru_input = torch.randn([3, params["input_size"]]) # seq_len=3 - gru_state = torch.randn([1, params["hidden_size"]]) + # copy weights from torch GRU to prunable GRU + prunable_gru = torchgru_to_prunablegru(torch_gru) + + # run same input through both units + seq_length = 2 + gru_input = torch.randn([seq_length, params["input_size"]]) # test sequence length of 2 + gru_state = torch.randn([params["num_layers"], params["hidden_size"]]) torch_gru_out = torch_gru(gru_input, hx=gru_state)[0] prunable_gru_out = prunable_gru(gru_input, hx=gru_state)[0] - assert torch.allclose(torch_gru_out, prunable_gru_out, atol=1e-6), \ - "Single-layer GRU outputs do not match!" - - def test_multi_layer_gru_equivalence(self, gru_params): - """Test multi-layer GRU equivalence.""" - torch_gru = nn.GRU(**gru_params) - prunable_gru = torchgru_to_prunablegru_equalhidden(torch_gru) - - # Test with sequence-first input (default) - gru_input = torch.randn([2, gru_params["input_size"]]) # seq_len=2 - gru_state = torch.randn([gru_params["num_layers"], gru_params["hidden_size"]]) + assert torch.allclose( + torch_gru_out, prunable_gru_out, atol=1e-6 + ), "Single layer GRU outputs do not match!" - torch_gru_out = torch_gru(gru_input, hx=gru_state)[0] - prunable_gru_out = prunable_gru(gru_input, hx=gru_state)[0] + def test_multi_layer_gru_equivalence(self): + """Test multi layer GRU equivalence.""" + # get params for single layer GRU + params = self.gru_params() + params["num_layers"] = 2 # test multiple layers - assert torch.allclose(torch_gru_out, prunable_gru_out, atol=1e-6), \ - "Multi-layer GRU outputs do not match!" - - def test_batch_first_equivalence(self, gru_params): - """Test batch_first=True equivalence.""" - batch_size = 1 - seq_len = 2 + # create torch GRU and prunable GRU + torch_gru = nn.GRU(**params) - torch_gru = nn.GRU(**gru_params, batch_first=True) - prunable_gru = torchgru_to_prunablegru_equalhidden(torch_gru) + # copy weights from torch GRU to prunable GRU + prunable_gru = torchgru_to_prunablegru(torch_gru) - # Test with batch-first input - gru_input = torch.randn([batch_size, seq_len, gru_params["input_size"]]) - gru_state = torch.randn([gru_params["num_layers"], batch_size, gru_params["hidden_size"]]) + # run same input through both units + seq_length = 2 + gru_input = torch.randn([seq_length, params["input_size"]]) # test sequence length of 2 + gru_state = torch.randn([params["num_layers"], params["hidden_size"]]) torch_gru_out = torch_gru(gru_input, hx=gru_state)[0] prunable_gru_out = prunable_gru(gru_input, hx=gru_state)[0] - assert torch.allclose(torch_gru_out, prunable_gru_out, atol=1e-6), \ - "Batch-first GRU outputs do not match!" - - def test_no_hidden_state_equivalence(self, gru_params): - """Test equivalence when no hidden state is provided.""" - torch_gru = nn.GRU(**gru_params) - prunable_gru = torchgru_to_prunablegru_equalhidden(torch_gru) - - gru_input = torch.randn([2, gru_params["input_size"]]) + assert torch.allclose( + torch_gru_out, prunable_gru_out, atol=1e-6 + ), "Multi layer GRU outputs do not match!" - torch_gru_out = torch_gru(gru_input)[0] - prunable_gru_out = prunable_gru(gru_input)[0] + def test_gru_equivalence_with_batch_dim_input(self): + """Test GRU equivalence with batch dimension input. (only can have batch size = 1 for now)""" + params = self.gru_params() - assert torch.allclose(torch_gru_out, prunable_gru_out, atol=1e-6), \ - "GRU outputs without hidden state do not match!" - - -class TestModelConversion: - """Test model-level conversion between torch and prunable GRU implementations.""" - - def test_model_conversion_seq_first(self): - """Test model conversion with seq_first=True (default).""" - original_model = GRUTestNet() - example_input = torch.randn(1, 1, 28, 28) - - # Get original output - original_output = original_model(example_input) - - # Convert to prunable GRU - prunable_model = replace_torchgru_with_prunablegru_equalhidden(original_model) - prunable_output = prunable_model(example_input) + # create torch GRU and prunable GRU + torch_gru = nn.GRU(**params) - assert torch.allclose(original_output, prunable_output, atol=1e-6), \ - "Model outputs do not match after conversion to prunable GRU!" + # copy weights from torch GRU to prunable GRU + prunable_gru = torchgru_to_prunablegru(torch_gru) - # Convert back to torch GRU - converted_back_model = replace_prunablegru_with_torchgru_equalhidden(prunable_model) - converted_back_output = converted_back_model(example_input) + # run same input through both units + batch_size = 1 + seq_len = 2 + gru_input = torch.randn([seq_len, batch_size, params["input_size"]]) # test sequence length of 2 + gru_state = torch.randn([params["num_layers"], batch_size, params["hidden_size"]]) + torch_gru_out = torch_gru(gru_input, hx=gru_state)[0] + prunable_gru_out = prunable_gru(gru_input, hx=gru_state)[0] - assert torch.allclose(prunable_output, converted_back_output, atol=1e-6), \ - "Model outputs do not match after conversion back to torch GRU!" + assert torch.allclose( + torch_gru_out, prunable_gru_out, atol=1e-6 + ), "GRU outputs match with batch dim input!" - def test_model_conversion_batch_first(self): - """Test model conversion with batch_first=True.""" - original_model = GRUTestNetBatchFirst() - example_input = torch.randn(1, 1, 28, 28) + def test_gru_equivalence_with_batch_first(self): + params = self.gru_params() - # Get original output - original_output = original_model(example_input) + batch_size = 1 + seq_len = 2 - # Convert to prunable GRU - prunable_model = replace_torchgru_with_prunablegru_equalhidden(original_model) - prunable_output = prunable_model(example_input) + # try out batch_first=True, to match DFN network + gru_input = torch.randn([batch_size, seq_len, 80]) # batch, seq, input len + gru_state = torch.randn([params["num_layers"], batch_size, params["hidden_size"]]) - assert torch.allclose(original_output, prunable_output, atol=1e-6), \ - "Batch-first model outputs do not match after conversion to prunable GRU!" + # create torch GRU and prunable GRU + torch_gru = nn.GRU(input_size=params["input_size"], hidden_size=params["hidden_size"], num_layers=params["num_layers"], batch_first=True) + prunable_gru = torchgru_to_prunablegru(torch_gru) - # Convert back to torch GRU - converted_back_model = replace_prunablegru_with_torchgru_equalhidden(prunable_model) - converted_back_output = converted_back_model(example_input) + # pass same input through both units + torch_gru_out = torch_gru(gru_input, hx=gru_state)[0] + prunable_gru_out = prunable_gru(gru_input, hx=gru_state)[0] - assert torch.allclose(prunable_output, converted_back_output, atol=1e-6), \ - "Batch-first model outputs do not match after conversion back to torch GRU!" - + # verify outputs match + assert torch.allclose( + torch_gru_out, prunable_gru_out, atol=1e-6 + ), "Batch-first GRU outputs do not match!" -class TestPruningWorkflow: - """Test the complete pruning workflow.""" +class TestGRUPruneUtils: + """Test utility functions for GRU pruning.""" - @pytest.mark.parametrize("batch_first", [True, False]) - def test_pruning_workflow(self, batch_first): - """Test complete pruning workflow.""" - # Create model - if batch_first: - model = GRUTestNetBatchFirst() - else: - model = GRUTestNet() - - example_input = torch.randn(1, 1, 28, 28) - input_data = {"x": example_input, "hx": None} - - # Get original output size for comparison - original_output = model(**input_data) - original_gru_hidden_size = model.gru.hidden_size - - # Convert to prunable GRU - model = replace_torchgru_with_prunablegru_equalhidden(model) - - # Build dependency graph - DG = tp.DependencyGraph().build_dependency(model, example_inputs=input_data) - - # Set up pruning - imp = tp.importance.GroupMagnitudeImportance(p=2) - pruning_ratio = 0.5 - pruner = tp.pruner.MetaPruner( - model, - input_data, - importance=imp, - pruning_ratio=pruning_ratio, - isomorphic=True, - global_pruning=True, - root_module_types=(nn.Linear, nn.LayerNorm, nn.Conv2d), - ) - - # Execute pruning - pruner.step() - - # Test inference after pruning - pruned_output = model(**input_data) - assert pruned_output.shape == original_output.shape, \ - "Output shape changed after pruning!" - - # Convert back to torch GRU - final_model = replace_prunablegru_with_torchgru_equalhidden(model) - - # Test final inference - final_output = final_model(**input_data) - assert final_output.shape == original_output.shape, \ - "Final output shape changed after conversion!" - - # Verify that hidden size was actually reduced - final_gru_hidden_size = final_model.gru.hidden_size - assert final_gru_hidden_size < original_gru_hidden_size, \ - f"Hidden size was not reduced: {final_gru_hidden_size} >= {original_gru_hidden_size}" - - print(f"Original hidden size: {original_gru_hidden_size}") - print(f"Pruned hidden size: {final_gru_hidden_size}") - print(f"Reduction ratio: {final_gru_hidden_size / original_gru_hidden_size:.2f}") + # here we are testing the prunable gru where hidden size has to be the same across layers! + def get_model_input(self): + return torch.randn(1, 1, 28, 28) - def test_pruning_preserves_functionality(self): - """Test that pruned model maintains basic functionality.""" - model = GRUTestNet() - example_input = torch.randn(1, 1, 28, 28) - input_data = {"x": example_input, "hx": None} - - # Convert and prune - model = replace_torchgru_with_prunablegru_equalhidden(model) - - imp = tp.importance.GroupMagnitudeImportance(p=2) - pruner = tp.pruner.MetaPruner( - model, - input_data, - importance=imp, - pruning_ratio=0.3, # Moderate pruning - global_pruning=True, - ) - pruner.step() - - # Convert back - final_model = replace_prunablegru_with_torchgru_equalhidden(model) - - # Test with different inputs - test_inputs = [ - torch.randn(1, 1, 28, 28), - torch.randn(2, 1, 28, 28), # Different batch size - ] - - for test_input in test_inputs: - try: - output = final_model(test_input, hx=None) - assert output.shape[0] == test_input.shape[0], \ - f"Batch dimension mismatch: {output.shape[0]} != {test_input.shape[0]}" - assert torch.isfinite(output).all(), \ - "Model output contains non-finite values!" - except Exception as e: - pytest.fail(f"Model failed on input shape {test_input.shape}: {e}") + def test_replacement_utils(self): + model_input = self.get_model_input() + model_torchGRU = GRUTestNet() + + # test process of finding GRUs and replacing with prunable GRU. For use pre-pruning. + model_prunableGRU = replace_torchgru_with_prunablegru(model_torchGRU) + model_torchGRU_out = model_torchGRU(model_input) + model_prunableGRU_out = model_prunableGRU(model_input) + assert torch.allclose( + model_torchGRU_out, model_prunableGRU_out, atol=1e-6 + ), "Outputs of original and prunable GRU models do not match!" + + # test process of going the other way, for post-pruning + model_torchGRU_copy = replace_prunablegru_with_torchgru(model_prunableGRU) + model_torchGRU_out = model_torchGRU_copy(model_input) + assert torch.allclose( + model_torchGRU_out, model_prunableGRU_out, atol=1e-6 + ), "Outputs of original and prunable GRU models do not match after conversion back!" if __name__ == "__main__": - # Run basic functionality test test_behavior = TestPrunableGRUBehavior() - test_behavior.test_multi_layer_gru_equivalence({ - "input_size": 80, - "hidden_size": 164, - "num_layers": 2 - }) - - test_conversion = TestModelConversion() - test_conversion.test_model_conversion_seq_first() - test_conversion.test_model_conversion_batch_first() - - test_pruning = TestPruningWorkflow() - test_pruning.test_pruning_workflow(batch_first=True) - test_pruning.test_pruning_workflow(batch_first=False) + test_behavior.test_single_layer_gru_equivalence() + test_behavior.test_multi_layer_gru_equivalence() + test_behavior.test_gru_equivalence_with_batch_dim_input() + test_behavior.test_gru_equivalence_with_batch_first() - print("All tests passed!") \ No newline at end of file + test_replacement_utils = TestGRUPruneUtils() + test_replacement_utils.test_replacement_utils() + print("All tests passed!") From 43527214567afd5163525a87a6012287e44c3a62 Mon Sep 17 00:00:00 2001 From: Sheila Seidel Date: Mon, 28 Jul 2025 12:50:17 -0500 Subject: [PATCH 03/12] gru pruning example --- examples/GRUs/gru_pruning_example.py | 40 +++------------------------- 1 file changed, 3 insertions(+), 37 deletions(-) diff --git a/examples/GRUs/gru_pruning_example.py b/examples/GRUs/gru_pruning_example.py index 02012ba..d489621 100644 --- a/examples/GRUs/gru_pruning_example.py +++ b/examples/GRUs/gru_pruning_example.py @@ -42,49 +42,15 @@ import torch_pruning as tp import torch import torch.nn as nn -import torch.nn.functional as F # Import your utility functions from gru_utils import ( replace_prunablegru_with_torchgru, replace_torchgru_with_prunablegru, + GRUTestNet, ) -class GRUTestNet(torch.nn.Module): - """ - Simple test network demonstrating GRU pruning workflow. - - Architecture: Conv layers → FC layers → Multi-layer GRU → Output FC - This mimics common architectures where GRU processes encoded features. - """ - def __init__(self, input_size=80, hidden_size=164): - super(GRUTestNet, self).__init__() - # Feature extraction layers - self.conv1 = nn.Conv2d(1, 6, 5) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(256, 196) - self.fc2 = nn.Linear(196, 80) - - # Multi-layer GRU (this is what we want to prune) - self.gru = nn.GRU(input_size, hidden_size, num_layers=2) - - # Output layer - self.fc3 = nn.Linear(164, 10) - - def forward(self, x, hx=None): - # Feature extraction - x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) - x = F.max_pool2d(F.relu(self.conv2(x)), 2) - x = x.view(-1, int(x.nelement() / x.shape[0])) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - - # GRU processing (sequence length = 1 for this example) - x = self.gru(x, hx=hx)[0] - - # Final classification - x = self.fc3(x) - return x + def demonstrate_gru_pruning_workflow(): @@ -146,7 +112,7 @@ def demonstrate_gru_pruning_workflow(): # Step 6: Convert back to torch.nn.GRU (removes identity layers) print("\n6. Converting back to torch.nn.GRU (removes identity layers)...") final_model = replace_prunablegru_with_torchgru(model) - + pruned_output = final_model(example_inputs) # Show the results print(f" Final GRU hidden size: {final_model.gru.hidden_size}") print(f" Hidden size reduction: {model.gru.hidden_size} → {final_model.gru.hidden_size}") From 2802ad9dbda4f5fc8fae217cd42aa811ebcd1b8f Mon Sep 17 00:00:00 2001 From: Sheila Seidel Date: Mon, 28 Jul 2025 13:43:11 -0500 Subject: [PATCH 04/12] draft PR --- examples/GRUs/README.md | 61 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 examples/GRUs/README.md diff --git a/examples/GRUs/README.md b/examples/GRUs/README.md new file mode 100644 index 0000000..3e083dc --- /dev/null +++ b/examples/GRUs/README.md @@ -0,0 +1,61 @@ +# GRU Pruning with Torch-Pruning + +This example demonstrates how to prune GRU (Gated Recurrent Unit) layers in PyTorch models using the torch-pruning library. The key challenge addressed here is making GRU layers compatible with torch-pruning through a prunable GRU implementation. + +## Key Innovations + +### Problem 1: Opaque C++ Implementation +This example demonstrates how to prune GRU layers using torch-pruning. Since `torch.nn.GRU` uses an opaque C++ implementation that torch-pruning cannot analyze, we provide a `PrunableGRU` implementation that exposes internal operations for structural pruning. + +### Problem 2: Circular Dependency in Hidden States +GRU layers create circular dependencies that prevent torch-pruning from properly modifying hidden dimensions. In other words, torch-pruning sees the hidden state as both a module input AND output and torch-pruning does not naturally have the freedom to update the size of model inputs. + +## Solution: Custom PrunableGRU with Hidden State Identity "Buffer" Layer + +Our approach addresses both problems: + +1. **Replace opaque torch.nn.GRU** with a custom `PrunableGRU` implementation that exposes all internal operations as standard PyTorch layers that are prunable by torch-pruning. These modules can be replaced by PyTorch GRU after pruning is performed. +2. **Identity layers (`hidden_map`)** in `PrunableGRU` to break the circular dependency: + +hidden_state → GRU → hidden_map (identity) → pruned_hidden_state + +This provides "safe" pruning points where torch-pruning can modify dimensions without worrying about the circular constraint. + +## Workflow + +1. **Convert**: Replace `torch.nn.GRU` with `PrunableGRU` (includes identity layers) +2. **Prune**: Run torch-pruning (can now safely modify hidden dimensions) +3. **Convert Back**: Convert back to `torch.nn.GRU` (removes identity layers, keeps pruned structure) + +## Usage + +A basic example is in `gru_pruning_example.py`. + +Run the example: +```bash +python gru_pruning_example.py +``` + +**What this example *does* include**: A demonstration of the mechanics of pruning GRUs in a very simple test network. In particular, we demonstrate that the GRU hidden and input sizes are smaller after pruning is performed and that the model can still perform inference after pruning. + +**What this example *does not* include**: An analysis of performance in a useful model after GRU pruning is performed. We leave this to the user to explore. + +## Files +- `gru_pruning_example.py` - Complete working example +- `gru_utils.py` - PrunableGRU implementation and utilities +- `tets_gru.py` - Unit tests of the gru pruning utilities in `gru_utils.py` +- `README.md` - This file + +## Limitations + +- Sequence length: Supports sequence_length=1 only during pruning; in torch-pruning `example_inputs` argument must correspond to sequence_length=1. +- Batch size: Supports batch_size=1 only during pruning; if gru input data has a batch dimension, torch-pruning `example_inputs` argument must correspond to batch_size=1. +- Multi-layer: Tested with single-layer GRUs only. We recommend that multilayer `torch.nn.GRU`s are "unwrapped" prior to pruning, into multiple cascaded single layer `torch.nn.GRU`s, to allow pruning to achieve different hidden state sizes across layers + +## Extensions + +The approach demonstrated here can be extended to: +- LSTM layers +- minGRU/minLSTM architectures +- Other recurrent architectures with similar circular dependency issues that can be decomposed into prunable building blocks + From 91d9b74c083aca298f98ef3d9f6a5b954ed5580a Mon Sep 17 00:00:00 2001 From: Sheila Seidel Date: Mon, 28 Jul 2025 13:44:30 -0500 Subject: [PATCH 05/12] delete old file --- examples/GRUs/readme.md | 43 ----------------------------------------- 1 file changed, 43 deletions(-) delete mode 100644 examples/GRUs/readme.md diff --git a/examples/GRUs/readme.md b/examples/GRUs/readme.md deleted file mode 100644 index 046d2fb..0000000 --- a/examples/GRUs/readme.md +++ /dev/null @@ -1,43 +0,0 @@ -# GRU Pruning with Torch-Pruning - -This example demonstrates how to prune GRU (Gated Recurrent Unit) layers in PyTorch models using the torch-pruning library. The key challenge addressed here is making GRU layers compatible with torch-pruning through custom implementations that resolve fundamental architectural constraints. - -## Key Innovations - -### Problem 1: Opaque C++ Implementation -The standard `torch.nn.GRU` module uses an optimized C++ implementation under the hood that torch-pruning cannot analyze or modify. This black-box nature prevents the pruning library from understanding the internal structure needed for safe pruning operations. - -### Problem 2: Circular Dependency in Hidden States -GRU layers create circular dependencies that prevent torch-pruning from modifying hidden dimensions: - -torch-pruning sees the hidden state as both input AND output, so it refuses to change the hidden dimension to avoid breaking this cycle. - -## Solution: Custom PrunableGRU with Identity Layers - -Our approach addresses both problems: - -1. **Replace opaque torch.nn.GRU** with a custom `PrunableGRU` implementation that exposes all internal operations as standard PyTorch layers -2. **Insert identity layers (`hidden_map`)** to break the circular dependency: - -hidden_state → GRU → hidden_map (identity) → pruned_hidden_state - - -This provides "safe" pruning points where torch-pruning can modify dimensions without worrying about the circular constraint. - -## Workflow - -1. **Convert**: Replace `torch.nn.GRU` with `PrunableGRU` (includes identity layers) -2. **Prune**: Run torch-pruning (can now safely modify hidden dimensions) -3. **Convert Back**: Convert back to `torch.nn.GRU` (removes identity layers, keeps pruned structure) - -## Usage - -### Basic Example - -```python -import torch -import torch.nn as nn -from gru_utils import replace_torchgru_with_prunablegru, replace_prunablegru_with_torchgru - - -#### Original model with torch.nn.GRU From c7b52b066423649b5372ac991fcf7d4d93616b4f Mon Sep 17 00:00:00 2001 From: Sheila Seidel Date: Mon, 28 Jul 2025 14:07:23 -0500 Subject: [PATCH 06/12] docstring edits --- examples/GRUs/gru_pruning_example.py | 31 ++++++++------------ examples/GRUs/gru_utils.py | 44 +++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/examples/GRUs/gru_pruning_example.py b/examples/GRUs/gru_pruning_example.py index d489621..b1af059 100644 --- a/examples/GRUs/gru_pruning_example.py +++ b/examples/GRUs/gru_pruning_example.py @@ -13,19 +13,12 @@ making it impossible for the pruning framework to understand the internal structure and dependencies of the GRU operations. -Problem 2: GRU hidden states create circular dependencies that prevent torch-pruning -from modifying hidden dimensions: +Problem 2: Past GRU hidden states are inputs to model so torch-pruning will not alter its size. - Hidden State (t-1) → GRU → Hidden State (t) → GRU → Hidden State (t+1) - ↑ ↓ - └──────────────── Circular dependency ─────────┘ - -torch-pruning sees the hidden state as both input AND output, so it refuses to -change the hidden dimension to avoid breaking this cycle. Solution: Create a custom PrunableGRUEqualHiddenSize that: 1. Implements GRU operations in pure Python/PyTorch (transparent to torch-pruning) -2. Inserts identity layers (hidden_map) to break circular dependencies: +2. Inserts identity layers (hidden_map) to break circular dependencies and allow pruning of the hidden state size hidden_state → GRU → hidden_map (identity) → pruned_hidden_state @@ -57,10 +50,7 @@ def demonstrate_gru_pruning_workflow(): """ Complete workflow showing GRU pruning with the identity layer solution. - This function demonstrates: - 1. The circular dependency problem - 2. How our solution works - 3. End-to-end pruning workflow + This function demonstrates and end-to-end GRU pruning workflow """ print("=" * 60) print("GRU Pruning Workflow Demonstration") @@ -69,7 +59,11 @@ def demonstrate_gru_pruning_workflow(): # Step 1: Create original model with standard torch.nn.GRU print("\n1. Creating original model with torch.nn.GRU...") model = GRUTestNet() - print(f" Original GRU hidden size: {model.gru.hidden_size}") + original_gru_hidden_size = model.gru.hidden_size + original_gru_input_size = model.gru.input_size + print(f" Original GRU hidden size: {original_gru_hidden_size}") + print(f" Original GRU input size: {original_gru_input_size}") + # Step 2: Prepare inputs for torch-pruning dependency analysis print("\n2. Preparing example inputs for dependency graph...") @@ -94,7 +88,7 @@ def demonstrate_gru_pruning_workflow(): model, input_data, importance=imp, - pruning_ratio=0.5, # Remove 50% of parameters + pruning_ratio=0.2, # Remove 20% of input/output channels isomorphic=True, global_pruning=True, root_module_types=(nn.Linear, nn.LayerNorm, nn.Conv2d), @@ -106,20 +100,19 @@ def demonstrate_gru_pruning_workflow(): # Verify model still works after pruning pruned_output = model(example_inputs) - print(f" Pruned model output shape: {pruned_output.shape}") print(" ✓ Model still functional after pruning") # Step 6: Convert back to torch.nn.GRU (removes identity layers) print("\n6. Converting back to torch.nn.GRU (removes identity layers)...") final_model = replace_prunablegru_with_torchgru(model) pruned_output = final_model(example_inputs) + # Show the results - print(f" Final GRU hidden size: {final_model.gru.hidden_size}") - print(f" Hidden size reduction: {model.gru.hidden_size} → {final_model.gru.hidden_size}") + print(f" Hidden size reduction: {original_gru_hidden_size} → {final_model.gru.hidden_size}") + print(f" Hidden size reduction: {original_gru_input_size} → {final_model.gru.input_size}") # Final verification final_output = final_model(example_inputs) - print(f" Final model output shape: {final_output.shape}") print(" ✓ Successfully pruned GRU while maintaining functionality!") print("\n" + "=" * 60) diff --git a/examples/GRUs/gru_utils.py b/examples/GRUs/gru_utils.py index 1418f70..46c8c09 100644 --- a/examples/GRUs/gru_utils.py +++ b/examples/GRUs/gru_utils.py @@ -7,7 +7,7 @@ identity layers that provide safe pruning points. Key Components: -- PrunableGRU: Custom GRU implementation with identity layers for pruning +- PrunableGRU: Custom GRU implementation composed of exposed, prunable operators and identity layer for pruning - Conversion functions between nn.GRU and PrunableGRU - Model-wide replacement utilities for seamless integration """ @@ -15,6 +15,8 @@ import torch import torch.nn as nn import copy +import torch.nn.functional as F + class PrunableGRU(nn.Module): """ @@ -145,6 +147,8 @@ def forward(self, x, hx=None): h_new = (1 - z) * n + z * h if layer_idx > 0: + # this forces that the hidden state size to be the same across all layers + # will reduce pruning flexibility but is required for torch.nn.GRU compatibility when num_layers > 1 outputs.append(h_new + 0 * out[t, 0, :].unsqueeze(0)) else: outputs.append(h_new) @@ -350,3 +354,41 @@ def _replace_prunablegru(module): # Move the copied model to the same device as the original model model_copy.to(device) return model_copy + +class GRUTestNet(torch.nn.Module): + """ + Simple test network demonstrating GRU pruning workflow. + + Architecture: Conv layers → FC layers → Multi-layer GRU → Output FC + This mimics common architectures where GRU processes encoded features. + """ + def __init__(self, input_size=80, hidden_size=164): + super(GRUTestNet, self).__init__() + # Feature extraction layers + self.input_size = input_size + self.hidden_size = hidden_size + self.conv1 = nn.Conv2d(1, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(256, 196) + self.fc2 = nn.Linear(196, 80) + + # Multi-layer GRU (this is what we want to prune) + self.gru = nn.GRU(input_size, hidden_size, num_layers=2) + + # Output layer + self.fc3 = nn.Linear(164, 10) + + def forward(self, x, hx=None): + # Feature extraction + x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) + x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = x.view(-1, int(x.nelement() / x.shape[0])) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + + # GRU processing (sequence length = 1 for this example) + x = self.gru(x, hx=hx)[0] + + # Final classification + x = self.fc3(x) + return x \ No newline at end of file From 047317ae6239a9f9777b302417a27bd2f6a67b75 Mon Sep 17 00:00:00 2001 From: "Seidel, Sheila" Date: Mon, 28 Jul 2025 15:14:07 -0400 Subject: [PATCH 07/12] Update examples/GRUs/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- examples/GRUs/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/GRUs/README.md b/examples/GRUs/README.md index 3e083dc..f55a12a 100644 --- a/examples/GRUs/README.md +++ b/examples/GRUs/README.md @@ -43,7 +43,7 @@ python gru_pruning_example.py ## Files - `gru_pruning_example.py` - Complete working example - `gru_utils.py` - PrunableGRU implementation and utilities -- `tets_gru.py` - Unit tests of the gru pruning utilities in `gru_utils.py` +- `test_gru.py` - Unit tests of the gru pruning utilities in `gru_utils.py` - `README.md` - This file ## Limitations From 07d3caa6d87c8c932be142ee72855bf191319638 Mon Sep 17 00:00:00 2001 From: "Seidel, Sheila" Date: Mon, 28 Jul 2025 15:15:07 -0400 Subject: [PATCH 08/12] Update examples/GRUs/test_gru.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- examples/GRUs/test_gru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/GRUs/test_gru.py b/examples/GRUs/test_gru.py index ad5e539..6133d67 100644 --- a/examples/GRUs/test_gru.py +++ b/examples/GRUs/test_gru.py @@ -47,7 +47,7 @@ def test_single_layer_gru_equivalence(self): def test_multi_layer_gru_equivalence(self): """Test multi layer GRU equivalence.""" - # get params for single layer GRU + # get params for multi-layer GRU params = self.gru_params() params["num_layers"] = 2 # test multiple layers From dbc4f9abaa0d7fd1042d17d6f280ed8603122b08 Mon Sep 17 00:00:00 2001 From: "Seidel, Sheila" Date: Mon, 28 Jul 2025 15:15:20 -0400 Subject: [PATCH 09/12] Update examples/GRUs/test_gru.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- examples/GRUs/test_gru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/GRUs/test_gru.py b/examples/GRUs/test_gru.py index 6133d67..0f4b1af 100644 --- a/examples/GRUs/test_gru.py +++ b/examples/GRUs/test_gru.py @@ -1,6 +1,6 @@ """ This script demonstrates the pruning of GRU modules in a PyTorch model and tests out some of the different building blocks. -This work was a precurser to testing GRU pruning in our actual DeepFilterNet model. +This work was a precursor to testing GRU pruning in our actual DeepFilterNet model. """ import torch import torch.nn as nn From b0b66ee4ebbc205ce37925aa217501a3628fd83b Mon Sep 17 00:00:00 2001 From: Sheila Seidel Date: Mon, 28 Jul 2025 15:41:22 -0500 Subject: [PATCH 10/12] 2 -> 1 layer example --- examples/GRUs/gru_utils.py | 2 +- examples/GRUs/test_gru.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/GRUs/gru_utils.py b/examples/GRUs/gru_utils.py index 46c8c09..c934757 100644 --- a/examples/GRUs/gru_utils.py +++ b/examples/GRUs/gru_utils.py @@ -373,7 +373,7 @@ def __init__(self, input_size=80, hidden_size=164): self.fc2 = nn.Linear(196, 80) # Multi-layer GRU (this is what we want to prune) - self.gru = nn.GRU(input_size, hidden_size, num_layers=2) + self.gru = nn.GRU(input_size, hidden_size, num_layers=1) # Output layer self.fc3 = nn.Linear(164, 10) diff --git a/examples/GRUs/test_gru.py b/examples/GRUs/test_gru.py index 0f4b1af..c90cb2d 100644 --- a/examples/GRUs/test_gru.py +++ b/examples/GRUs/test_gru.py @@ -4,7 +4,6 @@ """ import torch import torch.nn as nn -import torch_pruning as tp from gru_utils import ( replace_prunablegru_with_torchgru, replace_torchgru_with_prunablegru, From 207b4170fd3feacf65c104f723ae411fa2c89c60 Mon Sep 17 00:00:00 2001 From: Sheila Seidel Date: Wed, 30 Jul 2025 15:40:04 -0500 Subject: [PATCH 11/12] copyright --- examples/GRUs/gru_utils.py | 1 + examples/GRUs/{README.md => readme.md} | 1 + examples/GRUs/test_gru.py | 1 + 3 files changed, 3 insertions(+) rename examples/GRUs/{README.md => readme.md} (98%) diff --git a/examples/GRUs/gru_utils.py b/examples/GRUs/gru_utils.py index c934757..359ee6b 100644 --- a/examples/GRUs/gru_utils.py +++ b/examples/GRUs/gru_utils.py @@ -1,3 +1,4 @@ +# Copyright © \2025 Analog Devices, Inc. """ Utilities for GRU pruning with torch-pruning library. diff --git a/examples/GRUs/README.md b/examples/GRUs/readme.md similarity index 98% rename from examples/GRUs/README.md rename to examples/GRUs/readme.md index f55a12a..bc98aa8 100644 --- a/examples/GRUs/README.md +++ b/examples/GRUs/readme.md @@ -1,3 +1,4 @@ +Copyright © \2025 Analog Devices, Inc. # GRU Pruning with Torch-Pruning This example demonstrates how to prune GRU (Gated Recurrent Unit) layers in PyTorch models using the torch-pruning library. The key challenge addressed here is making GRU layers compatible with torch-pruning through a prunable GRU implementation. diff --git a/examples/GRUs/test_gru.py b/examples/GRUs/test_gru.py index c90cb2d..7bbc289 100644 --- a/examples/GRUs/test_gru.py +++ b/examples/GRUs/test_gru.py @@ -1,3 +1,4 @@ +# Copyright © \2025 Analog Devices, Inc. """ This script demonstrates the pruning of GRU modules in a PyTorch model and tests out some of the different building blocks. This work was a precursor to testing GRU pruning in our actual DeepFilterNet model. From b3427a0260f28ab1fa3f0238d7d77ec24cec14cc Mon Sep 17 00:00:00 2001 From: Sheila Seidel Date: Wed, 30 Jul 2025 15:40:46 -0500 Subject: [PATCH 12/12] copyright --- examples/GRUs/gru_pruning_example.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/GRUs/gru_pruning_example.py b/examples/GRUs/gru_pruning_example.py index b1af059..18298fb 100644 --- a/examples/GRUs/gru_pruning_example.py +++ b/examples/GRUs/gru_pruning_example.py @@ -1,3 +1,5 @@ +# Copyright © \2025 Analog Devices, Inc. + """ GRU Pruning Example for Torch-Pruning