Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/test_dependency_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,17 @@ def test_depgraph():
#for g in groups:
# print(g)

def test_depgraph_inpruning():
model = resnet18().eval()

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))

for group in DG.get_all_groups(in_pruning=True):
idxs = [2,4,6] # my pruning indices
group.prune(idxs=idxs)
print(model)

if __name__=='__main__':
test_depgraph()
16 changes: 15 additions & 1 deletion tests/test_single_channel_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,19 @@ def test_single_channel_output():
print(all_groups[0])
assert len(all_groups[0])==3

def test_single_channel_input():
model = Model()
example_inputs = torch.randn(1, 3, 224, 224)
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

all_groups = list(DG.get_all_groups(in_pruning=True))
print(all_groups[0])
print(all_groups[1])
print(all_groups[2])
assert len(all_groups[0])==4
assert len(all_groups[1])==4
assert len(all_groups[2])==1

if __name__ == "__main__":
test_single_channel_output()
test_single_channel_output()
test_single_channel_input()
16 changes: 9 additions & 7 deletions torch_pruning/dependency/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,15 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args):
merged_group[i].root_idxs = root_idxs
return merged_group

def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, ops.TORCH_LINEAR)):
def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, ops.TORCH_LINEAR), in_pruning:bool = False):
"""
Get all pruning groups for the given module. Groups are generated based on root module types.
All groups will contain a full indices of the prunable elements or channels.

Args:
ignored_layers (list): List of layers to be ignored during pruning.
root_module_types (tuple): Tuple of root module types to consider for pruning.

in_pruning (bool): Set to True to use input channels as root instead of output channels (default: False)
Yields:
list: A pruning group containing dependencies and their corresponding pruning handlers.

Expand Down Expand Up @@ -309,16 +309,18 @@ def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, o
if m in visited_layers:
continue

# use output pruning as the root
layer_channels = pruner.get_out_channels(m)
group = self.get_pruning_group(
m, pruner.prune_out_channels, list(range(layer_channels)))
if not in_pruning:
layer_channels = pruner.get_out_channels(m)
group = self.get_pruning_group(m, pruner.prune_out_channels, list(range(layer_channels)))
else:
layer_channels = pruner.get_in_channels(m)
group = self.get_pruning_group(m, pruner.prune_in_channels, list(range(layer_channels)))

prunable_group = True
for dep, _ in group:
module = dep.target.module
pruning_fn = dep.handler
if self.is_out_channel_pruning_fn(pruning_fn):
if (not in_pruning and self.is_out_channel_pruning_fn(pruning_fn)) or (in_pruning and self.is_in_channel_pruning_fn(pruning_fn)):
visited_layers.append(module)
if module in ignored_layers:
prunable_group = False
Expand Down