1111from cinema import CineMA , patchify , unpatchify
1212
1313
14+ def plot_mae_reconstruction (
15+ batch : dict [str , torch .Tensor ],
16+ pred_dict : dict [str , torch .Tensor ],
17+ enc_mask_dict : dict [str , torch .Tensor ],
18+ patch_size_dict : dict [str , tuple [int , ...]],
19+ grid_size_dict : dict [str , tuple [int , ...]],
20+ sax_slices : int ,
21+ ) -> plt .Figure :
22+ """Plot MAE reconstruction."""
23+ n_rows = sax_slices + 3
24+ n_cols = 4
25+ fig , axs = plt .subplots (n_rows , n_cols , figsize = (n_cols * 2 , n_rows * 2 ), dpi = 300 )
26+ for i , view in enumerate (["lax_2c" , "lax_3c" , "lax_4c" , "sax" ]):
27+ patches = patchify (image = batch [view ], patch_size = patch_size_dict [view ])
28+ patches [enc_mask_dict [view ]] = pred_dict [view ]
29+ masks = torch .zeros_like (patches )
30+ masks [enc_mask_dict [view ]] = 1
31+ masks = unpatchify (masks , patch_size = patch_size_dict [view ], grid_size = grid_size_dict [view ])
32+ masks = masks [0 , 0 ]
33+ reconstructed = unpatchify (
34+ patches ,
35+ patch_size = patch_size_dict [view ],
36+ grid_size = grid_size_dict [view ],
37+ )
38+ reconstructed = reconstructed [0 , 0 ].numpy ()
39+ image = batch [view ][0 , 0 ].numpy ()
40+ error = np .abs (reconstructed - image )
41+
42+ if view == "sax" :
43+ reconstructed = reconstructed [..., :sax_slices ]
44+ for j in range (sax_slices ):
45+ axs [3 + j , 0 ].set_ylabel (f"SAX slice { j } " )
46+ axs [3 + j , 0 ].imshow (image [..., j ], cmap = "gray" )
47+ axs [3 + j , 1 ].imshow (masks [..., j ], cmap = "gray" )
48+ axs [3 + j , 2 ].imshow (reconstructed [..., j ], cmap = "gray" )
49+ axs [3 + j , 3 ].imshow (error [..., j ], cmap = "gray" )
50+ else :
51+ axs [i , 0 ].imshow (image , cmap = "gray" )
52+ axs [i , 1 ].imshow (masks , cmap = "gray" )
53+ axs [i , 2 ].imshow (reconstructed , cmap = "gray" )
54+ axs [i , 3 ].imshow (error , cmap = "gray" )
55+ axs [i , 0 ].set_ylabel ({"lax_2c" : "LAX 2C" , "lax_3c" : "LAX 3C" , "lax_4c" : "LAX 4C" }[view ])
56+ if i == 0 :
57+ axs [i , 0 ].set_title ("Original" )
58+ axs [i , 1 ].set_title ("Mask" )
59+ axs [i , 2 ].set_title ("Reconstructed" )
60+ axs [i , 3 ].set_title ("Error" )
61+ # remove the x and y ticks
62+ for i in range (n_rows ):
63+ for j in range (n_cols ):
64+ axs [i , j ].set_xticks ([])
65+ axs [i , j ].set_yticks ([])
66+ fig .tight_layout ()
67+ fig .subplots_adjust (wspace = 0 , hspace = 0 )
68+ return fig
69+
70+
1471def run (device : torch .device , dtype : torch .dtype ) -> None :
1572 """Run MAE reconstruction."""
73+ t = 25 # which time frame to use
74+
1675 # load model
1776 model = CineMA .from_pretrained ()
18- model .to (device )
1977 model .eval ()
78+ patch_size_dict = model .dec_patch_size_dict
79+ grid_size_dict = {k : v .patch_embed .grid_size for k , v in model .enc_down_dict .items ()}
80+ model .to (device )
2081
2182 # load sample data and form a batch of size 1
2283 transform = Compose (
@@ -46,7 +107,7 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
46107 lax_4c_image = torch .from_numpy (
47108 np .transpose (sitk .GetArrayFromImage (sitk .ReadImage (exp_dir / "data/ukb/1/1_lax_4c.nii.gz" )))
48109 )
49- t = 25 # which time frame to use
110+ sax_slices = sax_image . shape [ - 2 ]
50111 batch = {
51112 "sax" : sax_image [None , ..., t ],
52113 "lax_2c" : lax_2c_image [None , ..., 0 , t ],
@@ -62,45 +123,9 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
62123 _ , pred_dict , enc_mask_dict , _ = model (batch , enc_mask_ratio = 0.75 )
63124
64125 # visualize
65- _ , axs = plt .subplots (6 , 4 , figsize = (12 , 18 ))
66- for i , view in enumerate (["lax_2c" , "lax_3c" , "lax_4c" , "sax" ]):
67- patches = patchify (image = batch [view ], patch_size = model .dec_patch_size_dict [view ])
68- patches [enc_mask_dict [view ]] = pred_dict [view ]
69- masks = torch .zeros_like (patches )
70- masks [enc_mask_dict [view ]] = 1
71- masks = unpatchify (
72- masks , patch_size = model .dec_patch_size_dict [view ], grid_size = model .enc_down_dict [view ].patch_embed .grid_size
73- )
74- masks = masks [0 , 0 ]
75- reconstructed = unpatchify (
76- patches ,
77- patch_size = model .dec_patch_size_dict [view ],
78- grid_size = model .enc_down_dict [view ].patch_embed .grid_size ,
79- )
80- reconstructed = reconstructed [0 , 0 ].detach ().cpu ().numpy ()
81- image = batch [view ][0 , 0 ].detach ().cpu ().numpy ()
82- error = np .abs (reconstructed - image )
83-
84- if view == "sax" :
85- for j in range (3 ):
86- z = j * 3
87- axs [3 + j , 0 ].set_ylabel (f"{ view } slice { z } " )
88- axs [3 + j , 0 ].imshow (image [..., z ], cmap = "gray" )
89- axs [3 + j , 1 ].imshow (masks [..., z ], cmap = "gray" )
90- axs [3 + j , 2 ].imshow (reconstructed [..., z ], cmap = "gray" )
91- axs [3 + j , 3 ].imshow (error [..., z ], cmap = "gray" )
92- else :
93- axs [i , 0 ].imshow (image , cmap = "gray" )
94- axs [i , 1 ].imshow (masks , cmap = "gray" )
95- axs [i , 2 ].imshow (reconstructed , cmap = "gray" )
96- axs [i , 3 ].imshow (error , cmap = "gray" )
97- axs [i , 0 ].set_ylabel (view )
98- if i == 0 :
99- axs [i , 0 ].set_title ("Original" )
100- axs [i , 1 ].set_title ("Mask" )
101- axs [i , 2 ].set_title ("Reconstructed" )
102- axs [i , 3 ].set_title ("Error" )
103- plt .savefig ("mae_reconstruction.png" , dpi = 300 , bbox_inches = "tight" )
126+ batch = {k : v .detach ().cpu () for k , v in batch .items ()}
127+ fig = plot_mae_reconstruction (batch , pred_dict , enc_mask_dict , patch_size_dict , grid_size_dict , sax_slices )
128+ fig .savefig ("mae_reconstruction.png" , dpi = 300 , bbox_inches = "tight" )
104129 plt .show (block = False )
105130
106131
0 commit comments