1010
1111INDEX_MAPPING_PLACEHOLDER = None
1212MAX_RECURSION_DEPTH = 500
13+ MAX_LEGAL_DIM = 100
1314
1415class Node (object ):
1516 """ Node of DepGraph.
@@ -365,7 +366,7 @@ def build_dependency(
365366 # Ignore layers & nn.Parameter
366367 if ignored_layers is not None :
367368 self .IGNORED_LAYERS_IN_TRACING .extend (ignored_layers )
368- self .ignored_params = ignored_params
369+ self .ignored_params = ignored_params if ignored_params is not None else []
369370
370371 # Ignore all sub-modules of customized layers since they will be handled by the customized pruner
371372 for layer_type_or_instance in self .CUSTOMIZED_PRUNERS .keys ():
@@ -512,7 +513,7 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args):
512513 )
513514
514515 _fix_dependency_graph_non_recursive (* group [0 ])
515-
516+
516517 # merge pruning ops
517518 merged_group = Group () # craft a new group for merging
518519 for dep , idxs in group .items :
@@ -941,9 +942,9 @@ def _update_slice_index_mapping(self, slice_node: Node):
941942 return
942943 grad_fn = slice_node .grad_fn
943944 if hasattr (grad_fn , '_saved_self_sym_sizes' ):
944- if len (grad_fn ._saved_self_sym_sizes )== 4 and grad_fn ._saved_dim != 1 :
945+ if len (grad_fn ._saved_self_sym_sizes )== 4 and grad_fn ._saved_dim != 1 and grad_fn . _saved_dim < MAX_LEGAL_DIM :
945946 return
946- elif len (grad_fn ._saved_self_sym_sizes )== 3 and grad_fn ._saved_dim != 2 :
947+ elif len (grad_fn ._saved_self_sym_sizes )== 3 and grad_fn ._saved_dim != 2 and grad_fn . _saved_dim < MAX_LEGAL_DIM :
947948 return
948949
949950 start , step , end , dim = slice_node .module .start , slice_node .module .step , slice_node .module .end , slice_node .module .dim
@@ -966,8 +967,6 @@ def _init_shape_information(self):
966967 grad_fn = node .grad_fn
967968
968969 if hasattr (grad_fn , '_saved_self_sizes' ) or hasattr (grad_fn , '_saved_split_sizes' ):
969- MAX_LEGAL_DIM = 100
970-
971970 if hasattr (grad_fn , '_saved_split_sizes' ) and hasattr (grad_fn , '_saved_dim' ) :
972971 if grad_fn ._saved_dim != 1 and grad_fn ._saved_dim < MAX_LEGAL_DIM : # a temp fix for pytorch==1.11, where the _saved_dim is an uninitialized value like 118745347895359
973972 continue
@@ -1095,7 +1094,7 @@ def _update_concat_index_mapping(self, cat_node: Node):
10951094 if cat_node .type != ops .OPTYPE .CONCAT :
10961095 return
10971096
1098- if hasattr (cat_node .grad_fn , '_saved_dim' ) and cat_node .grad_fn ._saved_dim != 1 : # this only works for Pytorch>=1.12
1097+ if hasattr (cat_node .grad_fn , '_saved_dim' ) and cat_node .grad_fn ._saved_dim < MAX_LEGAL_DIM and cat_node . grad_fn . _saved_dim != 1 : # this only works for Pytorch>=1.12
10991098 return
11001099
11011100 if cat_node .module .concat_sizes is not None :
@@ -1151,11 +1150,10 @@ def _update_split_index_mapping(self, split_node: Node):
11511150 # There a issue in some pytorch version, where the _saved_dim is an uninitialized value like 118745347895359
11521151 # So we need to check if the _saved_dim is a valid value (<len(_saved_self_sym_sizes) or a nominal value like 20)
11531152 if hasattr (split_node .grad_fn , '_saved_self_sym_sizes' ):
1154- if split_node .grad_fn ._saved_dim < len (split_node .grad_fn ._saved_self_sym_sizes ) and split_node .grad_fn ._saved_dim != 1 :
1153+ if split_node .grad_fn ._saved_dim < len (split_node .grad_fn ._saved_self_sym_sizes ) and split_node .grad_fn ._saved_dim < MAX_LEGAL_DIM and split_node . grad_fn . _saved_dim != 1 :
11551154 return
11561155 else :
1157- THRESHOLD = 20
1158- if split_node .grad_fn ._saved_dim < THRESHOLD and split_node .grad_fn ._saved_dim >= 0 and split_node .grad_fn ._saved_dim != 1 :
1156+ if split_node .grad_fn ._saved_dim >= 0 and split_node .grad_fn ._saved_dim < MAX_LEGAL_DIM and split_node .grad_fn ._saved_dim != 1 :
11591157 return
11601158 offsets = split_node .module .offsets
11611159
@@ -1189,7 +1187,7 @@ def _update_unbind_index_mapping(self, unbind_node: Node):
11891187 if unbind_node .type != ops .OPTYPE .UNBIND :
11901188 return
11911189
1192- if hasattr (unbind_node .grad_fn , '_saved_dim' ) and unbind_node .grad_fn ._saved_dim != 0 : # For timm attention
1190+ if hasattr (unbind_node .grad_fn , '_saved_dim' ) and unbind_node .grad_fn ._saved_dim < MAX_LEGAL_DIM and ( unbind_node . grad_fn . _saved_dim ) != 0 : # this only works for Pytorch>=1.12
11931191 return
11941192
11951193 num_chunks = len (unbind_node .outputs )
0 commit comments