@@ -20,8 +20,8 @@ def smooth_map(data: torch.Tensor, kernel_size: int = 5):
2020 # set kernel
2121 kernel = torch .zeros (data .shape [:2 ], dtype = data .dtype , device = data .device )
2222 kernel [
23- (data .shape [0 ] - kernel_size ) // 2 :(data .shape [0 ] + kernel_size ) // 2 ,
24- (data .shape [1 ] - kernel_size ) // 2 :(data .shape [1 ] + kernel_size ) // 2
23+ (data .shape [0 ] - kernel_size ) // 2 :(data .shape [0 ] + kernel_size ) // 2 ,
24+ (data .shape [1 ] - kernel_size ) // 2 :(data .shape [1 ] + kernel_size ) // 2
2525 ] = 1
2626 while kernel .shape .__len__ () < data .shape .__len__ ():
2727 kernel = kernel .unsqueeze (- 1 )
@@ -611,7 +611,7 @@ def unflatten_batched_indices(flat_indices, shape):
611611 Convert batched flattened indices back to individual dimension indices.
612612
613613 Args:
614- flat_indices (torch.Tensor): Batched flattened indices with shape [b1, b2, flat_inds]
614+ flat_indices (torch.Tensor): Batched flattened indices with shape [b1, b2, ..., flat_inds]
615615 shape (tuple): Original tensor shape to unflatten into
616616
617617 Returns:
@@ -640,14 +640,15 @@ def unflatten_batched_indices(flat_indices, shape):
640640 return out_indices
641641
642642
643- def estimate_b1_from_db (
643+ def estimate_b1_b0_from_db (
644644 data : torch .Tensor , db_t1t2b1b0 : torch .Tensor , device : torch .device ,
645- t1t2b1b0_vals : torch .Tensor ,batch_size : int = 10 ):
645+ b1_vals : torch .Tensor , b0_vals : torch . Tensor , batch_size : int = 10 ):
646646 logger .info (f"l2 fit - b1 estimate" )
647647 num_batches = int (np .ceil (data .shape [0 ] / batch_size ))
648648 b1_alloc = torch .zeros (data .shape [:- 1 ])
649+ b0_alloc = torch .zeros (data .shape [:- 1 ])
649650 db_shape = db_t1t2b1b0 .shape
650- db_t1t2b1b0 = db_t1t2b1b0 .to (device )
651+ db_t1t2b1b0 = db_t1t2b1b0 .view ( - 1 , db_t1t2b1b0 . shape [ - 1 ]). to (device = device , dtype = data . dtype )
651652
652653 if data .ndim < 5 :
653654 msg = f"Assume 4D input data, added singular channel dim for processing"
@@ -662,20 +663,27 @@ def estimate_b1_from_db(
662663
663664 for idx_c in tqdm .trange (data .shape [- 2 ], desc = "channel dim processing" ):
664665 for idx_z in range (data .shape [2 ]):
666+ # we process per B0 simulated value, otherwise the memory explodes
665667 for idx_x in range (num_batches ):
666668 start = idx_x * batch_size
667669 end = min ((idx_x + 1 ) * batch_size , data .shape [0 ])
668670 data_batch = data [start :end , :, idx_z , idx_c , :].to (device )
671+ d_shape = data_batch .shape
672+ data_batch = data_batch .view (- 1 , d_shape [- 1 ])
669673
670- l2 = torch .linalg .norm (data_batch [:, None ] - db_t1t2b1b0 .view (- 1 , db_shape [- 1 ])[None , :, None ], dim = - 1 )
671- vals , indices = torch .min (l2 , dim = 1 )
672- batch_t1t2b1b0_vals = t1t2b1b0_vals [indices ]
674+ l2 = torch .cdist (data_batch , db_t1t2b1b0 ).reshape (* d_shape [:- 1 ], db_t1t2b1b0 .shape [0 ])
675+ vals , indices = torch .min (l2 , dim = - 1 )
676+ indices = unflatten_batched_indices (indices , db_shape [:- 1 ])
677+ # l2 = torch.linalg.norm(data_batch[:, None] - db_t1t2b1b0.view(-1, db_shape[-1])[None, :, None], dim=-1)
678+ # vals, indices = torch.min(l2, dim=1)
673679
674- b1_alloc [start :end , :, idx_z , idx_c ] = batch_t1t2b1b0_vals [..., 2 ].cpu ()
680+ b1_alloc [start :end , :, idx_z , idx_c ] = b1_vals [indices [:, :, 2 ].cpu ()]
681+ b0_alloc [start :end , :, idx_z , idx_c ] = b0_vals [indices [:, :, 3 ].cpu ()]
675682
676683 logger .info ("B1 smoothing" )
677- b1_sm = smooth_map (b1_alloc , kernel_size = min (b1_alloc .shape [:2 ]) // 35 )
678- return b1_sm , b1_alloc
684+ b1_sm = smooth_map (b1_alloc , kernel_size = min (b1_alloc .shape [:2 ]) // 28 )
685+ b0_sm = smooth_map (b0_alloc , kernel_size = min (b0_alloc .shape [:2 ]) // 28 )
686+ return b1_sm , b1_alloc , b0_sm , b0_alloc
679687
680688
681689def regularised_fit (
@@ -865,15 +873,19 @@ def fit_mese(
865873 combined_data = root_sum_of_squares (data_xyzce , dim_channel = - 2 ).unsqueeze (- 2 )
866874 else :
867875 combined_data = data_xyzce .clone ()
868- b1_data , b1_est = estimate_b1_from_db (
869- data = combined_data , db_t1t2b1b0 = db_t1t2b1b0 , t1t2b1b0_vals = t1t2b1b0_vals , device = device ,
870- batch_size = 1
876+ b1_data , b1_est , b0_data , b0_est = estimate_b1_b0_from_db (
877+ data = combined_data , db_t1t2b1b0 = db_t1t2b1b0 , b1_vals = b1_vals , b0_vals = b0_vals ,
878+ device = device , batch_size = 1
871879 )
872- nifti_save (b1_est , img_aff = input_affine , path_to_dir = path_out , file_name = "b1_estimate" )
873- nifti_save (b1_data , img_aff = input_affine , path_to_dir = path_out , file_name = "b1_estimate_smoothed" )
880+ for i , d in enumerate ([b1_data , b1_est , b0_data , b0_est ]):
881+ nifti_save (
882+ d , img_aff = input_affine , path_to_dir = path_out ,
883+ file_name = ["b1_estimate_smoothed" , "b1_estimate" , "b0_estimate_smoothed" , "b0_estimate" ][i ]
884+ )
874885 if data_xyzce .shape [- 2 ] > 1 :
875886 # need to expand back to channels
876887 b1_data = b1_data .expand_as (data_xyzce [..., 0 ])
888+ b0_data = b0_data .expand_as (data_xyzce [..., 0 ])
877889
878890
879891 # now use this for input to to the regularised fit
0 commit comments