Skip to content

Commit e17a53a

Browse files
authored
Merge pull request #493 from VainF/v2.0
V2.0
2 parents 6ca5595 + 582d899 commit e17a53a

4 files changed

Lines changed: 45 additions & 40 deletions

File tree

.github/workflows/test_torch_200.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
uses: actions/setup-python@v3
2020
with:
2121
python-version: "3.10"
22-
cache: 'pip' # caching pip dependencies
22+
# cache: 'pip' # caching pip dependencies
2323
- name: Install dependencies
2424
run: |
2525
python -m pip install --upgrade pip

tests/test_concat.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,48 @@
66
import torch_pruning as tp
77
import torch.nn as nn
88

9-
class Net(nn.Module):
10-
def __init__(self, in_dim):
9+
class MLP(nn.Module):
10+
def __init__(self, input_dim, output_layer=True, dims=None, dropout=0):
1111
super().__init__()
12-
self.block1 = nn.Sequential(
13-
nn.Conv2d(in_dim, in_dim, 1),
14-
nn.BatchNorm2d(in_dim),
15-
nn.GELU(),
16-
nn.Conv2d(in_dim, in_dim, 1),
17-
nn.BatchNorm2d(in_dim)
18-
)
19-
self.parallel_path = nn.Sequential(
20-
nn.Conv2d(in_dim, in_dim, 1),
21-
nn.BatchNorm2d(in_dim),
22-
nn.GELU(),
23-
nn.Conv2d(in_dim, in_dim//2, 1),
24-
nn.BatchNorm2d(in_dim//2)
25-
)
26-
self.block2 = nn.Sequential(
27-
nn.Conv2d(in_dim * 2 + in_dim//2, in_dim, 1),
28-
nn.BatchNorm2d(in_dim)
29-
)
30-
12+
if dims is None:
13+
dims = []
14+
layers = list()
15+
for i_dim in dims:
16+
layers.append(nn.Linear(input_dim, i_dim))
17+
layers.append(nn.BatchNorm1d(i_dim))
18+
layers.append(nn.ReLU())
19+
layers.append(nn.Dropout(p=dropout))
20+
input_dim = i_dim
21+
if output_layer:
22+
layers.append(nn.Linear(input_dim, 1))
23+
self.mlp = nn.Sequential(*layers)
24+
25+
def forward(self, x):
26+
return self.mlp(x)
27+
28+
class widedeep(nn.Module):
29+
def __init__(self, input_dim):
30+
super(widedeep, self).__init__()
31+
self.dims = input_dim
32+
33+
self.mlp = MLP(self.dims, True, dims=[32,16], dropout=0.2)
34+
self.linear = nn.Linear(self.dims, 3)
35+
self.lin2 = nn.Linear(4, 1)
36+
3137
def forward(self, x):
32-
x = self.block1(x)
33-
x2 = self.parallel_path(x)
34-
x = torch.cat([x, x, x2], dim=1)
35-
x = self.block2(x)
38+
x = x.reshape(x.shape[0], -1)
39+
mlp_out = self.mlp(x)
40+
linear_out = self.linear(x)
41+
x = torch.concat([linear_out, mlp_out], dim=-1)
42+
x = self.lin2(x)
43+
x = torch.sigmoid(x)
3644
return x
3745

3846
def test_pruner():
39-
model = Net(512)
47+
model = widedeep(32)
4048
print(model)
4149
# Global metrics
42-
example_inputs = torch.randn(1, 512, 7, 7)
50+
example_inputs = torch.randn(1, 32)
4351
imp = tp.importance.MagnitudeImportance(p=2)
4452
ignored_layers = []
4553

torch_pruning/_helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def __init__(self, offset, reverse=False):
6969
self.reverse = reverse
7070

7171
def __call__(self, idxs: _HybridIndex):
72-
7372
if self.reverse == True:
7473
new_idxs = [
7574
_HybridIndex(idx = i.idx - self.offset[0], root_idx=i.root_idx )

torch_pruning/dependency.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
INDEX_MAPPING_PLACEHOLDER = None
1212
MAX_RECURSION_DEPTH = 500
13+
MAX_LEGAL_DIM = 100
1314

1415
class Node(object):
1516
""" Node of DepGraph.
@@ -365,7 +366,7 @@ def build_dependency(
365366
# Ignore layers & nn.Parameter
366367
if ignored_layers is not None:
367368
self.IGNORED_LAYERS_IN_TRACING.extend(ignored_layers)
368-
self.ignored_params = ignored_params
369+
self.ignored_params = ignored_params if ignored_params is not None else []
369370

370371
# Ignore all sub-modules of customized layers since they will be handled by the customized pruner
371372
for layer_type_or_instance in self.CUSTOMIZED_PRUNERS.keys():
@@ -512,7 +513,7 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args):
512513
)
513514

514515
_fix_dependency_graph_non_recursive(*group[0])
515-
516+
516517
# merge pruning ops
517518
merged_group = Group() # craft a new group for merging
518519
for dep, idxs in group.items:
@@ -941,9 +942,9 @@ def _update_slice_index_mapping(self, slice_node: Node):
941942
return
942943
grad_fn = slice_node.grad_fn
943944
if hasattr(grad_fn, '_saved_self_sym_sizes'):
944-
if len(grad_fn._saved_self_sym_sizes)==4 and grad_fn._saved_dim != 1:
945+
if len(grad_fn._saved_self_sym_sizes)==4 and grad_fn._saved_dim != 1 and grad_fn._saved_dim<MAX_LEGAL_DIM:
945946
return
946-
elif len(grad_fn._saved_self_sym_sizes)==3 and grad_fn._saved_dim != 2:
947+
elif len(grad_fn._saved_self_sym_sizes)==3 and grad_fn._saved_dim != 2 and grad_fn._saved_dim<MAX_LEGAL_DIM:
947948
return
948949

949950
start, step, end, dim = slice_node.module.start, slice_node.module.step, slice_node.module.end, slice_node.module.dim
@@ -966,8 +967,6 @@ def _init_shape_information(self):
966967
grad_fn = node.grad_fn
967968

968969
if hasattr(grad_fn, '_saved_self_sizes') or hasattr(grad_fn, '_saved_split_sizes'):
969-
MAX_LEGAL_DIM = 100
970-
971970
if hasattr(grad_fn, '_saved_split_sizes') and hasattr(grad_fn, '_saved_dim') :
972971
if grad_fn._saved_dim != 1 and grad_fn._saved_dim < MAX_LEGAL_DIM: # a temp fix for pytorch==1.11, where the _saved_dim is an uninitialized value like 118745347895359
973972
continue
@@ -1095,7 +1094,7 @@ def _update_concat_index_mapping(self, cat_node: Node):
10951094
if cat_node.type != ops.OPTYPE.CONCAT:
10961095
return
10971096

1098-
if hasattr(cat_node.grad_fn, '_saved_dim') and cat_node.grad_fn._saved_dim != 1: # this only works for Pytorch>=1.12
1097+
if hasattr(cat_node.grad_fn, '_saved_dim') and cat_node.grad_fn._saved_dim<MAX_LEGAL_DIM and cat_node.grad_fn._saved_dim != 1: # this only works for Pytorch>=1.12
10991098
return
11001099

11011100
if cat_node.module.concat_sizes is not None:
@@ -1151,11 +1150,10 @@ def _update_split_index_mapping(self, split_node: Node):
11511150
# There a issue in some pytorch version, where the _saved_dim is an uninitialized value like 118745347895359
11521151
# So we need to check if the _saved_dim is a valid value (<len(_saved_self_sym_sizes) or a nominal value like 20)
11531152
if hasattr(split_node.grad_fn, '_saved_self_sym_sizes'):
1154-
if split_node.grad_fn._saved_dim<len(split_node.grad_fn._saved_self_sym_sizes) and split_node.grad_fn._saved_dim != 1:
1153+
if split_node.grad_fn._saved_dim<len(split_node.grad_fn._saved_self_sym_sizes) and split_node.grad_fn._saved_dim<MAX_LEGAL_DIM and split_node.grad_fn._saved_dim != 1:
11551154
return
11561155
else:
1157-
THRESHOLD = 20
1158-
if split_node.grad_fn._saved_dim<THRESHOLD and split_node.grad_fn._saved_dim>=0 and split_node.grad_fn._saved_dim != 1:
1156+
if split_node.grad_fn._saved_dim>=0 and split_node.grad_fn._saved_dim<MAX_LEGAL_DIM and split_node.grad_fn._saved_dim != 1:
11591157
return
11601158
offsets = split_node.module.offsets
11611159

@@ -1189,7 +1187,7 @@ def _update_unbind_index_mapping(self, unbind_node: Node):
11891187
if unbind_node.type != ops.OPTYPE.UNBIND:
11901188
return
11911189

1192-
if hasattr(unbind_node.grad_fn, '_saved_dim') and unbind_node.grad_fn._saved_dim != 0: # For timm attention
1190+
if hasattr(unbind_node.grad_fn, '_saved_dim') and unbind_node.grad_fn._saved_dim<MAX_LEGAL_DIM and (unbind_node.grad_fn._saved_dim )!= 0: # this only works for Pytorch>=1.12
11931191
return
11941192

11951193
num_chunks = len(unbind_node.outputs)

0 commit comments

Comments
 (0)