From 817f929b4b008cdd5b0ea3cf680d4fb9b8859a85 Mon Sep 17 00:00:00 2001 From: Dhruva Kashyap <40919082+DhruvaKashyap@users.noreply.github.com> Date: Thu, 25 Sep 2025 12:27:16 +0530 Subject: [PATCH 1/2] feat: add in_pruning option to get_all_groups --- torch_pruning/dependency/graph.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/torch_pruning/dependency/graph.py b/torch_pruning/dependency/graph.py index 30cd9da..6c37dca 100644 --- a/torch_pruning/dependency/graph.py +++ b/torch_pruning/dependency/graph.py @@ -273,7 +273,7 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args): merged_group[i].root_idxs = root_idxs return merged_group - def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, ops.TORCH_LINEAR)): + def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, ops.TORCH_LINEAR), in_pruning:bool = False): """ Get all pruning groups for the given module. Groups are generated based on root module types. All groups will contain a full indices of the prunable elements or channels. @@ -281,7 +281,7 @@ def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, o Args: ignored_layers (list): List of layers to be ignored during pruning. root_module_types (tuple): Tuple of root module types to consider for pruning. - + in_pruning (bool): Set to True to use input channels as root instead of output channels (default: False) Yields: list: A pruning group containing dependencies and their corresponding pruning handlers. @@ -309,16 +309,18 @@ def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, o if m in visited_layers: continue - # use output pruning as the root - layer_channels = pruner.get_out_channels(m) - group = self.get_pruning_group( - m, pruner.prune_out_channels, list(range(layer_channels))) + if not in_pruning: + layer_channels = pruner.get_out_channels(m) + group = self.get_pruning_group(m, pruner.prune_out_channels, list(range(layer_channels))) + else: + layer_channels = pruner.get_in_channels(m) + group = self.get_pruning_group(m, pruner.prune_in_channels, list(range(layer_channels))) prunable_group = True for dep, _ in group: module = dep.target.module pruning_fn = dep.handler - if self.is_out_channel_pruning_fn(pruning_fn): + if (not in_pruning and self.is_out_channel_pruning_fn(pruning_fn)) or (in_pruning and self.is_in_channel_pruning_fn(pruning_fn)): visited_layers.append(module) if module in ignored_layers: prunable_group = False From 96b6ceda756017a1f98618a46e683bdad02a4d5b Mon Sep 17 00:00:00 2001 From: Dhruva Kashyap <40919082+DhruvaKashyap@users.noreply.github.com> Date: Thu, 25 Sep 2025 12:40:33 +0530 Subject: [PATCH 2/2] tests: add tests for in_pruning in compute_all_groups --- tests/test_dependency_graph.py | 12 ++++++++++++ tests/test_single_channel_output.py | 16 +++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/test_dependency_graph.py b/tests/test_dependency_graph.py index 78823f1..2712af0 100644 --- a/tests/test_dependency_graph.py +++ b/tests/test_dependency_graph.py @@ -38,5 +38,17 @@ def test_depgraph(): #for g in groups: # print(g) +def test_depgraph_inpruning(): + model = resnet18().eval() + + # 1. build dependency graph for resnet18 + DG = tp.DependencyGraph() + DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224)) + + for group in DG.get_all_groups(in_pruning=True): + idxs = [2,4,6] # my pruning indices + group.prune(idxs=idxs) + print(model) + if __name__=='__main__': test_depgraph() \ No newline at end of file diff --git a/tests/test_single_channel_output.py b/tests/test_single_channel_output.py index 99f5a84..b8ea180 100644 --- a/tests/test_single_channel_output.py +++ b/tests/test_single_channel_output.py @@ -28,5 +28,19 @@ def test_single_channel_output(): print(all_groups[0]) assert len(all_groups[0])==3 +def test_single_channel_input(): + model = Model() + example_inputs = torch.randn(1, 3, 224, 224) + DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs) + + all_groups = list(DG.get_all_groups(in_pruning=True)) + print(all_groups[0]) + print(all_groups[1]) + print(all_groups[2]) + assert len(all_groups[0])==4 + assert len(all_groups[1])==4 + assert len(all_groups[2])==1 + if __name__ == "__main__": - test_single_channel_output() \ No newline at end of file + test_single_channel_output() + test_single_channel_input() \ No newline at end of file