Skip to content

Commit 4cf5f4b

Browse files
committed
optimise matching, add b0 fit
1 parent 4a8da9c commit 4cf5f4b

1 file changed

Lines changed: 29 additions & 17 deletions

File tree

pymritools/modeling/dictionary/grid_search_channels.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def smooth_map(data: torch.Tensor, kernel_size: int = 5):
2020
# set kernel
2121
kernel = torch.zeros(data.shape[:2], dtype=data.dtype, device=data.device)
2222
kernel[
23-
(data.shape[0] - kernel_size) // 2:(data.shape[0] + kernel_size) // 2,
24-
(data.shape[1] - kernel_size) // 2:(data.shape[1] + kernel_size) // 2
23+
(data.shape[0] - kernel_size) // 2:(data.shape[0] + kernel_size) // 2,
24+
(data.shape[1] - kernel_size) // 2:(data.shape[1] + kernel_size) // 2
2525
] = 1
2626
while kernel.shape.__len__() < data.shape.__len__():
2727
kernel = kernel.unsqueeze(-1)
@@ -611,7 +611,7 @@ def unflatten_batched_indices(flat_indices, shape):
611611
Convert batched flattened indices back to individual dimension indices.
612612
613613
Args:
614-
flat_indices (torch.Tensor): Batched flattened indices with shape [b1, b2, flat_inds]
614+
flat_indices (torch.Tensor): Batched flattened indices with shape [b1, b2, ..., flat_inds]
615615
shape (tuple): Original tensor shape to unflatten into
616616
617617
Returns:
@@ -640,14 +640,15 @@ def unflatten_batched_indices(flat_indices, shape):
640640
return out_indices
641641

642642

643-
def estimate_b1_from_db(
643+
def estimate_b1_b0_from_db(
644644
data: torch.Tensor, db_t1t2b1b0: torch.Tensor, device: torch.device,
645-
t1t2b1b0_vals: torch.Tensor,batch_size: int = 10):
645+
b1_vals: torch.Tensor, b0_vals: torch.Tensor, batch_size: int = 10):
646646
logger.info(f"l2 fit - b1 estimate")
647647
num_batches = int(np.ceil(data.shape[0] / batch_size))
648648
b1_alloc = torch.zeros(data.shape[:-1])
649+
b0_alloc = torch.zeros(data.shape[:-1])
649650
db_shape = db_t1t2b1b0.shape
650-
db_t1t2b1b0 = db_t1t2b1b0.to(device)
651+
db_t1t2b1b0 = db_t1t2b1b0.view(-1,db_t1t2b1b0.shape[-1]).to(device=device, dtype=data.dtype)
651652

652653
if data.ndim < 5:
653654
msg = f"Assume 4D input data, added singular channel dim for processing"
@@ -662,20 +663,27 @@ def estimate_b1_from_db(
662663

663664
for idx_c in tqdm.trange(data.shape[-2], desc="channel dim processing"):
664665
for idx_z in range(data.shape[2]):
666+
# we process per B0 simulated value, otherwise the memory explodes
665667
for idx_x in range(num_batches):
666668
start = idx_x * batch_size
667669
end = min((idx_x + 1) * batch_size, data.shape[0])
668670
data_batch = data[start:end, :, idx_z, idx_c, :].to(device)
671+
d_shape = data_batch.shape
672+
data_batch = data_batch.view(-1, d_shape[-1])
669673

670-
l2 = torch.linalg.norm(data_batch[:, None] - db_t1t2b1b0.view(-1, db_shape[-1])[None, :, None], dim=-1)
671-
vals, indices = torch.min(l2, dim=1)
672-
batch_t1t2b1b0_vals = t1t2b1b0_vals[indices]
674+
l2 = torch.cdist(data_batch, db_t1t2b1b0).reshape(*d_shape[:-1], db_t1t2b1b0.shape[0])
675+
vals, indices = torch.min(l2, dim=-1)
676+
indices = unflatten_batched_indices(indices, db_shape[:-1])
677+
# l2 = torch.linalg.norm(data_batch[:, None] - db_t1t2b1b0.view(-1, db_shape[-1])[None, :, None], dim=-1)
678+
# vals, indices = torch.min(l2, dim=1)
673679

674-
b1_alloc[start:end, :, idx_z, idx_c] = batch_t1t2b1b0_vals[..., 2].cpu()
680+
b1_alloc[start:end, :, idx_z, idx_c] = b1_vals[indices[:, :, 2].cpu()]
681+
b0_alloc[start:end, :, idx_z, idx_c] = b0_vals[indices[:, :, 3].cpu()]
675682

676683
logger.info("B1 smoothing")
677-
b1_sm = smooth_map(b1_alloc, kernel_size=min(b1_alloc.shape[:2]) // 35)
678-
return b1_sm, b1_alloc
684+
b1_sm = smooth_map(b1_alloc, kernel_size=min(b1_alloc.shape[:2]) // 28)
685+
b0_sm = smooth_map(b0_alloc, kernel_size=min(b0_alloc.shape[:2]) // 28)
686+
return b1_sm, b1_alloc, b0_sm, b0_alloc
679687

680688

681689
def regularised_fit(
@@ -865,15 +873,19 @@ def fit_mese(
865873
combined_data = root_sum_of_squares(data_xyzce, dim_channel=-2).unsqueeze(-2)
866874
else:
867875
combined_data = data_xyzce.clone()
868-
b1_data, b1_est = estimate_b1_from_db(
869-
data=combined_data, db_t1t2b1b0=db_t1t2b1b0, t1t2b1b0_vals=t1t2b1b0_vals, device=device,
870-
batch_size=1
876+
b1_data, b1_est, b0_data, b0_est = estimate_b1_b0_from_db(
877+
data=combined_data, db_t1t2b1b0=db_t1t2b1b0, b1_vals=b1_vals, b0_vals=b0_vals,
878+
device=device, batch_size=1
871879
)
872-
nifti_save(b1_est, img_aff=input_affine, path_to_dir=path_out, file_name="b1_estimate")
873-
nifti_save(b1_data, img_aff=input_affine, path_to_dir=path_out, file_name="b1_estimate_smoothed")
880+
for i, d in enumerate([b1_data, b1_est, b0_data, b0_est]):
881+
nifti_save(
882+
d, img_aff=input_affine, path_to_dir=path_out,
883+
file_name=["b1_estimate_smoothed", "b1_estimate", "b0_estimate_smoothed", "b0_estimate"][i]
884+
)
874885
if data_xyzce.shape[-2] > 1:
875886
# need to expand back to channels
876887
b1_data = b1_data.expand_as(data_xyzce[..., 0])
888+
b0_data = b0_data.expand_as(data_xyzce[..., 0])
877889

878890

879891
# now use this for input to to the regularised fit

0 commit comments

Comments
 (0)