@@ -905,3 +905,43 @@ def input_kwargs(shape, dtype, device):
905905 )
906906 bench .set_gems (flag_gems .per_token_group_quant_fp8 )
907907 bench .run ()
908+
909+
910+ @pytest .mark .unfold
911+ def test_perf_unfold_backward ():
912+ def unfold_backward_input_fn (config , dtype , device ):
913+ input_sizes , dim , size , step = config
914+ d = dim % len (input_sizes )
915+ num_windows = (input_sizes [d ] - size ) // step + 1
916+ grad_shape = (
917+ list (input_sizes [:d ]) + [num_windows ] + list (input_sizes [d + 1 :]) + [size ]
918+ )
919+ grad_in = torch .randn (grad_shape , dtype = dtype , device = device )
920+ yield grad_in , list (input_sizes ), dim , size , step
921+
922+ class UnfoldBackwardBenchmark (Benchmark ):
923+ def set_shapes (self , shape_file_path = None ):
924+ self .shapes = [
925+ ((32 , 64 ), 1 , 16 , 16 ),
926+ ((16 , 33 ), 0 , 5 , 2 ),
927+ ((4 , 8 , 12 ), - 1 , 6 , 4 ),
928+ ((7 , 13 ), 1 , 13 , 3 ),
929+ ((6 , 20 ), 1 , 7 , 4 ),
930+ ((2 , 3 , 17 ), - 1 , 9 , 1 ),
931+ ((2 , 17 ), 1 , 4 , 6 ),
932+ ]
933+
934+ def set_more_shapes (self ):
935+ return None
936+
937+ def get_input_iter (self , cur_dtype ):
938+ for config in self .shapes :
939+ yield from unfold_backward_input_fn (config , cur_dtype , self .device )
940+
941+ bench = UnfoldBackwardBenchmark (
942+ op_name = "unfold_backward" ,
943+ torch_op = torch .ops .aten .unfold_backward ,
944+ dtypes = [torch .float16 , torch .float32 , torch .bfloat16 ],
945+ )
946+ bench .set_gems (flag_gems .unfold_backward )
947+ bench .run ()
0 commit comments