diff --git a/.codespellrc b/.codespellrc index 7035fe17..fdf6f1d6 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ [codespell] skip = [setup.cfg] -ignore-words-list = Reson, DNE, ACI, FPT, sagital, saggital, abd, Joo, Mapp, Commun, vor, Claus \ No newline at end of file +ignore-words-list = Reson, DNE, ACI, FPT, sagital, saggital, abd, Joo, Mapp, Commun, vor, Claus, coo \ No newline at end of file diff --git a/AFQ/_fixes.py b/AFQ/_fixes.py index 9a62f20a..78cd6b0e 100644 --- a/AFQ/_fixes.py +++ b/AFQ/_fixes.py @@ -4,6 +4,8 @@ from math import radians import numpy as np +from dipy.align import vector_fields as vfu +from dipy.align.imwarp import DiffeomorphicMap, mult_aff from dipy.data import default_sphere from dipy.reconst.gqi import squared_radial_component from dipy.tracking.streamline import set_number_of_points @@ -15,6 +17,75 @@ logger = logging.getLogger("AFQ") +def get_simplified_transform(self): + """Constructs a simplified version of this Diffeomorhic Map + + The simplified version incorporates the pre-align transform, as well as + the domain and codomain affine transforms into the displacement field. + The resulting transformation may be regarded as operating on the + image spaces given by the domain and codomain discretization. As a + result, self.prealign, self.disp_grid2world, self.domain_grid2world and + self.codomain affine will be None (denoting Identity) in the resulting + diffeomorphic map. + """ + if self.dim == 2: + simplify_f = vfu.simplify_warp_function_2d + else: + simplify_f = vfu.simplify_warp_function_3d + # Simplify the forward transform + D = self.domain_grid2world + P = self.prealign + Rinv = self.disp_world2grid + Cinv = self.codomain_world2grid + + # this is the matrix which we need to multiply the voxel coordinates + # to interpolate on the forward displacement field ("in"side the + # 'forward' brackets in the expression above) + affine_idx_in = mult_aff(Rinv, mult_aff(P, D)) + + # this is the matrix which we need to multiply the voxel coordinates + # to add to the displacement ("out"side the 'forward' brackets in the + # expression above) + affine_idx_out = mult_aff(Cinv, mult_aff(P, D)) + + # this is the matrix which we need to multiply the displacement vector + # prior to adding to the transformed input point + affine_disp = Cinv + + new_forward = simplify_f( + self.forward, affine_idx_in, affine_idx_out, affine_disp, self.domain_shape + ) + + # Simplify the backward transform + C = self.codomain_grid2world + Pinv = self.prealign_inv + Dinv = self.domain_world2grid + + affine_idx_in = mult_aff(Rinv, C) + affine_idx_out = mult_aff(Dinv, mult_aff(Pinv, C)) + affine_disp = mult_aff(Dinv, Pinv) + new_backward = simplify_f( + self.backward, + affine_idx_in, + affine_idx_out, + affine_disp, + self.codomain_shape, + ) + simplified = DiffeomorphicMap( + dim=self.dim, + disp_shape=self.disp_shape, + disp_grid2world=None, + domain_shape=self.domain_shape, + domain_grid2world=None, + codomain_shape=self.codomain_shape, + codomain_grid2world=None, + prealign=None, + ) + simplified.forward = new_forward + simplified.backward = new_backward + return simplified + + def gwi_odf(gqmodel, data): gqi_vector = np.real( squared_radial_component( @@ -152,7 +223,9 @@ def tensor_odf(evals, evecs, sphere, num_batches=100): return odf -def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, stat=np.mean): +def gaussian_weights( + bundle, assignment_idxs=None, n_points=100, return_mahalnobis=False, stat=np.mean +): """ Calculate weights for each streamline/node in a bundle, based on a Mahalanobis distance from the core the bundle, at that node (mean, per @@ -162,6 +235,8 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, stat=np.mean ---------- bundle : Streamlines The streamlines to weight. + assignment_idxs : array of shape (n_streamlines, n_points), optional + BUAN assignments, optional. n_points : int or None, optional The number of points to resample to. If this is None, we assume bundle is already resampled, and do not do any resampling. Default: 100. @@ -209,41 +284,51 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, stat=np.mean return weights / np.sum(weights, 0) else: weights = np.zeros((n_sls, n_nodes)) - diff = stat(sls, axis=0) - sls - for i in range(n_nodes): - # This should come back as a 3D covariance matrix with the spatial - # variance covariance of this node across the different streamlines, - # converted to a positive semi-definite matrix if necessary - cov = np.cov(sls[:, i, :].T, ddof=0) + + if assignment_idxs is None: + working_groups = np.tile(np.arange(n_nodes), (n_sls, 1)) + else: + working_groups = np.asarray(assignment_idxs) + + flat_coords = sls.reshape(-1, 3) + flat_groups = working_groups.reshape(-1) + unique_ids = np.unique(flat_groups) + + for gid in unique_ids: + mask = flat_groups == gid + group_data = flat_coords[mask] + + if len(group_data) < 15: + continue + + mu = stat(group_data, axis=0) + diff = group_data - mu + + cov = np.cov(group_data.T, ddof=0) + + # Ensure positive semi-definite if np.any(np.linalg.eigvals(cov) < 0): eigenvalues, eigenvectors = np.linalg.eigh((cov + cov.T) / 2) eigenvalues[eigenvalues < 0] = 0 cov = eigenvectors @ np.diag(eigenvalues) @ eigenvectors.T - # calculate Mahalanobis for node in every fiber if np.any(cov > 0): - weights[:, i] = np.sqrt( - np.einsum("ij,jk,ik->i", diff[:, i, :], pinvh(cov), diff[:, i, :]) + weights.ravel()[mask] = np.sqrt( + np.einsum("ij,jk,ik->i", diff, pinvh(cov), diff) ) - # In the special case where all the streamlines have the exact same - # coordinate in this node, the covariance matrix is all zeros, so - # we can't calculate the Mahalanobis distance, we will instead give - # each streamline an identical weight, equal to the number of - # streamlines: - else: - weights[:, i] = 0 if return_mahalnobis: return weights - # weighting is inverse to the distance (the further you are, the less you - # should be weighted) - weights = 1 / weights - # Normalize before returning, so that the weights in each node sum to 1: - return weights / np.sum(weights, 0) + with np.errstate(divide="ignore"): + w_inv = 1.0 / weights + w_inv[np.isinf(w_inv)] = 0 + denom = np.sum(w_inv, axis=0) + return np.divide(w_inv, denom, out=np.zeros_like(w_inv), where=denom != 0) -def make_gif(show_m, out_path, n_frames=36, az_ang=-10): + +def make_gif(show_m, out_path, n_frames=36, az_ang=-10, duration=150): """ Make a video from a Fury Show Manager. @@ -263,6 +348,10 @@ def make_gif(show_m, out_path, n_frames=36, az_ang=-10): The angle to rotate the camera around the z-axis for each frame, in degrees. Default: -10 + + duration : int + The duration of each frame in the output GIF, in milliseconds. + Default: 150 """ video = [] @@ -270,15 +359,49 @@ def make_gif(show_m, out_path, n_frames=36, az_ang=-10): show_m.window.draw() with tempfile.TemporaryDirectory() as tmp_dir: - for ii in tqdm(range(n_frames), desc="Generating GIF"): + for ii in tqdm(range(n_frames), desc="Generating GIF", leave=False): frame_fname = f"{tmp_dir}/{ii}.png" show_m.screens[0].controller.rotate((radians(az_ang), 0), None) show_m.render() show_m.window.draw() show_m.snapshot(frame_fname) - video.append(frame_fname) - - video = [Image.open(frame) for frame in video] - video[0].save( - out_path, save_all=True, append_images=video[1:], duration=300, loop=1 + video.append(Image.open(frame_fname).convert("RGB")) + + all_left, all_upper = float("inf"), float("inf") + all_right, all_lower = 0, 0 + + for img in video: + arr = np.array(img) + bg_color = arr[0, 0] + + mask = np.any(arr != bg_color, axis=-1) + + if np.any(mask): + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + ymin, ymax = np.where(rows)[0][[0, -1]] + xmin, xmax = np.where(cols)[0][[0, -1]] + + all_left = min(all_left, xmin) + all_upper = min(all_upper, ymin) + all_right = max(all_right, xmax) + all_lower = max(all_lower, ymax) + + if all_left < all_right: + crop_box = ( + max(0, all_left), + max(0, all_upper), + min(video[0].width, all_right), + min(video[0].height, all_lower), + ) + cropped_video = [img.crop(crop_box) for img in video] + else: + cropped_video = video + + cropped_video[0].save( + out_path, + save_all=True, + append_images=cropped_video[1:], + duration=duration, + loop=1, ) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 22c8e96e..2fbe5f42 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -10,7 +10,8 @@ from AFQ.definitions.utils import find_file from AFQ.tasks.utils import get_fname, str_to_desc -logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("AFQ") +logger.setLevel(logging.INFO) __all__ = [ @@ -118,13 +119,45 @@ def append_l_r(bundle_list, no_lr_list): DIPY_GH = "https://github.com/dipy/dipy/blob/master/dipy/" +def OR_bd(): + or_rois = afd.read_or_templates() + + return BundleDict( + { + "Left Optic Radiation": { + "include": [or_rois["left_OR_1"], or_rois["left_OR_2"]], + "exclude": [ + or_rois["left_OP_MNI"], + or_rois["left_TP_MNI"], + or_rois["left_pos_thal_MNI"], + ], + "start": or_rois["left_thal_MNI"], + "end": or_rois["left_V1_MNI"], + "cross_midline": False, + }, + "Right Optic Radiation": { + "include": [or_rois["right_OR_1"], or_rois["right_OR_2"]], + "exclude": [ + or_rois["right_OP_MNI"], + or_rois["right_TP_MNI"], + or_rois["right_pos_thal_MNI"], + ], + "start": or_rois["right_thal_MNI"], + "end": or_rois["right_V1_MNI"], + "cross_midline": False, + }, + }, + citations={"Caffarra2021"}, + ) + + def default_bd(): templates = afd.read_templates(as_img=False) templates["ARC_roi1_L"] = templates["SLF_roi1_L"] templates["ARC_roi1_R"] = templates["SLF_roi1_R"] templates["ARC_roi2_L"] = templates["SLFt_roi2_L"] templates["ARC_roi2_R"] = templates["SLFt_roi2_R"] - return BundleDict( + return OR_bd() + BundleDict( { "Left Anterior Thalamic": { "cross_midline": False, @@ -134,6 +167,7 @@ def default_bd(): "prob_map": templates["ATR_L_prob_map"], "start": templates["ATR_L_start"], "end": templates["ATR_L_end"], + "length": {"min_len": 30}, }, "Right Anterior Thalamic": { "cross_midline": False, @@ -143,6 +177,7 @@ def default_bd(): "prob_map": templates["ATR_R_prob_map"], "start": templates["ATR_R_start"], "end": templates["ATR_R_end"], + "length": {"min_len": 30}, }, "Left Cingulum Cingulate": { "cross_midline": False, @@ -151,6 +186,7 @@ def default_bd(): "space": "template", "prob_map": templates["CGC_L_prob_map"], "end": templates["CGC_L_start"], + "length": {"min_len": 30}, }, "Right Cingulum Cingulate": { "cross_midline": False, @@ -159,6 +195,7 @@ def default_bd(): "space": "template", "prob_map": templates["CGC_R_prob_map"], "end": templates["CGC_R_start"], + "length": {"min_len": 30}, }, "Left Corticospinal": { "cross_midline": False, @@ -167,6 +204,7 @@ def default_bd(): "space": "template", "prob_map": templates["CST_L_prob_map"], "end": templates["CST_L_start"], + "length": {"min_len": 40}, }, "Right Corticospinal": { "cross_midline": False, @@ -175,24 +213,27 @@ def default_bd(): "space": "template", "prob_map": templates["CST_R_prob_map"], "end": templates["CST_R_start"], + "length": {"min_len": 40}, }, "Left Inferior Fronto-occipital": { "cross_midline": False, "include": [templates["IFO_roi2_L"], templates["IFO_roi1_L"]], - "exclude": [templates["ARC_roi1_L"]], + "exclude": [templates["ARC_roi1_L"], templates["CGC_roi1_L"]], "space": "template", "prob_map": templates["IFO_L_prob_map"], "end": templates["IFO_L_start"], "start": templates["IFO_L_end"], + "length": {"min_len": 80}, }, "Right Inferior Fronto-occipital": { "cross_midline": False, "include": [templates["IFO_roi2_R"], templates["IFO_roi1_R"]], - "exclude": [templates["ARC_roi1_R"]], + "exclude": [templates["ARC_roi1_R"], templates["CGC_roi1_R"]], "space": "template", "prob_map": templates["IFO_R_prob_map"], "end": templates["IFO_R_start"], "start": templates["IFO_R_end"], + "length": {"min_len": 80}, }, "Left Inferior Longitudinal": { "cross_midline": False, @@ -202,6 +243,7 @@ def default_bd(): "prob_map": templates["ILF_L_prob_map"], "start": templates["ILF_L_end"], "end": templates["ILF_L_start"], + "length": {"min_len": 40}, }, "Right Inferior Longitudinal": { "cross_midline": False, @@ -211,24 +253,27 @@ def default_bd(): "prob_map": templates["ILF_R_prob_map"], "start": templates["ILF_R_end"], "end": templates["ILF_R_start"], + "length": {"min_len": 40}, }, "Left Arcuate": { "cross_midline": False, "include": [templates["SLF_roi1_L"], templates["SLFt_roi2_L"]], - "exclude": [], + "exclude": [templates["IFO_roi1_L"]], "space": "template", "prob_map": templates["ARC_L_prob_map"], "start": templates["ARC_L_start"], "end": templates["ARC_L_end"], + "length": {"min_len": 40}, }, "Right Arcuate": { "cross_midline": False, "include": [templates["SLF_roi1_R"], templates["SLFt_roi2_R"]], - "exclude": [], + "exclude": [templates["IFO_roi1_R"]], "space": "template", "prob_map": templates["ARC_R_prob_map"], "start": templates["ARC_R_start"], "end": templates["ARC_R_end"], + "length": {"min_len": 40}, }, "Left Uncinate": { "cross_midline": False, @@ -251,122 +296,253 @@ def default_bd(): "Left Posterior Arcuate": { "cross_midline": False, "include": [templates["SLFt_roi2_L"]], - "exclude": [templates["SLF_roi1_L"]], + "exclude": [ + templates["SLF_roi1_L"], + templates["IFO_roi1_L"], + templates["pARC_xroi1_L"], + ], "space": "template", "start": templates["pARC_L_start"], + "end": templates["pARC_L_end"], "Left Arcuate": {"overlap": 30}, + "Left Inferior Fronto-occipital": {"core": "Right"}, + "Left Optic Radiation": {"core": "Right"}, + "length": {"min_len": 30}, "primary_axis": "I/S", - "primary_axis_percentage": 40, }, "Right Posterior Arcuate": { "cross_midline": False, "include": [templates["SLFt_roi2_R"]], - "exclude": [templates["SLF_roi1_R"]], + "exclude": [ + templates["SLF_roi1_R"], + templates["IFO_roi1_R"], + templates["pARC_xroi1_R"], + ], "space": "template", "start": templates["pARC_R_start"], + "end": templates["pARC_R_end"], "Right Arcuate": {"overlap": 30}, + "Right Inferior Fronto-occipital": {"core": "Left"}, + "Right Optic Radiation": {"core": "Left"}, + "length": {"min_len": 30}, "primary_axis": "I/S", - "primary_axis_percentage": 40, }, "Left Vertical Occipital": { "cross_midline": False, "space": "template", + "prob_map": templates["VOF_L_prob_map"], "end": templates["VOF_L_end"], - "Left Arcuate": {"node_thresh": 20}, + "include": [templates["VOF_roi1_L"], templates["VOF_roi2_L"]], + "exclude": [ + templates["Cerebellar_Hemi_L"], + ], + "Left Arcuate": {"node_thresh": 20, "project": "L/R"}, "Left Posterior Arcuate": { "node_thresh": 20, - "entire_core": "Anterior", + "project": "L/R", + "core": "Anterior", }, - "Left Inferior Fronto-occipital": {"core": "Right"}, - "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, - "length": {"min_len": 25}, - "isolation_forest": {}, + "Left Optic Radiation": {"core": "Right"}, + "length": {"min_len": 30}, + "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "ORG_spectral_subbundles": SpectralSubbundleDict( + { + "Left V1V3": { + "cluster_IDs": [61, 63, 77, 82], + "orient_mahal": { + "distance_threshold": 3, + "length_threshold": 2, + "clean_rounds": 2, + "remove_lengths": "short", + }, + }, + "Left Posterior Vertical Occipital": { + "cluster_IDs": [1, 72, 81, 83], + "Left Inferior Fronto-occipital": {"core": "Right"}, + "orient_mahal": { + "distance_threshold": 4, + "length_threshold": 2, + "clean_rounds": 10, + "remove_lengths": "short", + }, + }, + "Left Anterior Vertical Occipital": { + "cluster_IDs": [2, 7, 18, 21, 25, 51], + "Left Inferior Fronto-occipital": {"core": "Right"}, + "exclude": [ + templates["pARC_xroi1_L"], + templates["pVOF_xroi1_L"], + ], + "orient_mahal": { + "distance_threshold": 4, + "length_threshold": 2, + "clean_rounds": 10, + "remove_lengths": "short", + }, + }, + }, + ), }, "Right Vertical Occipital": { "cross_midline": False, "space": "template", + "prob_map": templates["VOF_R_prob_map"], "end": templates["VOF_R_end"], - "Right Arcuate": {"node_thresh": 20}, + "include": [templates["VOF_roi1_R"], templates["VOF_roi2_R"]], + "exclude": [ + templates["Cerebellar_Hemi_R"], + ], + "Right Arcuate": {"node_thresh": 20, "project": "L/R"}, "Right Posterior Arcuate": { "node_thresh": 20, - "entire_core": "Anterior", + "project": "L/R", + "core": "Anterior", }, - "Right Inferior Fronto-occipital": {"core": "Left"}, - "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, - "length": {"min_len": 25}, - "isolation_forest": {}, + "Right Optic Radiation": {"core": "Left"}, + "length": {"min_len": 30}, + "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "ORG_spectral_subbundles": SpectralSubbundleDict( + { + "Right V1V3": { + "cluster_IDs": [61, 63, 77, 82], + "orient_mahal": { + "distance_threshold": 3, + "length_threshold": 2, + "clean_rounds": 2, + "remove_lengths": "short", + }, + }, + "Right Posterior Vertical Occipital": { + "cluster_IDs": [1, 72, 81, 83], + "Right Inferior Fronto-occipital": {"core": "Left"}, + "orient_mahal": { + "distance_threshold": 4, + "length_threshold": 2, + "clean_rounds": 10, + "remove_lengths": "short", + }, + }, + "Right Anterior Vertical Occipital": { + "cluster_IDs": [2, 7, 18, 21, 25, 51], + "Right Inferior Fronto-occipital": {"core": "Left"}, + "exclude": [ + templates["pARC_xroi1_R"], + templates["pVOF_xroi1_R"], + ], + "orient_mahal": { + "distance_threshold": 4, + "length_threshold": 2, + "clean_rounds": 10, + "remove_lengths": "short", + }, + }, + }, + ), + }, + }, + citations={ + "Yeatman2012", + "takemura2016major", + "Tzourio-Mazoyer2002", + "zhang2018anatomically", + "Hua2008", + }, + ) + + +def mdlf_bd(): + """ + Work in Progress. + """ + templates = afd.read_templates(as_img=False) + return default_bd() + BundleDict( + { + "Left Middle Longitudinal": { + "cross_midline": False, + "start": templates["Temporal_Sup_L"], + "end": templates["MdLF_L_end"], + "exclude": [templates["SLF_roi1_L"]], + "space": "template", + "Left Inferior Longitudinal": {"node_thresh": 20}, + "length": {"min_len": 50}, }, + "Right Middle Longitudinal": { + "cross_midline": False, + "start": templates["Temporal_Sup_R"], + "end": templates["MdLF_R_end"], + "exclude": [templates["SLF_roi1_R"]], + "space": "template", + "Right Inferior Longitudinal": {"node_thresh": 20}, + "length": {"min_len": 50}, + }, + }, + citations={ + "wang2013rethinking", }, - citations={"Yeatman2012", "takemura2017occipital"}, ) def slf_bd(): templates = afd.read_slf_templates(as_img=False) + templates_afq = afd.read_templates(as_img=False) + templates["Frontal_Lobe_L"] = templates_afq["ATR_L_start"] + templates["Frontal_Lobe_R"] = templates_afq["ATR_R_start"] return BundleDict( { "Left Superior Longitudinal I": { "include": [templates["SFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], + "start": templates["Frontal_Lobe_L"], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, + "Left Cingulum Cingulate": { + "node_thresh": 20, }, }, "Left Superior Longitudinal II": { "include": [templates["MFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], + "start": templates["Frontal_Lobe_L"], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, + "Left Cingulum Cingulate": { + "node_thresh": 20, }, }, "Left Superior Longitudinal III": { "include": [templates["PrgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], + "start": templates["Frontal_Lobe_L"], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, + "Left Cingulum Cingulate": { + "node_thresh": 20, }, }, "Right Superior Longitudinal I": { "include": [templates["SFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], + "start": templates["Frontal_Lobe_R"], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, + "Right Cingulum Cingulate": { + "node_thresh": 20, }, }, "Right Superior Longitudinal II": { "include": [templates["MFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], + "start": templates["Frontal_Lobe_R"], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, + "Right Cingulum Cingulate": { + "node_thresh": 20, }, }, "Right Superior Longitudinal III": { "include": [templates["PrgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], + "start": templates["Frontal_Lobe_R"], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, + "Right Cingulum Cingulate": { + "node_thresh": 20, }, }, }, @@ -894,38 +1070,6 @@ def cerebellar_bd(): ) -def OR_bd(): - or_rois = afd.read_or_templates() - - return BundleDict( - { - "Left Optic Radiation": { - "include": [or_rois["left_OR_1"], or_rois["left_OR_2"]], - "exclude": [ - or_rois["left_OP_MNI"], - or_rois["left_TP_MNI"], - or_rois["left_pos_thal_MNI"], - ], - "start": or_rois["left_thal_MNI"], - "end": or_rois["left_V1_MNI"], - "cross_midline": False, - }, - "Right Optic Radiation": { - "include": [or_rois["right_OR_1"], or_rois["right_OR_2"]], - "exclude": [ - or_rois["right_OP_MNI"], - or_rois["right_TP_MNI"], - or_rois["right_pos_thal_MNI"], - ], - "start": or_rois["right_thal_MNI"], - "end": or_rois["right_V1_MNI"], - "cross_midline": False, - }, - }, - citations={"Caffarra2021"}, - ) - - class _BundleEntry(Mapping): """Describes how to recognize a single bundle, immutable""" @@ -1052,7 +1196,6 @@ def __init__( self.resample_to = resample_to self.resample_subject_to = resample_subject_to self.keep_in_memory = keep_in_memory - self.max_includes = 3 self.citations = citations if self.citations is None: self.citations = set() @@ -1106,12 +1249,8 @@ def __init__( def __print__(self): print(self._dict) - def update_max_includes(self, new_max): - if new_max > self.max_includes: - self.max_includes = new_max - def _use_bids_info(self, roi_or_sl, bids_layout, bids_path, subject, session): - if isinstance(roi_or_sl, dict): + if isinstance(roi_or_sl, dict) and "roi" not in roi_or_sl: suffix = roi_or_sl.get("suffix", "dwi") roi_or_sl = find_file( bids_layout, bids_path, roi_or_sl, suffix, session, subject @@ -1124,6 +1263,41 @@ def _cond_load(self, roi_or_sl, resample_to): """ Load ROI or streamline if not already loaded """ + if isinstance(roi_or_sl, dict): + space = roi_or_sl.get("space", None) + roi_or_sl = roi_or_sl.get("roi", None) + if roi_or_sl is None or space is None: + raise ValueError( + ( + f"Unclear ROI definition for {roi_or_sl}. " + "See 'Defining Custom Bundle Dictionaries' " + "in the documentation for details." + ) + ) + if space == "template": + resample_to = self.resample_to + elif space == "subject": + resample_to = self.resample_subject_to + if resample_to is False: + raise ValueError( + ( + "When using mixed ROI bundle definitions, " + "and subject space ROIs, " + "resample_subject_to cannot be False." + ) + ) + else: + raise ValueError( + ( + f"Unknown space {space} for ROI definition {roi_or_sl}. " + "See 'Defining Custom Bundle Dictionaries' " + "in the documentation for details." + ) + ) + + logger.debug(f"Loading ROI or streamlines: {roi_or_sl}") + logger.debug(f"Loading ROI or streamlines from space: {resample_to}") + if isinstance(roi_or_sl, str): if ".nii" in roi_or_sl: return afd.read_resample_roi(roi_or_sl, resample_to=resample_to) @@ -1139,6 +1313,31 @@ def _cond_load(self, roi_or_sl, resample_to): def get_b_info(self, b_name): return self._dict[b_name] + def relax_cleaning(self, delta_distance=1, delta_length=1): + """ + This can be useful for PTT + """ + cleaner_keys = ["mahal", "isolation_forest", "orient_mahal"] + + for b_name in self.bundle_names: + bundle_data = self._dict[b_name] + + for key in cleaner_keys: + if key in bundle_data: + target = bundle_data[key] + if ( + "distance_threshold" in target + and target["distance_threshold"] != 0 + ): + target["distance_threshold"] += delta_distance + if "length_threshold" in target and target["length_threshold"] != 0: + target["length_threshold"] += delta_length + + if "ORG_spectral_subbundles" in bundle_data: + bundle_data["ORG_spectral_subbundles"].relax_cleaning( + delta_distance, delta_length + ) + def __getitem__(self, key): if isinstance(key, tuple) or isinstance(key, list): # Generates a copy of this BundleDict with only the bundle names @@ -1177,8 +1376,6 @@ def __getitem__(self, key): def __setitem__(self, key, item): self._dict[key] = item - if hasattr(item, "get"): - self.update_max_includes(len(item.get("include", []))) if key not in self.bundle_names: self.bundle_names.append(key) @@ -1261,24 +1458,31 @@ def is_bundle_in_template(self, bundle_name): return ( "space" not in self._dict[bundle_name] or self._dict[bundle_name]["space"] == "template" + or self._dict[bundle_name]["space"] == "mixed" ) - def _roi_transform_helper(self, roi_or_sl, mapping, new_affine, bundle_name): + def _roi_transform_helper(self, roi_or_sl, mapping, new_img, bundle_name): roi_or_sl = self._cond_load(roi_or_sl, self.resample_to) if isinstance(roi_or_sl, nib.Nifti1Image): + if ( + np.allclose(roi_or_sl.affine, new_img.affine) + and roi_or_sl.shape == new_img.shape[:3] + ): + # This is the case of a mixed bundle definition, where + # some ROIs need transformed and others do not + return roi_or_sl + fdata = roi_or_sl.get_fdata() if len(np.unique(fdata)) <= 2: boolean_ = True else: boolean_ = False - warped_img = auv.transform_inverse_roi( - fdata, mapping, bundle_name=bundle_name - ) + warped_img = auv.transform_roi(fdata, mapping, bundle_name=bundle_name) if boolean_: warped_img = warped_img.astype(np.uint8) - warped_img = nib.Nifti1Image(warped_img, new_affine) + warped_img = nib.Nifti1Image(warped_img, new_img.affine) return warped_img else: return roi_or_sl @@ -1287,7 +1491,7 @@ def transform_rois( self, bundle_name, mapping, - new_affine, + new_img, base_fname=None, to_space="subject", apply_to_recobundles=False, @@ -1306,8 +1510,8 @@ def transform_rois( Name of the bundle to be transformed. mapping : DiffeomorphicMap object A mapping between DWI space and a template. - new_affine : array - Affine of space transformed into. + new_img : Nifti1Image + Image of space transformed into. base_fname : str, optional Base file path to construct file path from. Additional BIDS descriptors will be added to this file path. If None, @@ -1333,7 +1537,7 @@ def transform_rois( bundle_name, self._roi_transform_helper, mapping, - new_affine, + new_img, bundle_name, dry_run=True, apply_to_recobundles=apply_to_recobundles, @@ -1382,7 +1586,9 @@ def transform_rois( def __add__(self, other): for resample in ["resample_to", "resample_subject_to"]: - if ( + if getattr(self, resample) == getattr(other, resample): + pass + elif ( not getattr(self, resample) or not getattr(other, resample) or getattr(self, resample) is None @@ -1429,6 +1635,41 @@ def __add__(self, other): ) +class SpectralSubbundleDict(BundleDict): + """ + A BundleDict where each bundle is defined as a spectral subbundle of a + larger bundle. See `Defining Custom Bundle Dictionaries` in the documentation + for details. + """ + + def __init__( + self, + bundle_info, + resample_to=None, + resample_subject_to=False, + keep_in_memory=False, + citations=None, + remove_cluster_IDs=None, + ): + super().__init__( + bundle_info, resample_to, resample_subject_to, keep_in_memory, citations + ) + if remove_cluster_IDs is None: + remove_cluster_IDs = [] + self.remove_cluster_IDs = remove_cluster_IDs + self.cluster_IDs = [] + for b_name, b_info in bundle_info.items(): + if "cluster_IDs" not in b_info: + raise ValueError( + ( + f"Bundle {b_name} does not have cluster_IDs. " + "All bundles in a SpectralSubbundleDict must have cluster_IDs." + ) + ) + self.cluster_IDs.extend(b_info["cluster_IDs"]) + self.all_cluster_IDs = self.remove_cluster_IDs + self.cluster_IDs + + def apply_to_roi_dict( dict_, func, diff --git a/AFQ/api/group.py b/AFQ/api/group.py index 2c225a9d..ea517031 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -8,7 +8,6 @@ import warnings from time import time -import dipy.tracking.streamline as dts import dipy.tracking.streamlinespeed as dps import nibabel as nib import numpy as np @@ -49,9 +48,11 @@ __all__ = ["GroupAFQ"] +logging.basicConfig(level=logging.INFO) logger = logging.getLogger("AFQ") logger.setLevel(logging.INFO) + warnings.simplefilter(action="ignore", category=FutureWarning) @@ -553,9 +554,12 @@ def load_next_subject(): this_bundles_file = self.export("bundles", collapse=False)[sub][ses] this_mapping = self.export("mapping", collapse=False)[sub][ses] this_img = self.export("dwi", collapse=False)[sub][ses] + this_reg_template = self.export("reg_template", collapse=False)[sub][ + ses + ] seg_sft = aus.SegmentedSFT.fromfile(this_bundles_file, this_img) seg_sft.sft.to_rasmm() - subses_info.append((seg_sft, this_mapping)) + subses_info.append((seg_sft, this_mapping, this_img, this_reg_template)) bundle_dict = self.export("bundle_dict", collapse=False)[ self.valid_sub_list[0] @@ -565,7 +569,7 @@ def load_next_subject(): load_next_subject() # load first subject for b in bundle_dict.bundle_names: for i in range(len(self.valid_sub_list)): - seg_sft, mapping = subses_info[i] + seg_sft, mapping, img, reg_template = subses_info[i] idx = seg_sft.bundle_idxs[b] # use the first subses that works # otherwise try each successive subses @@ -581,14 +585,11 @@ def load_next_subject(): idx = np.random.choice(idx, size=100, replace=False) these_sls = seg_sft.sft.streamlines[idx] these_sls = dps.set_number_of_points(these_sls, 100) - tg = StatefulTractogram(these_sls, seg_sft.sft, Space.RASMM) - delta = dts.values_from_volume( - mapping.forward, tg.streamlines, np.eye(4) - ) - moved_sl = dts.Streamlines( - [d + s for d, s in zip(delta, tg.streamlines)] + tg = StatefulTractogram(these_sls, img, Space.RASMM) + moved_sl = aus.move_streamlines( + tg, "template", mapping, reg_template ) - moved_sl = np.asarray(moved_sl) + moved_sl = np.asarray(moved_sl.streamlines) median_sl = np.median(moved_sl, axis=0) sls_dict[b] = {"coreFiber": median_sl.tolist()} for ii, sl_idx in enumerate(idx): @@ -786,7 +787,7 @@ def cmd_outputs( clobber = cmd_outputs # alias for default of cmd_outputs - def make_all_participant_montages(self, images_per_row=2): + def make_all_participant_montages(self, images_per_row=3): """ Generate montage of all bundles for a all subjects. @@ -794,7 +795,7 @@ def make_all_participant_montages(self, images_per_row=2): ---------- images_per_row : int Number of bundle images per row in output file. - Default: 2 + Default: 3 Returns ------- @@ -1019,15 +1020,17 @@ def combine_bundle(self, bundle_name): this_sub = self.valid_sub_list[ii] this_ses = self.valid_ses_list[ii] seg_sft = aus.SegmentedSFT.fromfile(bundles_dict[this_sub][this_ses]) - seg_sft.sft.to_vox() - sls = seg_sft.get_bundle(bundle_name).streamlines + sls = seg_sft.get_bundle(bundle_name) mapping = mapping_dict[this_sub][this_ses] if len(sls) > 0: - delta = dts.values_from_volume(mapping.forward, sls, np.eye(4)) - sls_mni.extend([d + s for d, s in zip(delta, sls)]) + sls_mni.extend( + aus.move_streamlines( + sls, "template", mapping, reg_template + ).streamlines + ) - moved_sft = StatefulTractogram(sls_mni, reg_template, Space.VOX) + moved_sft = StatefulTractogram(sls_mni, reg_template, Space.RASMM) save_path = op.abspath( op.join(self.afq_path, f"bundle-{bundle_name}_subjects-all_MNI.trk") diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 827b38b5..dc0fb525 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -2,12 +2,15 @@ import math import os.path as op import tempfile +from math import radians from time import time import nibabel as nib +from dipy.align import resample from PIL import Image, ImageDraw, ImageFont from tqdm import tqdm +import AFQ.utils.streamlines as aus from AFQ.api.utils import ( AFQclass_doc, check_attribute, @@ -33,6 +36,11 @@ __all__ = ["ParticipantAFQ"] +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("AFQ") +logger.setLevel(logging.INFO) + + class ParticipantAFQ(object): f"""{AFQclass_doc}""" @@ -157,9 +165,6 @@ def make_workflow(self): plan_kwargs[key] = self.kwargs[key] elif key in previous_plans: plan_kwargs[key] = previous_plans[key] - elif name not in ["data", "structural"] and key == "dwi_affine": - # simplifies syntax to access commonly used dwi_affine - plan_kwargs[key] = previous_plans["data_imap"][key] else: raise NotImplementedError( f"Missing required parameter {key} for {name} plan" @@ -252,7 +257,7 @@ def export_all(self, viz=True, xforms=True, indiv=True): export_all_helper(self, xforms, indiv, viz) self.logger.info(f"Time taken for export all: {time() - start_time}") - def participant_montage(self, images_per_row=2): + def participant_montage(self, images_per_row=3, anatomy=True, bundle_names=None): """ Generate montage of all bundles for a given subject. @@ -260,7 +265,15 @@ def participant_montage(self, images_per_row=2): ---------- images_per_row : int Number of bundle images per row in output file. - Default: 2 + Default: 3 + + anatomy : bool + Whether to include anatomical images in the montage. + Default: True + + bundle_names : list of str or None + List of bundle names to include in the montage. + Default: None (includes all bundles) Returns ------- @@ -269,62 +282,78 @@ def participant_montage(self, images_per_row=2): tdir = tempfile.gettempdir() all_fnames = [] - bundle_dict = self.export("bundle_dict") + seg_sft = aus.SegmentedSFT.fromfile(self.export("bundles")) + if bundle_names is None: + bundle_names = list(seg_sft.bundle_names) self.logger.info("Generating Montage...") viz_backend = self.export("viz_backend") - best_scalar = self.export(self.export("best_scalar")) t1 = nib.load(self.export("t1_masked")) - size = (images_per_row, math.ceil(len(bundle_dict) / images_per_row)) - for ii, bundle_name in enumerate(tqdm(bundle_dict)): + best_scalar = nib.load(self.export(self.kwargs["best_scalar"])) + best_scalar = resample(best_scalar, t1) + size = (images_per_row, math.ceil(3 * len(bundle_names) / images_per_row)) + for ii, bundle_name in enumerate(tqdm(bundle_names)): flip_axes = [False, False, False] for i in range(3): flip_axes[i] = self.export("dwi_affine")[i, i] < 0 - figure = viz_backend.visualize_volume( - t1, flip_axes=flip_axes, interact=False, inline=False - ) + if anatomy: + figure = viz_backend.visualize_volume( + t1.get_fdata(), flip_axes=flip_axes, interact=False, inline=False + ) + else: + figure = None figure = viz_backend.visualize_bundles( - self.export("bundles"), - affine=t1.affine, - shade_by_volume=best_scalar, + seg_sft, + img=t1, + shade_by_volume=best_scalar.get_fdata(), color_by_direction=True, flip_axes=flip_axes, bundle=bundle_name, figure=figure, + n_points=40, interact=False, inline=False, ) - view, direc = BEST_BUNDLE_ORIENTATIONS.get(bundle_name, ("Axial", "Top")) - eye = get_eye(view, direc) - - this_fname = tdir + f"/t{ii}.png" - if "plotly" in viz_backend.backend: - figure.update_layout( - scene_camera=dict( - projection=dict(type="orthographic"), - up={"x": 0, "y": 0, "z": 1}, - eye=eye, - center=dict(x=0, y=0, z=0), - ), - showlegend=False, - ) - figure.write_image(this_fname, scale=4) - - # temporary fix for memory leak - import plotly.io as pio - - pio.kaleido.scope._shutdown_kaleido() - else: - from fury import window + for jj, view in enumerate(["Sagittal", "Coronal", "Axial"]): + direc = BEST_BUNDLE_ORIENTATIONS.get( + bundle_name, ("Left", "Front", "Top") + )[jj] + + eye = get_eye(view, direc) + + this_fname = tdir + f"/t{ii}_{view}.png" + if "plotly" in viz_backend.backend: + figure.update_layout( + scene_camera=dict( + projection=dict(type="orthographic"), + up={"x": 0, "y": 0, "z": 1}, + eye=eye, + center=dict(x=0, y=0, z=0), + ), + showlegend=False, + ) + figure.write_image(this_fname, scale=4) + # temporary fix for memory leak + import plotly.io as pio - from AFQ.viz.fury_backend import scene_rotate_forward + pio.kaleido.scope._shutdown_kaleido() + else: + from fury import window - show_m = window.ShowManager( - scene=figure, window_type="offscreen", size=(600, 600) - ) - scene_rotate_forward(show_m, figure) - show_m.snapshot(this_fname) + show_m = window.ShowManager( + scene=figure, window_type="offscreen", size=(600, 600) + ) + window.update_camera(show_m.screens[0].camera, None, figure) + if view == "Coronal": + show_m.screens[0].controller.rotate((0, radians(-90)), None) + elif view == "Axial": + show_m.screens[0].controller.rotate((radians(90), 0), None) + elif view == "Sagittal": + pass + show_m.render() + show_m.window.draw() + show_m.snapshot(this_fname) def _save_file(curr_img): save_path = op.abspath( @@ -336,45 +365,47 @@ def _save_file(curr_img): this_img_trimmed = {} max_height = 0 max_width = 0 - for ii, bundle_name in enumerate(bundle_dict): - this_img = Image.open(tdir + f"/t{ii}.png") - try: - this_img_trimmed[ii] = trim(this_img) - except IndexError: # this_img is a picture of nothing - this_img_trimmed[ii] = this_img - - text_sz = 70 - width, height = this_img_trimmed[ii].size - height = height + text_sz - result = Image.new( - this_img_trimmed[ii].mode, (width, height), color=(255, 255, 255) - ) - result.paste(this_img_trimmed[ii], (0, text_sz)) - this_img_trimmed[ii] = result - - draw = ImageDraw.Draw(this_img_trimmed[ii]) - draw.text( - (0, 0), - bundle_name, - (0, 0, 0), - font=ImageFont.truetype("Arial", text_sz), - ) + ii = 0 + for b_idx, bundle_name in enumerate(bundle_names): + for view in ["Axial", "Coronal", "Sagittal"]: + this_img = Image.open(tdir + f"/t{b_idx}_{view}.png") + try: + this_img_trimmed[ii] = trim(this_img) + except IndexError: # this_img is a picture of nothing + this_img_trimmed[ii] = this_img + + text_sz = 40 + width, height = this_img_trimmed[ii].size + height = height + text_sz + result = Image.new( + this_img_trimmed[ii].mode, (width, height), color=(255, 255, 255) + ) + result.paste(this_img_trimmed[ii], (0, text_sz)) + this_img_trimmed[ii] = result + + draw = ImageDraw.Draw(this_img_trimmed[ii]) + draw.text( + (0, 0), + f"{bundle_name} - {view}", + (0, 0, 0), + font=ImageFont.load_default(text_sz), + ) - if this_img_trimmed[ii].size[0] > max_width: - max_width = this_img_trimmed[ii].size[0] - if this_img_trimmed[ii].size[1] > max_height: - max_height = this_img_trimmed[ii].size[1] + if this_img_trimmed[ii].size[0] > max_width: + max_width = this_img_trimmed[ii].size[0] + if this_img_trimmed[ii].size[1] > max_height: + max_height = this_img_trimmed[ii].size[1] + ii += 1 curr_img = Image.new( "RGB", (max_width * size[0], max_height * size[1]), color="white" ) - for ii in range(len(bundle_dict)): - x_pos = ii % size[0] - _ii = ii // size[0] + for jj in range(ii): + x_pos = jj % size[0] + _ii = jj // size[0] y_pos = _ii % size[1] - _ii = _ii // size[1] - this_img = this_img_trimmed[ii].resize((max_width, max_height)) + this_img = this_img_trimmed[jj].resize((max_width, max_height)) curr_img.paste(this_img, (x_pos * max_width, y_pos * max_height)) _save_file(curr_img) diff --git a/AFQ/api/utils.py b/AFQ/api/utils.py index 0c015b49..4dd83930 100644 --- a/AFQ/api/utils.py +++ b/AFQ/api/utils.py @@ -172,7 +172,7 @@ def export_all_helper(api_afq_object, xforms, indiv, viz): api_afq_object.export("indiv_bundles") api_afq_object.export("rois") api_afq_object.export("sl_counts") - api_afq_object.export("median_bundle_lengths") + api_afq_object.export("bundle_lengths") api_afq_object.export("profiles") if viz: diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index 9400509e..90f71fcc 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -24,7 +24,9 @@ from dipy.segment.metric import AveragePointwiseEuclideanMetric from tqdm import tqdm +from AFQ._fixes import get_simplified_transform from AFQ.data.utils import aws_import_msg_error +from AFQ.registration import read_old_mapping from AFQ.utils.path import apply_cmd_to_afq_derivs, drop_extension # capture templateflow resource warning and log @@ -719,10 +721,6 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "UNC_R_prob_map.nii.gz", "ARC_L_prob_map.nii.gz", "ARC_R_prob_map.nii.gz", - "VOF_R_end.nii.gz", - "VOF_R_start.nii.gz", - "VOF_L_end.nii.gz", - "VOF_L_start.nii.gz", "pARC_R_start.nii.gz", "pARC_L_start.nii.gz", "ARC_R_end.nii.gz", @@ -759,6 +757,26 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "ATR_R_start.nii.gz", "ATR_L_end.nii.gz", "ATR_L_start.nii.gz", + "pARC_xroi1_L.nii.gz", + "pARC_xroi1_R.nii.gz", + "Cerebellar_Hemi_L.nii.gz", + "Cerebellar_Hemi_R.nii.gz", + "VOF_L_end.nii.gz", + "VOF_R_end.nii.gz", + "VOF_L_prob_map.nii.gz", + "VOF_R_prob_map.nii.gz", + "VOF_roi1_L.nii.gz", + "VOF_roi1_R.nii.gz", + "VOF_roi2_L.nii.gz", + "VOF_roi2_R.nii.gz", + "Temporal_Sup_L.nii.gz", + "Temporal_Sup_R.nii.gz", + "pARC_L_end.nii.gz", + "pARC_R_end.nii.gz", + "MdLF_L_end.nii.gz", + "MdLF_R_end.nii.gz", + "pVOF_xroi1_R.nii.gz", + "pVOF_xroi1_L.nii.gz", ] @@ -821,10 +839,6 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "11458229", "11458232", "11458235", - "40943957", - "40943960", - "40943966", - "40943969", "40943972", "40943975", "40943978", @@ -861,6 +875,26 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "40944074", "40944077", "40944080", + "61737616", + "61737619", + "61970155", + "61970158", + "62134578", + "62134581", + "62031442", + "62031445", + "62213933", + "62213936", + "62213939", + "62213942", + "62282968", + "62282971", + "62316400", + "62316403", + "62283226", + "62283229", + "62316817", + "62316820", ] @@ -924,10 +958,6 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "19590c712f1776da1fdba64d4eb7f1f6", "04d5af0feb2c1b5b52a87ccbbf148e4b", "53c277be990d00f7de04f2ea35e74d73", - "d37d815fd1bdaaf3a9d2dcfc3ccb1345", - "95ed3189d8ac152945e6be1eb24381a3", - "a9007e6f2d6ae13ef182f65057c06573", - "c6eb9ee33b7caf691749e266f89e8ec4", "a06b2e2e52c09a601f683dc39859a7f1", "bee876a34fdb03e69a418b791f90975a", "680749c9e4565bc02492019d57d8e7d7", @@ -964,6 +994,26 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "ffc157e9f73a43eff23821f2cfca614a", "a8d308a93b26242c04b878c733cb252f", "1c0b570bb2d622718b01ee2c429a5d15", + "51c8a6b5fbb0834b03986093b9ee4fa3", + "7cf5800a4efa6bac7e70d84095bc259b", + "f65b3f9133820921d023517a68d4ea41", + "4476935f5aadfcdd633b9a23779625ef", + "11ba79ff1f9a01c6b064428323d01013", + "84df5abfefbed5e3e310f2db0b62fcea", + "db5bd2d1e810e366f5ef67a9cce205c2", + "6891cfc038ce7db21e0cc307ae2b1b37", + "ad5407fa6c058c9317a5ba51e5e188bf", + "5c52e20d74608da784ee874d23322385", + "76baf26294c8430afabcdc9a6d756b12", + "eff69ba30619bbf2ff500cd714f07894", + "088e850d38fed2d62fc768071d42a43d", + "a3d249535ca0452ce9f59c754e13695a", + "df8a7480c507e91976c5a82d5826d521", + "243ddae33bf84e7b24da4d1d9f90a121", + "827b21f9069cd0192ede3f7153f1ca80", + "e2912622d36db6723c48a0a1de887807", + "9ebb676cf213605063934981c7c8fe9e", + "cc30b09f5e4c084675380654fc691b95", ] fetch_templates = _make_reusable_fetcher( @@ -1089,6 +1139,174 @@ def read_oton_templates(as_img=True, resample_to=False): return template_dict +org800_fnames = [ + "ORG_atlas_tracks_reoriented.trx", + "ORG800_atlas_centroids.npy", + "ORG800_atlas_e_val.npy", + "ORG800_atlas_e_vec_norm.npy", + "ORG800_atlas_e_vec.npy", + "ORG800_atlas_number_of_eigenvectors.npy", + "ORG800_atlas_row_sum_1.npy", + "ORG800_atlas_row_sum_matrix.npy", + "ORG800_atlas_sigma.npy", +] + + +org800_remote_fnames = [ + "61762231", + "61762267", + "61762270", + "61762273", + "61762276", + "61762279", + "61762282", + "61762285", + "61762288", +] + + +org800_md5_hashes = [ + "9022799a73359209080ea832b22ec09b", + "09bfa384f5c44801dfa382d31392a979", + "bab61eb26cb21035e38b5f68b5fdad3e", + "9325f4cb168624d4f275785b18c9f859", + "12d426c5a6fcfbe3b8146bc335bdac96", + "db74e055c47c5b6354c3cb7bbf165f2c", + "e7e51f53b30764f104b93f50d94b6c3c", + "7e894c57a820cd7604a1db6b7ab8cce6", + "1e195b7055e98eb473bbb5af05d48f7d", +] + +fetch_org800_templates = _make_reusable_fetcher( + "fetch_org800_templates", + op.join(afq_home, "org800_templates"), + baseurl, + org800_remote_fnames, + org800_fnames, + md5_list=org800_md5_hashes, + doc="Download AFQ org800 templates", +) + + +def read_org800_templates(load_npy=True, load_trx=True): + """ + Load O'Donnell Research Group (ORG) Fiber Clustering White + Matter Atlas 800 modified for pyAFQ templates from file + + Parameters + ---------- + load_npy : bool, optional + If True, values are loaded as numpy arrays. Otherwise, values are + paths to npy files. Default: True + load_trx : bool, optional + If True, the tractogram is loaded as a StatefulTractogram. Otherwise, + the value is the path to the trx file. Default: True + + Returns + ------- + dict with: keys: names of atlas info + values: Floats, arrays, and StatefulTractogram for the atlas. + Any unloaded will instead be paths. + """ + logger = logging.getLogger("AFQ") + + logger.debug("loading org800 templates") + tic = time.perf_counter() + + template_dict = _fetcher_to_template(fetch_org800_templates) + + if load_trx: + template_dict["tracks_reoriented"] = load_tractogram( + template_dict.pop("ORG_atlas_tracks_reoriented"), + "same", + ) + if load_npy: + template_dict["centroids"] = np.load( + template_dict.pop("ORG800_atlas_centroids") + ) + template_dict["e_val"] = np.load(template_dict.pop("ORG800_atlas_e_val")) + template_dict["e_vec_norm"] = np.load( + template_dict.pop("ORG800_atlas_e_vec_norm") + ) + template_dict["e_vec"] = np.load(template_dict.pop("ORG800_atlas_e_vec")) + template_dict["number_of_eigenvectors"] = float( + np.load(template_dict.pop("ORG800_atlas_number_of_eigenvectors")) + ) + template_dict["row_sum_1"] = np.load( + template_dict.pop("ORG800_atlas_row_sum_1") + ) + template_dict["row_sum_matrix"] = np.load( + template_dict.pop("ORG800_atlas_row_sum_matrix") + ) + template_dict["sigma"] = float(np.load(template_dict.pop("ORG800_atlas_sigma"))) + + toc = time.perf_counter() + logger.debug( + f"O'Donnell Research Group 800 templates loaded in {toc - tic:0.4f} seconds" + ) + + return template_dict + + +massp_fnames = [ + "left_VTA.nii.gz", + "right_VTA.nii.gz", +] + +massp_remote_fnames = [ + "34892325", + "34892319", +] + +massp_md5_hashes = [ + "03d65d85abb161ea25501c343c136e40", + "440874b899d2c1057e5fd77b8b350bc4", +] + +fetch_massp_templates = _make_reusable_fetcher( + "fetch_massp_templates", + op.join(afq_home, "massp_templates"), + baseurl, + massp_remote_fnames, + massp_fnames, + md5_list=massp_md5_hashes, + doc="Download AFQ MassP templates", +) + + +def read_massp_templates(as_img=True, resample_to=False): + """Load AFQ MASSP templates from file + + Parameters + ---------- + as_img : bool, optional + If True, values are `Nifti1Image`. Otherwise, values are + paths to Nifti files. Default: True + resample_to : str or nibabel image class instance, optional + A template image to resample to. Typically, this should be the + template to which individual-level data are registered. Defaults to + the MNI template. Default: False + + Returns + ------- + dict with: keys: names of template ROIs and values: nibabel Nifti1Image + objects from each of the ROI nifti files. + """ + logger = logging.getLogger("AFQ") + + logger.debug("loading MASSP templates") + tic = time.perf_counter() + + template_dict = _fetcher_to_template( + fetch_massp_templates, as_img=as_img, resample_to=resample_to + ) + + toc = time.perf_counter() + logger.debug(f"MASSP templates loaded in {toc - tic:0.4f} seconds") + + return template_dict + + cp_fnames = [ "ICP_L_inferior_prob.nii.gz", "ICP_L_superior_prob.nii.gz", @@ -1457,22 +1675,26 @@ def read_stanford_hardi_tractography(): Reads a minimal tractography from the Stanford dataset. """ files, folder = fetch_stanford_hardi_tractography() - files_dict = {} - files_dict["mapping.nii.gz"] = nib.load( - op.join(afq_home, "stanford_hardi_tractography", "mapping.nii.gz") - ) # We need the original data as reference dwi_img, gtab = dpd.read_stanford_hardi() + reg_template = read_mni_template() + + files_dict = {} + files_dict["dwi"] = dwi_img - files_dict["tractography_subsampled.trk"] = load_tractogram( + mapping_file = op.join(afq_home, "stanford_hardi_tractography", "mapping.nii.gz") + old_mapping = read_old_mapping(mapping_file, dwi_img, reg_template) + files_dict["mapping"] = get_simplified_transform(old_mapping) + + files_dict["tractography_subsampled"] = load_tractogram( op.join(afq_home, "stanford_hardi_tractography", "tractography_subsampled.trk"), dwi_img, bbox_valid_check=False, trk_header_check=False, ).streamlines - files_dict["full_segmented_cleaned_tractography.trk"] = load_tractogram( + files_dict["full_segmented_cleaned_tractography"] = load_tractogram( op.join( afq_home, "stanford_hardi_tractography", @@ -1747,10 +1969,10 @@ def read_hcp_atlas(n_bundles=16, as_file=False): bundle_dict[bundle]["recobundles"]["centroid"] = centroid_file # For some reason, this file-name has a 0 in it, instead of an O: - bundle_dict["IFOF_R"] = bundle_dict["IF0F_R"] - # In the 80-bundle case, there are two files, and both have identical - # content, so this is fine: - del bundle_dict["IF0F_R"] + if "IF0F_R" in bundle_dict: + bundle_dict["IFOF_R"] = bundle_dict["IF0F_R"] + del bundle_dict["IF0F_R"] + return bundle_dict diff --git a/AFQ/definitions/image.py b/AFQ/definitions/image.py index 4208056c..4d68fb8b 100644 --- a/AFQ/definitions/image.py +++ b/AFQ/definitions/image.py @@ -3,8 +3,10 @@ import nibabel as nib import numpy as np from dipy.align import resample +from scipy.ndimage import distance_transform_edt from AFQ.definitions.utils import Definition, find_file, name_from_path +from AFQ.recognition.utils import tolerance_mm_to_vox from AFQ.tasks.utils import get_tp __all__ = [ @@ -324,6 +326,13 @@ class RoiImage(ImageDefinition): use_endpoints : bool Whether to use the endpoints ("start" and "end") to generate the image. + only_wmgmi : bool + Whether to only include portion of ROIs in the WM-GM interface. + only_wm : bool + Whether to only include portion of ROIs in the white matter. + dilate : bool + Whether to dilate the ROIs before combining them, according to the + tolerance that will be used during bundle recognition. tissue_property : str or None Tissue property from `scalars` to multiply the ROI image with. Can be useful to limit seed mask to the core white matter. @@ -350,14 +359,20 @@ def __init__( use_presegment=False, use_endpoints=False, only_wmgmi=False, + only_wm=False, + dilate=True, tissue_property=None, tissue_property_n_voxel=None, tissue_property_threshold=None, ): + if only_wmgmi and only_wm: + raise ValueError("only_wmgmi and only_wm cannot both be True") self.use_waypoints = use_waypoints self.use_presegment = use_presegment self.use_endpoints = use_endpoints self.only_wmgmi = only_wmgmi + self.only_wm = only_wm + self.dilate = dilate self.tissue_property = tissue_property self.tissue_property_n_voxel = tissue_property_n_voxel self.tissue_property_threshold = tissue_property_threshold @@ -384,23 +399,38 @@ def _image_getter_helper( for bundle_name in bundle_dict: bundle_entry = bundle_dict.transform_rois( - bundle_name, mapping_imap["mapping"], data_imap["dwi_affine"] + bundle_name, mapping_imap["mapping"], data_imap["dwi"] ) - rois = [] + rois = {} if self.use_endpoints: - rois.extend( - [ - bundle_entry[end_type] + rois.update( + { + bundle_entry[end_type]: end_type for end_type in ["start", "end"] if end_type in bundle_entry - ] + } ) if self.use_waypoints: - rois.extend(bundle_entry.get("include", [])) - for roi in rois: + rois.update( + dict.fromkeys(bundle_entry.get("include", []), "waypoint") + ) + + dist_to_waypoint, dist_to_atlas, _ = tolerance_mm_to_vox( + data_imap["dwi"], + segmentation_params["dist_to_waypoint"], + segmentation_params["dist_to_atlas"], + ) + for roi, roi_type in rois.items(): warped_roi = roi.get_fdata() if image_data is None: image_data = np.zeros(warped_roi.shape) + if self.dilate: + edt = distance_transform_edt(np.where(warped_roi == 0, 1, 0)) + if roi_type == "waypoint": + warped_roi = edt <= dist_to_waypoint + else: + warped_roi = edt <= dist_to_atlas + image_data = np.logical_or(image_data, warped_roi.astype(bool)) if self.tissue_property is not None: tp = nib.load( @@ -456,6 +486,10 @@ def _image_getter_helper( ) ) + if self.only_wm: + wm = nib.load(tissue_imap["pve_internal"]).get_fdata()[..., 2] >= 0.5 + image_data = np.logical_and(image_data, wm) + return nib.Nifti1Image( image_data.astype(np.float32), data_imap["dwi_affine"] ), dict(source="ROIs") @@ -898,7 +932,7 @@ def _image_getter_helper(mapping, reg_template, reg_subject): static_affine=reg_template.affine, ).get_fdata() - scalar_data = mapping.transform_inverse(img_data, interpolation="nearest") + scalar_data = mapping.transform(img_data, interpolation="nearest") return nib.Nifti1Image( scalar_data.astype(np.float32), reg_subject.affine ), dict(source=self.path) diff --git a/AFQ/definitions/mapping.py b/AFQ/definitions/mapping.py index dfcedef7..ed5ca043 100644 --- a/AFQ/definitions/mapping.py +++ b/AFQ/definitions/mapping.py @@ -5,9 +5,10 @@ import nibabel as nib import numpy as np from dipy.align import affine_registration, syn_registration -from dipy.align.imaffine import AffineMap +from dipy.align.streamlinear import whole_brain_slr import AFQ.registration as reg +from AFQ._fixes import get_simplified_transform from AFQ.definitions.utils import Definition, find_file from AFQ.tasks.utils import get_fname from AFQ.utils.path import space_from_fname, write_json @@ -177,11 +178,11 @@ def __init__(self, warp, ref_affine): self.ref_affine = ref_affine self.warp = warp - def transform_inverse(self, data, **kwargs): + def transform(self, data, **kwargs): data_img = Image(nib.Nifti1Image(data.astype(np.float32), self.ref_affine)) return np.asarray(applyDeformation(data_img, self.warp).data) - def transform_inverse_pts(self, pts): + def transform_pts(self, pts): # This should only be used for curvature analysis, # Because I think the results still need to be shifted pts = nib.affines.apply_affine(self.warp.src.getAffine("voxel", "world"), pts) @@ -189,39 +190,13 @@ def transform_inverse_pts(self, pts): pts = self.warp.transform(pts, "fsl", "world") return pts - def transform(self, data, **kwargs): + def transform_inverse(self, data, **kwargs): raise NotImplementedError( "Fnirt based mappings can currently" + " only transform from template to subject space" ) -class IdentityMap(Definition): - """ - Does not perform any transformations from MNI to subject where - pyAFQ normally would. - - Examples - -------- - my_example_mapping = IdentityMap() - api.GroupAFQ(mapping=my_example_mapping) - """ - - def __init__(self): - pass - - def get_for_subses( - self, base_fname, dwi, dwi_data_file, reg_subject, reg_template, tmpl_name - ): - return ConformedAffineMapping( - np.identity(4), - domain_grid_shape=reg.reduce_shape(reg_subject.shape), - domain_grid2world=reg_subject.affine, - codomain_grid_shape=reg.reduce_shape(reg_template.shape), - codomain_grid2world=reg_template.affine, - ) - - class GeneratedMapMixin(object): """ Helper Class @@ -236,24 +211,17 @@ def get_fnames(self, extension, base_fname, sub_name, tmpl_name): mapping_file = mapping_file + extension return mapping_file, meta_fname - def prealign( - self, base_fname, sub_name, tmpl_name, reg_subject, reg_template, save=True - ): - prealign_file_desc = f"_desc-prealign_from-{sub_name}_to-{tmpl_name}_xform" - prealign_file = get_fname(base_fname, f"{prealign_file_desc}.npy") - if not op.exists(prealign_file): - start_time = time() - _, aff = affine_registration( - reg_subject, reg_template, **self.affine_kwargs - ) - meta = dict(type="rigid", dependent="dwi", timing=time() - start_time) - if not save: - return aff - logger.info(f"Saving {prealign_file}") - np.save(prealign_file, aff) - meta_fname = get_fname(base_fname, f"{prealign_file_desc}.json") - write_json(meta_fname, meta) - return prealign_file if save else np.load(prealign_file) + def prealign(self, reg_subject, reg_template): + logger.info("Calculating affine pre-alignment...") + _, aff = affine_registration(reg_subject, reg_template, **self.affine_kwargs) + return aff + + +class AffineMapMixin(GeneratedMapMixin): + """ + Helper Class + Useful for maps that are generated by pyAFQ + """ def get_for_subses( self, @@ -268,34 +236,22 @@ def get_for_subses( ): sub_space = space_from_fname(dwi_data_file) mapping_file, meta_fname = self.get_fnames( - self.extension, base_fname, sub_space, tmpl_name + ".npy", base_fname, sub_space, tmpl_name ) - if self.use_prealign: - reg_prealign = np.load( - self.prealign( - base_fname, sub_space, tmpl_name, reg_subject, reg_template - ) - ) - else: - reg_prealign = None if not op.exists(mapping_file): start_time = time() - mapping = self.gen_mapping( - base_fname, - sub_space, - tmpl_name, + affine = self.gen_mapping( reg_subject, reg_template, subject_sls, template_sls, - reg_prealign, ) total_time = time() - start_time logger.info(f"Saving {mapping_file}") - reg.write_mapping(mapping, mapping_file) - meta = dict(type="displacementfield", timing=total_time) + np.save(mapping_file, affine) + meta = dict(type="affine", timing=total_time) if subject_sls is None: meta["dependent"] = "dwi" else: @@ -305,10 +261,7 @@ def get_for_subses( if isinstance(reg_template, str): meta["reg_template"] = reg_template write_json(meta_fname, meta) - reg_prealign_inv = np.linalg.inv(reg_prealign) if self.use_prealign else None - mapping = reg.read_mapping( - mapping_file, dwi, reg_template, prealign=reg_prealign_inv - ) + mapping = reg.read_affine_mapping(mapping_file, dwi, reg_template) return mapping @@ -353,33 +306,71 @@ def __init__(self, use_prealign=True, affine_kwargs=None, syn_kwargs=None): self.use_prealign = use_prealign self.affine_kwargs = affine_kwargs self.syn_kwargs = syn_kwargs - self.extension = ".nii.gz" - def gen_mapping( + def get_for_subses( self, base_fname, - sub_space, - tmpl_name, + dwi, + dwi_data_file, reg_subject, reg_template, - subject_sls, - template_sls, - reg_prealign, + tmpl_name, + subject_sls=None, + template_sls=None, ): - _, mapping = syn_registration( - reg_subject.get_fdata(), - reg_template.get_fdata(), - moving_affine=reg_subject.affine, - static_affine=reg_template.affine, - prealign=reg_prealign, - **self.syn_kwargs, + sub_space = space_from_fname(dwi_data_file) + mapping_file_forward, meta_forward_fname = self.get_fnames( + ".nii.gz", base_fname, sub_space, tmpl_name + ) + mapping_file_backward, meta_backward_fname = self.get_fnames( + ".nii.gz", base_fname, tmpl_name, sub_space ) - if self.use_prealign: - mapping.codomain_world2grid = np.linalg.inv(reg_prealign) + + if not op.exists(mapping_file_forward) or not op.exists(mapping_file_backward): + meta = dict(type="displacementfield") + meta["dependent"] = "dwi" + if isinstance(reg_subject, str): + meta["reg_subject"] = reg_subject + if isinstance(reg_template, str): + meta["reg_template"] = reg_template + + start_time = time() + if self.use_prealign: + reg_prealign = self.prealign(reg_subject, reg_template) + else: + reg_prealign = None + + logger.info("Calculating SyN registration...") + _, mapping = syn_registration( + reg_subject.get_fdata(), + reg_template.get_fdata(), + moving_affine=reg_subject.affine, + static_affine=reg_template.affine, + prealign=reg_prealign, + **self.syn_kwargs, + ) + mapping = get_simplified_transform(mapping) + + total_time = time() - start_time + meta["total_time"] = total_time + + logger.info(f"Saving {mapping_file_forward}") + nib.save( + nib.Nifti1Image(mapping.forward, reg_subject.affine), + mapping_file_forward, + ) + write_json(meta_forward_fname, meta) + logger.info(f"Saving {mapping_file_backward}") + nib.save( + nib.Nifti1Image(mapping.backward, reg_template.affine), + mapping_file_backward, + ) + write_json(meta_backward_fname, meta) + mapping = reg.read_syn_mapping(mapping_file_forward, mapping_file_backward) return mapping -class SlrMap(GeneratedMapMixin, Definition): +class SlrMap(AffineMapMixin, Definition): """ Calculate a SLR registration for each subject/session using reg_subject and reg_template. @@ -407,33 +398,23 @@ class SlrMap(GeneratedMapMixin, Definition): def __init__(self, slr_kwargs=None): if slr_kwargs is None: slr_kwargs = {} - self.slr_kwargs = {} - self.use_prealign = False - self.extension = ".npy" + self.slr_kwargs = slr_kwargs def gen_mapping( self, - base_fname, - sub_space, - tmpl_name, - reg_template, reg_subject, + reg_template, subject_sls, template_sls, - reg_prealign, ): - return reg.slr_registration( - subject_sls, - template_sls, - moving_affine=reg_subject.affine, - moving_shape=reg_subject.shape, - static_affine=reg_template.affine, - static_shape=reg_template.shape, - **self.slr_kwargs, + _, transform, _, _ = whole_brain_slr( + subject_sls, template_sls, x0="affine", verbose=False, **self.slr_kwargs ) + return transform + -class AffMap(GeneratedMapMixin, Definition): +class AffMap(AffineMapMixin, Definition): """ Calculate an affine registration for each subject/session using reg_subject and reg_template. @@ -457,47 +438,37 @@ class AffMap(GeneratedMapMixin, Definition): def __init__(self, affine_kwargs=None): if affine_kwargs is None: affine_kwargs = {} - self.use_prealign = False self.affine_kwargs = affine_kwargs - self.extension = ".npy" def gen_mapping( self, - base_fname, - sub_space, - tmpl_name, reg_subject, reg_template, subject_sls, template_sls, - reg_prealign, ): - return ConformedAffineMapping( - np.linalg.inv( - self.prealign( - base_fname, - sub_space, - tmpl_name, - reg_subject, - reg_template, - save=False, - ) - ), - domain_grid_shape=reg.reduce_shape(reg_subject.shape), - domain_grid2world=reg_subject.affine, - codomain_grid_shape=reg.reduce_shape(reg_template.shape), - codomain_grid2world=reg_template.affine, - ) + return np.linalg.inv(self.prealign(reg_subject, reg_template)) -class ConformedAffineMapping(AffineMap): +class IdentityMap(AffineMapMixin, Definition): """ - Modifies AffineMap API to match DiffeomorphicMap API. - Important for SLR maps API to be indistinguishable from SYN maps API. + Does not perform any transformations from MNI to subject where + pyAFQ normally would. + + Examples + -------- + my_example_mapping = IdentityMap() + api.GroupAFQ(mapping=my_example_mapping) """ - def transform(self, *args, **kwargs): - return super().transform_inverse(*args, **kwargs) + def __init__(self): + pass - def transform_inverse(self, *args, **kwargs): - return super().transform(*args, **kwargs) + def gen_mapping( + self, + reg_subject, + reg_template, + subject_sls, + template_sls, + ): + return np.identity(4) diff --git a/AFQ/nn/brainchop.py b/AFQ/nn/brainchop.py index ea2a6e15..13bfc968 100644 --- a/AFQ/nn/brainchop.py +++ b/AFQ/nn/brainchop.py @@ -4,7 +4,8 @@ import nibabel as nib import nibabel.processing as nbp import numpy as np -from scipy.ndimage import binary_dilation, gaussian_filter +from scipy.ndimage import gaussian_filter +from skimage.morphology import dilation from skimage.segmentation import find_boundaries from AFQ.data.fetch import afq_home, fetch_brainchop_models @@ -29,7 +30,7 @@ def _get_model(model_name): return model_fname -def run_brainchop(ort, t1_img, model_name): +def run_brainchop(ort, t1_img, model_name, onnx_kwargs): """ Run the Brainchop command line interface with the provided arguments. @@ -55,7 +56,7 @@ def run_brainchop(ort, t1_img, model_name): image = t1_data.astype(np.float32)[None, None, ...] logger.info(f"Running {model_name}...") - sess = ort.InferenceSession(model) + sess = ort.InferenceSession(model, **onnx_kwargs) input_name = sess.get_inputs()[0].name output_name = sess.get_outputs()[0].name output_channels = sess.run([output_name], {input_name: image})[0] @@ -66,7 +67,8 @@ def run_brainchop(ort, t1_img, model_name): # Mindgrab can be tight sometimes, # better to include a bit more, # than to miss some - output = binary_dilation(output, iterations=2) + for _ in range(2): + output = dilation(output) output_img = nbp.resample_from_to( nib.Nifti1Image(output.astype(np.uint8), t1_img_conformed.affine), t1_img diff --git a/AFQ/nn/multiaxial.py b/AFQ/nn/multiaxial.py index d5f5a12a..f0d72307 100644 --- a/AFQ/nn/multiaxial.py +++ b/AFQ/nn/multiaxial.py @@ -14,7 +14,9 @@ __all__ = ["run_multiaxial", "extract_brain_mask", "multiaxial"] -def multiaxial(ort, img, model_sagittal, model_axial, model_coronal, consensus_model): +def multiaxial( + ort, img, model_sagittal, model_axial, model_coronal, consensus_model, onnx_kwargs +): """ Perform multiaxial segmentation using three ONNX models and a consensus model [1]. @@ -31,6 +33,8 @@ def multiaxial(ort, img, model_sagittal, model_axial, model_coronal, consensus_m Path to coronal ONNX model. consensus_model : str Path to consensus ONNX model. + onnx_kwargs : dict + ONNX kwargs to use for inference. Returns ------- @@ -48,25 +52,25 @@ def multiaxial(ort, img, model_sagittal, model_axial, model_coronal, consensus_m pbar = tqdm(total=4) input_ = img[..., None] - sagittal_results = _run_onnx_model(ort, model_sagittal, input_, coords) + sagittal_results = _run_onnx_model(ort, model_sagittal, input_, coords, onnx_kwargs) pbar.update(1) input_ = np.swapaxes(img, 0, 1)[..., None] coronal_results = np.swapaxes( - _run_onnx_model(ort, model_coronal, input_, coords), 0, 1 + _run_onnx_model(ort, model_coronal, input_, coords, onnx_kwargs), 0, 1 ) pbar.update(1) input_ = np.transpose(img, (2, 0, 1))[..., None] axial_results = np.transpose( - _run_onnx_model(ort, model_axial, input_, coords), (1, 2, 0, 3) + _run_onnx_model(ort, model_axial, input_, coords, onnx_kwargs), (1, 2, 0, 3) ) pbar.update(1) X = np.concatenate( [img[..., None], sagittal_results, coronal_results, axial_results], -1 ) - sess = ort.InferenceSession(consensus_model, providers=["CPUExecutionProvider"]) + sess = ort.InferenceSession(consensus_model, **onnx_kwargs) input_name = sess.get_inputs()[0].name output_name = sess.get_outputs()[0].name yhat = sess.run([output_name], {input_name: X[None, ...]})[0] @@ -77,8 +81,8 @@ def multiaxial(ort, img, model_sagittal, model_axial, model_coronal, consensus_m return pred -def _run_onnx_model(ort, model, input_, coords): - sess = ort.InferenceSession(model, providers=["CPUExecutionProvider"]) +def _run_onnx_model(ort, model, input_, coords, onnx_kwargs): + sess = ort.InferenceSession(model, **onnx_kwargs) input_name = sess.get_inputs()[0].name coord_name = sess.get_inputs()[1].name output_name = sess.get_outputs()[0].name @@ -131,7 +135,7 @@ def _get_multiaxial_model(): return model_dict -def run_multiaxial(ort, t1_img): +def run_multiaxial(ort, t1_img, onnx_kwargs): """ Run the multiaxial model. """ @@ -155,6 +159,7 @@ def run_multiaxial(ort, t1_img): model_dict["axial_model"], model_dict["coronal_model"], model_dict["consensus_model"], + onnx_kwargs, ) output_img = nbp.resample_from_to( diff --git a/AFQ/nn/synthseg.py b/AFQ/nn/synthseg.py index 45f80f52..e67cfdf0 100644 --- a/AFQ/nn/synthseg.py +++ b/AFQ/nn/synthseg.py @@ -28,7 +28,7 @@ def _get_model(model_name): return model_fname -def run_synthseg(ort, t1_img, model_name): +def run_synthseg(ort, t1_img, model_name, onnx_kwargs): """ Run the Synthseg Model @@ -57,7 +57,7 @@ def run_synthseg(ort, t1_img, model_name): image = t1_data.astype(np.float32)[None, ..., None] logger.info(f"Running {model_name}...") - sess = ort.InferenceSession(model) + sess = ort.InferenceSession(model, **onnx_kwargs) input_name = sess.get_inputs()[0].name output_name = sess.get_outputs()[0].name output_channels = sess.run([output_name], {input_name: image})[0] @@ -86,8 +86,8 @@ def pve_from_synthseg(synthseg_data): PVE data with CSF, GM, and WM segmentations. """ - CSF_labels = [0, 3, 4, 11, 12, 21, 22, 17] - GM_labels = [2, 7, 8, 9, 10, 14, 15, 16, 20, 25, 26, 27, 28, 29, 30, 31] + CSF_labels = [0, 3, 4, 11, 12, 21, 22, 16] + GM_labels = [2, 7, 8, 9, 10, 14, 15, 17, 20, 25, 26, 27, 28, 29, 30, 31] WM_labels = [1, 5, 19, 23] mixed_labels = [13, 18, 32] diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index db741fe0..93a58b71 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -1,9 +1,9 @@ import logging import dipy.tracking.streamline as dts -import nibabel as nib import numpy as np from dipy.io.stateful_tractogram import StatefulTractogram +from dipy.stats.analysis import assignment_map from scipy.stats import zscore from sklearn.ensemble import IsolationForest @@ -13,15 +13,24 @@ logger = logging.getLogger("AFQ") -def clean_by_orientation(streamlines, primary_axis, affine, tol=None): +def clean_by_orientation(streamlines, primary_axis, core_only=0.6): """ - Compute the cardinal orientation of each streamline + Retain streamlines whose core is oriented along the primary axis + and have endpoints that are also oriented along the primary axis + and have a majority of their steps along the primary axis. Parameters ---------- streamlines : sequence of N by 3 arrays Where N is number of nodes in the array, the collection of streamlines to filter down to. + core_only : float, optional + If non-zero, only the core of the bundle is used for cleaning. + The core is defined as the middle 60% of each streamline, + thus our default is 0.6. This means streamlines are allowed to + deviate in the starting and ending 20% of the bundle. This is useful + for allowing more diverse endpoints. + Default: 0.6 Returns ------- @@ -32,37 +41,38 @@ def clean_by_orientation(streamlines, primary_axis, affine, tol=None): raise ValueError( f"Primary axis must be one of {axes_names}, got {primary_axis}" ) - orientation = nib.orientations.aff2axcodes(affine) - for idx, axis_label in enumerate(orientation): - if axis_label in primary_axis: - primary_axis = idx - break - axis_diff = np.zeros((len(streamlines), 3)) + primary_axis = abu.axes_dict[primary_axis] + + core_accepted_idx = np.zeros(len(streamlines), dtype=bool) + if core_only != 0: + crop_edge = (1.0 - core_only) / 2 + for ii, sl in enumerate(streamlines): + n_points = len(sl) + along_diff = np.abs( + np.diff( + sl[int(n_points * crop_edge) : int(n_points * (1 - crop_edge))], + axis=0, + ) + ) + # The majority of steps must be in the primary axis direction: + core_accepted_idx[ii] = np.sum( + np.argmax(along_diff, axis=1) == primary_axis + ) > (len(along_diff) / 2) + endpoint_diff = np.zeros((len(streamlines), 3)) + along_diff = np.zeros((len(streamlines), 3)) for ii, sl in enumerate(streamlines): - # endpoint diff is between first and last endpoint_diff[ii, :] = np.abs(sl[0, :] - sl[-1, :]) - # axis diff is difference between the nodes, along - axis_diff[ii, :] = np.sum(np.abs(np.diff(sl, axis=0)), axis=0) - - orientation_along = np.argmax(axis_diff, axis=1) - along_accepted_idx = orientation_along == primary_axis - if tol is not None: - percentage_primary = ( - 100 * axis_diff[:, primary_axis] / np.sum(axis_diff, axis=1) - ) - logger.debug( - (f"Maximum primary percentage found: {np.max(percentage_primary)}") - ) - along_accepted_idx = np.logical_and( - along_accepted_idx, percentage_primary > tol - ) - + along_diff[ii, :] = np.sum(np.abs(np.diff(sl, axis=0)), axis=0) orientation_end = np.argmax(endpoint_diff, axis=1) + orientation_along = np.argmax(along_diff, axis=1) end_accepted_idx = orientation_end == primary_axis + along_accepted_idx = orientation_along == primary_axis - cleaned_idx = np.logical_and(along_accepted_idx, end_accepted_idx) + cleaned_idx = np.logical_and( + along_accepted_idx, np.logical_and(end_accepted_idx, core_accepted_idx) + ) return cleaned_idx @@ -70,45 +80,76 @@ def clean_by_orientation(streamlines, primary_axis, affine, tol=None): def clean_by_orientation_mahalanobis( streamlines, n_points=100, - core_only=0.6, + core_only=0, min_sl=20, distance_threshold=3, + length_threshold=4, clean_rounds=5, + remove_lengths="long", ): - fgarray = np.array(abu.resample_tg(streamlines, n_points)) + if length_threshold == 0: + length_threshold = np.inf + fgarray = abu.resample_tg(streamlines, n_points) + + assignment_idxs = np.asarray(assignment_map(fgarray, fgarray, 100)) + assignment_idxs = assignment_idxs.reshape((len(fgarray), n_points)) + fgarray = np.asarray(fgarray) if core_only != 0: crop_edge = (1.0 - core_only) / 2 fgarray = fgarray[ :, int(n_points * crop_edge) : int(n_points * (1 - crop_edge)), : - ] # Crop to middle 60% + ] fgarray_dists = fgarray[:, 1:, :] - fgarray[:, :-1, :] + assignment_idxs = assignment_idxs[:, 1:] + lengths = np.array([sl.shape[0] for sl in streamlines]) idx = np.arange(len(fgarray)) rounds_elapsed = 0 while rounds_elapsed < clean_rounds: - # This calculates the Mahalanobis for each streamline/node: m_dist = gaussian_weights( - fgarray_dists, return_mahalnobis=True, n_points=None, stat=np.mean + fgarray_dists, + assignment_idxs=assignment_idxs, + return_mahalnobis=True, + n_points=None, + stat=np.mean, ) + length_z = zscore(lengths) + logger.debug(f"Shape of fgarray: {np.asarray(fgarray_dists).shape}") logger.debug((f"Maximum m_dist for each fiber: {np.max(m_dist, axis=1)}")) - if not (np.any(m_dist >= distance_threshold)): + if not ( + np.any(m_dist >= distance_threshold) or np.any(length_z >= length_threshold) + ): break + idx_dist = np.all(m_dist < distance_threshold, axis=-1) + if remove_lengths == "long": + idx_len = length_z < length_threshold + elif remove_lengths == "short": + idx_len = length_z > -length_threshold + elif remove_lengths == "both": + idx_len = np.abs(length_z) < length_threshold + else: + raise ValueError( + f"Invalid value for remove_lengths: {remove_lengths}. " + "Expected 'long', 'short', or 'both'." + ) + idx_belong = np.logical_and(idx_dist, idx_len) - if np.sum(idx_dist) < min_sl: - # need to sort and return exactly min_sl: + if np.sum(idx_belong) < min_sl: idx = idx[np.argsort(np.sum(m_dist, axis=-1))[:min_sl].astype(int)] + idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) break else: - # Update by selection: - idx = idx[idx_dist] - fgarray_dists = fgarray_dists[idx_dist] + idx = idx[idx_belong] + fgarray_dists = fgarray_dists[idx_belong] + lengths = lengths[idx_belong] + assignment_idxs = assignment_idxs[idx_belong] rounds_elapsed += 1 logger.debug((f"Rounds elapsed: {rounds_elapsed}, num kept: {len(idx)}")) logger.debug(f"Kept indices: {idx}") @@ -120,12 +161,13 @@ def clean_bundle( tg, n_points=100, clean_rounds=5, - distance_threshold=3, + distance_threshold=4, length_threshold=4, min_sl=20, stat=np.mean, core_only=0.6, return_idx=False, + remove_lengths="long", ): """ Clean a segmented fiber group based on the Mahalnobis distance of @@ -143,7 +185,7 @@ def clean_bundle( the mean of extracted bundles. Default: 5 distance_threshold : float, optional. Threshold of cleaning based on the Mahalanobis distance (the units are - standard deviations). Default: 3. + standard deviations). Default: 4. length_threshold: float, optional Threshold for cleaning based on length (in standard deviations). Length of any streamline should not be *more* than this number of stdevs from @@ -156,7 +198,7 @@ def clean_bundle( calculated. Default: `np.mean` (but can also use median, etc.) core_only : float, optional If non-zero, only the core of the bundle is used for cleaning. - The core is commonly defined as the middle 60% of each streamline, + The core is defined as the middle 60% of each streamline, thus our default is 0.6. This means streamlines are allowed to deviate in the starting and ending 20% of the bundle. This is useful for allowing more diverse endpoints. @@ -164,6 +206,13 @@ def clean_bundle( return_idx : bool Whether to return indices in the original streamlines. Default: False. + remove_lengths : str + Specifies which streamlines to remove based on their length. + Options are "long" (remove long streamlines), "short" + (remove short streamlines), or "both" + (remove both long and short streamlines). + Default: "long" + Returns ------- A StatefulTractogram class instance containing only the streamlines @@ -189,6 +238,9 @@ def clean_bundle( else: return tg + if length_threshold == 0: + length_threshold = np.inf + # Resample once up-front: fgarray = np.asarray(abu.resample_tg(streamlines, n_points)) if core_only != 0: @@ -225,12 +277,23 @@ def clean_bundle( # Select the fibers that have Mahalanobis smaller than the # threshold for all their nodes: idx_dist = np.all(m_dist < distance_threshold, axis=-1) - idx_len = length_z < length_threshold + if remove_lengths == "long": + idx_len = length_z < length_threshold + elif remove_lengths == "short": + idx_len = length_z > -length_threshold + elif remove_lengths == "both": + idx_len = np.abs(length_z) < length_threshold + else: + raise ValueError( + f"Invalid value for remove_lengths: {remove_lengths}. " + "Expected 'long', 'short', or 'both'." + ) idx_belong = np.logical_and(idx_dist, idx_len) if np.sum(idx_belong) < min_sl: # need to sort and return exactly min_sl: idx = idx[np.argsort(np.sum(m_dist, axis=-1))[:min_sl].astype(int)] + idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) @@ -318,6 +381,9 @@ def clean_by_isolation_forest( ) return np.ones(len(streamlines), dtype=bool) + if length_threshold == 0: + length_threshold = np.inf + # Resample once up-front: fgarray = np.asarray(abu.resample_tg(streamlines, n_points)) fgarray_dists = np.zeros_like(fgarray) @@ -360,6 +426,7 @@ def clean_by_isolation_forest( if np.sum(idx_belong) < min_sl: # need to sort and return exactly min_sl: idx = idx[np.argsort(-sl_outliers)[:min_sl].astype(int)] + idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) diff --git a/AFQ/recognition/clustering.py b/AFQ/recognition/clustering.py new file mode 100644 index 00000000..335f68b4 --- /dev/null +++ b/AFQ/recognition/clustering.py @@ -0,0 +1,189 @@ +# Original source: github.com/SlicerDMRI/whitematteranalysis +# Copyright 2026 BWH and 3D Slicer contributors +# Licensed under 3D Slicer license (BSD style; https://github.com/SlicerDMRI/whitematteranalysis/blob/master/License.txt) # noqa +# Modified by John Kruper for pyAFQ +# Modifications: +# 1. Only mean distance included, and mean distance replaced with numba version. +# 2. Uses atlas data from dictionary and numpy files rather than pickled files, +# to avoid additional dependencies. +# 3. Added function to move template streamlines +# to subject space to calculate distances. + +import numpy as np +import scipy +from dipy.io.stateful_tractogram import Space +from numba import njit, prange + +import AFQ.data.fetch as afd +import AFQ.recognition.utils as abu +import AFQ.utils.streamlines as aus + + +@njit(parallel=True) +def _compute_mean_euclidean_matrix(group_n, group_m): + len_n = group_n.shape[0] + len_m = group_m.shape[0] + num_points = group_n.shape[1] + + dist_matrix = np.empty((len_n, len_m), dtype=np.float64) + + for i in prange(len_n): + for j in range(len_m): + sum_dist = 0.0 + sum_dist_ref = 0.0 + + for k in range(num_points): + dx = group_n[i, k, 0] - group_m[j, k, 0] + dx_ref = group_n[i, k, 0] + group_m[j, k, 0] + dy = group_n[i, k, 1] - group_m[j, k, 1] + dz = group_n[i, k, 2] - group_m[j, k, 2] + + sum_dist += np.sqrt(dx * dx + dy * dy + dz * dz) + sum_dist_ref += np.sqrt(dx_ref * dx_ref + dy * dy + dz * dz) + + mean_d = sum_dist / num_points + mean_d_ref = sum_dist_ref / num_points + + final_d = min(mean_d, mean_d_ref) + dist_matrix[i, j] = final_d * final_d + + return dist_matrix.T + + +def _distance_to_similarity(distance, sigmasq): + similarities = np.exp(-distance / (sigmasq)) + + return similarities + + +def _rectangular_similarity_matrix(fgarray_sub, fgarray_atlas, sigma): + distances = _compute_mean_euclidean_matrix(fgarray_sub, fgarray_atlas) + + sigmasq = sigma * sigma + similarity_matrix = _distance_to_similarity(distances, sigmasq) + + return similarity_matrix + + +def spectral_atlas_label( + sub_fgarray, + atlas_fgarray, + atlas_data=None, + sigma_multiplier=1.0, + cluster_indices=None, +): + """ + Use an existing atlas to label a new streamlines. + + Parameters + ---------- + sub_fgarray : ndarray + Resampled fiber group to be labeled. + atlas_fgarray : ndarray + Resampled atlas to use for labelling. + atlas_data : dict, optional + Precomputed atlas data formatted as a dictionary of arrays and floats. + See `afd.read_org800_templates` as a reference. + sigma_multiplier : float, optional + Multiplier for the sigma value used in computing the similarity + matrix. Default is 1.0. + cluster_indices : list of int, optional + If provided, only these cluster indices from the atlas will be used + for labeling. Default is None, which uses all clusters. + + Returns + ------- + tuple of (ndarray, ndarray) + Cluster indices for all the fibers and their embedding + """ + if atlas_data is None: + atlas_data = afd.read_org800_templates(load_trx=False) + + number_fibers = sub_fgarray.shape[0] + sz = atlas_fgarray.shape[0] + + # Compute fiber similarities. + B = _rectangular_similarity_matrix( + sub_fgarray, atlas_fgarray, sigma=atlas_data["sigma"] * sigma_multiplier + ) + + # Do Normalized Cuts transform of similarity matrix. + # row sum estimate for current B part of the matrix + row_sum_2 = np.sum(B, axis=0) + np.dot(atlas_data["row_sum_matrix"], B) + + # This happens plenty in our cases. Why? + # Maybe a probabilistic vs UKF thing? + # In practice, this is not an issue since we just set to a small value. + if any(row_sum_2 <= 0): + row_sum_2[row_sum_2 < 0] = 1e-4 + + # Normalized cuts normalization + row_sum = np.concatenate((atlas_data["row_sum_1"], row_sum_2)) + dhat = np.sqrt(np.divide(1, row_sum)) + B = np.multiply(B, np.outer(dhat[0:sz], dhat[sz:].T)) + + # Compute embedding using eigenvectors + V = np.dot( + np.dot(B.T, atlas_data["e_vec"]), np.diag(np.divide(1.0, atlas_data["e_val"])) + ) + V = np.divide(V, atlas_data["e_vec_norm"]) + n_eigen = int(atlas_data["number_of_eigenvectors"]) + embed = np.zeros((number_fibers, n_eigen)) + for i in range(0, n_eigen): + embed[:, i] = np.divide(V[:, -(i + 2)], V[:, -1]) + + # Label streamlines using centroids from atlas + if cluster_indices is not None: + centroids = atlas_data["centroids"][cluster_indices, :] + cluster_idx, _ = scipy.cluster.vq.vq(embed, centroids) + cluster_idx = np.array([cluster_indices[i] for i in cluster_idx]) + else: + cluster_idx, _ = scipy.cluster.vq.vq(embed, atlas_data["centroids"]) + + return cluster_idx, embed + + +def subcluster_by_atlas( + sub_trk, mapping, dwi_ref, cluster_indices, atlas_data=None, n_points=20 +): + """ + Use an existing atlas to label a new set of streamlines, and return the + cluster indices for each streamline. + + Parameters + ---------- + sub_trk : StatefulTractogram + streamlines to be labeled. + mapping : DIPY or pyAFQ mapping + Mapping to use to move streamlines. + dwi_ref : Nifti1Image + Image defining reference for where the atlas streamlines move to. + cluster_indices : list of int + Cluster indices from the atlas to use for labeling. + atlas_data : dict, optional + Precomputed atlas data formatted as a dictionary of arrays and floats. + See `afd.read_org800_templates` as a reference. + n_points : int, optional + Number of points to resample streamlines to for labeling. Default is 20. + """ + + if atlas_data is None: + atlas_data = afd.read_org800_templates() + atlas_sft = atlas_data["tracks_reoriented"] + + moved_atlas_sft = aus.move_streamlines( + atlas_sft, "subject", mapping, dwi_ref, to_space=Space.RASMM + ) + atlas_fgarray = np.array(abu.resample_tg(moved_atlas_sft.streamlines, n_points)) + + sub_trk.to_rasmm() + sub_fgarray = np.array(abu.resample_tg(sub_trk.streamlines, n_points)) + + cluster_idxs, _ = spectral_atlas_label( + sub_fgarray, + atlas_fgarray, + atlas_data=atlas_data, + cluster_indices=cluster_indices, + ) + + return cluster_idxs diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index e68e6251..b054dda6 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -19,14 +19,16 @@ import AFQ.recognition.roi as abr import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import apply_to_roi_dict +from AFQ.recognition.clustering import subcluster_by_atlas from AFQ.utils.stats import chunk_indices +from AFQ.utils.streamlines import move_streamlines criteria_order_pre_other_bundles = [ - "prob_map", + "length", "cross_midline", "start", "end", - "length", + "prob_map", "primary_axis", "include", "exclude", @@ -44,17 +46,23 @@ "primary_axis_percentage", "inc_addtol", "exc_addtol", + "exact_endpoints", + "ORG_spectral_subbundles", + "cluster_IDs", + "startpoint_location", + "endpoint_location", ] logger = logging.getLogger("AFQ") -def prob_map(b_sls, bundle_def, preproc_imap, prob_threshold, **kwargs): +def prob_map(b_sls, bundle_def, preproc_imap, prob_threshold, img, **kwargs): b_sls.initiate_selection("Prob. Map") - # using entire fgarray here only because it is the first step fiber_probabilities = dts.values_from_volume( - bundle_def["prob_map"].get_fdata(), preproc_imap["fgarray"], np.eye(4) + bundle_def["prob_map"].get_fdata(), + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], + img.affine, ) fiber_probabilities = np.mean(fiber_probabilities, -1) b_sls.select(fiber_probabilities > prob_threshold, "Prob. Map") @@ -69,57 +77,82 @@ def cross_midline(b_sls, bundle_def, preproc_imap, **kwargs): def start(b_sls, bundle_def, preproc_imap, **kwargs): - accept_idx = b_sls.initiate_selection("Startpoint") - abr.clean_by_endpoints( - b_sls.get_selected_sls(), + b_sls.initiate_selection("Startpoint") + exact_endpoints = bundle_def.get("exact_endpoints", False) + if exact_endpoints: + tol = 0 + else: + tol = preproc_imap["dist_to_atlas"] + + accept_idx = abr.clean_by_endpoints( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["start"], 0, - tol=preproc_imap["dist_to_atlas"], + tol=tol, flip_sls=b_sls.sls_flipped, - accepted_idxs=accept_idx, ) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( - b_sls.get_selected_sls(), + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["start"], -1, - tol=preproc_imap["dist_to_atlas"], + tol=tol, + ) + new_accept_idx = np.logical_or(accepted_idx_flipped, accept_idx) + special_idx = np.logical_and(accept_idx, accepted_idx_flipped) + special_idx_to_flip = abu.manual_orient_sls( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs][special_idx] ) + accepted_idx_flipped[special_idx] = special_idx_to_flip b_sls.reorient(accepted_idx_flipped) - accept_idx = np.logical_xor(accepted_idx_flipped, accept_idx) + accept_idx = new_accept_idx + b_sls.select(accept_idx, "Startpoint") def end(b_sls, bundle_def, preproc_imap, **kwargs): - accept_idx = b_sls.initiate_selection("endpoint") - abr.clean_by_endpoints( - b_sls.get_selected_sls(), + b_sls.initiate_selection("endpoint") + exact_endpoints = bundle_def.get("exact_endpoints", False) + if exact_endpoints: + tol = 0 + else: + tol = preproc_imap["dist_to_atlas"] + + accept_idx = abr.clean_by_endpoints( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["end"], -1, - tol=preproc_imap["dist_to_atlas"], + tol=tol, flip_sls=b_sls.sls_flipped, - accepted_idxs=accept_idx, ) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( - b_sls.get_selected_sls(), + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["end"], 0, - tol=preproc_imap["dist_to_atlas"], + tol=tol, ) + new_accept_idx = np.logical_or(accepted_idx_flipped, accept_idx) + special_idx = np.logical_and(accept_idx, accepted_idx_flipped) + special_idx_to_flip = abu.manual_orient_sls( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs][special_idx] + ) + accepted_idx_flipped[special_idx] = special_idx_to_flip b_sls.reorient(accepted_idx_flipped) - accept_idx = np.logical_xor(accepted_idx_flipped, accept_idx) + accept_idx = new_accept_idx b_sls.select(accept_idx, "endpoint") def length(b_sls, bundle_def, preproc_imap, **kwargs): - accept_idx = b_sls.initiate_selection("length") - min_len = bundle_def["length"].get("min_len", 0) / preproc_imap["vox_dim"] - max_len = bundle_def["length"].get("max_len", np.inf) / preproc_imap["vox_dim"] - for idx, sl in enumerate(b_sls.get_selected_sls()): - sl_len = np.sum(np.linalg.norm(np.diff(sl, axis=0), axis=1)) - if sl_len >= min_len and sl_len <= max_len: - accept_idx[idx] = 1 + b_sls.initiate_selection("length") + min_len = bundle_def["length"].get("min_len", 0) + max_len = bundle_def["length"].get("max_len", np.inf) + + # No need to use b_sls.selected_fiber_idxs + # because this is first step + sl_lens = preproc_imap["lengths"] + + accept_idx = (sl_lens >= min_len) & (sl_lens <= max_len) b_sls.select(accept_idx, "length") @@ -128,13 +161,12 @@ def primary_axis(b_sls, bundle_def, img, **kwargs): accept_idx = abc.clean_by_orientation( b_sls.get_selected_sls(), bundle_def["primary_axis"], - img.affine, - bundle_def.get("primary_axis_percentage", None), + bundle_def.get("core_only", 0.6), ) b_sls.select(accept_idx, "orientation") -def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): +def include(b_sls, bundle_def, preproc_imap, n_cpus, **kwargs): accept_idx = b_sls.initiate_selection("include") flip_using_include = len(bundle_def["include"]) > 1 and not b_sls.oriented_yet @@ -147,6 +179,15 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): else: include_roi_tols = [preproc_imap["tol"] ** 2] * len(bundle_def["include"]) + # For now I am turning ray parallelization here off. + # It is never worthwhile considering other changes we + # have made to speed up this step, + # so spinning up ray and transferring data back + # and forth is not worth it. + # In the future, I think we should redo this with numba and + # use multithreading + n_cpus = 1 + # with parallel segmentation, the first for loop will # only collect streamlines and does not need tqdm if n_cpus > 1: @@ -172,15 +213,18 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): b_sls.get_selected_sls(), bundle_def["include"], include_roi_tols ) - roi_closest = -np.ones((max_includes, len(b_sls)), dtype=np.int32) + n_inc = len(bundle_def["include"]) + roi_closest = np.zeros((n_inc, len(b_sls)), dtype=np.int32) + roi_dists = np.zeros((n_inc, len(b_sls)), dtype=np.float32) if flip_using_include: to_flip = np.ones_like(accept_idx, dtype=np.bool_) for sl_idx, inc_result in enumerate(inc_results): - sl_accepted, sl_closest = inc_result + sl_accepted, sl_closest, sl_dists = inc_result if sl_accepted: + roi_closest[:, sl_idx] = sl_closest + roi_dists[:, sl_idx] = sl_dists if len(sl_closest) > 1: - roi_closest[: len(sl_closest), sl_idx] = sl_closest # Only accept SLs that, when cut, are meaningful if (len(sl_closest) < 2) or abs(sl_closest[0] - sl_closest[-1]) > 1: # Flip sl if it is close to second ROI @@ -188,12 +232,14 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): if flip_using_include: to_flip[sl_idx] = sl_closest[0] > sl_closest[-1] if to_flip[sl_idx]: - roi_closest[: len(sl_closest), sl_idx] = np.flip(sl_closest) + roi_closest[:, sl_idx] = np.flip(sl_closest) + roi_dists[:, sl_idx] = np.flip(sl_dists) accept_idx[sl_idx] = 1 else: accept_idx[sl_idx] = 1 b_sls.roi_closest = roi_closest.T + b_sls.roi_dists = roi_dists.T if flip_using_include: b_sls.reorient(to_flip) b_sls.select(accept_idx, "include") @@ -211,10 +257,9 @@ def curvature(b_sls, bundle_def, mapping, img, save_intermediates, **kwargs): ref_sl = load_tractogram( bundle_def["curvature"]["path"], "same", bbox_valid_check=False ) - moved_ref_sl = abu.move_streamlines( + moved_ref_sl = move_streamlines( ref_sl, "subject", mapping, img, save_intermediates=save_intermediates ) - moved_ref_sl.to_vox() moved_ref_sl = moved_ref_sl.streamlines[0] moved_ref_curve = abv.sl_curve(moved_ref_sl, len(moved_ref_sl)) ref_curve_threshold = np.radians(bundle_def["curvature"].get("thresh", 10)) @@ -257,11 +302,12 @@ def recobundles( **kwargs, ): b_sls.initiate_selection("Recobundles") - moved_sl = abu.move_streamlines( - StatefulTractogram(b_sls.get_selected_sls(), img, Space.VOX), + moved_sl = move_streamlines( + StatefulTractogram(b_sls.get_selected_sls(), img, Space.RASMM), "template", mapping, reg_template, + to_space=Space.RASMM, save_intermediates=save_intermediates, ).streamlines moved_sl_resampled = abu.resample_tg(moved_sl, 100) @@ -277,6 +323,7 @@ def recobundles( standard_sl = next(iter(bundle_def["recobundles"]["centroid"])) oriented_idx = abu.orient_by_streamline(moved_sl[rec_labels], standard_sl) b_sls.reorient(rec_labels[oriented_idx]) + rec_labels = sorted(rec_labels) b_sls.select(rec_labels, "Recobundles") @@ -284,7 +331,7 @@ def qb_thresh(b_sls, bundle_def, preproc_imap, clip_edges, **kwargs): b_sls.initiate_selection("qb_thresh") cut = clip_edges or ("bundlesection" in bundle_def) qbx = QuickBundles( - bundle_def["qb_thresh"] / preproc_imap["vox_dim"], + bundle_def["qb_thresh"], AveragePointwiseEuclideanMetric(ResampleFeature(nb_points=12)), ) clusters = qbx.cluster(b_sls.get_selected_sls(cut=cut, flip=True)) @@ -293,36 +340,39 @@ def qb_thresh(b_sls, bundle_def, preproc_imap, clip_edges, **kwargs): def clean_by_other_bundle( - b_sls, bundle_def, img, preproc_imap, other_bundle_name, other_bundle_sls, **kwargs + b_sls, bundle_def, img, other_bundle_name, other_bundle_sls, **kwargs ): cleaned_idx = b_sls.initiate_selection(other_bundle_name) cleaned_idx = 1 + flipped_sls = b_sls.get_selected_sls(flip=True) if "overlap" in bundle_def[other_bundle_name]: cleaned_idx_overlap = abo.clean_by_overlap( - b_sls.get_selected_sls(), + flipped_sls, other_bundle_sls, bundle_def[other_bundle_name]["overlap"], img, - False, + remove=False, + project=bundle_def[other_bundle_name].get("project", None), ) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_overlap) if "node_thresh" in bundle_def[other_bundle_name]: cleaned_idx_node_thresh = abo.clean_by_overlap( - b_sls.get_selected_sls(), + flipped_sls, other_bundle_sls, bundle_def[other_bundle_name]["node_thresh"], img, - True, + remove=True, + project=bundle_def[other_bundle_name].get("project", None), ) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_node_thresh) if "core" in bundle_def[other_bundle_name]: cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["core"].lower(), - preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - np.array(abu.resample_tg(other_bundle_sls, 20)), + np.array(abu.resample_tg(flipped_sls, 100)), + np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, False, ) @@ -331,8 +381,8 @@ def clean_by_other_bundle( if "entire_core" in bundle_def[other_bundle_name]: cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["entire_core"].lower(), - preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - np.array(abu.resample_tg(other_bundle_sls, 20)), + np.array(abu.resample_tg(flipped_sls, 100)), + np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, True, ) @@ -381,10 +431,7 @@ def run_bundle_rec_plan( reg_template, preproc_imap, bundle_name, - bundle_idx, - bundle_to_flip, - bundle_roi_closest, - bundle_decisions, + recognized_bundles_dict, **segmentation_params, ): # Warp ROIs @@ -392,9 +439,7 @@ def run_bundle_rec_plan( start_time = time() bundle_def = dict(bundle_dict.get_b_info(bundle_name)) bundle_def.update( - bundle_dict.transform_rois( - bundle_name, mapping, img.affine, apply_to_recobundles=True - ) + bundle_dict.transform_rois(bundle_name, mapping, img, apply_to_recobundles=True) ) def check_space(roi): @@ -420,20 +465,25 @@ def check_space(roi): ) logger.info(f"Time to prep ROIs: {time() - start_time}s") - b_sls = abu.SlsBeingRecognized( - tg.streamlines, - logger, - segmentation_params["save_intermediates"], - bundle_name, - img, - len(bundle_def.get("include", [])), - ) + if isinstance(tg, abu.SlsBeingRecognized): + # This only occurs when your inside a subbundle, + # in which case we want to keep the same SlsBeingRecognized object so that + # we can keep track of the same streamlines and their orientations + b_sls = tg + else: + b_sls = abu.SlsBeingRecognized( + tg.streamlines, + logger, + segmentation_params["save_intermediates"], + bundle_name, + img, + len(bundle_def.get("include", [])), + ) inputs = {} inputs["b_sls"] = b_sls inputs["preproc_imap"] = preproc_imap inputs["bundle_def"] = bundle_def - inputs["max_includes"] = bundle_dict.max_includes inputs["mapping"] = mapping inputs["img"] = img inputs["reg_template"] = reg_template @@ -444,7 +494,7 @@ def check_space(roi): if ( (potential_criterion not in criteria_order_post_other_bundles) and (potential_criterion not in criteria_order_pre_other_bundles) - and (potential_criterion not in bundle_dict.bundle_names) + and (potential_criterion not in recognized_bundles_dict.keys()) and (potential_criterion not in valid_noncriterion) ): raise ValueError( @@ -454,7 +504,7 @@ def check_space(roi): "Valid criteria are:\n" f"{criteria_order_pre_other_bundles}\n" f"{criteria_order_post_other_bundles}\n" - f"{bundle_dict.bundle_names}\n" + f"{recognized_bundles_dict.keys()}\n" f"{valid_noncriterion}\n" ) ) @@ -463,13 +513,14 @@ def check_space(roi): if b_sls and criterion in bundle_def: inputs[criterion] = globals()[criterion](**inputs) if b_sls: - for ii, bundle_name in enumerate(bundle_dict.bundle_names): - if bundle_name in bundle_def.keys(): - idx = np.where(bundle_decisions[:, ii])[0] + for o_bundle_name in recognized_bundles_dict.keys(): + if o_bundle_name in bundle_def.keys(): clean_by_other_bundle( **inputs, - other_bundle_name=bundle_name, - other_bundle_sls=tg.streamlines[idx], + other_bundle_name=o_bundle_name, + other_bundle_sls=recognized_bundles_dict[ + o_bundle_name + ].get_selected_sls(flip=True), ) for criterion in criteria_order_post_other_bundles: if b_sls and criterion in bundle_def: @@ -480,6 +531,19 @@ def check_space(roi): ): mahalanobis(**inputs) + # If you don't cross the midline, we remove streamliens + # entirely on the wrong side of the midline here after filtering + if b_sls and "cross_midline" in bundle_def and not bundle_def["cross_midline"]: + b_sls.initiate_selection("Wrong side of mid.") + avg_side = np.sign( + np.mean( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs, :, 0], + axis=1, + ) + ) + majority_side = np.sign(np.sum(avg_side)) + b_sls.select(avg_side == majority_side, "Wrong side of mid.") + if b_sls and not b_sls.oriented_yet: raise ValueError( "pyAFQ was unable to consistently orient streamlines " @@ -490,9 +554,44 @@ def check_space(roi): ) if b_sls: - bundle_to_flip[b_sls.selected_fiber_idxs, bundle_idx] = b_sls.sls_flipped.copy() - bundle_decisions[b_sls.selected_fiber_idxs, bundle_idx] = 1 - if hasattr(b_sls, "roi_closest"): - bundle_roi_closest[b_sls.selected_fiber_idxs, bundle_idx, :] = ( - b_sls.roi_closest.copy() + if "ORG_spectral_subbundles" in bundle_def: + subdict = bundle_def["ORG_spectral_subbundles"] + b_sls.initiate_selection( + ( + f"ORG spectral clustering, {len(subdict.bundle_names)} " + "subbundles being recognized" + ) ) + + sub_sft = StatefulTractogram( + b_sls.get_selected_sls(flip=True), img, Space.RASMM + ) + cluster_labels = subcluster_by_atlas( + sub_sft, mapping, img, subdict.all_cluster_IDs, n_points=40 + ) + clusters_being_recognized = [] + for sub_b_name in subdict.bundle_names: + c_ids = subdict._dict[sub_b_name]["cluster_IDs"] + n_roi = len(subdict._dict[sub_b_name].get("include", [])) + cluster_b_sls = b_sls.copy(sub_b_name, n_roi) + selected = np.zeros(len(b_sls), dtype=bool) + for c_id in c_ids: + selected = np.logical_or(selected, cluster_labels == c_id) + cluster_b_sls.select(selected, f"Clusters {c_ids}") + clusters_being_recognized.append(cluster_b_sls) + + for ii, sub_b_name in enumerate(subdict.bundle_names): + run_bundle_rec_plan( + bundle_def["ORG_spectral_subbundles"], + clusters_being_recognized[ii], + mapping, + img, + reg_template, + preproc_imap, + sub_b_name, + recognized_bundles_dict, + **segmentation_params, + ) + else: + b_sls.bundle_def = bundle_def + recognized_bundles_dict[bundle_name] = b_sls diff --git a/AFQ/recognition/other_bundles.py b/AFQ/recognition/other_bundles.py index 1e94106b..bff432c2 100644 --- a/AFQ/recognition/other_bundles.py +++ b/AFQ/recognition/other_bundles.py @@ -6,10 +6,20 @@ import numpy as np from scipy.spatial.distance import cdist +import AFQ.recognition.utils as abu + logger = logging.getLogger("AFQ") -def clean_by_overlap(this_bundle_sls, other_bundle_sls, overlap, img, remove=False): +def clean_by_overlap( + this_bundle_sls, + other_bundle_sls, + overlap, + img, + remove=False, + project=None, + other_bundle_min_density=0.05, +): """ Cleans a set of streamlines by only keeping (or removing) those with significant overlap with another set of streamlines. @@ -18,6 +28,7 @@ def clean_by_overlap(this_bundle_sls, other_bundle_sls, overlap, img, remove=Fal ---------- this_bundle_sls : array-like A list or array of streamlines to be cleaned. + Assumed to be in RASMM space. other_bundle_sls : array-like A reference list or array of streamlines to determine overlapping regions. overlap : int @@ -32,6 +43,16 @@ def clean_by_overlap(this_bundle_sls, other_bundle_sls, overlap, img, remove=Fal removed. If False, streamlines that overlap in more than `overlap` nodes are removed. Default: False. + project : {'A/P', 'I/S', 'L/R', None}, optional + If specified, the overlap calculation is projected along the given axis + before cleaning. For example, 'A/P' projects the streamlines along the + anterior-posterior axis. + Default: None. + other_bundle_min_density : float, optional + A threshold to binarize the density map of `other_bundle_sls`. Voxels + with density values above this threshold (as a fraction of the maximum + density) are considered occupied. + Default: 0.05. Returns ------- @@ -54,10 +75,34 @@ def clean_by_overlap(this_bundle_sls, other_bundle_sls, overlap, img, remove=Fal >>> cleaned_bundle = [s for i, s in enumerate(bundle1) if clean_idx[i]] """ other_bundle_density_map = dtu.density_map( - other_bundle_sls, np.eye(4), img.shape[:3] + other_bundle_sls, img.affine, img.shape[:3] ) + + if remove: + max_val = other_bundle_density_map.max() + if max_val > 0: + other_bundle_density_map = ( + other_bundle_density_map / max_val + ) > other_bundle_min_density + else: + other_bundle_density_map = np.zeros_like( + other_bundle_density_map, dtype=bool + ) + + if project is not None: + orientation = nib.orientations.aff2axcodes(img.affine) + core_axis = next( + idx for idx, label in enumerate(orientation) if label in project.upper() + ) + + projection = np.sum(other_bundle_density_map, axis=core_axis) + + other_bundle_density_map = np.broadcast_to( + np.expand_dims(projection, axis=core_axis), other_bundle_density_map.shape + ) + fiber_probabilities = dts.values_from_volume( - other_bundle_density_map, this_bundle_sls, np.eye(4) + other_bundle_density_map, this_bundle_sls, img.affine ) cleaned_idx = np.zeros(len(this_bundle_sls), dtype=np.bool_) for ii, fp in enumerate(fiber_probabilities): @@ -83,8 +128,10 @@ def clean_relative_to_other_core( retained. this_fgarray : ndarray An array of streamlines to be cleaned. + Assumed to be in RASMM space. other_fgarray : ndarray An array of reference streamlines to define the core. + Assumed to be in RASMM space. affine : ndarray The affine transformation matrix. entire : bool, optional @@ -122,17 +169,7 @@ def clean_relative_to_other_core( return np.ones(this_fgarray.shape[0], dtype=np.bool_) # find dimension of core axis - orientation = nib.orientations.aff2axcodes(affine) - core_axis = None - core_upper = core[0].upper() - axis_groups = { - "L": ("L", "R"), - "R": ("L", "R"), - "P": ("P", "A"), - "A": ("P", "A"), - "I": ("I", "S"), - "S": ("I", "S"), - } + core_axis = abu.axes_dict[core[0].upper()] direction_signs = { "L": 1, @@ -143,18 +180,7 @@ def clean_relative_to_other_core( "S": -1, } - core_axis = None - for idx, axis_label in enumerate(orientation): - if core_upper in axis_groups[axis_label]: - core_axis = idx - core_direc = direction_signs[core_upper] - break - - if affine[core_axis, core_axis] < 0: - core_direc = -core_direc - - if core_axis is None: - raise ValueError(f"Invalid core axis: {core}") + core_direc = direction_signs[core[0].upper()] core_bundle = np.median(other_fgarray, axis=0) cleaned_idx_core = np.zeros(this_fgarray.shape[0], dtype=np.bool_) diff --git a/AFQ/recognition/preprocess.py b/AFQ/recognition/preprocess.py index 8bb41e67..ec47c7a4 100644 --- a/AFQ/recognition/preprocess.py +++ b/AFQ/recognition/preprocess.py @@ -1,9 +1,7 @@ import logging from time import time -import dipy.tracking.streamline as dts import immlib -import nibabel as nib import numpy as np import AFQ.recognition.utils as abu @@ -13,20 +11,7 @@ @immlib.calc("tol", "dist_to_atlas", "vox_dim") def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): - # We need to calculate the size of a voxel, so we can transform - # from mm to voxel units: - R = img.affine[0:3, 0:3] - vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) - - # Tolerance is set to the square of the distance to the corner - # because we are using the squared Euclidean distance in calls to - # `cdist` to make those calls faster. - if dist_to_waypoint is None: - tol = dts.dist_to_corner(img.affine) - else: - tol = dist_to_waypoint / vox_dim - dist_to_atlas = int(input_dist_to_atlas / vox_dim) - return tol, dist_to_atlas, vox_dim + return abu.tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas) @immlib.calc("fgarray") @@ -42,34 +27,40 @@ def fgarray(tg): @immlib.calc("crosses") -def crosses(fgarray, img): +def crosses(fgarray): """ Classify the streamlines by whether they cross the midline. Creates a crosses attribute which is an array of booleans. Each boolean corresponds to a streamline, and is whether or not that streamline crosses the midline. """ - # What is the x,y,z coordinate of 0,0,0 in the template space? - zero_coord = np.dot(np.linalg.inv(img.affine), np.array([0, 0, 0, 1])) - - orientation = nib.orientations.aff2axcodes(img.affine) - lr_axis = 0 - for idx, axis_label in enumerate(orientation): - if axis_label in ["L", "R"]: - lr_axis = idx - break - return np.logical_and( - np.any(fgarray[:, :, lr_axis] > zero_coord[lr_axis], axis=1), - np.any(fgarray[:, :, lr_axis] < zero_coord[lr_axis], axis=1), + np.any(fgarray[:, :, 0] > 0, axis=1), + np.any(fgarray[:, :, 0] < 0, axis=1), ) +@immlib.calc("lengths") +def lengths(fgarray): + """ + Calculate the lengths of the streamlines. + Using resampled fgarray biases lengths to be lower. However, + this is not meant to be a precise selection requirement, and + is more meant for efficiency. + """ + segments = np.diff(fgarray, axis=1) + segment_lengths = np.sqrt(np.sum(segments**2, axis=2)) + return np.sum(segment_lengths, axis=1) + + # Things that can be calculated for multiple bundles at once # (i.e., for a whole tractogram) go here def get_preproc_plan(img, tg, dist_to_waypoint, dist_to_atlas): preproc_plan = immlib.plan( - tolerance_mm_to_vox=tolerance_mm_to_vox, fgarray=fgarray, crosses=crosses + tolerance_mm_to_vox=tolerance_mm_to_vox, + fgarray=fgarray, + crosses=crosses, + lengths=lengths, ) return preproc_plan( img=img, diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 1727bd0f..51861deb 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -6,10 +6,12 @@ import numpy as np from dipy.io.stateful_tractogram import Space, StatefulTractogram +import AFQ.recognition.sparse_decisions as ars import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import BundleDict from AFQ.recognition.criteria import run_bundle_rec_plan from AFQ.recognition.preprocess import get_preproc_plan +from AFQ.utils.path import write_json logger = logging.getLogger("AFQ") @@ -153,13 +155,9 @@ def recognize( if not isinstance(bundle_dict, BundleDict): bundle_dict = BundleDict(bundle_dict) - tg.to_vox() + tg.to_rasmm() n_streamlines = len(tg) - bundle_decisions = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.bool_) - bundle_to_flip = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.bool_) - bundle_roi_closest = -np.ones( - (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.uint32 - ) + recognized_bundles_dict = {} fiber_groups = {} meta = {} @@ -167,7 +165,7 @@ def recognize( preproc_imap = get_preproc_plan(img, tg, dist_to_waypoint, dist_to_atlas) logger.info("Assigning Streamlines to Bundles") - for bundle_idx, bundle_name in enumerate(bundle_dict.bundle_names): + for bundle_name in bundle_dict.bundle_names: logger.info(f"Finding Streamlines for {bundle_name}") run_bundle_rec_plan( bundle_dict, @@ -177,10 +175,7 @@ def recognize( reg_template, preproc_imap, bundle_name, - bundle_idx, - bundle_to_flip, - bundle_roi_closest, - bundle_decisions, + recognized_bundles_dict, clip_edges=clip_edges, n_cpus=n_cpus, rb_recognize_params=rb_recognize_params, @@ -195,68 +190,62 @@ def recognize( if save_intermediates is not None: os.makedirs(save_intermediates, exist_ok=True) - bc_path = op.join(save_intermediates, "sls_bundle_decisions.npy") - np.save(bc_path, bundle_decisions) + bc_path = op.join(save_intermediates, "sls_bundle_decisions.json") + write_json( + bc_path, + { + b_name: b_sls.selected_fiber_idxs.tolist() + for b_name, b_sls in recognized_bundles_dict.items() + }, + ) + + sparse_dists = ars.compute_sparse_decisions(recognized_bundles_dict, n_streamlines) - conflicts = np.sum(np.sum(bundle_decisions, axis=1) > 1) + conflicts = ars.get_conflict_count(sparse_dists) if conflicts > 0: logger.info( ( "Conflicts in bundle assignment detected. " f"{conflicts} conflicts detected in total out of " f"{n_streamlines} total streamlines. " - "Defaulting to whichever bundle appears first " + "Defaulting to whichever bundle is closest to the include ROI," + "followed by whichever appears first " "in the bundle_dict." ) ) - bundle_decisions = np.concatenate( - (bundle_decisions, np.ones((n_streamlines, 1))), axis=1 - ) - bundle_decisions = np.argmax(bundle_decisions, -1) + + ars.remove_conflicts(sparse_dists, recognized_bundles_dict) # We do another round through, so that we can: # 1. Clip streamlines according to ROIs # 2. Re-orient streamlines logger.info("Re-orienting streamlines to consistent directions") - for bundle_idx, bundle in enumerate(bundle_dict.bundle_names): - logger.info(f"Processing {bundle}") + for b_name, r_bd in recognized_bundles_dict.items(): + logger.info(f"Processing {b_name}") - select_idx = np.where(bundle_decisions == bundle_idx)[0] - - if len(select_idx) == 0: + if len(r_bd.selected_fiber_idxs) == 0: # There's nothing here, set and move to the next bundle: - if "bundlesection" in bundle_dict.get_b_info(bundle): - for sb_name in bundle_dict.get_b_info(bundle)["bundlesection"]: + if "bundlesection" in bundle_dict.get_b_info(b_name): + for sb_name in bundle_dict.get_b_info(b_name)["bundlesection"]: _return_empty(sb_name, return_idx, fiber_groups, img) else: - _return_empty(bundle, return_idx, fiber_groups, img) + _return_empty(b_name, return_idx, fiber_groups, img) continue - # Use a list here, because ArraySequence doesn't support item - # assignment: - select_sl = list(tg.streamlines[select_idx]) - roi_closest = bundle_roi_closest[select_idx, bundle_idx, :] - n_includes = len(bundle_dict.get_b_info(bundle).get("include", [])) - if clip_edges and n_includes > 1: - logger.info("Clipping Streamlines by ROI") - select_sl = abu.cut_sls_by_closest( - select_sl, roi_closest, (0, n_includes - 1), in_place=True - ) - - to_flip = bundle_to_flip[select_idx, bundle_idx] - b_def = dict(bundle_dict.get_b_info(bundle_name)) + b_def = r_bd.bundle_def if "bundlesection" in b_def: - for sb_name, sb_include_cuts in bundle_dict.get_b_info(bundle)[ - "bundlesection" - ].items(): + for sb_name, sb_include_cuts in b_def["bundlesection"].items(): bundlesection_select_sl = abu.cut_sls_by_closest( - select_sl, roi_closest, sb_include_cuts, in_place=False + r_bd.get_selected_sls(), + r_bd.roi_closest, + sb_include_cuts, + in_place=False, ) _add_bundle_to_fiber_group( sb_name, bundlesection_select_sl, - select_idx, - to_flip, + r_bd.selected_fiber_idxs, + r_bd.sls_flipped, return_idx, fiber_groups, img, @@ -264,9 +253,15 @@ def recognize( _add_bundle_to_meta(sb_name, b_def, meta) else: _add_bundle_to_fiber_group( - bundle, select_sl, select_idx, to_flip, return_idx, fiber_groups, img + b_name, + r_bd.get_selected_sls(cut=clip_edges), + r_bd.selected_fiber_idxs, + r_bd.sls_flipped, + return_idx, + fiber_groups, + img, ) - _add_bundle_to_meta(bundle, b_def, meta) + _add_bundle_to_meta(b_name, b_def, meta) return fiber_groups, meta @@ -278,10 +273,10 @@ def _return_empty(bundle_name, return_idx, fiber_groups, img): """ if return_idx: fiber_groups[bundle_name] = {} - fiber_groups[bundle_name]["sl"] = StatefulTractogram([], img, Space.VOX) + fiber_groups[bundle_name]["sl"] = StatefulTractogram([], img, Space.RASMM) fiber_groups[bundle_name]["idx"] = np.array([]) else: - fiber_groups[bundle_name] = StatefulTractogram([], img, Space.VOX) + fiber_groups[bundle_name] = StatefulTractogram([], img, Space.RASMM) def _add_bundle_to_fiber_group(b_name, sl, idx, to_flip, return_idx, fiber_groups, img): @@ -290,7 +285,7 @@ def _add_bundle_to_fiber_group(b_name, sl, idx, to_flip, return_idx, fiber_group """ sl = abu.flip_sls(sl, to_flip, in_place=False) - sl = StatefulTractogram(sl, img, Space.VOX) + sl = StatefulTractogram(sl, img, Space.RASMM) if return_idx: fiber_groups[b_name] = {} diff --git a/AFQ/recognition/roi.py b/AFQ/recognition/roi.py index d87062f5..4d26dbfc 100644 --- a/AFQ/recognition/roi.py +++ b/AFQ/recognition/roi.py @@ -2,59 +2,65 @@ from dipy.core.interpolation import interpolate_scalar_3d -def _interp3d(roi, sl): - return interpolate_scalar_3d(roi.get_fdata(), np.asarray(sl))[0] +def _interp_arr_with_affine(roi, arr, affine): + inv_affine = np.linalg.inv(affine) + lin_T = inv_affine[:3, :3].T + offset = inv_affine[:3, 3] + arr_xformd = np.dot(arr, lin_T) + offset + return np.array(interpolate_scalar_3d(roi, arr_xformd)[0]) -def check_sls_with_inclusion( - sls, include_rois, include_roi_tols): + +def check_sls_with_inclusion(sls, include_rois, include_roi_tols): inc_results = np.zeros(len(sls), dtype=tuple) - include_rois = [roi_.get_fdata().copy() for roi_ in include_rois] for jj, sl in enumerate(sls): closest = np.zeros(len(include_rois), dtype=np.int32) + dists = np.zeros(len(include_rois), dtype=np.float32) sl = np.asarray(sl) valid = True for ii, roi in enumerate(include_rois): - dist = interpolate_scalar_3d(roi, sl)[0] + dist = _interp_arr_with_affine(roi.get_fdata(), sl, roi.affine) closest[ii] = np.argmin(dist) + dists[ii] = dist[closest[ii]] if dist[closest[ii]] > include_roi_tols[ii]: # Too far from one of them: - inc_results[jj] = (False, []) + inc_results[jj] = (False, [], []) valid = False break # Checked all the ROIs and it was close to all of them if valid: - inc_results[jj] = (True, closest) + inc_results[jj] = (True, closest, dists) return inc_results -def check_sl_with_exclusion(sl, exclude_rois, - exclude_roi_tols): - """ Helper function to check that a streamline is not too close to a +def check_sl_with_exclusion(sl, exclude_rois, exclude_roi_tols): + """Helper function to check that a streamline is not too close to a list of exclusion ROIs. """ for ii, roi in enumerate(exclude_rois): # if any part of the streamline is near any exclusion ROI, # return False - if np.any(_interp3d(roi, sl) <= exclude_roi_tols[ii]): + if np.any( + _interp_arr_with_affine(roi.get_fdata(), sl, roi.affine) + <= exclude_roi_tols[ii] + ): return False # Either there are no exclusion ROIs, or you are not close to any: return True -def clean_by_endpoints(streamlines, target, target_idx, tol=0, - flip_sls=None, accepted_idxs=None): +def clean_by_endpoints(fgarray, target, target_idx, tol=0, flip_sls=None): """ Clean a collection of streamlines based on an endpoint ROI. Filters down to only include items that have their start or end points close to the targets. Parameters ---------- - streamlines : sequence of N by 3 arrays - Where N is number of nodes in the array, the collection of - streamlines to filter down to. + fgarray : ndarray of shape (N, M, 3) + Where N is number of streamlines, M is number of nodes. + Assumed to be in RASMM space. target: Nifti1Image Nifti1Image containing a distance transform of the ROI. target_idx: int. @@ -67,24 +73,30 @@ def clean_by_endpoints(streamlines, target, target_idx, tol=0, the endpoint is exactly in the coordinate of the target ROI. flip_sls : 1d array, optional Length is len(streamlines), whether to flip the streamline. - accepted_idxs : 1d array, optional - Boolean array, where entries correspond to eachs streamline, - and streamlines that pass cleaning will be set to 1. Yields ------- boolean array of streamlines that survive cleaning. """ - if accepted_idxs is None: - accepted_idxs = np.zeros(len(streamlines), dtype=np.bool_) + if not isinstance(fgarray, np.ndarray): + raise ValueError( + ( + "fgarray must be a numpy ndarray, you can resample " + "your streamlines using resample_tg in AFQ.recognition.utils" + ) + ) + + n_sls, n_nodes, _ = fgarray.shape + + # handle target_idx negative values as wrapping around + effective_idx = target_idx if target_idx >= 0 else (n_nodes + target_idx) + indices = np.full(n_sls, effective_idx) - if flip_sls is None: - flip_sls = np.zeros(len(streamlines)) - flip_sls = flip_sls.astype(int) + if flip_sls is not None: + flipped_indices = n_nodes - 1 - effective_idx + indices = np.where(flip_sls.astype(bool), flipped_indices, indices) - for ii, sl in enumerate(streamlines): - this_idx = target_idx - if flip_sls[ii]: - this_idx = (len(sl) - this_idx - 1) % len(sl) - accepted_idxs[ii] = _interp3d(target, [sl[this_idx]])[0] <= tol + distances = _interp_arr_with_affine( + target.get_fdata(), fgarray[np.arange(n_sls), indices], target.affine + ) - return accepted_idxs + return distances <= tol diff --git a/AFQ/recognition/sparse_decisions.py b/AFQ/recognition/sparse_decisions.py new file mode 100644 index 00000000..586721d5 --- /dev/null +++ b/AFQ/recognition/sparse_decisions.py @@ -0,0 +1,116 @@ +import numpy as np +from scipy.sparse import csr_matrix + + +def compute_sparse_decisions(bundles_being_recognized, n_streamlines): + """ + Compute a sparse matrix of distances to ROIs for the streamlines that are + currently being recognized. This can be used to weight decisions by distance + to ROIs, without having to create a dense matrix of distances for all + streamlines and all bundles. + + Parameters + ---------- + bundles_being_recognized : dict + A dictionary of SlsBeingRecognized objects, keyed by bundle name. + n_streamlines : int + The total number of streamlines in the original tractogram. + + Returns + ------- + csr_matrix + A sparse matrix of shape (number of bundles being recognized, n_streamlines), + where the entry (i, j) is a score: + bundles with ROIs result in weights [2.0 to 3.0] with higher scores + for streamlines closer to ROIs + Non-ROI bundles result in weight 1.0 + Everything else is 0.0 (implicit in sparse matrices) + """ + rows, cols, data = [], [], [] + epsilon = 1e-6 + + global_max_dist = 0.0 + for b in bundles_being_recognized.values(): + if hasattr(b, "roi_dists"): + global_max_dist = max(global_max_dist, np.sum(b.roi_dists, axis=-1).max()) + + norm_factor = global_max_dist + 1.0 + + for b_idx, name in enumerate(bundles_being_recognized.keys()): + bundle = bundles_being_recognized[name] + indices = bundle.selected_fiber_idxs + + if hasattr(bundle, "roi_dists"): + dists = np.sum(bundle.roi_dists, axis=-1) + dists = np.maximum(dists, epsilon) + bundle_weights = dists / norm_factor + else: + bundle_weights = np.full(len(indices), 2.0, dtype=np.float32) + + rows.extend([b_idx] * len(indices)) + cols.extend(indices) + data.extend(bundle_weights) + + sparse_scores = csr_matrix( + (data, (rows, cols)), shape=(len(bundles_being_recognized), n_streamlines) + ) + + # Final Decision: 3.0 - Score + # ROI bundles result in weights [2.0 to 3.0] + # No-ROI bundles result in weight 1.0 + sparse_scores.data = 3.0 - sparse_scores.data + + return sparse_scores + + +def get_conflict_count(sparse_scores): + """ + Count how many streamlines are being considered for more than one bundle + """ + sorted_indices = np.sort(sparse_scores.indices) + is_duplicate = np.diff(sorted_indices) == 0 + num_conflicts = np.sum(is_duplicate) + + return num_conflicts + + +def remove_conflicts(sparse_scores, bundles_being_recognized): + """ + Returns a dictionary of {bundle_name: np.array(accepted_indices)} + """ + coo = sparse_scores.tocoo() + + order = np.lexsort((-coo.data, coo.col)) + + mask = np.concatenate(([True], np.diff(coo.col[order]) != 0)) + winner_rows = coo.row[order][mask] + winner_cols = coo.col[order][mask] + + row_sort = np.argsort(winner_rows) + winner_rows = winner_rows[row_sort] + winner_cols = winner_cols[row_sort] + + num_bundles = len(bundles_being_recognized) + split_indices = np.searchsorted(winner_rows, np.arange(num_bundles + 1)) + + for i, b_name in enumerate(list(bundles_being_recognized.keys())): + b_sls = bundles_being_recognized[b_name] + if np.any(b_sls.selected_fiber_idxs[:-1] > b_sls.selected_fiber_idxs[1:]): + raise NotImplementedError( + f"Bundle '{b_name}' has unsorted selected_fiber_idxs. " + "The searchsorted optimization requires sorted indices." + "This is a bug in the implementation of the bundle " + "recognition procedure, please report it to the developers." + ) + + accept_idx = b_sls.initiate_selection(f"{b_name} conflicts") + start, end = split_indices[i], split_indices[i + 1] + bundle_winners = winner_cols[start:end] + + if len(bundle_winners) > 0: + local_positions = np.searchsorted(b_sls.selected_fiber_idxs, bundle_winners) + accept_idx[local_positions] = True + b_sls.select(local_positions, "conflicts") + else: + b_sls.select(accept_idx, "conflicts") + bundles_being_recognized.pop(b_name) diff --git a/AFQ/recognition/tests/test_other_bundles.py b/AFQ/recognition/tests/test_other_bundles.py index 0724bf3e..f1a83b19 100644 --- a/AFQ/recognition/tests/test_other_bundles.py +++ b/AFQ/recognition/tests/test_other_bundles.py @@ -1,3 +1,4 @@ +import nibabel as nib import numpy as np import AFQ.recognition.other_bundles as abo @@ -9,7 +10,7 @@ other_bundle_sls_sample = np.array( [[[0, 1, 2], [1, 2, 3], [2, 2, 2]], [[1, 1, 1], [2, 2, 2], [3, 3, 3]]] ) -img_sample = np.zeros((5, 5, 5)) +img_sample = nib.Nifti1Image(np.zeros((5, 5, 5)), np.eye(4)) node_thresh_sample = 1 diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index ef513f14..5b24d470 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -9,9 +9,9 @@ from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.stats.analysis import afq_profile +import AFQ.api.bundle_dict as abd import AFQ.data.fetch as afd import AFQ.recognition.cleaning as abc -import AFQ.registration as reg from AFQ.recognition.recognize import recognize dpd.fetch_stanford_hardi() @@ -22,10 +22,9 @@ hardi_fbvec = op.join(hardi_dir, "HARDI150.bvec") file_dict = afd.read_stanford_hardi_tractography() reg_template = afd.read_mni_template() -mapping = reg.read_mapping(file_dict["mapping.nii.gz"], hardi_img, reg_template) -streamlines = file_dict["tractography_subsampled.trk"] +mapping = file_dict["mapping"] +streamlines = file_dict["tractography_subsampled"] tg = StatefulTractogram(streamlines, hardi_img, Space.RASMM) -tg.to_vox() streamlines = tg.streamlines templates = afd.read_templates() cst_r_curve_ref = StatefulTractogram( @@ -63,9 +62,7 @@ def test_segment(): - fiber_groups, _ = recognize( - tg, nib.load(hardi_fdata), mapping, bundles, reg_template, 2 - ) + fiber_groups, _ = recognize(tg, hardi_img, mapping, bundles, reg_template, 2) # We asked for 2 fiber groups: npt.assert_equal(len(fiber_groups), 2) @@ -75,7 +72,7 @@ def test_segment(): npt.assert_(len(CST_R_sl) > 0) # Calculate the tract profile for a volume of all-ones: tract_profile = afq_profile( - np.ones(nib.load(hardi_fdata).shape[:3]), CST_R_sl.streamlines, np.eye(4) + np.ones(hardi_img.shape[:3]), CST_R_sl.streamlines, hardi_img.affine ) npt.assert_almost_equal(tract_profile, np.ones(100)) @@ -83,6 +80,47 @@ def test_segment(): npt.assert_equal(len(clean_sl), len(CST_R_sl)) +def test_segment_mixed_roi(): + lv1_files, lv1_folder = afd.fetch_stanford_hardi_lv1() + ar_rois = afd.read_ar_templates() + lv1_fname = op.join(lv1_folder, list(lv1_files.keys())[0]) + + bundle_info = { + "OR LV1": { + "start": {"roi": ar_rois["AAL_Thal_L"], "space": "template"}, + "end": {"roi": lv1_fname, "space": "subject"}, + "space": "mixed", + } + } + + with pytest.raises( + ValueError, + match=( + "When using mixed ROI bundle definitions, and subject space ROIs, " + "resample_subject_to cannot be False." + ), + ): + fiber_groups, _ = recognize( + tg, hardi_img, mapping, bundle_info, reg_template, 2 + ) + + bundle_info = abd.BundleDict(bundle_info, resample_subject_to=hardi_fdata) + fiber_groups, _ = recognize( + tg, + hardi_img, + mapping, + bundle_info, + reg_template, + 2, + dist_to_atlas=10, + ) + + # We asked for 2 fiber groups: + npt.assert_equal(len(fiber_groups), 1) + OR_LV1_sl = fiber_groups["OR LV1"] + npt.assert_(len(OR_LV1_sl) == 6) + + @pytest.mark.nightly def test_segment_no_prob(): # What if you don't have probability maps? @@ -98,7 +136,7 @@ def test_segment_no_prob(): } fiber_groups, _ = recognize( - tg, nib.load(hardi_fdata), mapping, bundles_no_prob, reg_template, 1 + tg, hardi_img, mapping, bundles_no_prob, reg_template, 1 ) # This condition should still hold @@ -109,7 +147,7 @@ def test_segment_no_prob(): def test_segment_return_idx(): # Test with the return_idx kwarg set to True: fiber_groups, _ = recognize( - tg, nib.load(hardi_fdata), mapping, bundles, reg_template, 1, return_idx=True + tg, hardi_img, mapping, bundles, reg_template, 1, return_idx=True ) npt.assert_equal(len(fiber_groups), 2) @@ -125,7 +163,7 @@ def test_segment_return_idx(): def test_segment_clip_edges_api(): # Test with the clip_edges kwarg set to True: fiber_groups, _ = recognize( - tg, nib.load(hardi_fdata), mapping, bundles, reg_template, 1, clip_edges=True + tg, hardi_img, mapping, bundles, reg_template, 1, clip_edges=True ) npt.assert_equal(len(fiber_groups), 2) npt.assert_(len(fiber_groups["Right Corticospinal"]) > 0) @@ -134,7 +172,7 @@ def test_segment_clip_edges_api(): def test_segment_reco(): # get bundles for reco method bundles_reco = afd.read_hcp_atlas(16) - bundle_names = ["CST_R", "CST_L"] + bundle_names = ["CCMid"] for key in list(bundles_reco): if key not in bundle_names: bundles_reco.pop(key, None) @@ -142,7 +180,7 @@ def test_segment_reco(): # Try recobundles method fiber_groups, _ = recognize( tg, - nib.load(hardi_fdata), + hardi_img, mapping, bundles_reco, reg_template, @@ -151,8 +189,8 @@ def test_segment_reco(): ) # This condition should still hold - npt.assert_equal(len(fiber_groups), 2) - npt.assert_(len(fiber_groups["CST_R"]) > 0) + npt.assert_equal(len(fiber_groups), 1) + npt.assert_(len(fiber_groups["CCMid"]) > 0) def test_exclusion_ROI(): @@ -175,25 +213,19 @@ def test_exclusion_ROI(): hardi_img, Space.VOX, ) - fiber_groups, _ = recognize( - slf_tg, nib.load(hardi_fdata), mapping, slf_bundle, reg_template, 1 - ) + fiber_groups, _ = recognize(slf_tg, hardi_img, mapping, slf_bundle, reg_template, 1) npt.assert_equal(len(fiber_groups["Left Superior Longitudinal"]), 2) slf_bundle["Left Superior Longitudinal"]["exclude"] = [templates["SLFt_roi2_L"]] - fiber_groups, _ = recognize( - slf_tg, nib.load(hardi_fdata), mapping, slf_bundle, reg_template, 1 - ) + fiber_groups, _ = recognize(slf_tg, hardi_img, mapping, slf_bundle, reg_template, 1) npt.assert_equal(len(fiber_groups["Left Superior Longitudinal"]), 1) def test_segment_sampled_streamlines(): - fiber_groups, _ = recognize( - tg, nib.load(hardi_fdata), mapping, bundles, reg_template, 1 - ) + fiber_groups, _ = recognize(tg, hardi_img, mapping, bundles, reg_template, 1) # Already using a subsampled tck # the Right Corticospinal has two streamlines and @@ -206,7 +238,7 @@ def test_segment_sampled_streamlines(): # sample and segment streamlines sampled_fiber_groups, _ = recognize( tg, - nib.load(hardi_fdata), + hardi_img, mapping, bundles, reg_template, diff --git a/AFQ/recognition/tests/test_rois.py b/AFQ/recognition/tests/test_rois.py index f5140660..67ef384e 100644 --- a/AFQ/recognition/tests/test_rois.py +++ b/AFQ/recognition/tests/test_rois.py @@ -4,6 +4,7 @@ from scipy.ndimage import distance_transform_edt import AFQ.recognition.roi as abr +import AFQ.recognition.utils as abu from AFQ.recognition.roi import check_sl_with_exclusion, check_sls_with_inclusion shape = (15, 15, 15) @@ -17,6 +18,7 @@ np.array([[1, 1, 1], [2, 1, 1], [3, 1, 1]]), np.array([[1, 1, 1], [2, 1, 1]]), ] +fgarray = np.array(abu.resample_tg(streamlines, 20)) roi1 = np.ones(shape, dtype=np.float32) roi1[1, 2, 3] = 0 @@ -43,15 +45,15 @@ def test_clean_by_endpoints(): - clean_idx_start = list(abr.clean_by_endpoints(streamlines, start_roi, 0)) - clean_idx_end = list(abr.clean_by_endpoints(streamlines, end_roi, -1)) + clean_idx_start = list(abr.clean_by_endpoints(fgarray, start_roi, 0)) + clean_idx_end = list(abr.clean_by_endpoints(fgarray, end_roi, -1)) npt.assert_array_equal( np.logical_and(clean_idx_start, clean_idx_end), np.array([1, 1, 0, 0]) ) # If tol=1, the third streamline also gets included - clean_idx_start = list(abr.clean_by_endpoints(streamlines, start_roi, 0, tol=1)) - clean_idx_end = list(abr.clean_by_endpoints(streamlines, end_roi, -1, tol=1)) + clean_idx_start = list(abr.clean_by_endpoints(fgarray, start_roi, 0, tol=1)) + clean_idx_end = list(abr.clean_by_endpoints(fgarray, end_roi, -1, tol=1)) npt.assert_array_equal( np.logical_and(clean_idx_start, clean_idx_end), np.array([1, 1, 1, 0]) ) @@ -63,22 +65,27 @@ def test_check_sls_with_inclusion(): assert result[0][0] is True assert np.allclose(result[0][1][0], 0) assert np.allclose(result[0][1][1], 2) + assert np.allclose(result[0][2][0], 0) + assert np.allclose(result[0][2][1], 0) assert result[1][0] is False def test_check_sl_with_inclusion_pass(): - result, dists = check_sls_with_inclusion( + result, dist_idxs, dists = check_sls_with_inclusion( [streamline1], include_rois, include_roi_tols )[0] assert result is True assert len(dists) == 2 + assert np.allclose(dist_idxs[0], 0) + assert np.allclose(dist_idxs[1], 2) def test_check_sl_with_inclusion_fail(): - result, dists = check_sls_with_inclusion( + result, dist_idxs, dists = check_sls_with_inclusion( [streamline2], include_rois, include_roi_tols )[0] assert result is False + assert dist_idxs == [] assert dists == [] diff --git a/AFQ/recognition/tests/test_utils.py b/AFQ/recognition/tests/test_utils.py index 59930cf0..c765ba27 100644 --- a/AFQ/recognition/tests/test_utils.py +++ b/AFQ/recognition/tests/test_utils.py @@ -16,9 +16,8 @@ hardi_fdata = op.join(hardi_dir, "HARDI150.nii.gz") hardi_img = nib.load(hardi_fdata) file_dict = afd.read_stanford_hardi_tractography() -streamlines = file_dict["tractography_subsampled.trk"] +streamlines = file_dict["tractography_subsampled"] tg = StatefulTractogram(streamlines, hardi_img, Space.RASMM) -tg.to_vox() streamlines = tg.streamlines @@ -83,21 +82,22 @@ def test_segment_clip_edges(): def test_segment_orientation(): cleaned_idx = abc.clean_by_orientation( - streamlines, primary_axis="P/A", affine=np.eye(4) + streamlines, + primary_axis="P/A", ) - npt.assert_equal(np.sum(cleaned_idx), 93) - cleaned_idx_tol = abc.clean_by_orientation( - streamlines, primary_axis="P/A", affine=np.eye(4), tol=50 - ) - npt.assert_(np.sum(cleaned_idx_tol) < np.sum(cleaned_idx)) + npt.assert_equal(np.sum(cleaned_idx), 79) cleaned_idx = abc.clean_by_orientation( - streamlines, primary_axis="I/S", affine=np.eye(4) + streamlines, + primary_axis="I/S", ) - cleaned_idx_tol = abc.clean_by_orientation( - streamlines, primary_axis="I/S", affine=np.eye(4), tol=33 + npt.assert_equal(np.sum(cleaned_idx), 58) + + cleaned_idx = abc.clean_by_orientation( + streamlines, + primary_axis="L/R", ) - npt.assert_array_equal(cleaned_idx_tol, cleaned_idx) + npt.assert_equal(np.sum(cleaned_idx), 57) def test_clean_isolation_forest_basic(): diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index fd34aa18..217ae415 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -1,3 +1,4 @@ +import copy import logging import os.path as op from time import time @@ -9,11 +10,49 @@ from dipy.io.streamline import save_tractogram from dipy.tracking.distances import bundles_distances_mdf -from AFQ.definitions.mapping import ConformedFnirtMapping - logger = logging.getLogger("AFQ") +axes_dict = { + "L/R": 0, + "L": 0, + "R": 0, + "P/A": 1, + "P": 1, + "A": 1, + "I/S": 2, + "I": 2, + "S": 2, +} + + +def manual_orient_sls(fgarray): + """ + Helper function to manually orient streamlines by their endpoints, + according to LPI+ pyAFQ standard assuming streamlines are in RASMM + """ + endpoint_diff = fgarray[:, 0, :] - fgarray[:, -1, :] + primary_axis = np.argmax(np.abs(endpoint_diff), axis=1) + return endpoint_diff[np.arange(len(fgarray)), primary_axis] < 0 + + +def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): + # We need to calculate the size of a voxel, so we can transform + # from mm to voxel units: + R = img.affine[0:3, 0:3] + vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) + + # Tolerance is set to the square of the distance to the corner + # because we are using the squared Euclidean distance in calls to + # `cdist` to make those calls faster. + if dist_to_waypoint is None: + tol = dts.dist_to_corner(img.affine) + else: + tol = dist_to_waypoint / vox_dim + dist_to_atlas = int(input_dist_to_atlas / vox_dim) + return tol, dist_to_atlas, vox_dim + + def flip_sls(select_sl, idx_to_flip, in_place=False): """ Helper function to flip streamlines @@ -91,47 +130,6 @@ def orient_by_streamline(sls, template_sl): return DM[:, 0] > DM[:, 1] -def move_streamlines(tg, to, mapping, img, save_intermediates=None): - """Move streamlines to or from template space. - - to : str - Either "template" or "subject". - mapping : ConformedMapping - Mapping to use to move streamlines. - img : Nifti1Image - Space to move streamlines to. - """ - tg_og_space = tg.space - if isinstance(mapping, ConformedFnirtMapping): - if to != "subject": - raise ValueError( - "Attempted to transform streamlines to template using " - "unsupported mapping. " - "Use something other than Fnirt." - ) - tg.to_vox() - moved_sl = [] - for sl in tg.streamlines: - moved_sl.append(mapping.transform_inverse_pts(sl)) - else: - tg.to_rasmm() - if to == "template": - volume = mapping.forward - else: - volume = mapping.backward - delta = dts.values_from_volume(volume, tg.streamlines, np.eye(4)) - moved_sl = dts.Streamlines([d + s for d, s in zip(delta, tg.streamlines)]) - moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) - if save_intermediates is not None: - save_tractogram( - moved_sft, - op.join(save_intermediates, f"sls_in_{to}.trk"), - bbox_valid_check=False, - ) - tg.to_space(tg_og_space) - return moved_sft - - def resample_tg(tg, n_points): # reformat for dipy's set_number_of_points if isinstance(tg, np.ndarray): @@ -169,14 +167,22 @@ def select(self, idx, clean_name, cut=False): self.sls_flipped = self.sls_flipped[idx] if hasattr(self, "roi_closest"): self.roi_closest = self.roi_closest[idx] + if hasattr(self, "roi_dists"): + self.roi_dists = self.roi_dists[idx] time_taken = time() - self.start_time self.logger.info( f"After filtering by {clean_name} (time: {time_taken}s), " f"{len(self)} streamlines remain." ) - if self.save_intermediates is not None: + + # Only save intermediates after the 90% of the + # streamlines have been filtered out, + # otherwise its impractical + if self.save_intermediates is not None and len(self) < 0.1 * len(self.ref_sls): save_tractogram( - StatefulTractogram(self.get_selected_sls(cut=cut), self.ref, Space.VOX), + StatefulTractogram( + self.get_selected_sls(cut=cut), self.ref, Space.RASMM + ), op.join( self.save_intermediates, f"sls_after_{clean_name}_for_{self.b_name}.trk", @@ -212,3 +218,27 @@ def __bool__(self): def __len__(self): return len(self.selected_fiber_idxs) + + def copy(self, new_name, n_roi): + new_copy = copy.copy(self) + new_copy.b_name = new_name + if n_roi > 0: + if self.n_roi > 0: + raise NotImplementedError( + ( + "You cannot have includes in the original bundle and" + " subbundles; only one or the other." + ) + ) + else: + new_copy.n_roi = n_roi + + new_copy.selected_fiber_idxs = self.selected_fiber_idxs.copy() + new_copy.sls_flipped = self.sls_flipped.copy() + + if hasattr(self, "roi_closest"): + new_copy.roi_closest = self.roi_closest.copy() + if hasattr(self, "roi_dists"): + new_copy.roi_dists = self.roi_dists.copy() + + return new_copy diff --git a/AFQ/registration.py b/AFQ/registration.py index f63e784b..28b72799 100644 --- a/AFQ/registration.py +++ b/AFQ/registration.py @@ -4,11 +4,14 @@ import nibabel as nib import numpy as np -from dipy.align import syn_registration +from dipy.align.imaffine import AffineMap from dipy.align.imwarp import DiffeomorphicMap -from dipy.align.streamlinear import whole_brain_slr -__all__ = ["syn_register_dwi", "write_mapping", "read_mapping", "slr_registration"] +__all__ = [ + "read_affine_mapping", + "read_syn_mapping", + "read_old_mapping", +] def reduce_shape(shape): @@ -21,78 +24,55 @@ def reduce_shape(shape): return shape -def syn_register_dwi(dwi, gtab, template=None, **syn_kwargs): +def read_syn_mapping(disp, codisp): """ - Register DWI data to a template. + Read a syn registration mapping from a nifti file Parameters - ----------- - dwi : nifti image or str - Image containing DWI data, or full path to a nifti file with DWI. - gtab : GradientTable - The gradients associated with the DWI data - template : nifti image or str, optional + ---------- + disp : str or Nifti1Image + If string, file must of an image or ndarray. + If image, contains the mapping displacement field in each voxel + from subject to template - syn_kwargs : key-word arguments for :func:`syn_registration` + codisp : str or Nifti1Image + If string, file must of an image or ndarray. + If image, contains the mapping displacement field in each voxel + from template to subject Returns ------- - DiffeomorphicMap object + A :class:`DiffeomorphicMap` object """ - if template is None: - import AFQ.data.fetch as afd - - template = afd.read_mni_template() - if isinstance(template, str): - template = nib.load(template) - - template_data = template.get_fdata() - template_affine = template.affine - - if isinstance(dwi, str): - dwi = nib.load(dwi) - - dwi_affine = dwi.affine - dwi_data = dwi.get_fdata() - mean_b0 = np.mean(dwi_data[..., gtab.b0s_mask], -1) - warped_b0, mapping = syn_registration( - mean_b0, - template_data, - moving_affine=dwi_affine, - static_affine=template_affine, - **syn_kwargs, + if isinstance(disp, str): + disp = nib.load(disp) + + if isinstance(codisp, str): + codisp = nib.load(codisp) + + mapping = DiffeomorphicMap( + dim=3, + disp_shape=codisp.get_fdata().shape[:3], + disp_grid2world=None, + domain_shape=disp.get_fdata().shape[:3], + domain_grid2world=None, + codomain_shape=codisp.get_fdata().shape[:3], + codomain_grid2world=None, ) - return warped_b0, mapping - - -def write_mapping(mapping, fname): - """ - Write out a syn registration mapping to file + mapping.forward = disp.get_fdata().astype(np.float32) + mapping.backward = codisp.get_fdata().astype(np.float32) - Parameters - ---------- - mapping : a DiffeomorphicMap object derived from :func:`syn_registration` - fname : str - Full path to the nifti file storing the mapping - - """ - if isinstance(mapping, DiffeomorphicMap): - mapping_imap = np.array([mapping.forward.T, mapping.backward.T]).T - nib.save(nib.Nifti1Image(mapping_imap, mapping.codomain_world2grid), fname) - else: - np.save(fname, mapping.affine) + return mapping -def read_mapping(disp, domain_img, codomain_img, prealign=None): +def read_affine_mapping(affine, domain_img, codomain_img): """ Read a syn registration mapping from a nifti file Parameters ---------- - disp : str, Nifti1Image, or ndarray - If string, file must of an image or ndarray. - If image, contains the mapping displacement field in each voxel - Shape (x, y, z, 3, 2) + affine : str or ndarray + If string, file must of an ndarray. If ndarray, contains affine transformation used for mapping domain_img : str or Nifti1Image @@ -101,13 +81,10 @@ def read_mapping(disp, domain_img, codomain_img, prealign=None): Returns ------- - A :class:`DiffeomorphicMap` object + A :class:`AffineMap` object """ - if isinstance(disp, str): - if "nii.gz" in disp: - disp = nib.load(disp) - else: - disp = np.load(disp) + if isinstance(affine, str): + affine = np.load(affine) if isinstance(domain_img, str): domain_img = nib.load(domain_img) @@ -115,79 +92,60 @@ def read_mapping(disp, domain_img, codomain_img, prealign=None): if isinstance(codomain_img, str): codomain_img = nib.load(codomain_img) - if isinstance(disp, nib.Nifti1Image): - mapping = DiffeomorphicMap( - 3, - disp.shape[:3], - disp_grid2world=np.linalg.inv(disp.affine), - domain_shape=domain_img.shape[:3], - domain_grid2world=domain_img.affine, - codomain_shape=codomain_img.shape, - codomain_grid2world=codomain_img.affine, - prealign=prealign, - ) - - disp_data = disp.get_fdata().astype(np.float32) - mapping.forward = disp_data[..., 0] - mapping.backward = disp_data[..., 1] - mapping.is_inverse = True - else: - from AFQ.definitions.mapping import ConformedAffineMapping - - mapping = ConformedAffineMapping( - disp, - domain_grid_shape=reduce_shape(domain_img.shape), - domain_grid2world=domain_img.affine, - codomain_grid_shape=reduce_shape(codomain_img.shape), - codomain_grid2world=codomain_img.affine, - ) + mapping = AffineMap( + affine, + domain_grid_shape=reduce_shape(domain_img.shape), + domain_grid2world=domain_img.affine, + codomain_grid_shape=reduce_shape(codomain_img.shape), + codomain_grid2world=codomain_img.affine, + ) return mapping -def slr_registration( - moving_data, - static_data, - moving_affine=None, - static_affine=None, - moving_shape=None, - static_shape=None, - **kwargs, -): - """Register a source image (moving) to a target image (static). +def read_old_mapping(disp, domain_img, codomain_img, prealign=None): + """ + Warning: This is only used for pyAFQ tests and backwards compatibility. + Read old-style registration mapping from a nifti file. Parameters ---------- - moving : ndarray - The source tractography data to be registered - moving_affine : ndarray - The affine associated with the moving (source) data. - moving_shape : ndarray - The shape of the space associated with the static (target) data. - static : ndarray - The target tractography data for registration - static_affine : ndarray - The affine associated with the static (target) data. - static_shape : ndarray - The shape of the space associated with the static (target) data. - - **kwargs: - kwargs are passed into whole_brain_slr + disp : str or Nifti1Image + If string, file must of an image or ndarray. + If image, contains the mapping displacement field in each voxel + Shape (x, y, z, 3, 2) + + domain_img : str or Nifti1Image + + codomain_img : str or Nifti1Image Returns ------- - AffineMap + A :class:`DiffeomorphicMap` object """ - from AFQ.definitions.mapping import ConformedAffineMapping + if isinstance(disp, str): + disp = nib.load(disp) - _, transform, _, _ = whole_brain_slr( - static_data, moving_data, x0="affine", verbose=False, **kwargs - ) + if isinstance(domain_img, str): + domain_img = nib.load(domain_img) + + if isinstance(codomain_img, str): + codomain_img = nib.load(codomain_img) - return ConformedAffineMapping( - transform, - codomain_grid_shape=reduce_shape(static_shape), - codomain_grid2world=static_affine, - domain_grid_shape=reduce_shape(moving_shape), - domain_grid2world=moving_affine, + mapping = DiffeomorphicMap( + 3, + disp.shape[:3], + disp_grid2world=np.linalg.inv(disp.affine), + domain_shape=domain_img.shape[:3], + domain_grid2world=domain_img.affine, + codomain_shape=codomain_img.shape, + codomain_grid2world=codomain_img.affine, + prealign=prealign, ) + + disp_data = disp.get_fdata().astype(np.float32) + mapping.forward = disp_data[..., 0] + mapping.backward = disp_data[..., 1] + mapping.is_inverse = False + + return mapping diff --git a/AFQ/tasks/data.py b/AFQ/tasks/data.py index 66dfbfc4..1f2e1e35 100644 --- a/AFQ/tasks/data.py +++ b/AFQ/tasks/data.py @@ -1,5 +1,4 @@ import logging -import multiprocessing import dipy.core.gradients as dpg import dipy.reconst.dki as dpy_dki @@ -16,7 +15,6 @@ from dipy.reconst.dki_micro import axonal_water_fraction from dipy.reconst.gqi import GeneralizedQSamplingModel from dipy.reconst.rumba import RumbaSDModel -from numba import get_num_threads import AFQ.api.bundle_dict as abd import AFQ.data.fetch as afd @@ -100,40 +98,6 @@ def get_data_gtab( return data, gtab, img, img.affine -@immlib.calc("n_cpus", "n_threads", "low_mem") -def configure_ncpus_nthreads(ray_n_cpus=None, numba_n_threads=None, low_memory=False): - """ - Configure the number of CPUs to use for parallel processing with Ray, - the number of threads to use for Numba, - and whether to use low-memory versions of algorithms - where available - - Parameters - ---------- - ray_n_cpus : int, optional - The number of CPUs to use for parallel processing with Ray. - If None, uses the number of available CPUs minus one. - Tractography, Recognition, and MSMT use Ray. - Default: None - numba_n_threads : int, optional - The number of threads to use for Numba. - If None, uses the number of available CPUs minus one, - but with a maximum of 16. - ASYM fit uses Numba. - Default: None - low_memory : bool, optional - Whether to use low-memory versions of algorithms - where available. - Default: False - """ - if ray_n_cpus is None: - ray_n_cpus = max(multiprocessing.cpu_count() - 1, 1) - if numba_n_threads is None: - numba_n_threads = min(max(get_num_threads() - 1, 1), 16) - - return ray_n_cpus, numba_n_threads, low_memory - - @immlib.calc("b0") @as_file("_b0ref.nii.gz") @as_img @@ -519,7 +483,7 @@ def csd_params( @immlib.calc("csd_aodf_params") @as_file(suffix="_model-csd_param-aodf_dwimap.nii.gz", subfolder="models") @as_img -def csd_aodf(csd_params, n_threads, low_mem, citations): +def csd_aodf(structural_imap, csd_params, citations): """ full path to a nifti file containing SSST CSD ODFs filtered by unified filtering [1] @@ -536,7 +500,10 @@ def csd_aodf(csd_params, n_threads, low_mem, citations): logger.info("Applying unified filtering to generate asymmetric CSD ODFs...") aodf = unified_filtering( - sh_coeff, get_sphere(name="repulsion724"), n_threads=n_threads, low_mem=low_mem + sh_coeff, + get_sphere(name="repulsion724"), + n_threads=structural_imap["n_threads"], + low_mem=structural_imap["low_mem"], ) return aodf, dict( @@ -1132,6 +1099,17 @@ def dki_lt(dki_tf, dwi_affine): return dki_lt_dict +@immlib.calc("dki_cfa") +@as_file(suffix="_model-kurtosis_param-cfa_dwimap.nii.gz", subfolder="models") +@as_fit_deriv("DKI") +def dki_cfa(dki_tf): + """ + full path to a nifti file containing + the DKI color fractional anisotropy + """ + return dki_tf.color_fa, {"Description": "Color Fractional Anisotropy"} + + @immlib.calc("dki_fa") @as_file("_model-kurtosis_param-fa_dwimap.nii.gz", subfolder="models") @as_fit_deriv("DKI") @@ -1403,7 +1381,6 @@ def get_data_plan(kwargs): b0, b0_mask, brain_mask, - configure_ncpus_nthreads, dti_fit, dki_fit, fwdti_fit, @@ -1438,6 +1415,7 @@ def get_data_plan(kwargs): dki_awf, dki_mk, dki_kfa, + dki_cfa, dki_ga, dki_rd, dti_ga, diff --git a/AFQ/tasks/decorators.py b/AFQ/tasks/decorators.py index bb1d750c..689e1426 100644 --- a/AFQ/tasks/decorators.py +++ b/AFQ/tasks/decorators.py @@ -26,7 +26,6 @@ logger = logging.getLogger("AFQ") -logger.setLevel(logging.INFO) def get_new_signature(og_func, needed_args): @@ -164,13 +163,21 @@ def as_fit_deriv(tf_name): """ def _as_fit_deriv(func): + module_name = func.__module__ + is_data_module = "data" in module_name + dependency = "dwi_affine" if is_data_module else "data_imap" + new_signature, new_params = get_new_signature( - func, ["dwi_affine", f"{tf_name.lower()}_params"] + func, [dependency, f"{tf_name.lower()}_params"] ) @functools.wraps(func) def wrapper_as_fit_deriv(*args, **kwargs): - dwi_affine = get_param(kwargs, new_params, "dwi_affine") + if is_data_module: + dwi_affine = get_param(kwargs, new_params, "dwi_affine") + else: + data_imap = get_param(kwargs, new_params, "data_imap") + dwi_affine = data_imap["dwi_affine"] params = get_param(kwargs, new_params, f"{tf_name.lower()}_params") params_meta = read_json(drop_extension(params) + ".json") img_meta = {} @@ -203,11 +210,19 @@ def as_img(func): Decorator to convert function output (ndarray, meta) into (Nifti1Image, meta). Supports functions returning a single tuple or a list of tuples. """ - new_signature, new_params = get_new_signature(func, ["dwi_affine"]) + module_name = func.__module__ + is_data_module = "data" in module_name + dependency = "dwi_affine" if is_data_module else "data_imap" + + new_signature, new_params = get_new_signature(func, [dependency]) @functools.wraps(func) def wrapper_as_img(*args, **kwargs): - dwi_affine = get_param(kwargs, new_params, "dwi_affine") + if is_data_module: + dwi_affine = get_param(kwargs, new_params, "dwi_affine") + else: + data_imap = get_param(kwargs, new_params, "data_imap") + dwi_affine = data_imap["dwi_affine"] start_time = time() results = func(*args, **kwargs) diff --git a/AFQ/tasks/mapping.py b/AFQ/tasks/mapping.py index 33006030..4dc62973 100644 --- a/AFQ/tasks/mapping.py +++ b/AFQ/tasks/mapping.py @@ -30,7 +30,7 @@ def export_registered_b0(base_fname, data_imap, mapping): ) if not op.exists(warped_b0_fname): mean_b0 = nib.load(data_imap["b0"]).get_fdata() - warped_b0 = mapping.transform(mean_b0) + warped_b0 = mapping.transform_inverse(mean_b0) warped_b0 = nib.Nifti1Image(warped_b0, data_imap["reg_template"].affine) logger.info(f"Saving {warped_b0_fname}") nib.save(warped_b0, warped_b0_fname) @@ -54,9 +54,7 @@ def template_xform(base_fname, dwi_data_file, data_imap, mapping): base_fname, f"_space-{subject_space}_desc-template_anat.nii.gz" ) if not op.exists(template_xform_fname): - template_xform = mapping.transform_inverse( - data_imap["reg_template"].get_fdata() - ) + template_xform = mapping.transform(data_imap["reg_template"].get_fdata()) template_xform = nib.Nifti1Image(template_xform, data_imap["dwi_affine"]) logger.info(f"Saving {template_xform_fname}") nib.save(template_xform, template_xform_fname) @@ -85,7 +83,7 @@ def export_rois(base_fname, output_dir, dwi_data_file, data_imap, mapping): *bundle_dict.transform_rois( bundle_name, mapping, - data_imap["dwi_affine"], + data_imap["dwi"], base_fname=base_roi_fname, to_space=to_space, ) @@ -175,7 +173,7 @@ def sls_mapping( ) streamlines_file = tractography_imap["streamlines"] tg = load_tractogram( - streamlines_file, reg_subject, Space.VOX, bbox_valid_check=False + streamlines_file, reg_subject, Space.RASMM, bbox_valid_check=False ) tg.to_rasmm() diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index 2781778e..5e577082 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -32,6 +32,7 @@ from dipy.io.streamline import load_tractogram, save_tractogram from dipy.stats.analysis import afq_profile from dipy.tracking.streamline import set_number_of_points, values_from_volume +from dipy.tracking.utils import length from nibabel.affines import voxel_sizes from nibabel.orientations import aff2axcodes @@ -40,7 +41,9 @@ @immlib.calc("bundles") @as_file("_desc-bundles_tractography") -def segment(data_imap, mapping_imap, tractography_imap, segmentation_params): +def segment( + structural_imap, data_imap, mapping_imap, tractography_imap, segmentation_params +): """ full path to a trk/trx file containing containing segmented streamlines, labeled by bundle @@ -59,9 +62,7 @@ def segment(data_imap, mapping_imap, tractography_imap, segmentation_params): or streamlines.endswith(".tck") or streamlines.endswith(".vtk") ): - tg = load_tractogram( - streamlines, data_imap["dwi"], Space.VOX, bbox_valid_check=False - ) + tg = load_tractogram(streamlines, data_imap["dwi"], bbox_valid_check=False) is_trx = False elif streamlines.endswith(".trx"): is_trx = True @@ -85,9 +86,7 @@ def segment(data_imap, mapping_imap, tractography_imap, segmentation_params): with open(temp_tck, "wb") as f_out: shutil.copyfileobj(f_in, f_out) # initialize stateful tractogram from tck file: - tg = load_tractogram( - temp_tck, data_imap["dwi"], Space.VOX, bbox_valid_check=False - ) + tg = load_tractogram(temp_tck, data_imap["dwi"], bbox_valid_check=False) is_trx = False if len(tg.streamlines) == 0: raise ValueError( @@ -109,17 +108,17 @@ def segment(data_imap, mapping_imap, tractography_imap, segmentation_params): mapping_imap["mapping"], bundle_dict, reg_template, - data_imap["n_cpus"], + structural_imap["n_cpus"], **segmentation_params, ) - seg_sft = aus.SegmentedSFT(bundles, Space.VOX) + seg_sft = aus.SegmentedSFT(bundles) if len(seg_sft.sft) < 1: raise ValueError("Fatal: No bundles recognized.") if is_trx: - seg_sft.sft.dtype_dict = {"positions": np.float16, "offsets": np.uint32} + seg_sft.sft.dtype_dict = {"positions": np.float32, "offsets": np.uint32} tgram = TrxFile.from_sft(seg_sft.sft) tgram.groups = seg_sft.bundle_idxs @@ -175,7 +174,7 @@ def export_bundles(base_fname, output_dir, bundles, tracking_params): logger.info(f"Saving {fname}") if is_trx: seg_sft.sft.dtype_dict = { - "positions": np.float16, + "positions": np.float32, "offsets": np.uint32, } trxfile = TrxFile.from_sft(bundle_sft) @@ -210,26 +209,38 @@ def export_sl_counts(bundles): return counts_df, dict(source=bundles) -@immlib.calc("median_bundle_lengths") +@immlib.calc("bundle_lengths") @as_file("_desc-medianBundleLengths_tractography.csv", subfolder="stats") def export_bundle_lengths(bundles): """ full path to a JSON file containing median bundle lengths """ - med_len_counts = [] + len_data = {} seg_sft = aus.SegmentedSFT.fromfile(bundles) for bundle in seg_sft.bundle_names: - these_lengths = seg_sft.get_bundle(bundle)._tractogram._streamlines._lengths + these_lengths = list(length(seg_sft.get_bundle(bundle).streamlines)) if len(these_lengths) > 0: - med_len_counts.append(np.median(these_lengths)) + len_data[f"{bundle} Median"] = np.median(these_lengths) + len_data[f"{bundle} Min"] = np.min(these_lengths) + len_data[f"{bundle} Max"] = np.max(these_lengths) else: - med_len_counts.append(0) - med_len_counts.append(np.median(seg_sft.sft._tractogram._streamlines._lengths)) + len_data[f"{bundle} Median"] = 0 + len_data[f"{bundle} Min"] = 0 + len_data[f"{bundle} Max"] = 0 + len_data["Total Recognized Median"] = np.median( + seg_sft.sft._tractogram._streamlines._lengths + ) + len_data["Total Recognized Min"] = np.min( + seg_sft.sft._tractogram._streamlines._lengths + ) + len_data["Total Recognized Max"] = np.max( + seg_sft.sft._tractogram._streamlines._lengths + ) counts_df = pd.DataFrame( - data=dict(median_len=med_len_counts), - index=seg_sft.bundle_names + ["Total Recognized"], + data=len_data, + index=[0], ) return counts_df, dict(source=bundles) @@ -258,7 +269,7 @@ def export_density_maps(bundles, data_imap): @immlib.calc("profiles") @as_file("_desc-profiles_tractography.csv") def tract_profiles( - bundles, scalar_dict, dwi_affine, profile_weights="gauss", n_points_profile=100 + bundles, scalar_dict, data_imap, profile_weights="gauss", n_points_profile=100 ): """ full path to a CSV file containing tract profiles @@ -327,7 +338,11 @@ def tract_profiles( def _median_weight(bundle): fgarray = set_number_of_points(bundle, n_points_profile) values = np.array( - values_from_volume(scalar_data, fgarray, dwi_affine) # noqa B023 + values_from_volume( + scalar_data, # noqa B023 + fgarray, + data_imap["dwi_affine"], + ) ) weights = np.zeros(values.shape) for ii, jj in enumerate( @@ -355,7 +370,7 @@ def _median_weight(bundle): this_profile[ii] = afq_profile( scalar_data, this_sl, - dwi_affine, + data_imap["dwi_affine"], weights=this_prof_weights, n_points=n_points_profile, ) diff --git a/AFQ/tasks/structural.py b/AFQ/tasks/structural.py index 302d3984..3a563555 100644 --- a/AFQ/tasks/structural.py +++ b/AFQ/tasks/structural.py @@ -1,7 +1,9 @@ import logging +import multiprocessing import immlib import nibabel as nib +from numba import get_num_threads from AFQ.definitions.utils import Definition from AFQ.nn.brainchop import run_brainchop @@ -13,9 +15,93 @@ logger = logging.getLogger("AFQ") +@immlib.calc("n_cpus", "n_threads", "low_mem") +def configure_ncpus_nthreads(ray_n_cpus=None, numba_n_threads=None, low_memory=False): + """ + Configure the number of CPUs to use for parallel processing with Ray, + the number of threads to use for Numba, + and whether to use low-memory versions of algorithms + where available + + Parameters + ---------- + ray_n_cpus : int, optional + The number of CPUs to use for parallel processing with Ray. + If None, uses the number of available CPUs minus one. + Tractography and MSMT use Ray. + Default: None + numba_n_threads : int, optional + The number of threads to use for Numba. + If None, uses the number of available CPUs minus one, + but with a maximum of 16. + ASYM fit uses Numba. + Default: None + low_memory : bool, optional + Whether to use low-memory versions of algorithms + where available. + Default: False + """ + if ray_n_cpus is None: + ray_n_cpus = max(multiprocessing.cpu_count() - 1, 1) + if numba_n_threads is None: + numba_n_threads = min(max(get_num_threads() - 1, 1), 16) + + return ray_n_cpus, numba_n_threads, low_memory + + +@immlib.calc("onnx_kwargs") +def onnx_kwargs( + low_mem, onnx_execution_provider="CPUExecutionProvider", onnx_inter_threads=1 +): + """ + The execution provider to use for onnx models + + Parameters + ---------- + onnx_execution_provider : str, optional + The execution provider to use for onnx models. + By default this is set to CPUExecutionProvider + which should work on all systems. If you have a + compatible GPU and the appropriate onnxruntime installed + you can set this to "CUDAExecutionProvider" or + "OpenVINOExecutionProvider" for potentially faster + inference. + Default: "CPUExecutionProvider" + onnx_inter_threads : int, optional + The number of inter threads to use for onnx models. + Increasing will increase memory usage significantly. + Default: 1 + + """ + try: + import onnxruntime as ort + except ImportError: + # In this case, we can throw a more informative error + # when the user tries to run a model + # that requires onnxruntime + return onnx_execution_provider + if onnx_execution_provider not in ort.get_available_providers(): + logger.warning( + f"{onnx_execution_provider} is not available. " + f"Available providers are: {ort.get_available_providers()}. " + "Falling back to CPUExecutionProvider." + ) + onnx_execution_provider = "CPUExecutionProvider" + options = ort.SessionOptions() + options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + options.inter_op_num_threads = onnx_inter_threads + if low_mem: + options.enable_cpu_mem_arena = False + options.enable_mem_pattern = False + + onnx_kwargs = {"providers": [onnx_execution_provider], "sess_options": options} + + return {"onnx_kwargs": onnx_kwargs} + + @immlib.calc("synthseg_model") @as_file(suffix="_model-synthseg2_probseg.nii.gz", subfolder="nn") -def synthseg_model(t1_masked, citations): +def synthseg_model(t1_masked, citations, onnx_kwargs): """ full path to the synthseg2 model segmentations @@ -36,13 +122,13 @@ def synthseg_model(t1_masked, citations): "Or, provide your own segmentations using PVEImage or PVEImages.", ) t1_img = nib.load(t1_masked) - predictions = run_synthseg(ort, t1_img, "synthseg2") + predictions = run_synthseg(ort, t1_img, "synthseg2", onnx_kwargs) return predictions, dict(T1w=t1_masked) @immlib.calc("mx_model") @as_file(suffix="_model-multiaxial_probseg.nii.gz", subfolder="nn") -def mx_model(t1_file, t1w_brain_mask, citations): +def mx_model(t1_file, t1w_brain_mask, citations, onnx_kwargs): """ full path to the multi-axial model for brain extraction outputs @@ -59,7 +145,7 @@ def mx_model(t1_file, t1w_brain_mask, citations): ) t1_img = nib.load(t1_file) t1_mask = nib.load(t1w_brain_mask) - predictions = run_multiaxial(ort, t1_img) + predictions = run_multiaxial(ort, t1_img, onnx_kwargs) predictions = nib.Nifti1Image( predictions.get_fdata() * t1_mask.get_fdata(), t1_img.affine ) @@ -68,7 +154,7 @@ def mx_model(t1_file, t1w_brain_mask, citations): @immlib.calc("t1w_brain_mask") @as_file(suffix="_desc-T1w_mask.nii.gz") -def t1w_brain_mask(t1_file, citations, brain_mask_definition=None): +def t1w_brain_mask(t1_file, citations, onnx_kwargs, brain_mask_definition=None): """ full path to a nifti file containing brain mask from T1w image @@ -98,7 +184,7 @@ def t1w_brain_mask(t1_file, citations, brain_mask_definition=None): ort = check_onnxruntime( "Mindgrab", "Or, provide your own brain mask using brain_mask_definition." ) - return run_brainchop(ort, nib.load(t1_file), "mindgrab"), dict( + return run_brainchop(ort, nib.load(t1_file), "mindgrab", onnx_kwargs), dict( T1w=t1_file, model="mindgrab" ) @@ -119,7 +205,7 @@ def t1_masked(t1_file, t1w_brain_mask): @immlib.calc("t1_subcortex") @as_file(suffix="_desc-subcortex_probseg.nii.gz", subfolder="nn") -def t1_subcortex(t1_masked, citations): +def t1_subcortex(t1_masked, citations, onnx_kwargs): """ full path to a nifti file containing segmentation of subcortical structures from T1w image using Brainchop @@ -140,7 +226,7 @@ def t1_subcortex(t1_masked, citations): t1_img_masked = nib.load(t1_masked) - subcortical_img = run_brainchop(ort, t1_img_masked, "subcortical") + subcortical_img = run_brainchop(ort, t1_img_masked, "subcortical", onnx_kwargs) meta = dict( T1w=t1_masked, @@ -172,7 +258,15 @@ def t1_subcortex(t1_masked, citations): def get_structural_plan(kwargs): structural_tasks = with_name( - [mx_model, synthseg_model, t1w_brain_mask, t1_subcortex, t1_masked] + [ + mx_model, + synthseg_model, + t1w_brain_mask, + t1_subcortex, + t1_masked, + onnx_kwargs, + configure_ncpus_nthreads, + ] ) bm_def = kwargs.get("brain_mask_definition", None) diff --git a/AFQ/tasks/tissue.py b/AFQ/tasks/tissue.py index c890fa7d..3b3b5bd3 100644 --- a/AFQ/tasks/tissue.py +++ b/AFQ/tasks/tissue.py @@ -127,7 +127,14 @@ def pve_internal(structural_imap, pve="synthseg"): subfolder="models", ) @as_img -def msmt_params(data_imap, pve_internal, citations, msmt_sh_order=8, msmt_fa_thr=0.7): +def msmt_params( + structural_imap, + data_imap, + pve_internal, + citations, + msmt_sh_order=8, + msmt_fa_thr=0.7, +): """ full path to a nifti file containing parameters for the MSMT CSD white matter fit, @@ -201,7 +208,7 @@ def msmt_params(data_imap, pve_internal, citations, msmt_sh_order=8, msmt_fa_thr mcsd_model = MultiShellDeconvModel(data_imap["gtab"], response_mcsd) logger.info("Fitting Multi-Shell CSD model...") - mcsd_fit = mcsd_model.fit(data_imap["data"], mask, n_cpus=data_imap["n_cpus"]) + mcsd_fit = mcsd_model.fit(data_imap["data"], mask, n_cpus=structural_imap["n_cpus"]) def _get_meta(desc, sh_order, response): return dict( @@ -268,7 +275,7 @@ def msmt_apm(msmtcsd_params): @immlib.calc("msmt_aodf_params") @as_file(suffix="_model-msmtcsd_param-aodf_dwimap.nii.gz", subfolder="models") @as_img -def msmt_aodf(msmtcsd_params, data_imap, citations): +def msmt_aodf(msmtcsd_params, structural_imap, citations): """ full path to a nifti file containing MSMT CSD ODFs filtered by unified filtering [1] @@ -288,8 +295,8 @@ def msmt_aodf(msmtcsd_params, data_imap, citations): aodf = unified_filtering( sh_coeff, get_sphere(name="repulsion724"), - n_threads=data_imap["n_threads"], - low_mem=data_imap["low_mem"], + n_threads=structural_imap["n_threads"], + low_mem=structural_imap["low_mem"], ) return aodf, dict( diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 58e43619..21e9c64f 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -58,9 +58,27 @@ def _meta_from_tracking_params(tracking_params, start_time, n_streamlines, seed, return meta +def _fiber_odf(data_imap, tissue_imap, tracking_params): + odf_model = tracking_params["odf_model"] + if isinstance(odf_model, str): + calc_name = f"{odf_model.lower()}_params" + if calc_name in data_imap: + params_file = data_imap[calc_name] + elif calc_name in tissue_imap: + params_file = tissue_imap[calc_name] + else: + raise ValueError((f"Could not find {odf_model}")) + else: + raise TypeError(("odf_model must be a string or Definition")) + + return params_file + + @immlib.calc("streamlines") @as_file("_tractography", subfolder="tractography") -def streamlines(data_imap, seed, tissue_imap, fodf, citations, tracking_params): +def streamlines( + structural_imap, data_imap, seed, tissue_imap, citations, tracking_params +): """ full path to the complete, unsegmented tractography file @@ -77,6 +95,7 @@ def streamlines(data_imap, seed, tissue_imap, fodf, citations, tracking_params): citations.add("smith2012anatomically") this_tracking_params = tracking_params.copy() + fodf = _fiber_odf(data_imap, tissue_imap, tracking_params) # get masks this_tracking_params["seed_mask"] = nib.load(seed).get_fdata() @@ -84,11 +103,11 @@ def streamlines(data_imap, seed, tissue_imap, fodf, citations, tracking_params): is_trx = this_tracking_params.get("trx", False) - num_chunks = data_imap["n_cpus"] + num_chunks = structural_imap["n_cpus"] if is_trx: start_time = time() - dtype_dict = {"positions": np.float16, "offsets": np.uint32} + dtype_dict = {"positions": np.float32, "offsets": np.uint32} if num_chunks and num_chunks > 1: if not has_ray: raise ImportError( @@ -206,26 +225,6 @@ def delete_lazyt(self, id): ) -@immlib.calc("fodf") -def fiber_odf(data_imap, tissue_imap, tracking_params): - """ - Nifti Image containing the fiber orientation distribution function - """ - odf_model = tracking_params["odf_model"] - if isinstance(odf_model, str): - calc_name = f"{odf_model.lower()}_params" - if calc_name in data_imap: - params_file = data_imap[calc_name] - elif calc_name in tissue_imap: - params_file = tissue_imap[calc_name] - else: - raise ValueError((f"Could not find {odf_model}")) - else: - raise TypeError(("odf_model must be a string or Definition")) - - return params_file - - @immlib.calc("streamlines") def custom_tractography(import_tract=None): """ @@ -249,11 +248,11 @@ def custom_tractography(import_tract=None): def gpu_tractography( data_imap, tracking_params, - fodf, seed, tissue_imap, tractography_ngpus=0, - chunk_size=100000, + gpu_backend="auto", + chunk_size=25000, ): """ full path to the complete, unsegmented tractography file @@ -267,11 +266,17 @@ def gpu_tractography( PTT, Prob can be used with any SHM model. Bootstrapped can be done with CSA/OPDT. Default: 0 + gpu_backend : str, optional + GPU backend to use for tractography. + One of {"auto", "cuda", "metal", "webgpu"}. + Default: "auto" chunk_size : int, optional Chunk size for GPU tracking. - Default: 100000 + Default: 25000 """ start_time = time() + fodf = _fiber_odf(data_imap, tissue_imap, tracking_params) + if tracking_params["directions"] == "boot": data = data_imap["data"] else: @@ -297,12 +302,15 @@ def gpu_tractography( tracking_params["thresholds_as_percentages"], tracking_params["max_angle"], tracking_params["step_size"], + tracking_params["minlen"], + tracking_params["maxlen"], tracking_params["n_seeds"], tracking_params["random_seeds"], tracking_params["rng_seed"], tracking_params["trx"], tractography_ngpus, chunk_size, + gpu_backend, ) return sft, _meta_from_tracking_params(tracking_params, start_time, sft, seed, pve) @@ -312,7 +320,7 @@ def get_tractography_plan(kwargs): if "tracking_params" in kwargs and not isinstance(kwargs["tracking_params"], dict): raise TypeError("tracking_params a dict") - tractography_tasks = with_name([streamlines, fiber_odf]) + tractography_tasks = with_name([streamlines]) # use GPU accelerated tractography if asked for if "tractography_ngpus" in kwargs and kwargs["tractography_ngpus"] != 0: diff --git a/AFQ/tasks/utils.py b/AFQ/tasks/utils.py index 128fd157..3cafc325 100644 --- a/AFQ/tasks/utils.py +++ b/AFQ/tasks/utils.py @@ -29,7 +29,12 @@ def get_base_fname(output_dir, dwi_data_file): key = key_val_pair.split("-")[0] if key not in used_key_list: fname = fname + key_val_pair + "_" - fname = fname[:-1] + if fname[-1] == "_": + fname = fname[:-1] + else: + # if no key value pairs found, + # have some default base file name + fname = fname + "subject" return fname @@ -55,7 +60,11 @@ def check_onnxruntime(model_name, alternative_text): f"When we tried to import onnxruntime, we got the " f"following error:\n{e}\n" f"Please install onnxruntime to use this feature, " - f"by doing `pip install onnxruntime` or `pip install pyAFQ[nn]`. " + f"by doing `pip install onnxruntime` or " + f"`pip install onnxruntime-gpu` if you have a compatible GPU or " + f"`pip install pyAFQ[nn]` or " + "`pip install pyAFQ[gpu]` if you have a compatible GPU " + "and want GPUStreamlines. " f"{alternative_text}\n" "If there are still issues, post an issue on " "https://github.com/tractometry/pyAFQ/issues" diff --git a/AFQ/tasks/viz.py b/AFQ/tasks/viz.py index 145566c6..cf1a8909 100644 --- a/AFQ/tasks/viz.py +++ b/AFQ/tasks/viz.py @@ -21,7 +21,7 @@ logger = logging.getLogger("AFQ") -def _viz_prepare_vol(vol, xform, mapping, scalar_dict, ref): +def _viz_prepare_vol(vol, scalar_dict, ref): if vol in scalar_dict.keys(): vol = scalar_dict[vol] @@ -31,8 +31,6 @@ def _viz_prepare_vol(vol, xform, mapping, scalar_dict, ref): vol = resample(vol, ref) vol = vol.get_fdata() - if xform: - vol = mapping.transform_inverse(vol) vol[np.isnan(vol)] = 0 return vol @@ -81,15 +79,12 @@ def viz_bundles( """ if sbv_lims_bundles is None: sbv_lims_bundles = [None, None] - mapping = mapping_imap["mapping"] scalar_dict = segmentation_imap["scalar_dict"] profiles_file = segmentation_imap["profiles"] t1_img = nib.load(structural_imap["t1_masked"]) shade_by_volume = get_tp(best_scalar, structural_imap, data_imap, tissue_imap) - shade_by_volume = _viz_prepare_vol( - shade_by_volume, False, mapping, scalar_dict, t1_img - ) - volume = _viz_prepare_vol(t1_img, False, mapping, scalar_dict, t1_img) + shade_by_volume = _viz_prepare_vol(shade_by_volume, scalar_dict, t1_img) + volume = _viz_prepare_vol(t1_img, scalar_dict, t1_img) flip_axes = [False, False, False] for i in range(3): @@ -183,7 +178,6 @@ def viz_indivBundle( """ if sbv_lims_indiv is None: sbv_lims_indiv = [None, None] - mapping = mapping_imap["mapping"] bundle_dict = data_imap["bundle_dict"] scalar_dict = segmentation_imap["scalar_dict"] volume_img = nib.load(structural_imap["t1_masked"]) @@ -191,10 +185,8 @@ def viz_indivBundle( profiles = pd.read_csv(segmentation_imap["profiles"]) start_time = time() - volume = _viz_prepare_vol(volume_img, False, mapping, scalar_dict, volume_img) - shade_by_volume = _viz_prepare_vol( - shade_by_volume, False, mapping, scalar_dict, volume_img - ) + volume = _viz_prepare_vol(volume_img, scalar_dict, volume_img) + shade_by_volume = _viz_prepare_vol(shade_by_volume, scalar_dict, volume_img) flip_axes = [False, False, False] for i in range(3): @@ -211,6 +203,9 @@ def viz_indivBundle( if "bundlesection" in b_info: for sb_name in b_info["bundlesection"]: segmented_bname_to_roi_bname[sb_name] = b_name + elif "ORG_spectral_subbundles" in b_info: + for sb_name in b_info["ORG_spectral_subbundles"]: + segmented_bname_to_roi_bname[sb_name] = b_name else: segmented_bname_to_roi_bname[b_name] = b_name diff --git a/AFQ/tests/test_api.py b/AFQ/tests/test_api.py index c5bfdc8c..e6b6577f 100644 --- a/AFQ/tests/test_api.py +++ b/AFQ/tests/test_api.py @@ -697,7 +697,8 @@ def test_AFQ_data_waypoint(): tmpdir, bids_path, _ = get_temp_hardi() t1_path = op.join(tmpdir, "T1.nii.gz") t1_path_other = op.join(tmpdir, "T1-untransformed.nii.gz") - nib.save(afd.read_mni_template(mask=True, weight="T1w"), t1_path) + reg_template = afd.read_mni_template(mask=True, weight="T1w") + nib.save(reg_template, t1_path) shutil.copy(t1_path, t1_path_other) vista_folder = op.join(bids_path, "derivatives/vistasoft/sub-01/ses-01/dwi") @@ -775,8 +776,8 @@ def test_AFQ_data_waypoint(): # Replace the mapping and streamlines with precomputed: file_dict = afd.read_stanford_hardi_tractography() - mapping = file_dict["mapping.nii.gz"] - streamlines = file_dict["tractography_subsampled.trk"] + mapping = file_dict["mapping"] + streamlines = file_dict["tractography_subsampled"] dwi_affine = myafq.export("dwi_affine") streamlines = dts.Streamlines( dtu.transform_tracking_output( @@ -784,16 +785,23 @@ def test_AFQ_data_waypoint(): ) ) - mapping_file = op.join( + mapping_file_forward = op.join( myafq.export("output_dir"), "sub-01_ses-01_desc-mapping_from-subject_to-mni_xform.nii.gz", ) - nib.save(mapping, mapping_file) - reg_prealign_file = op.join( + nib.save( + nib.Nifti1Image(mapping.forward, dwi_affine), + mapping_file_forward, + ) + + mapping_file_backward = op.join( myafq.export("output_dir"), - "sub-01_ses-01_desc-prealign_from-subject_to-mni_xform.npy", + "sub-01_ses-01_desc-mapping_from-mni_to-subject_xform.nii.gz", + ) + nib.save( + nib.Nifti1Image(mapping.backward, reg_template.affine), + mapping_file_backward, ) - np.save(reg_prealign_file, np.eye(4)) # Test ROI exporting: myafq.export("rois") @@ -814,7 +822,7 @@ def test_AFQ_data_waypoint(): op.join( myafq.export("output_dir"), "bundles", - "sub-01_ses-01_desc-RightInferiorLongitudinal_tractography.trk", + "sub-01_ses-01_desc-RightInferiorLongitudinal_tractography.trx", ) ) # noqa @@ -831,7 +839,7 @@ def test_AFQ_data_waypoint(): all_sl = load_tractogram( op.join( - myafq.export("output_dir"), "tractography", "sub-01_ses-01_tractography.trk" + myafq.export("output_dir"), "tractography", "sub-01_ses-01_tractography.trx" ), reference="same", ).streamlines @@ -844,7 +852,7 @@ def test_AFQ_data_waypoint(): op.join( myafq.export("output_dir"), "bundles", - "sub-01_ses-01_desc-LeftArcuate_tractography.trk", + "sub-01_ses-01_desc-LeftArcuate_tractography.trx", ), reference="same", ).streamlines @@ -972,6 +980,6 @@ def test_AFQ_data_waypoint(): op.join( output_dir, "bundles", - "sub-01_ses-01_desc-RightInferiorLongitudinal_tractography.trk", + "sub-01_ses-01_desc-RightInferiorLongitudinal_tractography.trx", ) ) # noqa diff --git a/AFQ/tests/test_bundle_dict.py b/AFQ/tests/test_bundle_dict.py index 5d9d57c5..acff27ce 100644 --- a/AFQ/tests/test_bundle_dict.py +++ b/AFQ/tests/test_bundle_dict.py @@ -20,7 +20,7 @@ def test_BundleDict(): # test defaults afq_bundles = abd.default_bd() - assert len(afq_bundles) == 18 + assert len(afq_bundles) == 20 # Arcuate Fasciculus afq_bundles = abd.default_bd()["Left Arcuate", "Right Arcuate"] diff --git a/AFQ/tests/test_fixes.py b/AFQ/tests/test_fixes.py index 3e9dc0d2..437cb8ef 100644 --- a/AFQ/tests/test_fixes.py +++ b/AFQ/tests/test_fixes.py @@ -34,7 +34,7 @@ def test_GQI_fix(): def test_gaussian_weights(): file_dict = afd.read_stanford_hardi_tractography() - streamlines = file_dict["tractography_subsampled.trk"] + streamlines = file_dict["tractography_subsampled"] assert not np.any(np.isnan(gaussian_weights(streamlines[76:92]))) diff --git a/AFQ/tests/test_nn.py b/AFQ/tests/test_nn.py index 0d999194..a410af51 100644 --- a/AFQ/tests/test_nn.py +++ b/AFQ/tests/test_nn.py @@ -30,12 +30,13 @@ def test_run_brainchop(): "sub-01/ses-01/anat/sub-01_ses-01_T1w.nii.gz" ), ) - chopped_brain = run_brainchop(ort, nib.load(t1_path), "mindgrab") + chopped_brain = run_brainchop(ort, nib.load(t1_path), "mindgrab", {}) npt.assert_(chopped_brain.get_fdata().sum() > 200000) @pytest.mark.skipif(not has_onnx, reason="onnxruntime is not installed") +@pytest.mark.nightly def test_run_multiaxial(): tmpdir = tempfile.mkdtemp() afd.organize_stanford_data(path=tmpdir) @@ -47,6 +48,6 @@ def test_run_multiaxial(): "sub-01/ses-01/anat/sub-01_ses-01_T1w.nii.gz" ), ) - chopped_brain = run_multiaxial(ort, nib.load(t1_path)) + chopped_brain = run_multiaxial(ort, nib.load(t1_path), {}) npt.assert_(chopped_brain.get_fdata().sum() > 200000) diff --git a/AFQ/tests/test_registration.py b/AFQ/tests/test_registration.py index c1b394c1..b2b77d9c 100644 --- a/AFQ/tests/test_registration.py +++ b/AFQ/tests/test_registration.py @@ -5,16 +5,12 @@ import nibabel.tmpdirs as nbtmp import numpy as np import numpy.testing as npt -from dipy.align.imwarp import DiffeomorphicMap +from dipy.align.imaffine import AffineMap +from dipy.align.streamlinear import whole_brain_slr from dipy.io.streamline import load_tractogram import AFQ.data.fetch as afd -from AFQ.registration import ( - read_mapping, - slr_registration, - syn_register_dwi, - write_mapping, -) +from AFQ.registration import read_affine_mapping, reduce_shape MNI_T2 = afd.read_mni_template() hardi_img, gtab = dpd.read_stanford_hardi() @@ -36,7 +32,7 @@ def test_slr_registration(): # have to import subject sls file_dict = afd.read_stanford_hardi_tractography() - streamlines = file_dict["tractography_subsampled.trk"] + streamlines = file_dict["tractography_subsampled"] # have to import sls atlas afd.fetch_hcp_atlas_16_bundles() @@ -50,27 +46,34 @@ def test_slr_registration(): hcp_atlas = load_tractogram(atlas_fname, "same", bbox_valid_check=False) with nbtmp.InTemporaryDirectory() as tmpdir: - mapping = slr_registration( + _, transform, _, _ = whole_brain_slr( streamlines, hcp_atlas.streamlines, - moving_affine=subset_b0_img.affine, - static_affine=subset_t2_img.affine, - moving_shape=subset_b0_img.shape, - static_shape=subset_t2_img.shape, + x0="affine", + verbose=False, progressive=False, greater_than=10, rm_small_clusters=1, rng=np.random.RandomState(seed=8), ) - warped_moving = mapping.transform(subset_b0) + + mapping = AffineMap( + transform, + domain_grid_shape=reduce_shape(subset_b0_img.shape), + domain_grid2world=subset_b0_img.affine, + codomain_grid_shape=reduce_shape(subset_t2_img.shape), + codomain_grid2world=subset_t2_img.affine, + ) + + warped_moving = mapping.transform_inverse(subset_b0) npt.assert_equal(warped_moving.shape, subset_t2.shape) mapping_fname = op.join(tmpdir, "mapping.npy") - write_mapping(mapping, mapping_fname) - file_mapping = read_mapping(mapping_fname, subset_b0_img, subset_t2_img) + np.save(mapping_fname, transform) + file_mapping = read_affine_mapping(mapping_fname, subset_b0_img, subset_t2_img) # Test that it has the same effect on the data: - warped_from_file = file_mapping.transform(subset_b0) + warped_from_file = file_mapping.transform_inverse(subset_b0) npt.assert_equal(warped_from_file, warped_moving) # Test that it is, attribute by attribute, identical: @@ -78,11 +81,3 @@ def test_slr_registration(): assert np.all( mapping.__getattribute__(k) == file_mapping.__getattribute__(k) ) - - -def test_syn_register_dwi(): - warped_b0, mapping = syn_register_dwi( - subset_dwi_data, gtab, template=subset_t2_img, radius=1 - ) - npt.assert_equal(isinstance(mapping, DiffeomorphicMap), True) - npt.assert_equal(warped_b0.shape, subset_t2_img.shape) diff --git a/AFQ/tractography/gputractography.py b/AFQ/tractography/gputractography.py index c8087d22..2ebd2608 100644 --- a/AFQ/tractography/gputractography.py +++ b/AFQ/tractography/gputractography.py @@ -3,12 +3,6 @@ import nibabel as nib import numpy as np -from cuslines import ( - BootDirectionGetter, - GPUTracker, - ProbDirectionGetter, - PttDirectionGetter, -) from dipy.align import resample from dipy.reconst import shm @@ -30,12 +24,15 @@ def gpu_track( thresholds_as_percentages, max_angle, step_size, + minlen, + maxlen, n_seeds, random_seeds, rng_seed, use_trx, ngpus, chunk_size, + gpu_backend, ): """ Perform GPU tractography on DWI data. @@ -70,6 +67,10 @@ def gpu_track( array, these are the coordinates of the seeds. Unless random_seeds is set to True, in which case this is the total number of random seeds to generate within the mask. Default: 1 + minlen: int, optional + The minimal length (mm) in a streamline + maxlen: int, optional + The maximum length (mm) in a streamline random_seeds : bool If True, n_seeds is total number of random seeds to generate. If False, n_seeds encodes the density of seeds to generate. @@ -82,14 +83,71 @@ def gpu_track( Number of GPUs to use. chunk_size : int Chunk size for GPU tracking. + gpu_backend : str, optional + GPU backend to use for tractography. + One of {"auto", "cuda", "metal", "webgpu"}. Returns ------- """ + gpu_backend = gpu_backend.lower() + if gpu_backend == "auto": + from cuslines import ( + BootDirectionGetter, + GPUTracker, + ProbDirectionGetter, + PttDirectionGetter, + ) + elif gpu_backend == "cuda": + from cuslines.cuda_python import ( + BootDirectionGetter, + GPUTracker, + ProbDirectionGetter, + PttDirectionGetter, + ) + elif gpu_backend == "metal": + from cuslines.metal import ( + MetalBootDirectionGetter as BootDirectionGetter, + ) + from cuslines.metal import ( + MetalGPUTracker as GPUTracker, + ) + from cuslines.metal import ( + MetalProbDirectionGetter as ProbDirectionGetter, + ) + from cuslines.metal import ( + MetalPttDirectionGetter as PttDirectionGetter, + ) + elif gpu_backend == "webgpu": + from cuslines.webgpu import ( + WebGPUBootDirectionGetter as BootDirectionGetter, + ) + from cuslines.webgpu import ( + WebGPUProbDirectionGetter as ProbDirectionGetter, + ) + from cuslines.webgpu import ( + WebGPUPttDirectionGetter as PttDirectionGetter, + ) + from cuslines.webgpu import ( + WebGPUTracker as GPUTracker, + ) + else: + raise ValueError( + "gpu_backend must be one of 'auto', 'cuda', " + "'metal', or 'webgpu', not {gpu_backend}" + ) + seed_img = nib.load(seed_path) directions = directions.lower() + minlen = int(minlen / step_size) + maxlen = int(maxlen / step_size) + + R = seed_img.affine[0:3, 0:3] + vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) + step_size = step_size / vox_dim + # Roughly handle ACT/CMC for now - wm_threshold = 0.01 + wm_threshold = 0.5 pve_img = nib.load(pve_path) @@ -165,6 +223,8 @@ def gpu_track( sphere.edges, max_angle=radians(max_angle), step_size=step_size, + min_pts=minlen, + max_pts=maxlen, ngpus=ngpus, rng_seed=rng_seed, chunk_size=chunk_size, diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 1624986f..25474ec2 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -30,7 +30,7 @@ def track( seed_mask=None, seed_threshold=0.5, thresholds_as_percentages=False, - n_seeds=2000000, + n_seeds=5000000, random_seeds=True, rng_seed=None, step_size=0.5, @@ -40,7 +40,7 @@ def track( basis_type="descoteaux07", legacy=True, tracker="pft", - trx=False, + trx=True, ): """ Tractography @@ -75,7 +75,7 @@ def track( voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D array, these are the coordinates of the seeds. Unless random_seeds is set to True, in which case this is the total number of random seeds - to generate within the mask. Default: 2000000 + to generate within the mask. Default: 5000000 random_seeds : bool Whether to generate a total of n_seeds random seeds in the mask. Default: True @@ -93,7 +93,7 @@ def track( minlen: int, optional The minimal length (mm) in a streamline. Default: 20 maxlen: int, optional - The minimal length (mm) in a streamline. Default: 250 + The maximum length (mm) in a streamline. Default: 250 odf_model : str or Definition, optional Can be either a string or Definition. If a string, it must be one of {"DTI", "CSD", "DKI", "GQ", "RUMBA", "MSMT_AODF", "CSD_AODF", "MSMTCSD"}. @@ -113,7 +113,7 @@ def track( trx : bool, optional Whether to return the streamlines compatible with input to TRX file (i.e., as a LazyTractogram class instance). - Default: False + Default: True Returns ------- diff --git a/AFQ/utils/streamlines.py b/AFQ/utils/streamlines.py index 8c37801c..6d64003f 100644 --- a/AFQ/utils/streamlines.py +++ b/AFQ/utils/streamlines.py @@ -2,7 +2,7 @@ import numpy as np from dipy.io.stateful_tractogram import Space, StatefulTractogram -from dipy.io.streamline import load_tractogram +from dipy.io.streamline import load_tractogram, save_tractogram try: from trx.io import load as load_trx @@ -11,14 +11,14 @@ except ModuleNotFoundError: has_trx = False +from AFQ.definitions.mapping import ConformedFnirtMapping from AFQ.utils.path import drop_extension, read_json class SegmentedSFT: - def __init__(self, bundles, space, sidecar_info=None): + def __init__(self, bundles, sidecar_info=None): if sidecar_info is None: sidecar_info = {} - reference = None self.bundle_names = [] sls = [] idxs = {} @@ -26,20 +26,17 @@ def __init__(self, bundles, space, sidecar_info=None): idx_count = 0 for b_name in bundles: if isinstance(bundles[b_name], dict): - this_sls = bundles[b_name]["sl"] + this_sft = bundles[b_name]["sl"] this_tracking_idxs[b_name] = bundles[b_name]["idx"] else: - this_sls = bundles[b_name] - if reference is None: - reference = this_sls - this_sls = list(this_sls.streamlines) + this_sft = bundles[b_name] + this_sls = list(this_sft.streamlines) sls.extend(this_sls) new_idx_count = idx_count + len(this_sls) idxs[b_name] = np.arange(idx_count, new_idx_count, dtype=np.uint32) idx_count = new_idx_count self.bundle_names.append(b_name) - self.sft = StatefulTractogram(sls, reference, space) self.bundle_idxs = idxs if len(this_tracking_idxs) > 1: self.this_tracking_idxs = this_tracking_idxs @@ -48,12 +45,13 @@ def __init__(self, bundles, space, sidecar_info=None): self.sidecar_info = sidecar_info self.sidecar_info["bundle_ids"] = {} - dps = np.zeros(len(self.sft.streamlines)) + dps = np.zeros(len(sls)) for ii, bundle_name in enumerate(self.bundle_names): self.sidecar_info["bundle_ids"][f"{bundle_name}"] = ii + 1 dps[self.bundle_idxs[bundle_name]] = ii + 1 - dps = {"bundle": dps} - self.sft.data_per_streamline = dps + self.sft = StatefulTractogram.from_sft( + sls, this_sft, data_per_streamline={"bundle": dps} + ) if self.this_tracking_idxs is not None: for kk, _vv in self.this_tracking_idxs.items(): self.this_tracking_idxs[kk] = ( @@ -108,7 +106,7 @@ def fromfile(cls, trk_or_trx_file, reference="same", sidecar_file=None): else: bundles["whole_brain"] = sft - return cls(bundles, Space.RASMM, sidecar_info) + return cls(bundles, sidecar_info) def split_streamline(streamlines, sl_to_split, split_idx): @@ -140,3 +138,57 @@ def split_streamline(streamlines, sl_to_split, split_idx): ) return streamlines + + +def move_streamlines(tg, to, mapping, img, to_space=None, save_intermediates=None): + """Move streamlines to or from template space. + + to : str + Either "template" or "subject". This determines + whether we will use the forward or backwards displacement field. + mapping : DIPY or pyAFQ mapping + Mapping to use to move streamlines. + img : Nifti1Image + Image defining reference for where the streamlines move to. + to_space : Space or None + If not None, space to move streamlines to after moving them to the + template or subject space. If None, streamlines will be moved back to + their original space. + Default: None. + save_intermediates : str or None + If not None, path to save intermediate tractogram after moving to template + or subject space. + Default: None. + """ + tg_og_space = tg.space + if isinstance(mapping, ConformedFnirtMapping): + if to != "subject": + raise ValueError( + "Attempted to transform streamlines to template using " + "unsupported mapping. " + "Use something other than Fnirt." + ) + tg.to_vox() + moved_sl = [] + for sl in tg.streamlines: + moved_sl.append(mapping.transform_pts(sl)) + moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) + else: + tg.to_vox() + if to == "template": + moved_sl = mapping.transform_points(tg.streamlines) + else: + moved_sl = mapping.transform_points_inverse(tg.streamlines) + moved_sft = StatefulTractogram(moved_sl, img, Space.VOX) + + if save_intermediates is not None: + save_tractogram( + moved_sft, + op.join(save_intermediates, f"sls_in_{to}.trk"), + bbox_valid_check=False, + ) + if to_space is None: + moved_sft.to_space(tg_og_space) + else: + moved_sft.to_space(to_space) + return moved_sft diff --git a/AFQ/utils/tests/test_streamlines.py b/AFQ/utils/tests/test_streamlines.py index 39c624c7..8931670f 100644 --- a/AFQ/utils/tests/test_streamlines.py +++ b/AFQ/utils/tests/test_streamlines.py @@ -37,7 +37,7 @@ def test_SegmentedSFT(): ), } - seg_sft = aus.SegmentedSFT(bundles, Space.VOX) + seg_sft = aus.SegmentedSFT(bundles) for k1 in bundles.keys(): for sl1, sl2 in zip( bundles[k1].streamlines, seg_sft.get_bundle(k1).streamlines diff --git a/AFQ/utils/tests/test_volume.py b/AFQ/utils/tests/test_volume.py index e8be2b38..c5c437a7 100644 --- a/AFQ/utils/tests/test_volume.py +++ b/AFQ/utils/tests/test_volume.py @@ -22,12 +22,10 @@ def test_density_map(): file_dict = afd.read_stanford_hardi_tractography() # subsample even more - subsampled_tractography = file_dict["tractography_subsampled.trk"][441:444] - sft = StatefulTractogram( - subsampled_tractography, file_dict["mapping.nii.gz"], Space.VOX - ) + subsampled_tractography = file_dict["tractography_subsampled"][441:444] + sft = StatefulTractogram(subsampled_tractography, file_dict["dwi"], Space.RASMM) density_map = afv.density_map(sft) - npt.assert_equal(int(np.sum(density_map.get_fdata())), 69) + npt.assert_equal(int(np.sum(density_map.get_fdata())), 36) density_map = afv.density_map(sft, normalize=True) npt.assert_equal(density_map.get_fdata().max(), 1) diff --git a/AFQ/utils/volume.py b/AFQ/utils/volume.py index cea03655..572ce1f1 100644 --- a/AFQ/utils/volume.py +++ b/AFQ/utils/volume.py @@ -7,12 +7,12 @@ from dipy.io.utils import create_nifti_header, get_reference_info from dipy.tracking.streamline import select_random_set_of_streamlines from scipy.spatial.distance import dice -from skimage.morphology import binary_dilation +from skimage.morphology import dilation logger = logging.getLogger("AFQ") -def transform_inverse_roi(roi, mapping, bundle_name="ROI"): +def transform_roi(roi, mapping, bundle_name="ROI"): """ After being non-linearly transformed, ROIs tend to have holes in them. We perform a couple of computational geometry operations on the ROI to @@ -40,12 +40,20 @@ def transform_inverse_roi(roi, mapping, bundle_name="ROI"): if isinstance(roi, nib.Nifti1Image): roi = roi.get_fdata() - _roi = mapping.transform_inverse(roi, interpolation="linear") + # dilate binary images to avoid losing small ROIs + if np.unique(roi).size < 3: + scale_factor = max( + np.asarray(mapping.codomain_shape) / np.asarray(mapping.domain_shape) + ) + for _ in range(max(np.ceil(scale_factor) - 1, 0).astype(int)): + roi = dilation(roi) + + _roi = mapping.transform((roi.astype(float)), interpolation="linear") if np.sum(_roi) == 0: - logger.warning(f"Lost ROI {bundle_name}, performing automatic binary dilation") - _roi = binary_dilation(roi) - _roi = mapping.transform_inverse(_roi, interpolation="linear") + logger.warning(f"Lost ROI {bundle_name}, performing automatic dilation") + _roi = dilation(roi) + _roi = mapping.transform(_roi.astype(float), interpolation="linear") _roi = patch_up_roi(_roi > 0, bundle_name=bundle_name).astype(np.int32) diff --git a/AFQ/viz/fury_backend.py b/AFQ/viz/fury_backend.py index 5040fbc9..4e73f22e 100644 --- a/AFQ/viz/fury_backend.py +++ b/AFQ/viz/fury_backend.py @@ -194,11 +194,7 @@ def create_gif( def visualize_roi( roi, - affine_or_mapping=None, - static_img=None, - roi_affine=None, - static_affine=None, - reg_template=None, + resample_to=None, name="ROI", figure=None, color=None, @@ -215,22 +211,8 @@ def visualize_roi( roi : str or Nifti1Image The ROI information - affine_or_mapping : ndarray, Nifti1Image, or str, optional - An affine transformation or mapping to apply to the ROIs before - visualization. Default: no transform. - - static_img: str or Nifti1Image, optional - Template to resample roi to. - Default: None - - roi_affine: ndarray, optional - Default: None - - static_affine: ndarray, optional - Default: None - - reg_template: str or Nifti1Image, optional - Template to use for registration. + resample_to : Nifti1Image, optional + If not None, the ROI will be resampled to the space of this image. Default: None name: str, optional @@ -266,9 +248,7 @@ def visualize_roi( """ if color is None: color = np.array([1, 0, 0]) - roi = vut.prepare_roi( - roi, affine_or_mapping, static_img, roi_affine, static_affine, reg_template - ) + roi = vut.prepare_roi(roi, resample_to) for i, flip in enumerate(flip_axes): if flip: roi = np.flip(roi, axis=i) diff --git a/AFQ/viz/plotly_backend.py b/AFQ/viz/plotly_backend.py index 0076c1b9..8b2f8745 100644 --- a/AFQ/viz/plotly_backend.py +++ b/AFQ/viz/plotly_backend.py @@ -40,8 +40,6 @@ def _inline_interact(figure, show, show_inline): def _to_color_range(num): - if num < 0: - num = 0 if num >= 0.999: num = 0.999 if num <= 0.001: @@ -232,9 +230,10 @@ def _draw_streamlines( def _plot_profiles(profiles, bundle_name, color, fig, scalar): if isinstance(profiles, pd.DataFrame): - sc_max = np.max(profiles[scalar].to_numpy()) - sc_90 = np.percentile(profiles[scalar].to_numpy(), 10) - sc_1 = np.percentile(profiles[scalar].to_numpy(), 99) + all_tp = profiles[scalar].to_numpy() + all_tp = np.max(all_tp) - all_tp + lim_0 = np.percentile(all_tp, 1) + lim_1 = np.percentile(all_tp, 90) profiles = profiles[profiles.tractID == bundle_name] x = profiles["nodeID"] @@ -242,10 +241,14 @@ def _plot_profiles(profiles, bundle_name, color, fig, scalar): line_color = [] for scalar_val in profiles[scalar].to_numpy(): - xformed_scalar = np.minimum( - (sc_max - scalar_val) / (sc_1 - sc_90) + sc_90 + 0.1, 0.999 + brightness = np.minimum( + np.maximum( + scalar_val - lim_0, + 0, + ), + lim_1, ) - line_color.append(_color_arr2str(xformed_scalar * color)) + line_color.append(_color_arr2str(brightness * color)) else: x = np.arange(len(profiles)) y = profiles @@ -512,7 +515,7 @@ def create_gif(figure, file_name, n_frames=30, zoom=2.5, z_offset=0.5, size=(600 def _draw_roi(figure, roi, name, color, opacity, dimensions, flip_axes): - roi = np.where(roi == 1) + roi = np.where(roi > 0) pts = [] for i, flip in enumerate(flip_axes): if flip: @@ -535,11 +538,7 @@ def _draw_roi(figure, roi, name, color, opacity, dimensions, flip_axes): def visualize_roi( roi, - affine_or_mapping=None, - static_img=None, - roi_affine=None, - static_affine=None, - reg_template=None, + resample_to=None, name="ROI", figure=None, flip_axes=None, @@ -556,22 +555,8 @@ def visualize_roi( roi : str or Nifti1Image The ROI information - affine_or_mapping : ndarray, Nifti1Image, or str, optional - An affine transformation or mapping to apply to the ROIs before - visualization. Default: no transform. - - static_img: str or Nifti1Image, optional - Template to resample roi to. - Default: None - - roi_affine: ndarray, optional - Default: None - - static_affine: ndarray, optional - Default: None - - reg_template: str or Nifti1Image, optional - Template to use for registration. + resample_to : Nifti1Image, optional + If not None, the ROI will be resampled to the space of this image. Default: None name: str, optional @@ -612,9 +597,7 @@ def visualize_roi( color = np.array([0.9999, 0, 0]) if flip_axes is None: flip_axes = [False, False, False] - roi = vut.prepare_roi( - roi, affine_or_mapping, static_img, roi_affine, static_affine, reg_template - ) + roi = vut.prepare_roi(roi, resample_to) if figure is None: figure = make_subplots(rows=1, cols=1, specs=[[{"type": "scene"}]]) diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index def69151..ec3bbd76 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -1,3 +1,4 @@ +import colorsys import logging import os.path as op from collections import OrderedDict @@ -13,12 +14,27 @@ from dipy.tracking.streamline import transform_streamlines from PIL import Image, ImageChops -import AFQ.registration as reg import AFQ.utils.streamlines as aus -import AFQ.utils.volume as auv __all__ = ["Viz"] + +def get_distinct_shades(base_rgb, n_steps, hue_shift): + """ + Creates distinct shades by shifting Hue + """ + hh, ll, ss = colorsys.rgb_to_hls(*base_rgb) + shades = [] + + for i in range(n_steps): + offset = i - (n_steps - 1) / 2 + + new_h = (hh + (offset * hue_shift)) % 1.0 + + shades.append(colorsys.hls_to_rgb(new_h, ll, ss)) + return shades + + viz_logger = logging.getLogger("AFQ") tableau_20 = [ (0.12156862745098039, 0.4666666666666667, 0.7058823529411765), @@ -53,6 +69,18 @@ small_font = 20 marker_size = 200 +slf_l_base = tableau_extension[0] +slf_r_base = tableau_extension[1] + +vof_l_base = tableau_20[6] +vof_r_base = tableau_20[7] + +slf_l_shades = get_distinct_shades(slf_l_base, 3, hue_shift=0.1) +slf_r_shades = get_distinct_shades(slf_r_base, 3, hue_shift=0.1) + +vof_l_shades = get_distinct_shades(vof_l_base, 3, hue_shift=0.15) +vof_r_shades = get_distinct_shades(vof_r_base, 3, hue_shift=0.15) + COLOR_DICT = OrderedDict( { "Left Anterior Thalamic": tableau_20[0], @@ -77,8 +105,14 @@ "F_L": tableau_20[12], "Right Inferior Longitudinal": tableau_20[13], "F_R": tableau_20[13], - "Left Superior Longitudinal": tableau_20[14], - "Right Superior Longitudinal": tableau_20[15], + "Left Superior Longitudinal": slf_l_base, + "Right Superior Longitudinal": slf_r_base, + "Left Superior Longitudinal I": slf_l_shades[0], + "Left Superior Longitudinal II": slf_l_shades[1], + "Left Superior Longitudinal III": slf_l_shades[2], + "Right Superior Longitudinal I": slf_r_shades[0], + "Right Superior Longitudinal II": slf_r_shades[1], + "Right Superior Longitudinal III": slf_r_shades[2], "Left Uncinate": tableau_20[16], "UF_L": tableau_20[16], "Right Uncinate": tableau_20[17], @@ -87,10 +121,16 @@ "AF_L": tableau_20[18], "Right Arcuate": tableau_20[19], "AF_R": tableau_20[19], - "Left Posterior Arcuate": tableau_20[6], - "Right Posterior Arcuate": tableau_20[7], - "Left Vertical Occipital": tableau_extension[0], - "Right Vertical Occipital": tableau_extension[1], + "Left Posterior Arcuate": tableau_20[14], + "Right Posterior Arcuate": tableau_20[15], + "Left Vertical Occipital": vof_l_base, + "Right Vertical Occipital": vof_r_base, + "Left V1V3": vof_l_shades[0], + "Left Posterior Vertical Occipital": vof_l_shades[1], + "Left Anterior Vertical Occipital": vof_l_shades[2], + "Right V1V3": vof_r_shades[0], + "Right Posterior Vertical Occipital": vof_r_shades[1], + "Right Anterior Vertical Occipital": vof_r_shades[2], "median": tableau_20[6], # Paul Tol's palette for callosal bundles "Callosum Orbital": (0.2, 0.13, 0.53), @@ -150,28 +190,28 @@ RECO_FLIP = ["IFO_L", "IFO_R", "UNC_L", "ILF_L", "ILF_R"] BEST_BUNDLE_ORIENTATIONS = { - "Left Anterior Thalamic": ("Sagittal", "Left"), - "Right Anterior Thalamic": ("Sagittal", "Right"), - "Left Corticospinal": ("Sagittal", "Left"), - "Right Corticospinal": ("Sagittal", "Right"), - "Left Cingulum Cingulate": ("Sagittal", "Left"), - "Right Cingulum Cingulate": ("Sagittal", "Right"), - "Forceps Minor": ("Axial", "Top"), - "Forceps Major": ("Axial", "Top"), - "Left Inferior Fronto-occipital": ("Sagittal", "Left"), - "Right Inferior Fronto-occipital": ("Sagittal", "Right"), - "Left Inferior Longitudinal": ("Sagittal", "Left"), - "Right Inferior Longitudinal": ("Sagittal", "Right"), - "Left Superior Longitudinal": ("Axial", "Top"), - "Right Superior Longitudinal": ("Axial", "Top"), - "Left Uncinate": ("Axial", "Bottom"), - "Right Uncinate": ("Axial", "Bottom"), - "Left Arcuate": ("Sagittal", "Left"), - "Right Arcuate": ("Sagittal", "Right"), - "Left Vertical Occipital": ("Coronal", "Back"), - "Right Vertical Occipital": ("Coronal", "Back"), - "Left Posterior Arcuate": ("Coronal", "Back"), - "Right Posterior Arcuate": ("Coronal", "Back"), + "Left Anterior Thalamic": ("Left", "Front", "Top"), + "Right Anterior Thalamic": ("Right", "Front", "Top"), + "Left Corticospinal": ("Left", "Front", "Top"), + "Right Corticospinal": ("Right", "Front", "Top"), + "Left Cingulum Cingulate": ("Left", "Front", "Top"), + "Right Cingulum Cingulate": ("Right", "Front", "Top"), + "Forceps Minor": ("Left", "Front", "Top"), + "Forceps Major": ("Left", "Back", "Top"), + "Left Inferior Fronto-occipital": ("Left", "Front", "Bottom"), + "Right Inferior Fronto-occipital": ("Right", "Front", "Bottom"), + "Left Inferior Longitudinal": ("Left", "Front", "Bottom"), + "Right Inferior Longitudinal": ("Right", "Front", "Bottom"), + "Left Superior Longitudinal": ("Left", "Front", "Top"), + "Right Superior Longitudinal": ("Right", "Front", "Top"), + "Left Uncinate": ("Left", "Front", "Bottom"), + "Right Uncinate": ("Right", "Front", "Bottom"), + "Left Arcuate": ("Left", "Front", "Top"), + "Right Arcuate": ("Right", "Front", "Top"), + "Left Vertical Occipital": ("Left", "Back", "Top"), + "Right Vertical Occipital": ("Right", "Back", "Top"), + "Left Posterior Arcuate": ("Left", "Back", "Top"), + "Right Posterior Arcuate": ("Right", "Back", "Top"), } @@ -512,7 +552,7 @@ def tract_generator( else: if bundle is None: # No selection: visualize all of them: - for bundle_name in seg_sft.bundle_names: + for bundle_name in sorted(seg_sft.bundle_names): idx = seg_sft.bundle_idxs[bundle_name] if len(idx) == 0: continue @@ -591,9 +631,7 @@ def gif_from_pngs(tdir, gif_fname, n_frames, png_fname="tgif", add_zeros=False): io.mimsave(gif_fname, angles) -def prepare_roi( - roi, affine_or_mapping, static_img, roi_affine, static_affine, reg_template -): +def prepare_roi(roi, resample_to=None): """ Load the ROI Possibly perform a transformation on an ROI @@ -605,60 +643,27 @@ def prepare_roi( The ROI information. If str, ROI will be loaded using the str as a path. - affine_or_mapping : ndarray, Nifti1Image, or str - An affine transformation or mapping to apply to the ROI before - visualization. Default: no transform. - - static_img: str or Nifti1Image - Template to resample roi to. - - roi_affine: ndarray - - static_affine: ndarray - - reg_template: str or Nifti1Image - Template to use for registration. + resample_to : Nifti1Image, optional + If not None, the ROI will be resampled to the space of this image. Returns ------- ndarray """ viz_logger.info("Preparing ROI...") + if isinstance(roi, str): + roi = nib.load(roi) + + if resample_to is not None: + if not isinstance(roi, nib.Nifti1Image): + raise ValueError( + ("If resampling, roi must be a Nifti1Image or a path to one.") + ) + roi = resample(roi, resample_to) + if not isinstance(roi, np.ndarray): - if isinstance(roi, str): - roi = nib.load(roi).get_fdata() - else: - roi = roi.get_fdata() - - if affine_or_mapping is not None: - if isinstance(affine_or_mapping, np.ndarray): - # This is an affine: - if static_img is None or roi_affine is None or static_affine is None: - raise ValueError( - "If using an affine to transform an ROI, " - "need to also specify all of the following", - "inputs: `static_img`, `roi_affine`, ", - "`static_affine`", - ) - roi = resample( - roi, static_img, moving_affine=roi_affine, static_affine=static_affine - ).get_fdata() - else: - # Assume it is a mapping: - if isinstance(affine_or_mapping, str) or isinstance( - affine_or_mapping, nib.Nifti1Image - ): - if reg_template is None or static_img is None: - raise ValueError( - "If using a mapping to transform an ROI, need to ", - "also specify all of the following inputs: ", - "`reg_template`, `static_img`", - ) - affine_or_mapping = reg.read_mapping( - affine_or_mapping, static_img, reg_template - ) + roi = roi.get_fdata() - roi = auv.transform_inverse_roi(roi, affine_or_mapping).astype(bool) return roi diff --git a/docs/source/reference/bundledict.rst b/docs/source/reference/bundledict.rst index 5927e818..581e662b 100644 --- a/docs/source/reference/bundledict.rst +++ b/docs/source/reference/bundledict.rst @@ -86,6 +86,46 @@ relation to the Left Arcuate and Inferior Longitudinal fasciculi: 'Left Inferior Longitudinal': {'core': 'Left'}, } + +Mixed space ROIs +================ +Everywhere in the bundle dictionary where an ROI is specified as a path, +be it start, end, include, exclude, or probability map, you can in fact input +a dictionary instead. This dictionary should have two keys: +- 'roi' : path to the ROI Nifti file +- 'space' : either 'template' or 'subject', describing the space the ROI + is currently in. + +Then, for the whole bundle, set "space" to 'mixed'. This allows you to +specify some ROIs in template space and some in subject space for the same +bundle. For example: + +.. code-block:: python + + import os.path as op + import AFQ.api.bundle_dict as abd + import AFQ.data.fetch as afd + + # First, organize the data + afd.organize_stanford_data() + bids_path = op.join(op.expanduser('~'), 'AFQ_data', 'stanford_hardi') + sub_path = op.join(bids_path, 'derivatives', 'vistasoft', 'sub-01', 'ses-01') + dwi_path = op.join(sub_path, 'dwi', 'sub-01_ses-01_dwi.nii.gz') + + lv1_files, lv1_folder = afd.fetch_stanford_hardi_lv1() + ar_rois = afd.read_ar_templates() + lv1_fname = op.join(lv1_folder, list(lv1_files.keys())[0]) + + # Then, prepare the bundle dictionary + bundle_info = abd.BundleDict({ + "OR LV1": { + "start": {"roi": ar_rois["AAL_Thal_L"], "space": "template"}, + "end": {"roi": lv1_fname, "space": "subject"}, + "space": "mixed" + } + }, resample_subject_to=dwi_path) + + Filtering Order =============== When doing bundle recognition, streamlines are filtered out from the whole diff --git a/docs/source/reference/kwargs.rst b/docs/source/reference/kwargs.rst index 914c1297..8709543c 100644 --- a/docs/source/reference/kwargs.rst +++ b/docs/source/reference/kwargs.rst @@ -149,7 +149,7 @@ tractography_ngpus: int Number of GPUs to use in tractography. If non-0, this algorithm is used for tractography, https://github.com/dipy/GPUStreamlines PTT, Prob can be used with any SHM model. Bootstrapped can be done with CSA/OPDT. Default: 0 chunk_size: int - Chunk size for GPU tracking. Default: 100000 + Chunk size for GPU tracking. Default: 25000 ========================================================== @@ -173,9 +173,6 @@ volume_opacity_indiv: float n_points_indiv: int or None n_points to resample streamlines to before plotting. If None, no resampling is done. Default: 40 -virtual_frame_buffer: bool - Whether to use a virtual frame buffer. This is if generating GIFs in a headless environment. Default: False - viz_backend_spec: str Which visualization backend to use. See Visualization Backends page in documentation for details https://tractometry.org/pyAFQ/reference/viz_backend.html One of {"fury", "plotly", "plotly_no_gif"}. Default: "plotly_no_gif" diff --git a/docs/source/reference/methods.rst b/docs/source/reference/methods.rst index 50c94406..bde820f3 100644 --- a/docs/source/reference/methods.rst +++ b/docs/source/reference/methods.rst @@ -111,6 +111,10 @@ dti_params: full path to a nifti file containing parameters for the DTI fit +dti_s0: + s0 values of DTI fit + + fwdti_tf: Free-water DTI TensorFit object @@ -127,6 +131,10 @@ dki_params: full path to a nifti file containing parameters for the DKI fit +dki_s0: + s0 values of DKI fit + + msdki_tf: Mean Signal DKI DiffusionKurtosisFit object @@ -135,6 +143,10 @@ msdki_params: full path to a nifti file containing parameters for the Mean Signal DKI fit +msdki_s0: + s0 values of Mean Signal DKI fit + + msdki_msd: full path to a nifti file containing the MSDKI mean signal diffusivity @@ -168,47 +180,27 @@ csd_ai: gq_params: - full path to a nifti file containing parameters for the Generalized Q-Sampling shm_coeff + full path to a nifti file containing ODF for the Generalized Q-Sampling gq_iso: full path to a nifti file containing isotropic diffusion component -gq_aso: - full path to a nifti file containing anisotropic diffusion component - - -gq_pmap: - full path to a nifti file containing the anisotropic power map from GQ - - -gq_ai: - full path to a nifti file containing the anisotropic index from GQ - - -rumba_model: - fit for RUMBA-SD model as documented on dipy reconstruction options - - rumba_params: - Takes the fitted RUMBA-SD model as input and returns the spherical harmonics coefficients (SHM). - - -rumba_fit: - RUMBA FIT + ODF for the RUMBA-SD model rumba_f_csf: - full path to a nifti file containing the CSF volume fraction for each voxel. + full path to a nifti file containing the CSF volume fraction for each voxel rumba_f_gm: - full path to a nifti file containing the GM volume fraction for each voxel. + full path to a nifti file containing the GM volume fraction for each voxel rumba_f_wm: - full path to a nifti file containing the white matter volume fraction for each voxel. + full path to a nifti file containing the white matter volume fraction for each voxel opdt_params: @@ -391,6 +383,10 @@ dki_lt5: Image of sixth element in the DTI tensor from DKI +dki_cfa: + full path to a nifti file containing the DKI color fractional anisotropy + + dki_fa: full path to a nifti file containing the DKI fractional anisotropy @@ -468,7 +464,15 @@ pve_internal: msmtcsd_params: - full path to a nifti file containing parameters for the MSMT CSD fit + full path to a nifti file containing parameters for the MSMT CSD white matter fit + + +msmtcsd_gm: + full path to a nifti file containing parameters for the MSMT CSD gray matter fit + + +msmtcsd_csf: + full path to a nifti file containing parameters for the MSMT CSD cerebrospinal fluid fit msmt_apm: @@ -527,7 +531,7 @@ sl_counts: full path to a JSON file containing streamline counts -median_bundle_lengths: +bundle_lengths: full path to a JSON file containing median bundle lengths @@ -565,3 +569,7 @@ tract_profile_plots: viz_backend: An instance of the `AFQ.viz.utils.viz_backend` class. + + +citations: + Export Bibtex citation file for methods used by pyAFQ. diff --git a/docs/source/references.bib b/docs/source/references.bib index 3dba804f..58a778bb 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -9,6 +9,27 @@ @article{Grotheer2022 publisher={Nature Publishing Group} } +@article{leong2016white, + title={White-matter tract connecting anterior insula to nucleus accumbens correlates with reduced preference for positively skewed gambles}, + author={Leong, Josiah K and Pestilli, Franco and Wu, Charlene C and Samanez-Larkin, Gregory R and Knutson, Brian}, + journal={Neuron}, + volume={89}, + number={1}, + pages={63--69}, + year={2016}, + publisher={Elsevier} +} + +@article{alkemade2020amsterdam, + title={The Amsterdam Ultra-high field adult lifespan database (AHEAD): A freely available multimodal 7 Tesla submillimeter magnetic resonance imaging database}, + author={Alkemade, Anneke and Mulder, Martijn J and Groot, Josephine M and Isaacs, Bethany R and van Berendonk, Nikita and Lute, Nicky and Isherwood, Scott JS and Bazin, Pierre-Louis and Forstmann, Birte U}, + journal={NeuroImage}, + volume={221}, + pages={117200}, + year={2020}, + publisher={Elsevier} +} + @article{grotheer2023human, title={Human white matter myelinates faster in utero than ex utero}, author={Grotheer, Mareike and Bloom, David and Kruper, John and Richie-Halford, Adam and Zika, Stephanie and Aguilera González, Vicente A and Yeatman, Jason D and Grill-Spector, Kalanit and Rokem, Ariel}, @@ -201,6 +222,28 @@ @article{takemura2017occipital publisher={Oxford University Press} } +@article{takemura2016major, + title={A major human white matter pathway between dorsal and ventral visual cortex}, + author={Takemura, Hiromasa and Rokem, Ariel and Winawer, Jonathan and Yeatman, Jason D and Wandell, Brian A and Pestilli, Franco}, + journal={Cerebral cortex}, + volume={26}, + number={5}, + pages={2205--2214}, + year={2016}, + publisher={Oxford University Press} +} + +@article{wang2013rethinking, + title={Rethinking the role of the middle longitudinal fascicle in language and auditory pathways}, + author={Wang, Yibao and Fern{\'a}ndez-Miranda, Juan C and Verstynen, Timothy and Pathak, Sudhir and Schneider, Walter and Yeh, Fang-Cheng}, + journal={Cerebral cortex}, + volume={23}, + number={10}, + pages={2347--2356}, + year={2013}, + publisher={Oxford University Press} +} + @ARTICLE{Dougherty2007, title = "Temporal-callosal pathway diffusivity predicts phonological skills in children", @@ -517,6 +560,16 @@ @InProceedings{Kruper2023 isbn="978-3-031-47292-3" } +@article{zhang2018anatomically, + title={An anatomically curated fiber clustering white matter atlas for consistent white matter tract parcellation across the lifespan}, + author={Zhang, Fan and Wu, Ye and Norton, Isaiah and Rigolo, Laura and Rathi, Yogesh and Makris, Nikos and O'Donnell, Lauren J}, + journal={Neuroimage}, + volume={179}, + pages={429--447}, + year={2018}, + publisher={Elsevier} +} + @ARTICLE{Hua2008, title = "Tract probability maps in stereotaxic spaces: analyses of white matter anatomy and tract-specific quantification", diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index 22fc7542..3f2dbd00 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -269,8 +269,10 @@ "NDARAA948VFH"]["HBNsiteRU"], index_col=[0]) for ind in bundle_counts.index: if ind == "Total Recognized": - threshold = 1000 - elif "Fronto-occipital" in ind or "Orbital" in ind: + threshold = 3000 + elif "Fronto-occipital" in ind: + threshold = 10 + elif "Vertical Occipital" in ind: threshold = 5 else: threshold = 15 diff --git a/setup.cfg b/setup.cfg index 6d66c246..eb0dd268 100644 --- a/setup.cfg +++ b/setup.cfg @@ -95,6 +95,7 @@ nn = onnxruntime gpu = cuslines==2.0.0 + onnxruntime-gpu all = %(dev)s %(fury)s