From 1c2eb044d784b530b253de88ae8852d49b10b4ff Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 27 Jan 2026 15:12:58 +0900 Subject: [PATCH 01/86] start allowing mixed bundle defs --- AFQ/api/bundle_dict.py | 48 ++++++++++++++++++++++++++---- AFQ/api/group.py | 3 +- AFQ/data/fetch.py | 59 +++++++++++++++++++++++++++++++++++++ AFQ/definitions/image.py | 2 +- AFQ/nn/synthseg.py | 4 +-- AFQ/recognition/criteria.py | 4 +-- AFQ/tasks/decorators.py | 1 - AFQ/tasks/mapping.py | 2 +- docs/source/references.bib | 21 +++++++++++++ 9 files changed, 128 insertions(+), 16 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 22c8e96e..9168da7b 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -10,7 +10,7 @@ from AFQ.definitions.utils import find_file from AFQ.tasks.utils import get_fname, str_to_desc -logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("AFQ") __all__ = [ @@ -1111,7 +1111,7 @@ def update_max_includes(self, new_max): self.max_includes = new_max def _use_bids_info(self, roi_or_sl, bids_layout, bids_path, subject, session): - if isinstance(roi_or_sl, dict): + if isinstance(roi_or_sl, dict) and "roi" not in roi_or_sl: suffix = roi_or_sl.get("suffix", "dwi") roi_or_sl = find_file( bids_layout, bids_path, roi_or_sl, suffix, session, subject @@ -1124,6 +1124,33 @@ def _cond_load(self, roi_or_sl, resample_to): """ Load ROI or streamline if not already loaded """ + if isinstance(roi_or_sl, dict): + space = roi_or_sl.get("space", None) + roi_or_sl = roi_or_sl.get("roi", None) + if roi_or_sl is None or space is None: + raise ValueError( + ( + f"Unclear ROI definition for {roi_or_sl}. " + "See 'Defining Custom Bundle Dictionaries' " + "in the documentation for details." + ) + ) + if space == "template": + resample_to = self.resample_to + elif space == "subject": + resample_to = self.resample_subject_to + else: + raise ValueError( + ( + f"Unknown space {space} for ROI definition {roi_or_sl}. " + "See 'Defining Custom Bundle Dictionaries' " + "in the documentation for details." + ) + ) + + logger.debug(f"Loading ROI or streamlines: {roi_or_sl}") + logger.debug(f"Loading ROI or streamlines from space: {resample_to}") + if isinstance(roi_or_sl, str): if ".nii" in roi_or_sl: return afd.read_resample_roi(roi_or_sl, resample_to=resample_to) @@ -1261,11 +1288,20 @@ def is_bundle_in_template(self, bundle_name): return ( "space" not in self._dict[bundle_name] or self._dict[bundle_name]["space"] == "template" + or self._dict[bundle_name]["space"] == "mixed" ) - def _roi_transform_helper(self, roi_or_sl, mapping, new_affine, bundle_name): + def _roi_transform_helper(self, roi_or_sl, mapping, new_img, bundle_name): roi_or_sl = self._cond_load(roi_or_sl, self.resample_to) if isinstance(roi_or_sl, nib.Nifti1Image): + if ( + np.allclose(roi_or_sl.affine, new_img.affine) + and roi_or_sl.shape == new_img.shape + ): + # This is the case of a mixed bundle definition, where + # some ROIs need transformed and others do not + return roi_or_sl + fdata = roi_or_sl.get_fdata() if len(np.unique(fdata)) <= 2: boolean_ = True @@ -1278,7 +1314,7 @@ def _roi_transform_helper(self, roi_or_sl, mapping, new_affine, bundle_name): if boolean_: warped_img = warped_img.astype(np.uint8) - warped_img = nib.Nifti1Image(warped_img, new_affine) + warped_img = nib.Nifti1Image(warped_img, new_img.affine) return warped_img else: return roi_or_sl @@ -1287,7 +1323,7 @@ def transform_rois( self, bundle_name, mapping, - new_affine, + new_img, base_fname=None, to_space="subject", apply_to_recobundles=False, @@ -1333,7 +1369,7 @@ def transform_rois( bundle_name, self._roi_transform_helper, mapping, - new_affine, + new_img, bundle_name, dry_run=True, apply_to_recobundles=apply_to_recobundles, diff --git a/AFQ/api/group.py b/AFQ/api/group.py index 2c225a9d..e2b19707 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -50,7 +50,6 @@ logger = logging.getLogger("AFQ") -logger.setLevel(logging.INFO) warnings.simplefilter(action="ignore", category=FutureWarning) @@ -142,7 +141,7 @@ def __init__( api.GroupAFQ(my_path, csd_sh_order_max=4) api.GroupAFQ( my_path, - reg_template_spec="mni_t2", reg_subject_spec="b0") + _spec="mni_t2", reg_subject_spec="b0") """ if bids_layout_kwargs is None: bids_layout_kwargs = {} diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index 9400509e..e9da8a99 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -1089,6 +1089,65 @@ def read_oton_templates(as_img=True, resample_to=False): return template_dict +massp_fnames = [ + "left_VTA.nii.gz", + "right_VTA.nii.gz", +] + +massp_remote_fnames = [ + "34892325", + "34892319", +] + +massp_md5_hashes = [ + "03d65d85abb161ea25501c343c136e40", + "440874b899d2c1057e5fd77b8b350bc4", +] + +fetch_massp_templates = _make_reusable_fetcher( + "fetch_massp_templates", + op.join(afq_home, "massp_templates"), + baseurl, + massp_remote_fnames, + massp_fnames, + md5_list=massp_md5_hashes, + doc="Download AFQ MassP templates", +) + + +def read_massp_templates(as_img=True, resample_to=False): + """Load AFQ MASSSP templates from file + + Parameters + ---------- + as_img : bool, optional + If True, values are `Nifti1Image`. Otherwise, values are + paths to Nifti files. Default: True + resample_to : str or nibabel image class instance, optional + A template image to resample to. Typically, this should be the + template to which individual-level data are registered. Defaults to + the MNI template. Default: False + + Returns + ------- + dict with: keys: names of template ROIs and values: nibabel Nifti1Image + objects from each of the ROI nifti files. + """ + logger = logging.getLogger("AFQ") + + logger.debug("loading oton templates") + tic = time.perf_counter() + + template_dict = _fetcher_to_template( + fetch_massp_templates, as_img=as_img, resample_to=resample_to + ) + + toc = time.perf_counter() + logger.debug(f"MASSSP templates loaded in {toc - tic:0.4f} seconds") + + return template_dict + + cp_fnames = [ "ICP_L_inferior_prob.nii.gz", "ICP_L_superior_prob.nii.gz", diff --git a/AFQ/definitions/image.py b/AFQ/definitions/image.py index 4208056c..985d270b 100644 --- a/AFQ/definitions/image.py +++ b/AFQ/definitions/image.py @@ -384,7 +384,7 @@ def _image_getter_helper( for bundle_name in bundle_dict: bundle_entry = bundle_dict.transform_rois( - bundle_name, mapping_imap["mapping"], data_imap["dwi_affine"] + bundle_name, mapping_imap["mapping"], data_imap["dwi"] ) rois = [] if self.use_endpoints: diff --git a/AFQ/nn/synthseg.py b/AFQ/nn/synthseg.py index 45f80f52..6ea9d0c2 100644 --- a/AFQ/nn/synthseg.py +++ b/AFQ/nn/synthseg.py @@ -86,8 +86,8 @@ def pve_from_synthseg(synthseg_data): PVE data with CSF, GM, and WM segmentations. """ - CSF_labels = [0, 3, 4, 11, 12, 21, 22, 17] - GM_labels = [2, 7, 8, 9, 10, 14, 15, 16, 20, 25, 26, 27, 28, 29, 30, 31] + CSF_labels = [0, 3, 4, 11, 12, 21, 22, 16] + GM_labels = [2, 7, 8, 9, 10, 14, 15, 17, 20, 25, 26, 27, 28, 29, 30, 31] WM_labels = [1, 5, 19, 23] mixed_labels = [13, 18, 32] diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index e68e6251..1e0bdccb 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -392,9 +392,7 @@ def run_bundle_rec_plan( start_time = time() bundle_def = dict(bundle_dict.get_b_info(bundle_name)) bundle_def.update( - bundle_dict.transform_rois( - bundle_name, mapping, img.affine, apply_to_recobundles=True - ) + bundle_dict.transform_rois(bundle_name, mapping, img, apply_to_recobundles=True) ) def check_space(roi): diff --git a/AFQ/tasks/decorators.py b/AFQ/tasks/decorators.py index bb1d750c..3c041287 100644 --- a/AFQ/tasks/decorators.py +++ b/AFQ/tasks/decorators.py @@ -26,7 +26,6 @@ logger = logging.getLogger("AFQ") -logger.setLevel(logging.INFO) def get_new_signature(og_func, needed_args): diff --git a/AFQ/tasks/mapping.py b/AFQ/tasks/mapping.py index 33006030..c593371a 100644 --- a/AFQ/tasks/mapping.py +++ b/AFQ/tasks/mapping.py @@ -85,7 +85,7 @@ def export_rois(base_fname, output_dir, dwi_data_file, data_imap, mapping): *bundle_dict.transform_rois( bundle_name, mapping, - data_imap["dwi_affine"], + data_imap["dwi"], base_fname=base_roi_fname, to_space=to_space, ) diff --git a/docs/source/references.bib b/docs/source/references.bib index 3dba804f..ec7b1eb7 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -9,6 +9,27 @@ @article{Grotheer2022 publisher={Nature Publishing Group} } +@article{leong2016white, + title={White-matter tract connecting anterior insula to nucleus accumbens correlates with reduced preference for positively skewed gambles}, + author={Leong, Josiah K and Pestilli, Franco and Wu, Charlene C and Samanez-Larkin, Gregory R and Knutson, Brian}, + journal={Neuron}, + volume={89}, + number={1}, + pages={63--69}, + year={2016}, + publisher={Elsevier} +} + +@article{alkemade2020amsterdam, + title={The Amsterdam Ultra-high field adult lifespan database (AHEAD): A freely available multimodal 7 Tesla submillimeter magnetic resonance imaging database}, + author={Alkemade, Anneke and Mulder, Martijn J and Groot, Josephine M and Isaacs, Bethany R and van Berendonk, Nikita and Lute, Nicky and Isherwood, Scott JS and Bazin, Pierre-Louis and Forstmann, Birte U}, + journal={NeuroImage}, + volume={221}, + pages={117200}, + year={2020}, + publisher={Elsevier} +} + @article{grotheer2023human, title={Human white matter myelinates faster in utero than ex utero}, author={Grotheer, Mareike and Bloom, David and Kruper, John and Richie-Halford, Adam and Zika, Stephanie and Aguilera González, Vicente A and Yeatman, Jason D and Grill-Spector, Kalanit and Rokem, Ariel}, From 6e2bb56f563b3087106fe38522349d21c269cb5c Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 28 Jan 2026 10:34:52 +0900 Subject: [PATCH 02/86] finish up wmgmi seeding changes --- AFQ/definitions/image.py | 52 ++++++++++++++++++++++++++++++----- AFQ/recognition/preprocess.py | 16 +---------- AFQ/recognition/utils.py | 17 ++++++++++++ AFQ/tasks/utils.py | 3 +- 4 files changed, 65 insertions(+), 23 deletions(-) diff --git a/AFQ/definitions/image.py b/AFQ/definitions/image.py index 985d270b..85454908 100644 --- a/AFQ/definitions/image.py +++ b/AFQ/definitions/image.py @@ -3,8 +3,10 @@ import nibabel as nib import numpy as np from dipy.align import resample +from scipy.ndimage import distance_transform_edt from AFQ.definitions.utils import Definition, find_file, name_from_path +from AFQ.recognition.utils import tolerance_mm_to_vox from AFQ.tasks.utils import get_tp __all__ = [ @@ -324,6 +326,13 @@ class RoiImage(ImageDefinition): use_endpoints : bool Whether to use the endpoints ("start" and "end") to generate the image. + only_wmgmi : bool + Whether to only include portion of ROIs in the WM-GM interface. + only_wm : bool + Whether to only include portion of ROIs in the white matter. + dilate : bool + Whether to dilate the ROIs before combining them, according to the + tolerance that will be used during bundle recognition. tissue_property : str or None Tissue property from `scalars` to multiply the ROI image with. Can be useful to limit seed mask to the core white matter. @@ -350,14 +359,20 @@ def __init__( use_presegment=False, use_endpoints=False, only_wmgmi=False, + only_wm=False, + dilate=True, tissue_property=None, tissue_property_n_voxel=None, tissue_property_threshold=None, ): + if only_wmgmi and only_wm: + raise ValueError("only_wmgmi and only_wm cannot both be True") self.use_waypoints = use_waypoints self.use_presegment = use_presegment self.use_endpoints = use_endpoints self.only_wmgmi = only_wmgmi + self.only_wm = only_wm + self.dilate = dilate self.tissue_property = tissue_property self.tissue_property_n_voxel = tissue_property_n_voxel self.tissue_property_threshold = tissue_property_threshold @@ -386,21 +401,40 @@ def _image_getter_helper( bundle_entry = bundle_dict.transform_rois( bundle_name, mapping_imap["mapping"], data_imap["dwi"] ) - rois = [] + rois = {} if self.use_endpoints: - rois.extend( - [ - bundle_entry[end_type] + rois.update( + { + bundle_entry[end_type]: end_type for end_type in ["start", "end"] if end_type in bundle_entry - ] + } ) if self.use_waypoints: - rois.extend(bundle_entry.get("include", [])) - for roi in rois: + rois.update( + dict.fromkeys(bundle_entry.get("include", []), "waypoint") + ) + for roi, roi_type in rois.items(): warped_roi = roi.get_fdata() if image_data is None: image_data = np.zeros(warped_roi.shape) + if self.dilate: + dist_to_waypoint, dist_to_atlas, _ = tolerance_mm_to_vox( + data_imap["dwi"], + segmentation_params["dist_to_waypoint"], + segmentation_params["dist_to_atlas"], + ) + edt = distance_transform_edt(np.where(warped_roi == 0, 1, 0)) + if roi_type == "waypoint": + warped_roi = edt <= dist_to_waypoint + else: + warped_roi = edt <= dist_to_atlas + warped_roi = ( + edt <= dist_to_waypoint + if roi_type == "waypoint" + else edt <= dist_to_atlas + ) + image_data = np.logical_or(image_data, warped_roi.astype(bool)) if self.tissue_property is not None: tp = nib.load( @@ -456,6 +490,10 @@ def _image_getter_helper( ) ) + if self.only_wm: + wm = nib.load(tissue_imap["pve_internal"]).get_fdata()[..., 2] >= 0.5 + image_data = np.logical_and(image_data, wm) + return nib.Nifti1Image( image_data.astype(np.float32), data_imap["dwi_affine"] ), dict(source="ROIs") diff --git a/AFQ/recognition/preprocess.py b/AFQ/recognition/preprocess.py index 8bb41e67..8b3ce657 100644 --- a/AFQ/recognition/preprocess.py +++ b/AFQ/recognition/preprocess.py @@ -1,7 +1,6 @@ import logging from time import time -import dipy.tracking.streamline as dts import immlib import nibabel as nib import numpy as np @@ -13,20 +12,7 @@ @immlib.calc("tol", "dist_to_atlas", "vox_dim") def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): - # We need to calculate the size of a voxel, so we can transform - # from mm to voxel units: - R = img.affine[0:3, 0:3] - vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) - - # Tolerance is set to the square of the distance to the corner - # because we are using the squared Euclidean distance in calls to - # `cdist` to make those calls faster. - if dist_to_waypoint is None: - tol = dts.dist_to_corner(img.affine) - else: - tol = dist_to_waypoint / vox_dim - dist_to_atlas = int(input_dist_to_atlas / vox_dim) - return tol, dist_to_atlas, vox_dim + return abu.tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas) @immlib.calc("fgarray") diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index fd34aa18..505f8300 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -14,6 +14,23 @@ logger = logging.getLogger("AFQ") +def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): + # We need to calculate the size of a voxel, so we can transform + # from mm to voxel units: + R = img.affine[0:3, 0:3] + vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) + + # Tolerance is set to the square of the distance to the corner + # because we are using the squared Euclidean distance in calls to + # `cdist` to make those calls faster. + if dist_to_waypoint is None: + tol = dts.dist_to_corner(img.affine) + else: + tol = dist_to_waypoint / vox_dim + dist_to_atlas = int(input_dist_to_atlas / vox_dim) + return tol, dist_to_atlas, vox_dim + + def flip_sls(select_sl, idx_to_flip, in_place=False): """ Helper function to flip streamlines diff --git a/AFQ/tasks/utils.py b/AFQ/tasks/utils.py index 128fd157..e3cfb1fe 100644 --- a/AFQ/tasks/utils.py +++ b/AFQ/tasks/utils.py @@ -29,7 +29,8 @@ def get_base_fname(output_dir, dwi_data_file): key = key_val_pair.split("-")[0] if key not in used_key_list: fname = fname + key_val_pair + "_" - fname = fname[:-1] + if fname[-1] == "_": + fname = fname[:-1] return fname From 0545fa7954f15a845b7fc0680229c3127a4a1407 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 28 Jan 2026 13:58:40 +0900 Subject: [PATCH 03/86] more mixed ROI fixes --- AFQ/api/bundle_dict.py | 15 +++++++++++++-- AFQ/api/group.py | 2 ++ AFQ/api/participant.py | 4 ++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 9168da7b..b9735357 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -11,6 +11,7 @@ from AFQ.tasks.utils import get_fname, str_to_desc logger = logging.getLogger("AFQ") +logger.setLevel(logging.INFO) __all__ = [ @@ -1139,6 +1140,14 @@ def _cond_load(self, roi_or_sl, resample_to): resample_to = self.resample_to elif space == "subject": resample_to = self.resample_subject_to + if resample_to is False: + raise ValueError( + ( + "When using mixed ROI bundle definitions, " + "and subject space ROIs, " + "resample_subject_to cannot be False." + ) + ) else: raise ValueError( ( @@ -1296,7 +1305,7 @@ def _roi_transform_helper(self, roi_or_sl, mapping, new_img, bundle_name): if isinstance(roi_or_sl, nib.Nifti1Image): if ( np.allclose(roi_or_sl.affine, new_img.affine) - and roi_or_sl.shape == new_img.shape + and roi_or_sl.shape == new_img.shape[:3] ): # This is the case of a mixed bundle definition, where # some ROIs need transformed and others do not @@ -1418,7 +1427,9 @@ def transform_rois( def __add__(self, other): for resample in ["resample_to", "resample_subject_to"]: - if ( + if getattr(self, resample) == getattr(other, resample): + pass + elif ( not getattr(self, resample) or not getattr(other, resample) or getattr(self, resample) is None diff --git a/AFQ/api/group.py b/AFQ/api/group.py index e2b19707..3faef912 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -50,6 +50,8 @@ logger = logging.getLogger("AFQ") +logger.setLevel(logging.INFO) + warnings.simplefilter(action="ignore", category=FutureWarning) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 827b38b5..a99d4645 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -33,6 +33,10 @@ __all__ = ["ParticipantAFQ"] +logger = logging.getLogger("AFQ") +logger.setLevel(logging.INFO) + + class ParticipantAFQ(object): f"""{AFQclass_doc}""" From 643395d4cf82455769f6acc8c0a9d795c57a3bee Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 28 Jan 2026 15:07:08 +0900 Subject: [PATCH 04/86] BFs --- AFQ/api/bundle_dict.py | 2 +- AFQ/tasks/utils.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index b9735357..ed37dd16 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -302,7 +302,7 @@ def default_bd(): "primary_axis_percentage": 40, }, }, - citations={"Yeatman2012", "takemura2017occipital"}, + citations={"Yeatman2012", "takemura2017occipital", "Tzourio-Mazoyer2002"}, ) diff --git a/AFQ/tasks/utils.py b/AFQ/tasks/utils.py index e3cfb1fe..e57eb393 100644 --- a/AFQ/tasks/utils.py +++ b/AFQ/tasks/utils.py @@ -31,6 +31,10 @@ def get_base_fname(output_dir, dwi_data_file): fname = fname + key_val_pair + "_" if fname[-1] == "_": fname = fname[:-1] + else: + # if no key value pairs found, + # have some default base file name + fname = fname + "subject" return fname From 03dc8b00b6317c6eda6e6d7c415991a55c7ea136 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 28 Jan 2026 15:29:26 +0900 Subject: [PATCH 05/86] return to setting logging to info --- AFQ/api/group.py | 1 + AFQ/api/participant.py | 1 + 2 files changed, 2 insertions(+) diff --git a/AFQ/api/group.py b/AFQ/api/group.py index 3faef912..c1b5a0e0 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -49,6 +49,7 @@ __all__ = ["GroupAFQ"] +logging.basicConfig(level=logging.INFO) logger = logging.getLogger("AFQ") logger.setLevel(logging.INFO) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index a99d4645..87f6b08c 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -33,6 +33,7 @@ __all__ = ["ParticipantAFQ"] +logging.basicConfig(level=logging.INFO) logger = logging.getLogger("AFQ") logger.setLevel(logging.INFO) From b0ce1a05ea52cec405294b5c9d80f2cca40557cd Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 28 Jan 2026 15:53:34 +0900 Subject: [PATCH 06/86] minor docs fixes from copilot --- AFQ/api/bundle_dict.py | 4 ++-- AFQ/api/group.py | 2 +- AFQ/data/fetch.py | 6 +++--- AFQ/definitions/image.py | 16 ++++++---------- 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index ed37dd16..273bd72e 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -1351,8 +1351,8 @@ def transform_rois( Name of the bundle to be transformed. mapping : DiffeomorphicMap object A mapping between DWI space and a template. - new_affine : array - Affine of space transformed into. + new_img : Nifti1Image + Image of space transformed into. base_fname : str, optional Base file path to construct file path from. Additional BIDS descriptors will be added to this file path. If None, diff --git a/AFQ/api/group.py b/AFQ/api/group.py index c1b5a0e0..1e4c7c49 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -144,7 +144,7 @@ def __init__( api.GroupAFQ(my_path, csd_sh_order_max=4) api.GroupAFQ( my_path, - _spec="mni_t2", reg_subject_spec="b0") + reg_template_spec="mni_t2", reg_subject_spec="b0") """ if bids_layout_kwargs is None: bids_layout_kwargs = {} diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index e9da8a99..c88fc74e 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -1116,7 +1116,7 @@ def read_oton_templates(as_img=True, resample_to=False): def read_massp_templates(as_img=True, resample_to=False): - """Load AFQ MASSSP templates from file + """Load AFQ MASSP templates from file Parameters ---------- @@ -1135,7 +1135,7 @@ def read_massp_templates(as_img=True, resample_to=False): """ logger = logging.getLogger("AFQ") - logger.debug("loading oton templates") + logger.debug("loading MASSP templates") tic = time.perf_counter() template_dict = _fetcher_to_template( @@ -1143,7 +1143,7 @@ def read_massp_templates(as_img=True, resample_to=False): ) toc = time.perf_counter() - logger.debug(f"MASSSP templates loaded in {toc - tic:0.4f} seconds") + logger.debug(f"MASSP templates loaded in {toc - tic:0.4f} seconds") return template_dict diff --git a/AFQ/definitions/image.py b/AFQ/definitions/image.py index 85454908..df5df2b4 100644 --- a/AFQ/definitions/image.py +++ b/AFQ/definitions/image.py @@ -414,26 +414,22 @@ def _image_getter_helper( rois.update( dict.fromkeys(bundle_entry.get("include", []), "waypoint") ) + + dist_to_waypoint, dist_to_atlas, _ = tolerance_mm_to_vox( + data_imap["dwi"], + segmentation_params["dist_to_waypoint"], + segmentation_params["dist_to_atlas"], + ) for roi, roi_type in rois.items(): warped_roi = roi.get_fdata() if image_data is None: image_data = np.zeros(warped_roi.shape) if self.dilate: - dist_to_waypoint, dist_to_atlas, _ = tolerance_mm_to_vox( - data_imap["dwi"], - segmentation_params["dist_to_waypoint"], - segmentation_params["dist_to_atlas"], - ) edt = distance_transform_edt(np.where(warped_roi == 0, 1, 0)) if roi_type == "waypoint": warped_roi = edt <= dist_to_waypoint else: warped_roi = edt <= dist_to_atlas - warped_roi = ( - edt <= dist_to_waypoint - if roi_type == "waypoint" - else edt <= dist_to_atlas - ) image_data = np.logical_or(image_data, warped_roi.astype(bool)) if self.tissue_property is not None: From bdf4b581502507d5ee3beaf1428e2a3ad52af77e Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 29 Jan 2026 13:38:06 +0900 Subject: [PATCH 07/86] toy with vof --- AFQ/api/bundle_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 273bd72e..8408bb2b 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -279,7 +279,7 @@ def default_bd(): "entire_core": "Anterior", }, "Left Inferior Fronto-occipital": {"core": "Right"}, - "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, + "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25}, "isolation_forest": {}, "primary_axis": "I/S", @@ -295,7 +295,7 @@ def default_bd(): "entire_core": "Anterior", }, "Right Inferior Fronto-occipital": {"core": "Left"}, - "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, + "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25}, "isolation_forest": {}, "primary_axis": "I/S", From 806e53daab5be4b38ce12b33a41cfa5c513a4d63 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 29 Jan 2026 15:20:07 +0900 Subject: [PATCH 08/86] much needed speedups in bundle recognition --- AFQ/api/bundle_dict.py | 12 +++++++++-- AFQ/recognition/criteria.py | 33 ++++++++++++++++------------- AFQ/recognition/roi.py | 42 ++++++++++++++++--------------------- 3 files changed, 46 insertions(+), 41 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 8408bb2b..be14a633 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -185,6 +185,7 @@ def default_bd(): "prob_map": templates["IFO_L_prob_map"], "end": templates["IFO_L_start"], "start": templates["IFO_L_end"], + "length": {"min_len": 100, "max_len": 250}, }, "Right Inferior Fronto-occipital": { "cross_midline": False, @@ -194,6 +195,7 @@ def default_bd(): "prob_map": templates["IFO_R_prob_map"], "end": templates["IFO_R_start"], "start": templates["IFO_R_end"], + "length": {"min_len": 100, "max_len": 250}, }, "Left Inferior Longitudinal": { "cross_midline": False, @@ -221,6 +223,7 @@ def default_bd(): "prob_map": templates["ARC_L_prob_map"], "start": templates["ARC_L_start"], "end": templates["ARC_L_end"], + "length": {"min_len": 50, "max_len": 250}, }, "Right Arcuate": { "cross_midline": False, @@ -230,6 +233,7 @@ def default_bd(): "prob_map": templates["ARC_R_prob_map"], "start": templates["ARC_R_start"], "end": templates["ARC_R_end"], + "length": {"min_len": 50, "max_len": 250}, }, "Left Uncinate": { "cross_midline": False, @@ -254,8 +258,10 @@ def default_bd(): "include": [templates["SLFt_roi2_L"]], "exclude": [templates["SLF_roi1_L"]], "space": "template", + "prob_map": templates["ARC_L_prob_map"], # Better than nothing "start": templates["pARC_L_start"], "Left Arcuate": {"overlap": 30}, + "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, }, @@ -264,8 +270,10 @@ def default_bd(): "include": [templates["SLFt_roi2_R"]], "exclude": [templates["SLF_roi1_R"]], "space": "template", + "prob_map": templates["ARC_R_prob_map"], # Better than nothing "start": templates["pARC_R_start"], "Right Arcuate": {"overlap": 30}, + "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, }, @@ -280,7 +288,7 @@ def default_bd(): }, "Left Inferior Fronto-occipital": {"core": "Right"}, "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, - "length": {"min_len": 25}, + "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, "primary_axis": "I/S", "primary_axis_percentage": 40, @@ -296,7 +304,7 @@ def default_bd(): }, "Right Inferior Fronto-occipital": {"core": "Left"}, "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, - "length": {"min_len": 25}, + "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, "primary_axis": "I/S", "primary_axis_percentage": 40, diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 1e0bdccb..5e078006 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -24,9 +24,9 @@ criteria_order_pre_other_bundles = [ "prob_map", "cross_midline", + "length", "start", "end", - "length", "primary_axis", "include", "exclude", @@ -69,18 +69,17 @@ def cross_midline(b_sls, bundle_def, preproc_imap, **kwargs): def start(b_sls, bundle_def, preproc_imap, **kwargs): - accept_idx = b_sls.initiate_selection("Startpoint") - abr.clean_by_endpoints( - b_sls.get_selected_sls(), + b_sls.initiate_selection("Startpoint") + accept_idx = abr.clean_by_endpoints( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["start"], 0, tol=preproc_imap["dist_to_atlas"], flip_sls=b_sls.sls_flipped, - accepted_idxs=accept_idx, ) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( - b_sls.get_selected_sls(), + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["start"], -1, tol=preproc_imap["dist_to_atlas"], @@ -91,18 +90,17 @@ def start(b_sls, bundle_def, preproc_imap, **kwargs): def end(b_sls, bundle_def, preproc_imap, **kwargs): - accept_idx = b_sls.initiate_selection("endpoint") - abr.clean_by_endpoints( - b_sls.get_selected_sls(), + b_sls.initiate_selection("endpoint") + accept_idx = abr.clean_by_endpoints( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["end"], -1, tol=preproc_imap["dist_to_atlas"], flip_sls=b_sls.sls_flipped, - accepted_idxs=accept_idx, ) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( - b_sls.get_selected_sls(), + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["end"], 0, tol=preproc_imap["dist_to_atlas"], @@ -116,10 +114,15 @@ def length(b_sls, bundle_def, preproc_imap, **kwargs): accept_idx = b_sls.initiate_selection("length") min_len = bundle_def["length"].get("min_len", 0) / preproc_imap["vox_dim"] max_len = bundle_def["length"].get("max_len", np.inf) / preproc_imap["vox_dim"] - for idx, sl in enumerate(b_sls.get_selected_sls()): - sl_len = np.sum(np.linalg.norm(np.diff(sl, axis=0), axis=1)) - if sl_len >= min_len and sl_len <= max_len: - accept_idx[idx] = 1 + + # Using resampled fgarray biases lengths to be lower. However, + # this is not meant to be a precise selection requirement, and + # is more meant for efficiency. + segments = np.diff(preproc_imap["fgarray"][b_sls.selected_fiber_idxs], axis=1) + segment_lengths = np.sqrt(np.sum(segments**2, axis=2)) + sl_lens = np.sum(segment_lengths, axis=1) + + accept_idx = (sl_lens >= min_len) & (sl_lens <= max_len) b_sls.select(accept_idx, "length") diff --git a/AFQ/recognition/roi.py b/AFQ/recognition/roi.py index d87062f5..83e1cf53 100644 --- a/AFQ/recognition/roi.py +++ b/AFQ/recognition/roi.py @@ -6,8 +6,7 @@ def _interp3d(roi, sl): return interpolate_scalar_3d(roi.get_fdata(), np.asarray(sl))[0] -def check_sls_with_inclusion( - sls, include_rois, include_roi_tols): +def check_sls_with_inclusion(sls, include_rois, include_roi_tols): inc_results = np.zeros(len(sls), dtype=tuple) include_rois = [roi_.get_fdata().copy() for roi_ in include_rois] for jj, sl in enumerate(sls): @@ -30,9 +29,8 @@ def check_sls_with_inclusion( return inc_results -def check_sl_with_exclusion(sl, exclude_rois, - exclude_roi_tols): - """ Helper function to check that a streamline is not too close to a +def check_sl_with_exclusion(sl, exclude_rois, exclude_roi_tols): + """Helper function to check that a streamline is not too close to a list of exclusion ROIs. """ for ii, roi in enumerate(exclude_rois): @@ -44,17 +42,15 @@ def check_sl_with_exclusion(sl, exclude_rois, return True -def clean_by_endpoints(streamlines, target, target_idx, tol=0, - flip_sls=None, accepted_idxs=None): +def clean_by_endpoints(fgarray, target, target_idx, tol=0, flip_sls=None): """ Clean a collection of streamlines based on an endpoint ROI. Filters down to only include items that have their start or end points close to the targets. Parameters ---------- - streamlines : sequence of N by 3 arrays - Where N is number of nodes in the array, the collection of - streamlines to filter down to. + fgarray : ndarray of shape (N, M, 3) + Where N is number of streamlines, M is number of nodes. target: Nifti1Image Nifti1Image containing a distance transform of the ROI. target_idx: int. @@ -67,24 +63,22 @@ def clean_by_endpoints(streamlines, target, target_idx, tol=0, the endpoint is exactly in the coordinate of the target ROI. flip_sls : 1d array, optional Length is len(streamlines), whether to flip the streamline. - accepted_idxs : 1d array, optional - Boolean array, where entries correspond to eachs streamline, - and streamlines that pass cleaning will be set to 1. Yields ------- boolean array of streamlines that survive cleaning. """ - if accepted_idxs is None: - accepted_idxs = np.zeros(len(streamlines), dtype=np.bool_) + n_sls, n_nodes, _ = fgarray.shape - if flip_sls is None: - flip_sls = np.zeros(len(streamlines)) - flip_sls = flip_sls.astype(int) + # handle target_idx negative values as wrapping around + effective_idx = target_idx if target_idx >= 0 else (n_nodes + target_idx) + indices = np.full(n_sls, effective_idx) - for ii, sl in enumerate(streamlines): - this_idx = target_idx - if flip_sls[ii]: - this_idx = (len(sl) - this_idx - 1) % len(sl) - accepted_idxs[ii] = _interp3d(target, [sl[this_idx]])[0] <= tol + if flip_sls is not None: + flipped_indices = n_nodes - 1 - effective_idx + indices = np.where(flip_sls.astype(bool), flipped_indices, indices) - return accepted_idxs + distances = interpolate_scalar_3d( + target.get_fdata(), fgarray[np.arange(n_sls), indices] + )[0] + + return distances <= tol From cc725ca50283ce99d468c720a8126ea639361468 Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 30 Jan 2026 15:52:21 +0900 Subject: [PATCH 09/86] tighter node thresh --- AFQ/api/bundle_dict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index be14a633..5044d2e4 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -281,9 +281,9 @@ def default_bd(): "cross_midline": False, "space": "template", "end": templates["VOF_L_end"], - "Left Arcuate": {"node_thresh": 20}, + "Left Arcuate": {"node_thresh": 10}, "Left Posterior Arcuate": { - "node_thresh": 20, + "node_thresh": 10, "entire_core": "Anterior", }, "Left Inferior Fronto-occipital": {"core": "Right"}, @@ -297,9 +297,9 @@ def default_bd(): "cross_midline": False, "space": "template", "end": templates["VOF_R_end"], - "Right Arcuate": {"node_thresh": 20}, + "Right Arcuate": {"node_thresh": 10}, "Right Posterior Arcuate": { - "node_thresh": 20, + "node_thresh": 10, "entire_core": "Anterior", }, "Right Inferior Fronto-occipital": {"core": "Left"}, From 77635cea4e770d9dbf2b621442de93a438e69fbc Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 30 Jan 2026 17:06:08 +0900 Subject: [PATCH 10/86] bring back parietal endpoint ROIs --- AFQ/api/bundle_dict.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 5044d2e4..d670086b 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -280,6 +280,7 @@ def default_bd(): "Left Vertical Occipital": { "cross_midline": False, "space": "template", + "start": templates["VOF_L_start"], "end": templates["VOF_L_end"], "Left Arcuate": {"node_thresh": 10}, "Left Posterior Arcuate": { @@ -296,6 +297,7 @@ def default_bd(): "Right Vertical Occipital": { "cross_midline": False, "space": "template", + "start": templates["VOF_R_start"], "end": templates["VOF_R_end"], "Right Arcuate": {"node_thresh": 10}, "Right Posterior Arcuate": { From bb9b9b7fed0078b6724788c641080563599bcde2 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 2 Feb 2026 10:17:07 +0900 Subject: [PATCH 11/86] Add projection to node threshold --- AFQ/api/bundle_dict.py | 8 ++++---- AFQ/recognition/criteria.py | 6 ++++-- AFQ/recognition/other_bundles.py | 22 +++++++++++++++++++++- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index d670086b..2dceec74 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -280,11 +280,11 @@ def default_bd(): "Left Vertical Occipital": { "cross_midline": False, "space": "template", - "start": templates["VOF_L_start"], "end": templates["VOF_L_end"], - "Left Arcuate": {"node_thresh": 10}, + "Left Arcuate": {"node_thresh": 10, "project": "L/R"}, "Left Posterior Arcuate": { "node_thresh": 10, + "project": "L/R", "entire_core": "Anterior", }, "Left Inferior Fronto-occipital": {"core": "Right"}, @@ -297,11 +297,11 @@ def default_bd(): "Right Vertical Occipital": { "cross_midline": False, "space": "template", - "start": templates["VOF_R_start"], "end": templates["VOF_R_end"], - "Right Arcuate": {"node_thresh": 10}, + "Right Arcuate": {"node_thresh": 10, "project": "L/R"}, "Right Posterior Arcuate": { "node_thresh": 10, + "project": "L/R", "entire_core": "Anterior", }, "Right Inferior Fronto-occipital": {"core": "Left"}, diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 5e078006..0514e445 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -307,7 +307,8 @@ def clean_by_other_bundle( other_bundle_sls, bundle_def[other_bundle_name]["overlap"], img, - False, + remove=False, + project=bundle_def[other_bundle_name].get("project", None), ) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_overlap) @@ -317,7 +318,8 @@ def clean_by_other_bundle( other_bundle_sls, bundle_def[other_bundle_name]["node_thresh"], img, - True, + remove=True, + project=bundle_def[other_bundle_name].get("project", None), ) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_node_thresh) diff --git a/AFQ/recognition/other_bundles.py b/AFQ/recognition/other_bundles.py index 1e94106b..9eaa4694 100644 --- a/AFQ/recognition/other_bundles.py +++ b/AFQ/recognition/other_bundles.py @@ -9,7 +9,9 @@ logger = logging.getLogger("AFQ") -def clean_by_overlap(this_bundle_sls, other_bundle_sls, overlap, img, remove=False): +def clean_by_overlap( + this_bundle_sls, other_bundle_sls, overlap, img, remove=False, project=None +): """ Cleans a set of streamlines by only keeping (or removing) those with significant overlap with another set of streamlines. @@ -32,6 +34,11 @@ def clean_by_overlap(this_bundle_sls, other_bundle_sls, overlap, img, remove=Fal removed. If False, streamlines that overlap in more than `overlap` nodes are removed. Default: False. + project : {'A/P', 'I/S', 'L/R', None}, optional + If specified, the overlap calculation is projected along the given axis + before cleaning. For example, 'A/P' projects the streamlines along the + anterior-posterior axis. + Default: None. Returns ------- @@ -56,6 +63,19 @@ def clean_by_overlap(this_bundle_sls, other_bundle_sls, overlap, img, remove=Fal other_bundle_density_map = dtu.density_map( other_bundle_sls, np.eye(4), img.shape[:3] ) + + if project is not None: + orientation = nib.orientations.aff2axcodes(img.affine) + core_axis = next( + idx for idx, label in enumerate(orientation) if label in project.upper() + ) + + projection = np.sum(other_bundle_density_map, axis=core_axis) + + other_bundle_density_map = np.broadcast_to( + np.expand_dims(projection, axis=core_axis), other_bundle_density_map.shape + ) + fiber_probabilities = dts.values_from_volume( other_bundle_density_map, this_bundle_sls, np.eye(4) ) From bc800835e1f0e34f853752f1bceb60fef9bf98e6 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 2 Feb 2026 15:01:41 +0900 Subject: [PATCH 12/86] for large tractographies, this 5 percent rule may be necessary --- AFQ/api/bundle_dict.py | 8 ++++---- AFQ/recognition/other_bundles.py | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 2dceec74..6e69995a 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -281,9 +281,9 @@ def default_bd(): "cross_midline": False, "space": "template", "end": templates["VOF_L_end"], - "Left Arcuate": {"node_thresh": 10, "project": "L/R"}, + "Left Arcuate": {"node_thresh": 20, "project": "L/R"}, "Left Posterior Arcuate": { - "node_thresh": 10, + "node_thresh": 20, "project": "L/R", "entire_core": "Anterior", }, @@ -298,9 +298,9 @@ def default_bd(): "cross_midline": False, "space": "template", "end": templates["VOF_R_end"], - "Right Arcuate": {"node_thresh": 10, "project": "L/R"}, + "Right Arcuate": {"node_thresh": 20, "project": "L/R"}, "Right Posterior Arcuate": { - "node_thresh": 10, + "node_thresh": 20, "project": "L/R", "entire_core": "Anterior", }, diff --git a/AFQ/recognition/other_bundles.py b/AFQ/recognition/other_bundles.py index 9eaa4694..e975765a 100644 --- a/AFQ/recognition/other_bundles.py +++ b/AFQ/recognition/other_bundles.py @@ -10,7 +10,13 @@ def clean_by_overlap( - this_bundle_sls, other_bundle_sls, overlap, img, remove=False, project=None + this_bundle_sls, + other_bundle_sls, + overlap, + img, + remove=False, + project=None, + other_bundle_min_density=0.05, ): """ Cleans a set of streamlines by only keeping (or removing) those with @@ -39,6 +45,11 @@ def clean_by_overlap( before cleaning. For example, 'A/P' projects the streamlines along the anterior-posterior axis. Default: None. + other_bundle_min_density : float, optional + A threshold to binarize the density map of `other_bundle_sls`. Voxels + with density values above this threshold (as a fraction of the maximum + density) are considered occupied. + Default: 0.05. Returns ------- @@ -64,6 +75,10 @@ def clean_by_overlap( other_bundle_sls, np.eye(4), img.shape[:3] ) + other_bundle_density_map = ( + other_bundle_density_map / other_bundle_density_map.max() + ) > other_bundle_min_density + if project is not None: orientation = nib.orientations.aff2axcodes(img.affine) core_axis = next( From b0e33f2a737e2c9a485358fc91da3280403ec1c0 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 2 Feb 2026 15:52:40 +0900 Subject: [PATCH 13/86] add exclude ROI to pAF --- AFQ/api/bundle_dict.py | 4 ++-- AFQ/recognition/other_bundles.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 6e69995a..e5aedad7 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -256,7 +256,7 @@ def default_bd(): "Left Posterior Arcuate": { "cross_midline": False, "include": [templates["SLFt_roi2_L"]], - "exclude": [templates["SLF_roi1_L"]], + "exclude": [templates["SLF_roi1_L"], templates["IFO_roi1_L"]], "space": "template", "prob_map": templates["ARC_L_prob_map"], # Better than nothing "start": templates["pARC_L_start"], @@ -268,7 +268,7 @@ def default_bd(): "Right Posterior Arcuate": { "cross_midline": False, "include": [templates["SLFt_roi2_R"]], - "exclude": [templates["SLF_roi1_R"]], + "exclude": [templates["SLF_roi1_R"], templates["IFO_roi1_R"]], "space": "template", "prob_map": templates["ARC_R_prob_map"], # Better than nothing "start": templates["pARC_R_start"], diff --git a/AFQ/recognition/other_bundles.py b/AFQ/recognition/other_bundles.py index e975765a..3b04cee0 100644 --- a/AFQ/recognition/other_bundles.py +++ b/AFQ/recognition/other_bundles.py @@ -75,9 +75,10 @@ def clean_by_overlap( other_bundle_sls, np.eye(4), img.shape[:3] ) - other_bundle_density_map = ( - other_bundle_density_map / other_bundle_density_map.max() - ) > other_bundle_min_density + if remove: + other_bundle_density_map = ( + other_bundle_density_map / other_bundle_density_map.max() + ) > other_bundle_min_density if project is not None: orientation = nib.orientations.aff2axcodes(img.affine) From bf37ca4f006dc2dfae64260954700ddcca06d882 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 3 Feb 2026 16:58:11 +0900 Subject: [PATCH 14/86] update montage code --- AFQ/api/participant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 87f6b08c..d28672b3 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -277,7 +277,7 @@ def participant_montage(self, images_per_row=2): bundle_dict = self.export("bundle_dict") self.logger.info("Generating Montage...") viz_backend = self.export("viz_backend") - best_scalar = self.export(self.export("best_scalar")) + best_scalar = self.kwargs["best_scalar"] t1 = nib.load(self.export("t1_masked")) size = (images_per_row, math.ceil(len(bundle_dict) / images_per_row)) for ii, bundle_name in enumerate(tqdm(bundle_dict)): From 13740dd36923b18c9097517fbf8c23349fa74fbd Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 4 Feb 2026 09:35:16 +0900 Subject: [PATCH 15/86] fix participant montage --- AFQ/api/participant.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index d28672b3..95c73e4e 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -5,6 +5,8 @@ from time import time import nibabel as nib +import numpy as np +from dipy.align import resample from PIL import Image, ImageDraw, ImageFont from tqdm import tqdm @@ -277,8 +279,9 @@ def participant_montage(self, images_per_row=2): bundle_dict = self.export("bundle_dict") self.logger.info("Generating Montage...") viz_backend = self.export("viz_backend") - best_scalar = self.kwargs["best_scalar"] t1 = nib.load(self.export("t1_masked")) + best_scalar = nib.load(self.export(self.kwargs["best_scalar"])) + best_scalar = resample(best_scalar, t1) size = (images_per_row, math.ceil(len(bundle_dict) / images_per_row)) for ii, bundle_name in enumerate(tqdm(bundle_dict)): flip_axes = [False, False, False] @@ -286,12 +289,12 @@ def participant_montage(self, images_per_row=2): flip_axes[i] = self.export("dwi_affine")[i, i] < 0 figure = viz_backend.visualize_volume( - t1, flip_axes=flip_axes, interact=False, inline=False + t1.get_fdata(), flip_axes=flip_axes, interact=False, inline=False ) figure = viz_backend.visualize_bundles( self.export("bundles"), - affine=t1.affine, - shade_by_volume=best_scalar, + img=t1, + shade_by_volume=best_scalar.get_fdata(), color_by_direction=True, flip_axes=flip_axes, bundle=bundle_name, From 78f8c176c744324b6d6a01b4d98e8d906c500ebb Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 4 Feb 2026 09:52:05 +0900 Subject: [PATCH 16/86] improve participant montage --- AFQ/api/participant.py | 119 ++++++++++++++++++++++------------------- AFQ/viz/utils.py | 44 +++++++-------- 2 files changed, 85 insertions(+), 78 deletions(-) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 95c73e4e..4a10818f 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -5,7 +5,7 @@ from time import time import nibabel as nib -import numpy as np +from math import radians from dipy.align import resample from PIL import Image, ImageDraw, ImageFont from tqdm import tqdm @@ -299,40 +299,46 @@ def participant_montage(self, images_per_row=2): flip_axes=flip_axes, bundle=bundle_name, figure=figure, + n_points=40, interact=False, inline=False, ) - view, direc = BEST_BUNDLE_ORIENTATIONS.get(bundle_name, ("Axial", "Top")) - eye = get_eye(view, direc) - - this_fname = tdir + f"/t{ii}.png" - if "plotly" in viz_backend.backend: - figure.update_layout( - scene_camera=dict( - projection=dict(type="orthographic"), - up={"x": 0, "y": 0, "z": 1}, - eye=eye, - center=dict(x=0, y=0, z=0), - ), - showlegend=False, - ) - figure.write_image(this_fname, scale=4) - - # temporary fix for memory leak - import plotly.io as pio - - pio.kaleido.scope._shutdown_kaleido() - else: - from fury import window - - from AFQ.viz.fury_backend import scene_rotate_forward + for jj, view in enumerate(["Sagittal", "Coronal", "Axial"]): + direc = BEST_BUNDLE_ORIENTATIONS.get( + bundle_name, ("Left", "Front", "Top") + )[jj] + + eye = get_eye(view, direc) + + this_fname = tdir + f"/t{ii}_{view}.png" + if "plotly" in viz_backend.backend: + figure.update_layout( + scene_camera=dict( + projection=dict(type="orthographic"), + up={"x": 0, "y": 0, "z": 1}, + eye=eye, + center=dict(x=0, y=0, z=0), + ), + showlegend=False, + ) + figure.write_image(this_fname, scale=4) + else: + from fury import window - show_m = window.ShowManager( - scene=figure, window_type="offscreen", size=(600, 600) - ) - scene_rotate_forward(show_m, figure) - show_m.snapshot(this_fname) + show_m = window.ShowManager( + scene=figure, window_type="offscreen", size=(600, 600) + ) + window.update_camera(show_m.screens[0].camera, None, figure) + if view == "Coronal": + show_m.screens[0].controller.rotate((0, radians(-eye["y"] * 90)), None) + elif view == "Axial": + show_m.screens[0].controller.rotate((radians(eye["z"] * 90), 0, 0), None) + elif view == "Sagittal": + pass + show_m.render() + show_m.window.draw() + show_m.snapshot(this_fname) def _save_file(curr_img): save_path = op.abspath( @@ -345,33 +351,34 @@ def _save_file(curr_img): max_height = 0 max_width = 0 for ii, bundle_name in enumerate(bundle_dict): - this_img = Image.open(tdir + f"/t{ii}.png") - try: - this_img_trimmed[ii] = trim(this_img) - except IndexError: # this_img is a picture of nothing - this_img_trimmed[ii] = this_img - - text_sz = 70 - width, height = this_img_trimmed[ii].size - height = height + text_sz - result = Image.new( - this_img_trimmed[ii].mode, (width, height), color=(255, 255, 255) - ) - result.paste(this_img_trimmed[ii], (0, text_sz)) - this_img_trimmed[ii] = result - - draw = ImageDraw.Draw(this_img_trimmed[ii]) - draw.text( - (0, 0), - bundle_name, - (0, 0, 0), - font=ImageFont.truetype("Arial", text_sz), - ) + for view in ["Axial", "Coronal", "Sagittal"]: + this_img = Image.open(tdir + f"/t{ii}_{view}.png") + try: + this_img_trimmed[ii] = trim(this_img) + except IndexError: # this_img is a picture of nothing + this_img_trimmed[ii] = this_img + + text_sz = 70 + width, height = this_img_trimmed[ii].size + height = height + text_sz + result = Image.new( + this_img_trimmed[ii].mode, (width, height), color=(255, 255, 255) + ) + result.paste(this_img_trimmed[ii], (0, text_sz)) + this_img_trimmed[ii] = result + + draw = ImageDraw.Draw(this_img_trimmed[ii]) + draw.text( + (0, 0), + bundle_name, + (0, 0, 0), + font=ImageFont.truetype("Arial", text_sz), + ) - if this_img_trimmed[ii].size[0] > max_width: - max_width = this_img_trimmed[ii].size[0] - if this_img_trimmed[ii].size[1] > max_height: - max_height = this_img_trimmed[ii].size[1] + if this_img_trimmed[ii].size[0] > max_width: + max_width = this_img_trimmed[ii].size[0] + if this_img_trimmed[ii].size[1] > max_height: + max_height = this_img_trimmed[ii].size[1] curr_img = Image.new( "RGB", (max_width * size[0], max_height * size[1]), color="white" diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index def69151..4ffe316a 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -150,28 +150,28 @@ RECO_FLIP = ["IFO_L", "IFO_R", "UNC_L", "ILF_L", "ILF_R"] BEST_BUNDLE_ORIENTATIONS = { - "Left Anterior Thalamic": ("Sagittal", "Left"), - "Right Anterior Thalamic": ("Sagittal", "Right"), - "Left Corticospinal": ("Sagittal", "Left"), - "Right Corticospinal": ("Sagittal", "Right"), - "Left Cingulum Cingulate": ("Sagittal", "Left"), - "Right Cingulum Cingulate": ("Sagittal", "Right"), - "Forceps Minor": ("Axial", "Top"), - "Forceps Major": ("Axial", "Top"), - "Left Inferior Fronto-occipital": ("Sagittal", "Left"), - "Right Inferior Fronto-occipital": ("Sagittal", "Right"), - "Left Inferior Longitudinal": ("Sagittal", "Left"), - "Right Inferior Longitudinal": ("Sagittal", "Right"), - "Left Superior Longitudinal": ("Axial", "Top"), - "Right Superior Longitudinal": ("Axial", "Top"), - "Left Uncinate": ("Axial", "Bottom"), - "Right Uncinate": ("Axial", "Bottom"), - "Left Arcuate": ("Sagittal", "Left"), - "Right Arcuate": ("Sagittal", "Right"), - "Left Vertical Occipital": ("Coronal", "Back"), - "Right Vertical Occipital": ("Coronal", "Back"), - "Left Posterior Arcuate": ("Coronal", "Back"), - "Right Posterior Arcuate": ("Coronal", "Back"), + "Left Anterior Thalamic": ("Left", "Front", "Top"), + "Right Anterior Thalamic": ("Right", "Front", "Top"), + "Left Corticospinal": ("Left", "Front", "Top"), + "Right Corticospinal": ("Right", "Front", "Top"), + "Left Cingulum Cingulate": ("Left", "Front", "Top"), + "Right Cingulum Cingulate": ("Right", "Front", "Top"), + "Forceps Minor": ("Left", "Front", "Top"), + "Forceps Major": ("Left", "Back", "Top"), + "Left Inferior Fronto-occipital": ("Left", "Front", "Bottom"), + "Right Inferior Fronto-occipital": ("Right", "Front", "Bottom"), + "Left Inferior Longitudinal": ("Left", "Front", "Bottom"), + "Right Inferior Longitudinal": ("Right", "Front", "Bottom"), + "Left Superior Longitudinal": ("Left", "Front", "Top"), + "Right Superior Longitudinal": ("Right", "Front", "Top"), + "Left Uncinate": ("Left", "Front", "Bottom"), + "Right Uncinate": ("Right", "Front", "Bottom"), + "Left Arcuate": ("Left", "Front", "Top"), + "Right Arcuate": ("Right", "Front", "Top"), + "Left Vertical Occipital": ("Left", "Back", "Top"), + "Right Vertical Occipital": ("Right", "Back", "Top"), + "Left Posterior Arcuate": ("Left", "Back", "Top"), + "Right Posterior Arcuate": ("Right", "Back", "Top"), } From 3d582c66c577304f4fdeae1897d7fc39e7658791 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 4 Feb 2026 10:13:44 +0900 Subject: [PATCH 17/86] More participant montage improvements --- AFQ/api/participant.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 4a10818f..1aed1b8d 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -259,7 +259,7 @@ def export_all(self, viz=True, xforms=True, indiv=True): export_all_helper(self, xforms, indiv, viz) self.logger.info(f"Time taken for export all: {time() - start_time}") - def participant_montage(self, images_per_row=2): + def participant_montage(self, images_per_row=3): """ Generate montage of all bundles for a given subject. @@ -267,7 +267,7 @@ def participant_montage(self, images_per_row=2): ---------- images_per_row : int Number of bundle images per row in output file. - Default: 2 + Default: 3 Returns ------- @@ -282,7 +282,7 @@ def participant_montage(self, images_per_row=2): t1 = nib.load(self.export("t1_masked")) best_scalar = nib.load(self.export(self.kwargs["best_scalar"])) best_scalar = resample(best_scalar, t1) - size = (images_per_row, math.ceil(len(bundle_dict) / images_per_row)) + size = (images_per_row, math.ceil(3 * len(bundle_dict) / images_per_row)) for ii, bundle_name in enumerate(tqdm(bundle_dict)): flip_axes = [False, False, False] for i in range(3): @@ -350,15 +350,16 @@ def _save_file(curr_img): this_img_trimmed = {} max_height = 0 max_width = 0 - for ii, bundle_name in enumerate(bundle_dict): + ii = 0 + for b_idx, bundle_name in enumerate(bundle_dict): for view in ["Axial", "Coronal", "Sagittal"]: - this_img = Image.open(tdir + f"/t{ii}_{view}.png") + this_img = Image.open(tdir + f"/t{b_idx}_{view}.png") try: this_img_trimmed[ii] = trim(this_img) except IndexError: # this_img is a picture of nothing this_img_trimmed[ii] = this_img - text_sz = 70 + text_sz = 40 width, height = this_img_trimmed[ii].size height = height + text_sz result = Image.new( @@ -370,26 +371,27 @@ def _save_file(curr_img): draw = ImageDraw.Draw(this_img_trimmed[ii]) draw.text( (0, 0), - bundle_name, + f"{bundle_name} - {view}", (0, 0, 0), - font=ImageFont.truetype("Arial", text_sz), + font=ImageFont.load_default(text_sz), ) if this_img_trimmed[ii].size[0] > max_width: max_width = this_img_trimmed[ii].size[0] if this_img_trimmed[ii].size[1] > max_height: max_height = this_img_trimmed[ii].size[1] + ii += 1 curr_img = Image.new( "RGB", (max_width * size[0], max_height * size[1]), color="white" ) - for ii in range(len(bundle_dict)): - x_pos = ii % size[0] - _ii = ii // size[0] + for jj in range(ii): + x_pos = jj % size[0] + _ii = jj // size[0] y_pos = _ii % size[1] _ii = _ii // size[1] - this_img = this_img_trimmed[ii].resize((max_width, max_height)) + this_img = this_img_trimmed[jj].resize((max_width, max_height)) curr_img.paste(this_img, (x_pos * max_width, y_pos * max_height)) _save_file(curr_img) From 2e8ca2e30a0a2316fee96185f70298eaeeb7d4c7 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 4 Feb 2026 10:29:19 +0900 Subject: [PATCH 18/86] add more options to p montage --- AFQ/api/participant.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 1aed1b8d..8b34030b 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -259,7 +259,7 @@ def export_all(self, viz=True, xforms=True, indiv=True): export_all_helper(self, xforms, indiv, viz) self.logger.info(f"Time taken for export all: {time() - start_time}") - def participant_montage(self, images_per_row=3): + def participant_montage(self, images_per_row=3, anatomy=True, bundle_names=None): """ Generate montage of all bundles for a given subject. @@ -269,6 +269,14 @@ def participant_montage(self, images_per_row=3): Number of bundle images per row in output file. Default: 3 + anatomy : bool + Whether to include anatomical images in the montage. + Default: True + + bundle_names : list of str or None + List of bundle names to include in the montage. + Default: None (includes all bundles) + Returns ------- filename of montage images @@ -276,21 +284,26 @@ def participant_montage(self, images_per_row=3): tdir = tempfile.gettempdir() all_fnames = [] - bundle_dict = self.export("bundle_dict") + if bundle_names is None: + bundle_dict = self.export("bundle_dict") + bundle_names = list(bundle_dict.keys()) self.logger.info("Generating Montage...") viz_backend = self.export("viz_backend") t1 = nib.load(self.export("t1_masked")) best_scalar = nib.load(self.export(self.kwargs["best_scalar"])) best_scalar = resample(best_scalar, t1) - size = (images_per_row, math.ceil(3 * len(bundle_dict) / images_per_row)) - for ii, bundle_name in enumerate(tqdm(bundle_dict)): + size = (images_per_row, math.ceil(3 * len(bundle_names) / images_per_row)) + for ii, bundle_name in enumerate(tqdm(bundle_names)): flip_axes = [False, False, False] for i in range(3): flip_axes[i] = self.export("dwi_affine")[i, i] < 0 - figure = viz_backend.visualize_volume( - t1.get_fdata(), flip_axes=flip_axes, interact=False, inline=False - ) + if anatomy: + figure = viz_backend.visualize_volume( + t1.get_fdata(), flip_axes=flip_axes, interact=False, inline=False + ) + else: + figure = None figure = viz_backend.visualize_bundles( self.export("bundles"), img=t1, @@ -351,7 +364,7 @@ def _save_file(curr_img): max_height = 0 max_width = 0 ii = 0 - for b_idx, bundle_name in enumerate(bundle_dict): + for b_idx, bundle_name in enumerate(bundle_names): for view in ["Axial", "Coronal", "Sagittal"]: this_img = Image.open(tdir + f"/t{b_idx}_{view}.png") try: From 944509b0f3506ac3e7e10165434277bcb8c872ac Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 6 Feb 2026 22:02:03 +0900 Subject: [PATCH 19/86] remove warnings from segmentedsft --- AFQ/tasks/segmentation.py | 2 +- AFQ/utils/streamlines.py | 19 ++++++++----------- AFQ/utils/tests/test_streamlines.py | 2 +- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index 2781778e..79cd64cb 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -113,7 +113,7 @@ def segment(data_imap, mapping_imap, tractography_imap, segmentation_params): **segmentation_params, ) - seg_sft = aus.SegmentedSFT(bundles, Space.VOX) + seg_sft = aus.SegmentedSFT(bundles) if len(seg_sft.sft) < 1: raise ValueError("Fatal: No bundles recognized.") diff --git a/AFQ/utils/streamlines.py b/AFQ/utils/streamlines.py index 8c37801c..1bb027c0 100644 --- a/AFQ/utils/streamlines.py +++ b/AFQ/utils/streamlines.py @@ -15,10 +15,9 @@ class SegmentedSFT: - def __init__(self, bundles, space, sidecar_info=None): + def __init__(self, bundles, sidecar_info=None): if sidecar_info is None: sidecar_info = {} - reference = None self.bundle_names = [] sls = [] idxs = {} @@ -26,20 +25,17 @@ def __init__(self, bundles, space, sidecar_info=None): idx_count = 0 for b_name in bundles: if isinstance(bundles[b_name], dict): - this_sls = bundles[b_name]["sl"] + this_sft = bundles[b_name]["sl"] this_tracking_idxs[b_name] = bundles[b_name]["idx"] else: - this_sls = bundles[b_name] - if reference is None: - reference = this_sls - this_sls = list(this_sls.streamlines) + this_sft = bundles[b_name] + this_sls = list(this_sft.streamlines) sls.extend(this_sls) new_idx_count = idx_count + len(this_sls) idxs[b_name] = np.arange(idx_count, new_idx_count, dtype=np.uint32) idx_count = new_idx_count self.bundle_names.append(b_name) - self.sft = StatefulTractogram(sls, reference, space) self.bundle_idxs = idxs if len(this_tracking_idxs) > 1: self.this_tracking_idxs = this_tracking_idxs @@ -48,12 +44,13 @@ def __init__(self, bundles, space, sidecar_info=None): self.sidecar_info = sidecar_info self.sidecar_info["bundle_ids"] = {} - dps = np.zeros(len(self.sft.streamlines)) + dps = np.zeros(len(sls)) for ii, bundle_name in enumerate(self.bundle_names): self.sidecar_info["bundle_ids"][f"{bundle_name}"] = ii + 1 dps[self.bundle_idxs[bundle_name]] = ii + 1 - dps = {"bundle": dps} - self.sft.data_per_streamline = dps + self.sft = StatefulTractogram.from_sft( + sls, this_sft, data_per_streamline={"bundle": dps} + ) if self.this_tracking_idxs is not None: for kk, _vv in self.this_tracking_idxs.items(): self.this_tracking_idxs[kk] = ( diff --git a/AFQ/utils/tests/test_streamlines.py b/AFQ/utils/tests/test_streamlines.py index 39c624c7..8931670f 100644 --- a/AFQ/utils/tests/test_streamlines.py +++ b/AFQ/utils/tests/test_streamlines.py @@ -37,7 +37,7 @@ def test_SegmentedSFT(): ), } - seg_sft = aus.SegmentedSFT(bundles, Space.VOX) + seg_sft = aus.SegmentedSFT(bundles) for k1 in bundles.keys(): for sl1, sl2 in zip( bundles[k1].streamlines, seg_sft.get_bundle(k1).streamlines From 5d82e9e037614b097a674946f20498f670ab1b94 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 09:54:18 +0900 Subject: [PATCH 20/86] bf --- AFQ/utils/streamlines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/utils/streamlines.py b/AFQ/utils/streamlines.py index 1bb027c0..571534a1 100644 --- a/AFQ/utils/streamlines.py +++ b/AFQ/utils/streamlines.py @@ -105,7 +105,7 @@ def fromfile(cls, trk_or_trx_file, reference="same", sidecar_file=None): else: bundles["whole_brain"] = sft - return cls(bundles, Space.RASMM, sidecar_info) + return cls(bundles, sidecar_info) def split_streamline(streamlines, sl_to_split, split_idx): From 4a6bfe4e6d94b80a63c91666be8ee2f985469637 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 10:36:10 +0900 Subject: [PATCH 21/86] viz bug fix --- AFQ/viz/plotly_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/viz/plotly_backend.py b/AFQ/viz/plotly_backend.py index 0076c1b9..1af6acb4 100644 --- a/AFQ/viz/plotly_backend.py +++ b/AFQ/viz/plotly_backend.py @@ -512,7 +512,7 @@ def create_gif(figure, file_name, n_frames=30, zoom=2.5, z_offset=0.5, size=(600 def _draw_roi(figure, roi, name, color, opacity, dimensions, flip_axes): - roi = np.where(roi == 1) + roi = np.where(roi > 0) pts = [] for i, flip in enumerate(flip_axes): if flip: From e500d0c561d9599c33598381dad7aa4dd41c628a Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 10:36:36 +0900 Subject: [PATCH 22/86] further restrict pAF --- AFQ/api/bundle_dict.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index e5aedad7..de2b444f 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -256,7 +256,11 @@ def default_bd(): "Left Posterior Arcuate": { "cross_midline": False, "include": [templates["SLFt_roi2_L"]], - "exclude": [templates["SLF_roi1_L"], templates["IFO_roi1_L"]], + "exclude": [ + templates["SLF_roi1_L"], + templates["IFO_roi1_L"], + templates["ILF_L_end"], + ], "space": "template", "prob_map": templates["ARC_L_prob_map"], # Better than nothing "start": templates["pARC_L_start"], @@ -268,7 +272,11 @@ def default_bd(): "Right Posterior Arcuate": { "cross_midline": False, "include": [templates["SLFt_roi2_R"]], - "exclude": [templates["SLF_roi1_R"], templates["IFO_roi1_R"]], + "exclude": [ + templates["SLF_roi1_R"], + templates["IFO_roi1_R"], + templates["ILF_R_end"], + ], "space": "template", "prob_map": templates["ARC_R_prob_map"], # Better than nothing "start": templates["pARC_R_start"], From df40e59fee1c8b0bc0e20a90755d7d3e45085631 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 11:34:14 +0900 Subject: [PATCH 23/86] try more constrained pAF/ARC defs --- AFQ/api/bundle_dict.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index de2b444f..aad19247 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -218,7 +218,7 @@ def default_bd(): "Left Arcuate": { "cross_midline": False, "include": [templates["SLF_roi1_L"], templates["SLFt_roi2_L"]], - "exclude": [], + "exclude": [templates["IFO_roi1_L"]], "space": "template", "prob_map": templates["ARC_L_prob_map"], "start": templates["ARC_L_start"], @@ -228,7 +228,7 @@ def default_bd(): "Right Arcuate": { "cross_midline": False, "include": [templates["SLF_roi1_R"], templates["SLFt_roi2_R"]], - "exclude": [], + "exclude": [templates["IFO_roi1_R"]], "space": "template", "prob_map": templates["ARC_R_prob_map"], "start": templates["ARC_R_start"], @@ -259,11 +259,13 @@ def default_bd(): "exclude": [ templates["SLF_roi1_L"], templates["IFO_roi1_L"], - templates["ILF_L_end"], + templates["ILF_roi2_L"], + templates["HCC_roi2_L"], ], "space": "template", - "prob_map": templates["ARC_L_prob_map"], # Better than nothing + "prob_map": templates["ARC_L_prob_map"], "start": templates["pARC_L_start"], + "end": templates["VOF_L_end"], "Left Arcuate": {"overlap": 30}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", @@ -275,11 +277,13 @@ def default_bd(): "exclude": [ templates["SLF_roi1_R"], templates["IFO_roi1_R"], - templates["ILF_R_end"], + templates["ILF_roi2_R"], + templates["HCC_roi2_R"], ], "space": "template", - "prob_map": templates["ARC_R_prob_map"], # Better than nothing + "prob_map": templates["ARC_R_prob_map"], "start": templates["pARC_R_start"], + "end": templates["VOF_R_end"], "Right Arcuate": {"overlap": 30}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", From 6d80a65a8b56b3b51a01dff76b9591838d390366 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 12:04:29 +0900 Subject: [PATCH 24/86] try this --- AFQ/api/bundle_dict.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index aad19247..8b253ae0 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -267,6 +267,7 @@ def default_bd(): "start": templates["pARC_L_start"], "end": templates["VOF_L_end"], "Left Arcuate": {"overlap": 30}, + "Left Inferior Longitudinal": {"node_thresh": 40}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, @@ -285,6 +286,7 @@ def default_bd(): "start": templates["pARC_R_start"], "end": templates["VOF_R_end"], "Right Arcuate": {"overlap": 30}, + "Right Inferior Longitudinal": {"node_thresh": 40}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, From 24c810fa9d9389eb46cc3cf4abbb46eb71416ad2 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 12:21:39 +0900 Subject: [PATCH 25/86] tighten ILF constraint --- AFQ/api/bundle_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 8b253ae0..47e8ce00 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -267,7 +267,7 @@ def default_bd(): "start": templates["pARC_L_start"], "end": templates["VOF_L_end"], "Left Arcuate": {"overlap": 30}, - "Left Inferior Longitudinal": {"node_thresh": 40}, + "Left Inferior Longitudinal": {"node_thresh": 20}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, @@ -286,7 +286,7 @@ def default_bd(): "start": templates["pARC_R_start"], "end": templates["VOF_R_end"], "Right Arcuate": {"overlap": 30}, - "Right Inferior Longitudinal": {"node_thresh": 40}, + "Right Inferior Longitudinal": {"node_thresh": 20}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, From 1ba3f614b7f8e5cc04f966d7a4302a2c854092c1 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 14:26:44 +0900 Subject: [PATCH 26/86] solve pAF issues with new exclusion ROI --- AFQ/api/bundle_dict.py | 8 ++------ AFQ/data/fetch.py | 6 ++++++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 47e8ce00..553ede71 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -259,15 +259,13 @@ def default_bd(): "exclude": [ templates["SLF_roi1_L"], templates["IFO_roi1_L"], - templates["ILF_roi2_L"], - templates["HCC_roi2_L"], + templates["pARC_xroi1_L"], ], "space": "template", "prob_map": templates["ARC_L_prob_map"], "start": templates["pARC_L_start"], "end": templates["VOF_L_end"], "Left Arcuate": {"overlap": 30}, - "Left Inferior Longitudinal": {"node_thresh": 20}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, @@ -278,15 +276,13 @@ def default_bd(): "exclude": [ templates["SLF_roi1_R"], templates["IFO_roi1_R"], - templates["ILF_roi2_R"], - templates["HCC_roi2_R"], + templates["pARC_xroi1_R"], ], "space": "template", "prob_map": templates["ARC_R_prob_map"], "start": templates["pARC_R_start"], "end": templates["VOF_R_end"], "Right Arcuate": {"overlap": 30}, - "Right Inferior Longitudinal": {"node_thresh": 20}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index c88fc74e..eae03481 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -759,6 +759,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "ATR_R_start.nii.gz", "ATR_L_end.nii.gz", "ATR_L_start.nii.gz", + "pARC_xroi1_L.nii.gz", + "pARC_xroi1_R.nii.gz", ] @@ -861,6 +863,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "40944074", "40944077", "40944080", + "61737616", + "61737619", ] @@ -964,6 +968,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "ffc157e9f73a43eff23821f2cfca614a", "a8d308a93b26242c04b878c733cb252f", "1c0b570bb2d622718b01ee2c429a5d15", + "51c8a6b5fbb0834b03986093b9ee4fa3", + "7cf5800a4efa6bac7e70d84095bc259b", ] fetch_templates = _make_reusable_fetcher( From 1a892abd9bcc15f75942836c0277c017c4d0be6f Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 16:34:22 +0900 Subject: [PATCH 27/86] return to strict VOF seg --- AFQ/api/bundle_dict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 553ede71..dc96824a 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -297,8 +297,8 @@ def default_bd(): "project": "L/R", "entire_core": "Anterior", }, - "Left Inferior Fronto-occipital": {"core": "Right"}, - "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, + "Left Inferior Longitudinal": {"core": "Right"}, + "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, "primary_axis": "I/S", @@ -314,8 +314,8 @@ def default_bd(): "project": "L/R", "entire_core": "Anterior", }, - "Right Inferior Fronto-occipital": {"core": "Left"}, - "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, + "Right Inferior Longitudinal": {"core": "Left"}, + "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, "primary_axis": "I/S", From 43fddd2d521ed1d082827f0791b7d758a971b224 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 10 Feb 2026 10:18:51 +0900 Subject: [PATCH 28/86] return to stricter cleaning --- AFQ/api/bundle_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index dc96824a..917f29ee 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -298,7 +298,7 @@ def default_bd(): "entire_core": "Anterior", }, "Left Inferior Longitudinal": {"core": "Right"}, - "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, + "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, "primary_axis": "I/S", @@ -315,7 +315,7 @@ def default_bd(): "entire_core": "Anterior", }, "Right Inferior Longitudinal": {"core": "Left"}, - "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, + "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, "primary_axis": "I/S", From 2746f1b69ef1e43b30865096c74b97ad82218a45 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 10 Feb 2026 11:40:41 +0900 Subject: [PATCH 29/86] cleaning by other core requires higher levels of precision --- AFQ/recognition/criteria.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 0514e445..c0086c60 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -327,7 +327,7 @@ def clean_by_other_bundle( cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["core"].lower(), preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - np.array(abu.resample_tg(other_bundle_sls, 20)), + np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, False, ) @@ -337,7 +337,7 @@ def clean_by_other_bundle( cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["entire_core"].lower(), preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - np.array(abu.resample_tg(other_bundle_sls, 20)), + np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, True, ) From d40464b874e0b4e1f073854781e4e946bdf345e7 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 10 Feb 2026 13:47:44 +0900 Subject: [PATCH 30/86] bf --- AFQ/recognition/criteria.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index c0086c60..eaa224b3 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -327,6 +327,7 @@ def clean_by_other_bundle( cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["core"].lower(), preproc_imap["fgarray"][b_sls.selected_fiber_idxs], + # the extra specificity of 100 points is needed np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, False, @@ -337,7 +338,7 @@ def clean_by_other_bundle( cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["entire_core"].lower(), preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - np.array(abu.resample_tg(other_bundle_sls, 100)), + np.array(abu.resample_tg(other_bundle_sls, 20)), img.affine, True, ) From ee49598f1d921cc3fdadafaaa47641184eb80842 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 10 Feb 2026 15:53:26 +0900 Subject: [PATCH 31/86] maybe we can do this after clustering --- AFQ/api/bundle_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 917f29ee..c5102f50 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -297,7 +297,7 @@ def default_bd(): "project": "L/R", "entire_core": "Anterior", }, - "Left Inferior Longitudinal": {"core": "Right"}, + # "Left Inferior Longitudinal": {"core": "Right"}, "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, @@ -314,7 +314,7 @@ def default_bd(): "project": "L/R", "entire_core": "Anterior", }, - "Right Inferior Longitudinal": {"core": "Left"}, + # "Right Inferior Longitudinal": {"core": "Left"}, "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, From adb53109a389b5a835889365dd6101c79877c9ba Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 11 Feb 2026 12:12:44 +0900 Subject: [PATCH 32/86] Major registration overhaul --- AFQ/_fixes.py | 71 ++++++ AFQ/api/bundle_dict.py | 34 +-- AFQ/api/group.py | 11 +- AFQ/data/fetch.py | 109 ++++++++ AFQ/definitions/image.py | 2 +- AFQ/definitions/mapping.py | 232 ++++++++---------- AFQ/recognition/recognize.py | 16 +- AFQ/recognition/tests/test_recognition.py | 2 +- AFQ/recognition/utils.py | 8 +- AFQ/registration.py | 209 +++++++--------- AFQ/tasks/mapping.py | 6 +- AFQ/tasks/viz.py | 18 +- AFQ/tests/test_api.py | 5 - AFQ/tests/test_registration.py | 43 ++-- AFQ/utils/volume.py | 6 +- AFQ/viz/fury_backend.py | 28 +-- AFQ/viz/plotly_backend.py | 28 +-- AFQ/viz/utils.py | 65 ++--- .../plot_001_group_afq_api.py | 6 +- 19 files changed, 439 insertions(+), 460 deletions(-) diff --git a/AFQ/_fixes.py b/AFQ/_fixes.py index 9a62f20a..ab610cbf 100644 --- a/AFQ/_fixes.py +++ b/AFQ/_fixes.py @@ -4,6 +4,8 @@ from math import radians import numpy as np +from dipy.align import vector_fields as vfu +from dipy.align.imwarp import DiffeomorphicMap, mult_aff from dipy.data import default_sphere from dipy.reconst.gqi import squared_radial_component from dipy.tracking.streamline import set_number_of_points @@ -15,6 +17,75 @@ logger = logging.getLogger("AFQ") +def get_simplified_transform(self): + """Constructs a simplified version of this Diffeomorhic Map + + The simplified version incorporates the pre-align transform, as well as + the domain and codomain affine transforms into the displacement field. + The resulting transformation may be regarded as operating on the + image spaces given by the domain and codomain discretization. As a + result, self.prealign, self.disp_grid2world, self.domain_grid2world and + self.codomain affine will be None (denoting Identity) in the resulting + diffeomorphic map. + """ + if self.dim == 2: + simplify_f = vfu.simplify_warp_function_2d + else: + simplify_f = vfu.simplify_warp_function_3d + # Simplify the forward transform + D = self.domain_grid2world + P = self.prealign + Rinv = self.disp_world2grid + Cinv = self.codomain_world2grid + + # this is the matrix which we need to multiply the voxel coordinates + # to interpolate on the forward displacement field ("in"side the + # 'forward' brackets in the expression above) + affine_idx_in = mult_aff(Rinv, mult_aff(P, D)) + + # this is the matrix which we need to multiply the voxel coordinates + # to add to the displacement ("out"side the 'forward' brackets in the + # expression above) + affine_idx_out = mult_aff(Cinv, mult_aff(P, D)) + + # this is the matrix which we need to multiply the displacement vector + # prior to adding to the transformed input point + affine_disp = Cinv + + new_forward = simplify_f( + self.forward, affine_idx_in, affine_idx_out, affine_disp, self.domain_shape + ) + + # Simplify the backward transform + C = self.codomain_grid2world + Pinv = self.prealign_inv + Dinv = self.domain_world2grid + + affine_idx_in = mult_aff(Rinv, C) + affine_idx_out = mult_aff(Dinv, mult_aff(Pinv, C)) + affine_disp = mult_aff(Dinv, Pinv) + new_backward = simplify_f( + self.backward, + affine_idx_in, + affine_idx_out, + affine_disp, + self.codomain_shape, + ) + simplified = DiffeomorphicMap( + dim=self.dim, + disp_shape=self.disp_shape, + disp_grid2world=None, + domain_shape=self.domain_shape, + domain_grid2world=None, + codomain_shape=self.codomain_shape, + codomain_grid2world=None, + prealign=None, + ) + simplified.forward = new_forward + simplified.backward = new_backward + return simplified + + def gwi_odf(gqmodel, data): gqi_vector = np.real( squared_radial_component( diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index c5102f50..2c7497bb 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -334,61 +334,31 @@ def slf_bd(): "include": [templates["SFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, - }, }, "Left Superior Longitudinal II": { "include": [templates["MFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, - }, }, "Left Superior Longitudinal III": { "include": [templates["PrgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, - }, }, "Right Superior Longitudinal I": { "include": [templates["SFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, - }, }, "Right Superior Longitudinal II": { "include": [templates["MFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, - }, }, "Right Superior Longitudinal III": { "include": [templates["PrgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, - }, }, }, citations={"Sagi2024"}, @@ -1337,9 +1307,7 @@ def _roi_transform_helper(self, roi_or_sl, mapping, new_img, bundle_name): else: boolean_ = False - warped_img = auv.transform_inverse_roi( - fdata, mapping, bundle_name=bundle_name - ) + warped_img = auv.transform_roi(fdata, mapping, bundle_name=bundle_name) if boolean_: warped_img = warped_img.astype(np.uint8) diff --git a/AFQ/api/group.py b/AFQ/api/group.py index 1e4c7c49..d1f3bef8 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -8,7 +8,6 @@ import warnings from time import time -import dipy.tracking.streamline as dts import dipy.tracking.streamlinespeed as dps import nibabel as nib import numpy as np @@ -584,12 +583,7 @@ def load_next_subject(): these_sls = seg_sft.sft.streamlines[idx] these_sls = dps.set_number_of_points(these_sls, 100) tg = StatefulTractogram(these_sls, seg_sft.sft, Space.RASMM) - delta = dts.values_from_volume( - mapping.forward, tg.streamlines, np.eye(4) - ) - moved_sl = dts.Streamlines( - [d + s for d, s in zip(delta, tg.streamlines)] - ) + moved_sl = mapping.transform_points_inverse(tg.streamlines) moved_sl = np.asarray(moved_sl) median_sl = np.median(moved_sl, axis=0) sls_dict[b] = {"coreFiber": median_sl.tolist()} @@ -1026,8 +1020,7 @@ def combine_bundle(self, bundle_name): mapping = mapping_dict[this_sub][this_ses] if len(sls) > 0: - delta = dts.values_from_volume(mapping.forward, sls, np.eye(4)) - sls_mni.extend([d + s for d, s in zip(delta, sls)]) + sls_mni = mapping.tranform_points(sls) moved_sft = StatefulTractogram(sls_mni, reg_template, Space.VOX) diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index eae03481..b40f34ad 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -1095,6 +1095,115 @@ def read_oton_templates(as_img=True, resample_to=False): return template_dict +org800_fnames = [ + "ORG_atlas_tracks_reoriented.trx", + "ORG800_atlas_centroids.npy", + "ORG800_atlas_e_val.npy", + "ORG800_atlas_e_vec_norm.npy", + "ORG800_atlas_e_vec.npy", + "ORG800_atlas_number_of_eigenvectors.npy", + "ORG800_atlas_row_sum_1.npy", + "ORG800_atlas_row_sum_matrix.npy", + "ORG800_atlas_sigma.npy", +] + + +org800_remote_fnames = [ + "61762231", + "61762267", + "61762270", + "61762273", + "61762276", + "61762279", + "61762282", + "61762285", + "61762288", +] + + +org800_md5_hashes = [ + "9022799a73359209080ea832b22ec09b", + "09bfa384f5c44801dfa382d31392a979", + "bab61eb26cb21035e38b5f68b5fdad3e", + "9325f4cb168624d4f275785b18c9f859", + "12d426c5a6fcfbe3b8146bc335bdac96", + "db74e055c47c5b6354c3cb7bbf165f2c", + "e7e51f53b30764f104b93f50d94b6c3c", + "7e894c57a820cd7604a1db6b7ab8cce6", + "1e195b7055e98eb473bbb5af05d48f7d", +] + +fetch_org800_templates = _make_reusable_fetcher( + "fetch_org800_templates", + op.join(afq_home, "org800_templates"), + baseurl, + org800_remote_fnames, + org800_fnames, + md5_list=org800_md5_hashes, + doc="Download AFQ org800 templates", +) + + +def read_org800_templates(load_npy=True, load_trx=True): + """ + Load O'Donnell Research Group (ORG) Fiber Clustering White + Matter Atlas 800 modified for pyAFQ templates from file + + Parameters + ---------- + load_npy : bool, optional + If True, values are loaded as numpy arrays. Otherwise, values are + paths to npy files. Default: True + load_trx : bool, optional + If True, the tractogram is loaded as a StatefulTractogram. Otherwise, + the value is the path to the trx file. Default: True + + Returns + ------- + dict with: keys: names of atlas info + values: Floats, arrays, and StatefulTractogram for the atlas. + Any unloaded will instead be paths. + """ + logger = logging.getLogger("AFQ") + + logger.debug("loading org800 templates") + tic = time.perf_counter() + + template_dict = _fetcher_to_template(fetch_org800_templates) + + if load_trx: + template_dict["tracks_reoriented"] = load_tractogram( + template_dict.pop("ORG_atlas_tracks_reoriented"), + "same", + ) + if load_npy: + template_dict["centroids"] = np.load( + template_dict.pop("ORG800_atlas_centroids") + ) + template_dict["e_val"] = np.load(template_dict.pop("ORG800_atlas_e_val")) + template_dict["e_vec_norm"] = np.load( + template_dict.pop("ORG800_atlas_e_vec_norm") + ) + template_dict["e_vec"] = np.load(template_dict.pop("ORG800_atlas_e_vec")) + template_dict["number_of_eigenvectors"] = float( + np.load(template_dict.pop("ORG800_atlas_number_of_eigenvectors")) + ) + template_dict["row_sum_1"] = np.load( + template_dict.pop("ORG800_atlas_row_sum_1") + ) + template_dict["row_sum_matrix"] = np.load( + template_dict.pop("ORG800_atlas_row_sum_matrix") + ) + template_dict["sigma"] = float(np.load(template_dict.pop("ORG800_atlas_sigma"))) + + toc = time.perf_counter() + logger.debug( + f"O'Donnell Research Group 800 templates loaded in {toc - tic:0.4f} seconds" + ) + + return template_dict + + massp_fnames = [ "left_VTA.nii.gz", "right_VTA.nii.gz", diff --git a/AFQ/definitions/image.py b/AFQ/definitions/image.py index df5df2b4..4d68fb8b 100644 --- a/AFQ/definitions/image.py +++ b/AFQ/definitions/image.py @@ -932,7 +932,7 @@ def _image_getter_helper(mapping, reg_template, reg_subject): static_affine=reg_template.affine, ).get_fdata() - scalar_data = mapping.transform_inverse(img_data, interpolation="nearest") + scalar_data = mapping.transform(img_data, interpolation="nearest") return nib.Nifti1Image( scalar_data.astype(np.float32), reg_subject.affine ), dict(source=self.path) diff --git a/AFQ/definitions/mapping.py b/AFQ/definitions/mapping.py index dfcedef7..17171faf 100644 --- a/AFQ/definitions/mapping.py +++ b/AFQ/definitions/mapping.py @@ -5,9 +5,10 @@ import nibabel as nib import numpy as np from dipy.align import affine_registration, syn_registration -from dipy.align.imaffine import AffineMap +from dipy.align.streamlinear import whole_brain_slr import AFQ.registration as reg +from AFQ._fixes import get_simplified_transform from AFQ.definitions.utils import Definition, find_file from AFQ.tasks.utils import get_fname from AFQ.utils.path import space_from_fname, write_json @@ -177,11 +178,11 @@ def __init__(self, warp, ref_affine): self.ref_affine = ref_affine self.warp = warp - def transform_inverse(self, data, **kwargs): + def transform(self, data, **kwargs): data_img = Image(nib.Nifti1Image(data.astype(np.float32), self.ref_affine)) return np.asarray(applyDeformation(data_img, self.warp).data) - def transform_inverse_pts(self, pts): + def transform_pts(self, pts): # This should only be used for curvature analysis, # Because I think the results still need to be shifted pts = nib.affines.apply_affine(self.warp.src.getAffine("voxel", "world"), pts) @@ -189,39 +190,13 @@ def transform_inverse_pts(self, pts): pts = self.warp.transform(pts, "fsl", "world") return pts - def transform(self, data, **kwargs): + def transform_inverse(self, data, **kwargs): raise NotImplementedError( "Fnirt based mappings can currently" + " only transform from template to subject space" ) -class IdentityMap(Definition): - """ - Does not perform any transformations from MNI to subject where - pyAFQ normally would. - - Examples - -------- - my_example_mapping = IdentityMap() - api.GroupAFQ(mapping=my_example_mapping) - """ - - def __init__(self): - pass - - def get_for_subses( - self, base_fname, dwi, dwi_data_file, reg_subject, reg_template, tmpl_name - ): - return ConformedAffineMapping( - np.identity(4), - domain_grid_shape=reg.reduce_shape(reg_subject.shape), - domain_grid2world=reg_subject.affine, - codomain_grid_shape=reg.reduce_shape(reg_template.shape), - codomain_grid2world=reg_template.affine, - ) - - class GeneratedMapMixin(object): """ Helper Class @@ -236,24 +211,16 @@ def get_fnames(self, extension, base_fname, sub_name, tmpl_name): mapping_file = mapping_file + extension return mapping_file, meta_fname - def prealign( - self, base_fname, sub_name, tmpl_name, reg_subject, reg_template, save=True - ): - prealign_file_desc = f"_desc-prealign_from-{sub_name}_to-{tmpl_name}_xform" - prealign_file = get_fname(base_fname, f"{prealign_file_desc}.npy") - if not op.exists(prealign_file): - start_time = time() - _, aff = affine_registration( - reg_subject, reg_template, **self.affine_kwargs - ) - meta = dict(type="rigid", dependent="dwi", timing=time() - start_time) - if not save: - return aff - logger.info(f"Saving {prealign_file}") - np.save(prealign_file, aff) - meta_fname = get_fname(base_fname, f"{prealign_file_desc}.json") - write_json(meta_fname, meta) - return prealign_file if save else np.load(prealign_file) + def prealign(self, reg_subject, reg_template): + _, aff = affine_registration(reg_subject, reg_template, **self.affine_kwargs) + return aff + + +class AffineMapMixin(GeneratedMapMixin): + """ + Helper Class + Useful for maps that are generated by pyAFQ + """ def get_for_subses( self, @@ -268,34 +235,22 @@ def get_for_subses( ): sub_space = space_from_fname(dwi_data_file) mapping_file, meta_fname = self.get_fnames( - self.extension, base_fname, sub_space, tmpl_name + ".npy", base_fname, sub_space, tmpl_name ) - if self.use_prealign: - reg_prealign = np.load( - self.prealign( - base_fname, sub_space, tmpl_name, reg_subject, reg_template - ) - ) - else: - reg_prealign = None if not op.exists(mapping_file): start_time = time() - mapping = self.gen_mapping( - base_fname, - sub_space, - tmpl_name, - reg_subject, + self.gen_mapping( reg_template, + reg_subject, subject_sls, template_sls, - reg_prealign, ) total_time = time() - start_time logger.info(f"Saving {mapping_file}") - reg.write_mapping(mapping, mapping_file) - meta = dict(type="displacementfield", timing=total_time) + np.save(mapping_file, mapping.affine) + meta = dict(type="affine", timing=total_time) if subject_sls is None: meta["dependent"] = "dwi" else: @@ -305,10 +260,7 @@ def get_for_subses( if isinstance(reg_template, str): meta["reg_template"] = reg_template write_json(meta_fname, meta) - reg_prealign_inv = np.linalg.inv(reg_prealign) if self.use_prealign else None - mapping = reg.read_mapping( - mapping_file, dwi, reg_template, prealign=reg_prealign_inv - ) + mapping = reg.read_affine_mapping(mapping_file, dwi, reg_template) return mapping @@ -353,33 +305,69 @@ def __init__(self, use_prealign=True, affine_kwargs=None, syn_kwargs=None): self.use_prealign = use_prealign self.affine_kwargs = affine_kwargs self.syn_kwargs = syn_kwargs - self.extension = ".nii.gz" - def gen_mapping( + def get_for_subses( self, base_fname, - sub_space, - tmpl_name, + dwi, + dwi_data_file, reg_subject, reg_template, - subject_sls, - template_sls, - reg_prealign, + tmpl_name, + subject_sls=None, + template_sls=None, ): - _, mapping = syn_registration( - reg_subject.get_fdata(), - reg_template.get_fdata(), - moving_affine=reg_subject.affine, - static_affine=reg_template.affine, - prealign=reg_prealign, - **self.syn_kwargs, + sub_space = space_from_fname(dwi_data_file) + mapping_file_forward, meta_forward_fname = self.get_fnames( + ".nii.gz", base_fname, sub_space, tmpl_name ) - if self.use_prealign: - mapping.codomain_world2grid = np.linalg.inv(reg_prealign) + mapping_file_backward, meta_backward_fname = self.get_fnames( + ".nii.gz", base_fname, tmpl_name, sub_space + ) + + if not op.exists(mapping_file_forward) or not op.exists(mapping_file_backward): + meta = dict(type="displacementfield") + meta["dependent"] = "dwi" + if isinstance(reg_subject, str): + meta["reg_subject"] = reg_subject + if isinstance(reg_template, str): + meta["reg_template"] = reg_template + + start_time = time() + if self.use_prealign: + reg_prealign = self.prealign(reg_subject, reg_template) + else: + reg_prealign = None + _, mapping = syn_registration( + reg_subject.get_fdata(), + reg_template.get_fdata(), + moving_affine=reg_subject.affine, + static_affine=reg_template.affine, + prealign=np.linalg.inv(reg_prealign), + **self.syn_kwargs, + ) + mapping = get_simplified_transform(mapping) + + total_time = time() - start_time + meta["total_time"] = total_time + + logger.info(f"Saving {mapping_file_forward}") + nib.save( + nib.Nifti1Image(mapping.forward, reg_subject.affine), + mapping_file_forward, + ) + write_json(meta_forward_fname, meta) + logger.info(f"Saving {mapping_file_backward}") + nib.save( + nib.Nifti1Image(mapping.backward, reg_template.affine), + mapping_file_backward, + ) + write_json(meta_backward_fname, meta) + mapping = reg.read_syn_mapping(mapping_file_forward, mapping_file_backward) return mapping -class SlrMap(GeneratedMapMixin, Definition): +class SlrMap(AffineMapMixin, Definition): """ Calculate a SLR registration for each subject/session using reg_subject and reg_template. @@ -407,33 +395,23 @@ class SlrMap(GeneratedMapMixin, Definition): def __init__(self, slr_kwargs=None): if slr_kwargs is None: slr_kwargs = {} - self.slr_kwargs = {} - self.use_prealign = False - self.extension = ".npy" + self.slr_kwargs = slr_kwargs def gen_mapping( self, - base_fname, - sub_space, - tmpl_name, reg_template, reg_subject, subject_sls, template_sls, - reg_prealign, ): - return reg.slr_registration( - subject_sls, - template_sls, - moving_affine=reg_subject.affine, - moving_shape=reg_subject.shape, - static_affine=reg_template.affine, - static_shape=reg_template.shape, - **self.slr_kwargs, + _, transform, _, _ = whole_brain_slr( + subject_sls, template_sls, x0="affine", verbose=False, **self.slr_kwargs ) + return transform + -class AffMap(GeneratedMapMixin, Definition): +class AffMap(AffineMapMixin, Definition): """ Calculate an affine registration for each subject/session using reg_subject and reg_template. @@ -457,47 +435,39 @@ class AffMap(GeneratedMapMixin, Definition): def __init__(self, affine_kwargs=None): if affine_kwargs is None: affine_kwargs = {} - self.use_prealign = False self.affine_kwargs = affine_kwargs - self.extension = ".npy" def gen_mapping( self, - base_fname, - sub_space, - tmpl_name, reg_subject, reg_template, subject_sls, template_sls, - reg_prealign, ): - return ConformedAffineMapping( - np.linalg.inv( - self.prealign( - base_fname, - sub_space, - tmpl_name, - reg_subject, - reg_template, - save=False, - ) - ), - domain_grid_shape=reg.reduce_shape(reg_subject.shape), - domain_grid2world=reg_subject.affine, - codomain_grid_shape=reg.reduce_shape(reg_template.shape), - codomain_grid2world=reg_template.affine, - ) + return np.linalg.inv( + self.prealign(reg_subject, reg_template) + ) # TODO: test: this still needs to be inverted? -class ConformedAffineMapping(AffineMap): +class IdentityMap(AffineMapMixin, Definition): """ - Modifies AffineMap API to match DiffeomorphicMap API. - Important for SLR maps API to be indistinguishable from SYN maps API. + Does not perform any transformations from MNI to subject where + pyAFQ normally would. + + Examples + -------- + my_example_mapping = IdentityMap() + api.GroupAFQ(mapping=my_example_mapping) """ - def transform(self, *args, **kwargs): - return super().transform_inverse(*args, **kwargs) + def __init__(self): + pass - def transform_inverse(self, *args, **kwargs): - return super().transform(*args, **kwargs) + def gen_mapping( + self, + reg_subject, + reg_template, + subject_sls, + template_sls, + ): + return np.identity(4) diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 1727bd0f..89b0ae24 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -155,10 +155,10 @@ def recognize( tg.to_vox() n_streamlines = len(tg) - bundle_decisions = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.bool_) + bundle_decisions = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.float32) bundle_to_flip = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.bool_) bundle_roi_closest = -np.ones( - (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.uint32 + (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.int32 ) fiber_groups = {} @@ -205,10 +205,20 @@ def recognize( "Conflicts in bundle assignment detected. " f"{conflicts} conflicts detected in total out of " f"{n_streamlines} total streamlines. " - "Defaulting to whichever bundle appears first " + "Defaulting to whichever bundle is closest to the include ROI," + "followed by whichever appears first " "in the bundle_dict." ) ) + + # Weight by distance to ROI + valid_dists = bundle_roi_closest != -1 + dist_sums = np.sum(np.where(valid_dists, bundle_roi_closest, 0), axis=2) + has_any_valid_roi = np.any(valid_dists, axis=2) + max_roi_dist_sum = float(dist_sums[has_any_valid_roi].max() + 1) + final_mask = (bundle_decisions > 0) & has_any_valid_roi + bundle_decisions[final_mask] = 2 - (dist_sums[final_mask] / max_roi_dist_sum) + bundle_decisions = np.concatenate( (bundle_decisions, np.ones((n_streamlines, 1))), axis=1 ) diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index ef513f14..3fdc0f52 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -22,7 +22,7 @@ hardi_fbvec = op.join(hardi_dir, "HARDI150.bvec") file_dict = afd.read_stanford_hardi_tractography() reg_template = afd.read_mni_template() -mapping = reg.read_mapping(file_dict["mapping.nii.gz"], hardi_img, reg_template) +mapping = reg.read_old_mapping(file_dict["mapping.nii.gz"], hardi_img, reg_template) streamlines = file_dict["tractography_subsampled.trk"] tg = StatefulTractogram(streamlines, hardi_img, Space.RASMM) tg.to_vox() diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 505f8300..8258dea9 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -129,15 +129,13 @@ def move_streamlines(tg, to, mapping, img, save_intermediates=None): tg.to_vox() moved_sl = [] for sl in tg.streamlines: - moved_sl.append(mapping.transform_inverse_pts(sl)) + moved_sl.append(mapping.transform_pts(sl)) else: tg.to_rasmm() if to == "template": - volume = mapping.forward + moved_sl = mapping.transform_points_inverse(tg.streamlines) else: - volume = mapping.backward - delta = dts.values_from_volume(volume, tg.streamlines, np.eye(4)) - moved_sl = dts.Streamlines([d + s for d, s in zip(delta, tg.streamlines)]) + moved_sl = mapping.transform_points(tg.streamlines) moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) if save_intermediates is not None: save_tractogram( diff --git a/AFQ/registration.py b/AFQ/registration.py index f63e784b..7a61c907 100644 --- a/AFQ/registration.py +++ b/AFQ/registration.py @@ -4,11 +4,14 @@ import nibabel as nib import numpy as np -from dipy.align import syn_registration +from dipy.align.imaffine import AffineMap from dipy.align.imwarp import DiffeomorphicMap -from dipy.align.streamlinear import whole_brain_slr -__all__ = ["syn_register_dwi", "write_mapping", "read_mapping", "slr_registration"] +__all__ = [ + "read_affine_mapping", + "read_syn_mapping", + "read_old_mapping", +] def reduce_shape(shape): @@ -21,78 +24,55 @@ def reduce_shape(shape): return shape -def syn_register_dwi(dwi, gtab, template=None, **syn_kwargs): +def read_syn_mapping(disp, codisp): """ - Register DWI data to a template. + Read a syn registration mapping from a nifti file Parameters - ----------- - dwi : nifti image or str - Image containing DWI data, or full path to a nifti file with DWI. - gtab : GradientTable - The gradients associated with the DWI data - template : nifti image or str, optional + ---------- + disp : str or Nifti1Image + If string, file must of an image or ndarray. + If image, contains the mapping displacement field in each voxel + from subject to template - syn_kwargs : key-word arguments for :func:`syn_registration` + codisp : str or Nifti1Image + If string, file must of an image or ndarray. + If image, contains the mapping displacement field in each voxel + from template to subject Returns ------- - DiffeomorphicMap object + A :class:`DiffeomorphicMap` object """ - if template is None: - import AFQ.data.fetch as afd - - template = afd.read_mni_template() - if isinstance(template, str): - template = nib.load(template) - - template_data = template.get_fdata() - template_affine = template.affine - - if isinstance(dwi, str): - dwi = nib.load(dwi) - - dwi_affine = dwi.affine - dwi_data = dwi.get_fdata() - mean_b0 = np.mean(dwi_data[..., gtab.b0s_mask], -1) - warped_b0, mapping = syn_registration( - mean_b0, - template_data, - moving_affine=dwi_affine, - static_affine=template_affine, - **syn_kwargs, + if isinstance(disp, str): + disp = nib.load(disp) + + if isinstance(codisp, str): + codisp = nib.load(codisp) + + mapping = DiffeomorphicMap( + dim=3, + disp_shape=codisp.get_fdata().shape[:3], + disp_grid2world=None, + domain_shape=disp.get_fdata().shape[:3], + domain_grid2world=None, + codomain_shape=codisp.get_fdata().shape[:3], + codomain_grid2world=None, ) - return warped_b0, mapping - - -def write_mapping(mapping, fname): - """ - Write out a syn registration mapping to file + mapping.forward = disp.get_fdata().astype(np.float32) + mapping.backward = codisp.get_fdata().astype(np.float32) - Parameters - ---------- - mapping : a DiffeomorphicMap object derived from :func:`syn_registration` - fname : str - Full path to the nifti file storing the mapping - - """ - if isinstance(mapping, DiffeomorphicMap): - mapping_imap = np.array([mapping.forward.T, mapping.backward.T]).T - nib.save(nib.Nifti1Image(mapping_imap, mapping.codomain_world2grid), fname) - else: - np.save(fname, mapping.affine) + return mapping -def read_mapping(disp, domain_img, codomain_img, prealign=None): +def read_affine_mapping(affine, domain_img, codomain_img): """ Read a syn registration mapping from a nifti file Parameters ---------- - disp : str, Nifti1Image, or ndarray - If string, file must of an image or ndarray. - If image, contains the mapping displacement field in each voxel - Shape (x, y, z, 3, 2) + affine : str or ndarray + If string, file must of an ndarray. If ndarray, contains affine transformation used for mapping domain_img : str or Nifti1Image @@ -101,13 +81,10 @@ def read_mapping(disp, domain_img, codomain_img, prealign=None): Returns ------- - A :class:`DiffeomorphicMap` object + A :class:`AffineMap` object """ - if isinstance(disp, str): - if "nii.gz" in disp: - disp = nib.load(disp) - else: - disp = np.load(disp) + if isinstance(affine, str): + affine = np.load(affine) if isinstance(domain_img, str): domain_img = nib.load(domain_img) @@ -115,79 +92,59 @@ def read_mapping(disp, domain_img, codomain_img, prealign=None): if isinstance(codomain_img, str): codomain_img = nib.load(codomain_img) - if isinstance(disp, nib.Nifti1Image): - mapping = DiffeomorphicMap( - 3, - disp.shape[:3], - disp_grid2world=np.linalg.inv(disp.affine), - domain_shape=domain_img.shape[:3], - domain_grid2world=domain_img.affine, - codomain_shape=codomain_img.shape, - codomain_grid2world=codomain_img.affine, - prealign=prealign, - ) - - disp_data = disp.get_fdata().astype(np.float32) - mapping.forward = disp_data[..., 0] - mapping.backward = disp_data[..., 1] - mapping.is_inverse = True - else: - from AFQ.definitions.mapping import ConformedAffineMapping - - mapping = ConformedAffineMapping( - disp, - domain_grid_shape=reduce_shape(domain_img.shape), - domain_grid2world=domain_img.affine, - codomain_grid_shape=reduce_shape(codomain_img.shape), - codomain_grid2world=codomain_img.affine, - ) + mapping = AffineMap( + affine, + domain_grid_shape=reduce_shape(domain_img.shape), + domain_grid2world=domain_img.affine, + codomain_grid_shape=reduce_shape(codomain_img.shape), + codomain_grid2world=codomain_img.affine, + ) return mapping -def slr_registration( - moving_data, - static_data, - moving_affine=None, - static_affine=None, - moving_shape=None, - static_shape=None, - **kwargs, -): - """Register a source image (moving) to a target image (static). +def read_old_mapping(disp, domain_img, codomain_img): + """ + Warning: This is only used for pyAFQ tests and backwards compatibility. + Read old-style registration mapping from a nifti file. Parameters ---------- - moving : ndarray - The source tractography data to be registered - moving_affine : ndarray - The affine associated with the moving (source) data. - moving_shape : ndarray - The shape of the space associated with the static (target) data. - static : ndarray - The target tractography data for registration - static_affine : ndarray - The affine associated with the static (target) data. - static_shape : ndarray - The shape of the space associated with the static (target) data. - - **kwargs: - kwargs are passed into whole_brain_slr + disp : str or Nifti1Image + If string, file must of an image or ndarray. + If image, contains the mapping displacement field in each voxel + Shape (x, y, z, 3, 2) + + domain_img : str or Nifti1Image + + codomain_img : str or Nifti1Image Returns ------- - AffineMap + A :class:`DiffeomorphicMap` object """ - from AFQ.definitions.mapping import ConformedAffineMapping + if isinstance(disp, str): + disp = nib.load(disp) - _, transform, _, _ = whole_brain_slr( - static_data, moving_data, x0="affine", verbose=False, **kwargs - ) + if isinstance(domain_img, str): + domain_img = nib.load(domain_img) + + if isinstance(codomain_img, str): + codomain_img = nib.load(codomain_img) - return ConformedAffineMapping( - transform, - codomain_grid_shape=reduce_shape(static_shape), - codomain_grid2world=static_affine, - domain_grid_shape=reduce_shape(moving_shape), - domain_grid2world=moving_affine, + mapping = DiffeomorphicMap( + 3, + disp.shape[:3], + disp_grid2world=np.linalg.inv(disp.affine), + domain_shape=domain_img.shape[:3], + domain_grid2world=domain_img.affine, + codomain_shape=codomain_img.shape, + codomain_grid2world=codomain_img.affine, ) + + disp_data = disp.get_fdata().astype(np.float32) + mapping.forward = disp_data[..., 0] + mapping.backward = disp_data[..., 1] + mapping.is_inverse = True + + return mapping diff --git a/AFQ/tasks/mapping.py b/AFQ/tasks/mapping.py index c593371a..724711ea 100644 --- a/AFQ/tasks/mapping.py +++ b/AFQ/tasks/mapping.py @@ -30,7 +30,7 @@ def export_registered_b0(base_fname, data_imap, mapping): ) if not op.exists(warped_b0_fname): mean_b0 = nib.load(data_imap["b0"]).get_fdata() - warped_b0 = mapping.transform(mean_b0) + warped_b0 = mapping.transform_inverse(mean_b0) warped_b0 = nib.Nifti1Image(warped_b0, data_imap["reg_template"].affine) logger.info(f"Saving {warped_b0_fname}") nib.save(warped_b0, warped_b0_fname) @@ -54,9 +54,7 @@ def template_xform(base_fname, dwi_data_file, data_imap, mapping): base_fname, f"_space-{subject_space}_desc-template_anat.nii.gz" ) if not op.exists(template_xform_fname): - template_xform = mapping.transform_inverse( - data_imap["reg_template"].get_fdata() - ) + template_xform = mapping.transform(data_imap["reg_template"].get_fdata()) template_xform = nib.Nifti1Image(template_xform, data_imap["dwi_affine"]) logger.info(f"Saving {template_xform_fname}") nib.save(template_xform, template_xform_fname) diff --git a/AFQ/tasks/viz.py b/AFQ/tasks/viz.py index 145566c6..0fae71a2 100644 --- a/AFQ/tasks/viz.py +++ b/AFQ/tasks/viz.py @@ -21,7 +21,7 @@ logger = logging.getLogger("AFQ") -def _viz_prepare_vol(vol, xform, mapping, scalar_dict, ref): +def _viz_prepare_vol(vol, scalar_dict, ref): if vol in scalar_dict.keys(): vol = scalar_dict[vol] @@ -31,8 +31,6 @@ def _viz_prepare_vol(vol, xform, mapping, scalar_dict, ref): vol = resample(vol, ref) vol = vol.get_fdata() - if xform: - vol = mapping.transform_inverse(vol) vol[np.isnan(vol)] = 0 return vol @@ -81,15 +79,12 @@ def viz_bundles( """ if sbv_lims_bundles is None: sbv_lims_bundles = [None, None] - mapping = mapping_imap["mapping"] scalar_dict = segmentation_imap["scalar_dict"] profiles_file = segmentation_imap["profiles"] t1_img = nib.load(structural_imap["t1_masked"]) shade_by_volume = get_tp(best_scalar, structural_imap, data_imap, tissue_imap) - shade_by_volume = _viz_prepare_vol( - shade_by_volume, False, mapping, scalar_dict, t1_img - ) - volume = _viz_prepare_vol(t1_img, False, mapping, scalar_dict, t1_img) + shade_by_volume = _viz_prepare_vol(shade_by_volume, scalar_dict, t1_img) + volume = _viz_prepare_vol(t1_img, scalar_dict, t1_img) flip_axes = [False, False, False] for i in range(3): @@ -183,7 +178,6 @@ def viz_indivBundle( """ if sbv_lims_indiv is None: sbv_lims_indiv = [None, None] - mapping = mapping_imap["mapping"] bundle_dict = data_imap["bundle_dict"] scalar_dict = segmentation_imap["scalar_dict"] volume_img = nib.load(structural_imap["t1_masked"]) @@ -191,10 +185,8 @@ def viz_indivBundle( profiles = pd.read_csv(segmentation_imap["profiles"]) start_time = time() - volume = _viz_prepare_vol(volume_img, False, mapping, scalar_dict, volume_img) - shade_by_volume = _viz_prepare_vol( - shade_by_volume, False, mapping, scalar_dict, volume_img - ) + volume = _viz_prepare_vol(volume_img, scalar_dict, volume_img) + shade_by_volume = _viz_prepare_vol(shade_by_volume, scalar_dict, volume_img) flip_axes = [False, False, False] for i in range(3): diff --git a/AFQ/tests/test_api.py b/AFQ/tests/test_api.py index c5bfdc8c..379b7788 100644 --- a/AFQ/tests/test_api.py +++ b/AFQ/tests/test_api.py @@ -789,11 +789,6 @@ def test_AFQ_data_waypoint(): "sub-01_ses-01_desc-mapping_from-subject_to-mni_xform.nii.gz", ) nib.save(mapping, mapping_file) - reg_prealign_file = op.join( - myafq.export("output_dir"), - "sub-01_ses-01_desc-prealign_from-subject_to-mni_xform.npy", - ) - np.save(reg_prealign_file, np.eye(4)) # Test ROI exporting: myafq.export("rois") diff --git a/AFQ/tests/test_registration.py b/AFQ/tests/test_registration.py index c1b394c1..d1b3144c 100644 --- a/AFQ/tests/test_registration.py +++ b/AFQ/tests/test_registration.py @@ -5,16 +5,12 @@ import nibabel.tmpdirs as nbtmp import numpy as np import numpy.testing as npt -from dipy.align.imwarp import DiffeomorphicMap +from dipy.align.imaffine import AffineMap +from dipy.align.streamlinear import whole_brain_slr from dipy.io.streamline import load_tractogram import AFQ.data.fetch as afd -from AFQ.registration import ( - read_mapping, - slr_registration, - syn_register_dwi, - write_mapping, -) +from AFQ.registration import read_affine_mapping, reduce_shape MNI_T2 = afd.read_mni_template() hardi_img, gtab = dpd.read_stanford_hardi() @@ -50,27 +46,34 @@ def test_slr_registration(): hcp_atlas = load_tractogram(atlas_fname, "same", bbox_valid_check=False) with nbtmp.InTemporaryDirectory() as tmpdir: - mapping = slr_registration( + _, transform, _, _ = whole_brain_slr( streamlines, hcp_atlas.streamlines, - moving_affine=subset_b0_img.affine, - static_affine=subset_t2_img.affine, - moving_shape=subset_b0_img.shape, - static_shape=subset_t2_img.shape, + x0="affine", + verbose=False, progressive=False, greater_than=10, rm_small_clusters=1, rng=np.random.RandomState(seed=8), ) - warped_moving = mapping.transform(subset_b0) + + mapping = AffineMap( + transform, + domain_grid_shape=reduce_shape(subset_b0_img.shape), + domain_grid2world=subset_b0_img.affine, + codomain_grid_shape=reduce_shape(subset_t2_img.shape), + codomain_grid2world=subset_t2_img.affine, + ) + + warped_moving = mapping.transform_inverse(subset_b0) npt.assert_equal(warped_moving.shape, subset_t2.shape) mapping_fname = op.join(tmpdir, "mapping.npy") - write_mapping(mapping, mapping_fname) - file_mapping = read_mapping(mapping_fname, subset_b0_img, subset_t2_img) + np.save(mapping_fname, transform) + file_mapping = read_affine_mapping(mapping_fname, subset_b0_img, subset_t2_img) # Test that it has the same effect on the data: - warped_from_file = file_mapping.transform(subset_b0) + warped_from_file = file_mapping.transform_inverse(subset_b0) npt.assert_equal(warped_from_file, warped_moving) # Test that it is, attribute by attribute, identical: @@ -78,11 +81,3 @@ def test_slr_registration(): assert np.all( mapping.__getattribute__(k) == file_mapping.__getattribute__(k) ) - - -def test_syn_register_dwi(): - warped_b0, mapping = syn_register_dwi( - subset_dwi_data, gtab, template=subset_t2_img, radius=1 - ) - npt.assert_equal(isinstance(mapping, DiffeomorphicMap), True) - npt.assert_equal(warped_b0.shape, subset_t2_img.shape) diff --git a/AFQ/utils/volume.py b/AFQ/utils/volume.py index cea03655..782ac391 100644 --- a/AFQ/utils/volume.py +++ b/AFQ/utils/volume.py @@ -12,7 +12,7 @@ logger = logging.getLogger("AFQ") -def transform_inverse_roi(roi, mapping, bundle_name="ROI"): +def transform_roi(roi, mapping, bundle_name="ROI"): """ After being non-linearly transformed, ROIs tend to have holes in them. We perform a couple of computational geometry operations on the ROI to @@ -40,12 +40,12 @@ def transform_inverse_roi(roi, mapping, bundle_name="ROI"): if isinstance(roi, nib.Nifti1Image): roi = roi.get_fdata() - _roi = mapping.transform_inverse(roi, interpolation="linear") + _roi = mapping.transform(roi.astype(float), interpolation="linear") if np.sum(_roi) == 0: logger.warning(f"Lost ROI {bundle_name}, performing automatic binary dilation") _roi = binary_dilation(roi) - _roi = mapping.transform_inverse(_roi, interpolation="linear") + _roi = mapping.transform(_roi.astype(float), interpolation="linear") _roi = patch_up_roi(_roi > 0, bundle_name=bundle_name).astype(np.int32) diff --git a/AFQ/viz/fury_backend.py b/AFQ/viz/fury_backend.py index 5040fbc9..4e73f22e 100644 --- a/AFQ/viz/fury_backend.py +++ b/AFQ/viz/fury_backend.py @@ -194,11 +194,7 @@ def create_gif( def visualize_roi( roi, - affine_or_mapping=None, - static_img=None, - roi_affine=None, - static_affine=None, - reg_template=None, + resample_to=None, name="ROI", figure=None, color=None, @@ -215,22 +211,8 @@ def visualize_roi( roi : str or Nifti1Image The ROI information - affine_or_mapping : ndarray, Nifti1Image, or str, optional - An affine transformation or mapping to apply to the ROIs before - visualization. Default: no transform. - - static_img: str or Nifti1Image, optional - Template to resample roi to. - Default: None - - roi_affine: ndarray, optional - Default: None - - static_affine: ndarray, optional - Default: None - - reg_template: str or Nifti1Image, optional - Template to use for registration. + resample_to : Nifti1Image, optional + If not None, the ROI will be resampled to the space of this image. Default: None name: str, optional @@ -266,9 +248,7 @@ def visualize_roi( """ if color is None: color = np.array([1, 0, 0]) - roi = vut.prepare_roi( - roi, affine_or_mapping, static_img, roi_affine, static_affine, reg_template - ) + roi = vut.prepare_roi(roi, resample_to) for i, flip in enumerate(flip_axes): if flip: roi = np.flip(roi, axis=i) diff --git a/AFQ/viz/plotly_backend.py b/AFQ/viz/plotly_backend.py index 1af6acb4..07d467e0 100644 --- a/AFQ/viz/plotly_backend.py +++ b/AFQ/viz/plotly_backend.py @@ -535,11 +535,7 @@ def _draw_roi(figure, roi, name, color, opacity, dimensions, flip_axes): def visualize_roi( roi, - affine_or_mapping=None, - static_img=None, - roi_affine=None, - static_affine=None, - reg_template=None, + resample_to=None, name="ROI", figure=None, flip_axes=None, @@ -556,22 +552,8 @@ def visualize_roi( roi : str or Nifti1Image The ROI information - affine_or_mapping : ndarray, Nifti1Image, or str, optional - An affine transformation or mapping to apply to the ROIs before - visualization. Default: no transform. - - static_img: str or Nifti1Image, optional - Template to resample roi to. - Default: None - - roi_affine: ndarray, optional - Default: None - - static_affine: ndarray, optional - Default: None - - reg_template: str or Nifti1Image, optional - Template to use for registration. + resample_to : Nifti1Image, optional + If not None, the ROI will be resampled to the space of this image. Default: None name: str, optional @@ -612,9 +594,7 @@ def visualize_roi( color = np.array([0.9999, 0, 0]) if flip_axes is None: flip_axes = [False, False, False] - roi = vut.prepare_roi( - roi, affine_or_mapping, static_img, roi_affine, static_affine, reg_template - ) + roi = vut.prepare_roi(roi, resample_to) if figure is None: figure = make_subplots(rows=1, cols=1, specs=[[{"type": "scene"}]]) diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index 4ffe316a..1b82a306 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -13,9 +13,7 @@ from dipy.tracking.streamline import transform_streamlines from PIL import Image, ImageChops -import AFQ.registration as reg import AFQ.utils.streamlines as aus -import AFQ.utils.volume as auv __all__ = ["Viz"] @@ -591,9 +589,7 @@ def gif_from_pngs(tdir, gif_fname, n_frames, png_fname="tgif", add_zeros=False): io.mimsave(gif_fname, angles) -def prepare_roi( - roi, affine_or_mapping, static_img, roi_affine, static_affine, reg_template -): +def prepare_roi(roi, resample_to=None): """ Load the ROI Possibly perform a transformation on an ROI @@ -605,60 +601,27 @@ def prepare_roi( The ROI information. If str, ROI will be loaded using the str as a path. - affine_or_mapping : ndarray, Nifti1Image, or str - An affine transformation or mapping to apply to the ROI before - visualization. Default: no transform. - - static_img: str or Nifti1Image - Template to resample roi to. - - roi_affine: ndarray - - static_affine: ndarray - - reg_template: str or Nifti1Image - Template to use for registration. + resample_to : Nifti1Image, optional + If not None, the ROI will be resampled to the space of this image. Returns ------- ndarray """ viz_logger.info("Preparing ROI...") + if isinstance(roi, str): + roi = nib.load(roi) + + if resample_to is not None: + if not isinstance(roi, nib.Nifti1Image): + raise ValueError( + ("If resampling, roi must be a Nifti1Image or a path to one.") + ) + roi = resample(roi, resample_to) + if not isinstance(roi, np.ndarray): - if isinstance(roi, str): - roi = nib.load(roi).get_fdata() - else: - roi = roi.get_fdata() - - if affine_or_mapping is not None: - if isinstance(affine_or_mapping, np.ndarray): - # This is an affine: - if static_img is None or roi_affine is None or static_affine is None: - raise ValueError( - "If using an affine to transform an ROI, " - "need to also specify all of the following", - "inputs: `static_img`, `roi_affine`, ", - "`static_affine`", - ) - roi = resample( - roi, static_img, moving_affine=roi_affine, static_affine=static_affine - ).get_fdata() - else: - # Assume it is a mapping: - if isinstance(affine_or_mapping, str) or isinstance( - affine_or_mapping, nib.Nifti1Image - ): - if reg_template is None or static_img is None: - raise ValueError( - "If using a mapping to transform an ROI, need to ", - "also specify all of the following inputs: ", - "`reg_template`, `static_img`", - ) - affine_or_mapping = reg.read_mapping( - affine_or_mapping, static_img, reg_template - ) + roi = roi.get_fdata() - roi = auv.transform_inverse_roi(roi, affine_or_mapping).astype(bool) return roi diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index 22fc7542..0b8f0980 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -269,9 +269,9 @@ "NDARAA948VFH"]["HBNsiteRU"], index_col=[0]) for ind in bundle_counts.index: if ind == "Total Recognized": - threshold = 1000 - elif "Fronto-occipital" in ind or "Orbital" in ind: - threshold = 5 + threshold = 3000 + elif "Fronto-occipital" in ind: + threshold = 10 else: threshold = 15 if bundle_counts["n_streamlines"][ind] < threshold: From 4632c3760189c80a457243afe787b166601a3843 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 11 Feb 2026 12:42:16 +0900 Subject: [PATCH 33/86] the transform points direction is opposite for some reason --- AFQ/api/group.py | 4 ++-- AFQ/recognition/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/AFQ/api/group.py b/AFQ/api/group.py index d1f3bef8..fe8e0394 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -583,7 +583,7 @@ def load_next_subject(): these_sls = seg_sft.sft.streamlines[idx] these_sls = dps.set_number_of_points(these_sls, 100) tg = StatefulTractogram(these_sls, seg_sft.sft, Space.RASMM) - moved_sl = mapping.transform_points_inverse(tg.streamlines) + moved_sl = mapping.transform_points(tg.streamlines) moved_sl = np.asarray(moved_sl) median_sl = np.median(moved_sl, axis=0) sls_dict[b] = {"coreFiber": median_sl.tolist()} @@ -1020,7 +1020,7 @@ def combine_bundle(self, bundle_name): mapping = mapping_dict[this_sub][this_ses] if len(sls) > 0: - sls_mni = mapping.tranform_points(sls) + sls_mni = mapping.transform_points(sls) moved_sft = StatefulTractogram(sls_mni, reg_template, Space.VOX) diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 8258dea9..57887d26 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -133,9 +133,9 @@ def move_streamlines(tg, to, mapping, img, save_intermediates=None): else: tg.to_rasmm() if to == "template": - moved_sl = mapping.transform_points_inverse(tg.streamlines) - else: moved_sl = mapping.transform_points(tg.streamlines) + else: + moved_sl = mapping.transform_points_inverse(tg.streamlines) moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) if save_intermediates is not None: save_tractogram( From a5d2fcfda3d21d144a5ee599c8d419b3ee6b5fc7 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 11 Feb 2026 13:22:53 +0900 Subject: [PATCH 34/86] verfified --- AFQ/definitions/mapping.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/AFQ/definitions/mapping.py b/AFQ/definitions/mapping.py index 17171faf..5cf326e4 100644 --- a/AFQ/definitions/mapping.py +++ b/AFQ/definitions/mapping.py @@ -444,9 +444,7 @@ def gen_mapping( subject_sls, template_sls, ): - return np.linalg.inv( - self.prealign(reg_subject, reg_template) - ) # TODO: test: this still needs to be inverted? + return np.linalg.inv(self.prealign(reg_subject, reg_template)) class IdentityMap(AffineMapMixin, Definition): From 15a2d0647e78f27c0ccf17fc7f1f6626126a2f6f Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 11 Feb 2026 21:06:53 +0900 Subject: [PATCH 35/86] this should be not the inverse --- AFQ/definitions/mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/definitions/mapping.py b/AFQ/definitions/mapping.py index 5cf326e4..f5c420e1 100644 --- a/AFQ/definitions/mapping.py +++ b/AFQ/definitions/mapping.py @@ -343,7 +343,7 @@ def get_for_subses( reg_template.get_fdata(), moving_affine=reg_subject.affine, static_affine=reg_template.affine, - prealign=np.linalg.inv(reg_prealign), + prealign=reg_prealign, **self.syn_kwargs, ) mapping = get_simplified_transform(mapping) From 926e16ba473457d14556972dc309dbcc2944aa5e Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 13:42:43 +0900 Subject: [PATCH 36/86] add ROI transformation fix --- AFQ/registration.py | 3 ++- AFQ/utils/volume.py | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/AFQ/registration.py b/AFQ/registration.py index 7a61c907..708492ca 100644 --- a/AFQ/registration.py +++ b/AFQ/registration.py @@ -103,7 +103,7 @@ def read_affine_mapping(affine, domain_img, codomain_img): return mapping -def read_old_mapping(disp, domain_img, codomain_img): +def read_old_mapping(disp, domain_img, codomain_img, prealign=None): """ Warning: This is only used for pyAFQ tests and backwards compatibility. Read old-style registration mapping from a nifti file. @@ -140,6 +140,7 @@ def read_old_mapping(disp, domain_img, codomain_img): domain_grid2world=domain_img.affine, codomain_shape=codomain_img.shape, codomain_grid2world=codomain_img.affine, + prealign=prealign, ) disp_data = disp.get_fdata().astype(np.float32) diff --git a/AFQ/utils/volume.py b/AFQ/utils/volume.py index 782ac391..a8269811 100644 --- a/AFQ/utils/volume.py +++ b/AFQ/utils/volume.py @@ -40,7 +40,15 @@ def transform_roi(roi, mapping, bundle_name="ROI"): if isinstance(roi, nib.Nifti1Image): roi = roi.get_fdata() - _roi = mapping.transform(roi.astype(float), interpolation="linear") + # dilate binary images to avoid losing small ROIs + if np.unique(roi).size < 3: + scale_factor = max( + np.asarray(mapping.codomain_shape) / np.asarray(mapping.domain_shape) + ) + for _ in range(max(np.ceil(scale_factor) - 1, 0).astype(int)): + roi = binary_dilation(roi) + + _roi = mapping.transform((roi.astype(float)), interpolation="linear") if np.sum(_roi) == 0: logger.warning(f"Lost ROI {bundle_name}, performing automatic binary dilation") From 9b7357121a54850b2af784f4434c206913dad8fe Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 14:25:46 +0900 Subject: [PATCH 37/86] cleanup some minor bugs from tests --- AFQ/recognition/recognize.py | 9 +++++---- AFQ/recognition/roi.py | 8 ++++++++ AFQ/recognition/tests/test_recognition.py | 2 +- AFQ/recognition/tests/test_rois.py | 10 ++++++---- AFQ/registration.py | 2 +- 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 89b0ae24..f9a79e63 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -213,11 +213,12 @@ def recognize( # Weight by distance to ROI valid_dists = bundle_roi_closest != -1 - dist_sums = np.sum(np.where(valid_dists, bundle_roi_closest, 0), axis=2) has_any_valid_roi = np.any(valid_dists, axis=2) - max_roi_dist_sum = float(dist_sums[has_any_valid_roi].max() + 1) - final_mask = (bundle_decisions > 0) & has_any_valid_roi - bundle_decisions[final_mask] = 2 - (dist_sums[final_mask] / max_roi_dist_sum) + if np.any(has_any_valid_roi): + dist_sums = np.sum(np.where(valid_dists, bundle_roi_closest, 0), axis=2) + max_roi_dist_sum = float(dist_sums[has_any_valid_roi].max() + 1) + final_mask = (bundle_decisions > 0) & has_any_valid_roi + bundle_decisions[final_mask] = 2 - (dist_sums[final_mask] / max_roi_dist_sum) bundle_decisions = np.concatenate( (bundle_decisions, np.ones((n_streamlines, 1))), axis=1 diff --git a/AFQ/recognition/roi.py b/AFQ/recognition/roi.py index 83e1cf53..474bc360 100644 --- a/AFQ/recognition/roi.py +++ b/AFQ/recognition/roi.py @@ -67,6 +67,14 @@ def clean_by_endpoints(fgarray, target, target_idx, tol=0, flip_sls=None): ------- boolean array of streamlines that survive cleaning. """ + if not isinstance(fgarray, np.ndarray): + raise ValueError( + ( + "fgarray must be a numpy ndarray, you can resample " + "your streamlines using resample_tg in AFQ.recognition.utils" + ) + ) + n_sls, n_nodes, _ = fgarray.shape # handle target_idx negative values as wrapping around diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index 3fdc0f52..c03f72b7 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -152,7 +152,7 @@ def test_segment_reco(): # This condition should still hold npt.assert_equal(len(fiber_groups), 2) - npt.assert_(len(fiber_groups["CST_R"]) > 0) + npt.assert_(len(fiber_groups["CST_L"]) > 0) def test_exclusion_ROI(): diff --git a/AFQ/recognition/tests/test_rois.py b/AFQ/recognition/tests/test_rois.py index f5140660..0f560115 100644 --- a/AFQ/recognition/tests/test_rois.py +++ b/AFQ/recognition/tests/test_rois.py @@ -4,6 +4,7 @@ from scipy.ndimage import distance_transform_edt import AFQ.recognition.roi as abr +import AFQ.recognition.utils as abu from AFQ.recognition.roi import check_sl_with_exclusion, check_sls_with_inclusion shape = (15, 15, 15) @@ -17,6 +18,7 @@ np.array([[1, 1, 1], [2, 1, 1], [3, 1, 1]]), np.array([[1, 1, 1], [2, 1, 1]]), ] +fgarray = np.array(abu.resample_tg(streamlines, 20)) roi1 = np.ones(shape, dtype=np.float32) roi1[1, 2, 3] = 0 @@ -43,15 +45,15 @@ def test_clean_by_endpoints(): - clean_idx_start = list(abr.clean_by_endpoints(streamlines, start_roi, 0)) - clean_idx_end = list(abr.clean_by_endpoints(streamlines, end_roi, -1)) + clean_idx_start = list(abr.clean_by_endpoints(fgarray, start_roi, 0)) + clean_idx_end = list(abr.clean_by_endpoints(fgarray, end_roi, -1)) npt.assert_array_equal( np.logical_and(clean_idx_start, clean_idx_end), np.array([1, 1, 0, 0]) ) # If tol=1, the third streamline also gets included - clean_idx_start = list(abr.clean_by_endpoints(streamlines, start_roi, 0, tol=1)) - clean_idx_end = list(abr.clean_by_endpoints(streamlines, end_roi, -1, tol=1)) + clean_idx_start = list(abr.clean_by_endpoints(fgarray, start_roi, 0, tol=1)) + clean_idx_end = list(abr.clean_by_endpoints(fgarray, end_roi, -1, tol=1)) npt.assert_array_equal( np.logical_and(clean_idx_start, clean_idx_end), np.array([1, 1, 1, 0]) ) diff --git a/AFQ/registration.py b/AFQ/registration.py index 708492ca..28b72799 100644 --- a/AFQ/registration.py +++ b/AFQ/registration.py @@ -146,6 +146,6 @@ def read_old_mapping(disp, domain_img, codomain_img, prealign=None): disp_data = disp.get_fdata().astype(np.float32) mapping.forward = disp_data[..., 0] mapping.backward = disp_data[..., 1] - mapping.is_inverse = True + mapping.is_inverse = False return mapping From 8ca8fa45700ea32af2cfcdd37433e8d6b1da97bf Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 15:10:41 +0900 Subject: [PATCH 38/86] add documentation and tests for mixed bundle definitions --- AFQ/recognition/tests/test_recognition.py | 42 +++++++++++++++++++++++ docs/source/reference/bundledict.rst | 38 ++++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index c03f72b7..95736aa2 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -9,6 +9,7 @@ from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.stats.analysis import afq_profile +import AFQ.api.bundle_dict as abd import AFQ.data.fetch as afd import AFQ.recognition.cleaning as abc import AFQ.registration as reg @@ -83,6 +84,47 @@ def test_segment(): npt.assert_equal(len(clean_sl), len(CST_R_sl)) +def test_segment_mixed_roi(): + lv1_files, lv1_folder = afd.fetch_stanford_hardi_lv1() + ar_rois = afd.read_ar_templates() + lv1_fname = op.join(lv1_folder, list(lv1_files.keys())[0]) + + bundle_info = { + "OR LV1": { + "start": {"roi": ar_rois["AAL_Thal_L"], "space": "template"}, + "end": {"roi": lv1_fname, "space": "subject"}, + "space": "mixed", + } + } + + with pytest.raises( + ValueError, + match=( + "When using mixed ROI bundle definitions, and subject space ROIs, " + "resample_subject_to cannot be False." + ), + ): + fiber_groups, _ = recognize( + tg, nib.load(hardi_fdata), mapping, bundle_info, reg_template, 2 + ) + + bundle_info = abd.BundleDict(bundle_info, resample_subject_to=hardi_fdata) + fiber_groups, _ = recognize( + tg, + nib.load(hardi_fdata), + mapping, + bundle_info, + reg_template, + 2, + dist_to_atlas=10, + ) + + # We asked for 2 fiber groups: + npt.assert_equal(len(fiber_groups), 1) + OR_LV1_sl = fiber_groups["OR LV1"] + npt.assert_(len(OR_LV1_sl) == 2) + + @pytest.mark.nightly def test_segment_no_prob(): # What if you don't have probability maps? diff --git a/docs/source/reference/bundledict.rst b/docs/source/reference/bundledict.rst index 5927e818..43d73d19 100644 --- a/docs/source/reference/bundledict.rst +++ b/docs/source/reference/bundledict.rst @@ -86,6 +86,44 @@ relation to the Left Arcuate and Inferior Longitudinal fasciculi: 'Left Inferior Longitudinal': {'core': 'Left'}, } + +Mixed space ROIs +================ +Everywhere in the bundle dictionary where an ROI is specified as a path, +be it start, end, include, exclude, or probability map, you can in fact input +a dictionary instead. This dictionary should have two keys: +- 'roi' : path to the ROI Nifti file +- 'space' : either 'template' or 'subject', describing the space the ROI + +Then, for the whole bundle, set "space" to 'mixed'. This allows you to +specify some ROIs in template space and some in subject space for the same +bundle. For example: + +.. code-block:: python + import os.path as op + import AFQ.api.bundle_dict as abd + import AFQ.data.fetch as afd + + # First, organize the data + afd.organize_stanford_data() + bids_path = op.join(op.expanduser('~'), 'AFQ_data', 'stanford_hardi') + sub_path = op.join(bids_path, 'derivatives', 'vistasoft', 'sub-01', 'ses-01') + dwi_path = op.join(sub_path, 'dwi', 'sub-01_ses-01_dwi.nii.gz') + + lv1_files, lv1_folder = afd.fetch_stanford_hardi_lv1() + ar_rois = afd.read_ar_templates() + lv1_fname = op.join(lv1_folder, list(lv1_files.keys())[0]) + + # Then, prepare the bundle dictionary + bundle_info = abd.BundleDict({ + "OR LV1": { + "start": {"roi": ar_rois["AAL_Thal_L"], "space": "template"}, + "end": {"roi": lv1_fname, "space": "subject"}, + "space": "mixed" + } + }, resample_subject_to=dwi_path) + + Filtering Order =============== When doing bundle recognition, streamlines are filtered out from the whole From 838bb3488f74d0be35a38f7b8bb308870b53d1a5 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 17:10:03 +0900 Subject: [PATCH 39/86] actually transform points / transform inverse points makes sense --- AFQ/api/group.py | 7 +++---- AFQ/recognition/other_bundles.py | 12 +++++++++--- AFQ/recognition/utils.py | 4 ++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/AFQ/api/group.py b/AFQ/api/group.py index fe8e0394..617fc6ac 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -583,7 +583,7 @@ def load_next_subject(): these_sls = seg_sft.sft.streamlines[idx] these_sls = dps.set_number_of_points(these_sls, 100) tg = StatefulTractogram(these_sls, seg_sft.sft, Space.RASMM) - moved_sl = mapping.transform_points(tg.streamlines) + moved_sl = mapping.transform_points_inverse(tg.streamlines) moved_sl = np.asarray(moved_sl) median_sl = np.median(moved_sl, axis=0) sls_dict[b] = {"coreFiber": median_sl.tolist()} @@ -1015,14 +1015,13 @@ def combine_bundle(self, bundle_name): this_sub = self.valid_sub_list[ii] this_ses = self.valid_ses_list[ii] seg_sft = aus.SegmentedSFT.fromfile(bundles_dict[this_sub][this_ses]) - seg_sft.sft.to_vox() sls = seg_sft.get_bundle(bundle_name).streamlines mapping = mapping_dict[this_sub][this_ses] if len(sls) > 0: - sls_mni = mapping.transform_points(sls) + sls_mni.extend(mapping.transform_points_inverse(sls)) - moved_sft = StatefulTractogram(sls_mni, reg_template, Space.VOX) + moved_sft = StatefulTractogram(sls_mni, reg_template, Space.RASMM) save_path = op.abspath( op.join(self.afq_path, f"bundle-{bundle_name}_subjects-all_MNI.trk") diff --git a/AFQ/recognition/other_bundles.py b/AFQ/recognition/other_bundles.py index 3b04cee0..618daf45 100644 --- a/AFQ/recognition/other_bundles.py +++ b/AFQ/recognition/other_bundles.py @@ -76,9 +76,15 @@ def clean_by_overlap( ) if remove: - other_bundle_density_map = ( - other_bundle_density_map / other_bundle_density_map.max() - ) > other_bundle_min_density + max_val = other_bundle_density_map.max() + if max_val > 0: + other_bundle_density_map = ( + other_bundle_density_map / max_val + ) > other_bundle_min_density + else: + other_bundle_density_map = np.zeros_like( + other_bundle_density_map, dtype=bool + ) if project is not None: orientation = nib.orientations.aff2axcodes(img.affine) diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 57887d26..8258dea9 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -133,9 +133,9 @@ def move_streamlines(tg, to, mapping, img, save_intermediates=None): else: tg.to_rasmm() if to == "template": - moved_sl = mapping.transform_points(tg.streamlines) - else: moved_sl = mapping.transform_points_inverse(tg.streamlines) + else: + moved_sl = mapping.transform_points(tg.streamlines) moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) if save_intermediates is not None: save_tractogram( From fbc09e840817619699ca8e6c782ff3e48095a356 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 17:21:51 +0900 Subject: [PATCH 40/86] BFs --- AFQ/api/participant.py | 1 - 1 file changed, 1 deletion(-) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 8b34030b..e647f834 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -403,7 +403,6 @@ def _save_file(curr_img): x_pos = jj % size[0] _ii = jj // size[0] y_pos = _ii % size[1] - _ii = _ii // size[1] this_img = this_img_trimmed[jj].resize((max_width, max_height)) curr_img.paste(this_img, (x_pos * max_width, y_pos * max_height)) From 627646de07e3bfa95af7059efdb3b8d7324ebd18 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 17:22:47 +0900 Subject: [PATCH 41/86] BFs --- AFQ/definitions/mapping.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/AFQ/definitions/mapping.py b/AFQ/definitions/mapping.py index f5c420e1..ce9ee6ec 100644 --- a/AFQ/definitions/mapping.py +++ b/AFQ/definitions/mapping.py @@ -240,16 +240,16 @@ def get_for_subses( if not op.exists(mapping_file): start_time = time() - self.gen_mapping( - reg_template, + affine = self.gen_mapping( reg_subject, + reg_template, subject_sls, template_sls, ) total_time = time() - start_time logger.info(f"Saving {mapping_file}") - np.save(mapping_file, mapping.affine) + np.save(mapping_file, affine) meta = dict(type="affine", timing=total_time) if subject_sls is None: meta["dependent"] = "dwi" @@ -399,8 +399,8 @@ def __init__(self, slr_kwargs=None): def gen_mapping( self, - reg_template, reg_subject, + reg_template, subject_sls, template_sls, ): From 2cf377b1a99726dbe684a90cd77148e4a91e1452 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 17:44:34 +0900 Subject: [PATCH 42/86] fix roi dist priority --- AFQ/recognition/criteria.py | 11 ++++++++++- AFQ/recognition/recognize.py | 8 ++++++-- AFQ/recognition/roi.py | 6 ++++-- AFQ/recognition/tests/test_rois.py | 9 +++++++-- AFQ/recognition/utils.py | 2 ++ 5 files changed, 29 insertions(+), 7 deletions(-) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index eaa224b3..8e58df9a 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -176,14 +176,16 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): ) roi_closest = -np.ones((max_includes, len(b_sls)), dtype=np.int32) + roi_dists = -np.ones((max_includes, len(b_sls)), dtype=np.float32) if flip_using_include: to_flip = np.ones_like(accept_idx, dtype=np.bool_) for sl_idx, inc_result in enumerate(inc_results): - sl_accepted, sl_closest = inc_result + sl_accepted, sl_closest, sl_dists = inc_result if sl_accepted: if len(sl_closest) > 1: roi_closest[: len(sl_closest), sl_idx] = sl_closest + roi_dists[: len(sl_dists), sl_idx] = sl_dists # Only accept SLs that, when cut, are meaningful if (len(sl_closest) < 2) or abs(sl_closest[0] - sl_closest[-1]) > 1: # Flip sl if it is close to second ROI @@ -192,11 +194,13 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): to_flip[sl_idx] = sl_closest[0] > sl_closest[-1] if to_flip[sl_idx]: roi_closest[: len(sl_closest), sl_idx] = np.flip(sl_closest) + roi_dists[: len(sl_dists), sl_idx] = np.flip(sl_dists) accept_idx[sl_idx] = 1 else: accept_idx[sl_idx] = 1 b_sls.roi_closest = roi_closest.T + b_sls.roi_dists = roi_dists.T if flip_using_include: b_sls.reorient(to_flip) b_sls.select(accept_idx, "include") @@ -390,6 +394,7 @@ def run_bundle_rec_plan( bundle_idx, bundle_to_flip, bundle_roi_closest, + bundle_roi_dists, bundle_decisions, **segmentation_params, ): @@ -500,3 +505,7 @@ def check_space(roi): bundle_roi_closest[b_sls.selected_fiber_idxs, bundle_idx, :] = ( b_sls.roi_closest.copy() ) + if hasattr(b_sls, "roi_dists"): + bundle_roi_dists[b_sls.selected_fiber_idxs, bundle_idx, :] = ( + b_sls.roi_dists.copy() + ) diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index f9a79e63..0d2c9711 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -160,6 +160,9 @@ def recognize( bundle_roi_closest = -np.ones( (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.int32 ) + bundle_roi_dists = -np.ones( + (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.float32 + ) fiber_groups = {} meta = {} @@ -180,6 +183,7 @@ def recognize( bundle_idx, bundle_to_flip, bundle_roi_closest, + bundle_roi_dists, bundle_decisions, clip_edges=clip_edges, n_cpus=n_cpus, @@ -212,10 +216,10 @@ def recognize( ) # Weight by distance to ROI - valid_dists = bundle_roi_closest != -1 + valid_dists = bundle_roi_dists > 0 has_any_valid_roi = np.any(valid_dists, axis=2) if np.any(has_any_valid_roi): - dist_sums = np.sum(np.where(valid_dists, bundle_roi_closest, 0), axis=2) + dist_sums = np.sum(np.where(valid_dists, bundle_roi_dists, 0), axis=2) max_roi_dist_sum = float(dist_sums[has_any_valid_roi].max() + 1) final_mask = (bundle_decisions > 0) & has_any_valid_roi bundle_decisions[final_mask] = 2 - (dist_sums[final_mask] / max_roi_dist_sum) diff --git a/AFQ/recognition/roi.py b/AFQ/recognition/roi.py index 474bc360..4ae02310 100644 --- a/AFQ/recognition/roi.py +++ b/AFQ/recognition/roi.py @@ -11,21 +11,23 @@ def check_sls_with_inclusion(sls, include_rois, include_roi_tols): include_rois = [roi_.get_fdata().copy() for roi_ in include_rois] for jj, sl in enumerate(sls): closest = np.zeros(len(include_rois), dtype=np.int32) + dists = np.zeros(len(include_rois), dtype=np.float32) sl = np.asarray(sl) valid = True for ii, roi in enumerate(include_rois): dist = interpolate_scalar_3d(roi, sl)[0] closest[ii] = np.argmin(dist) + dists[ii] = dist[closest[ii]] if dist[closest[ii]] > include_roi_tols[ii]: # Too far from one of them: - inc_results[jj] = (False, []) + inc_results[jj] = (False, [], []) valid = False break # Checked all the ROIs and it was close to all of them if valid: - inc_results[jj] = (True, closest) + inc_results[jj] = (True, closest, dists) return inc_results diff --git a/AFQ/recognition/tests/test_rois.py b/AFQ/recognition/tests/test_rois.py index 0f560115..67ef384e 100644 --- a/AFQ/recognition/tests/test_rois.py +++ b/AFQ/recognition/tests/test_rois.py @@ -65,22 +65,27 @@ def test_check_sls_with_inclusion(): assert result[0][0] is True assert np.allclose(result[0][1][0], 0) assert np.allclose(result[0][1][1], 2) + assert np.allclose(result[0][2][0], 0) + assert np.allclose(result[0][2][1], 0) assert result[1][0] is False def test_check_sl_with_inclusion_pass(): - result, dists = check_sls_with_inclusion( + result, dist_idxs, dists = check_sls_with_inclusion( [streamline1], include_rois, include_roi_tols )[0] assert result is True assert len(dists) == 2 + assert np.allclose(dist_idxs[0], 0) + assert np.allclose(dist_idxs[1], 2) def test_check_sl_with_inclusion_fail(): - result, dists = check_sls_with_inclusion( + result, dist_idxs, dists = check_sls_with_inclusion( [streamline2], include_rois, include_roi_tols )[0] assert result is False + assert dist_idxs == [] assert dists == [] diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 8258dea9..5c751d1b 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -184,6 +184,8 @@ def select(self, idx, clean_name, cut=False): self.sls_flipped = self.sls_flipped[idx] if hasattr(self, "roi_closest"): self.roi_closest = self.roi_closest[idx] + if hasattr(self, "roi_dists"): + self.roi_dists = self.roi_dists[idx] time_taken = time() - self.start_time self.logger.info( f"After filtering by {clean_name} (time: {time_taken}s), " From 4e8cf4d5a9dd3e80db460a006dfe5b7619d64d92 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 17:46:29 +0900 Subject: [PATCH 43/86] tweaks --- AFQ/recognition/criteria.py | 2 +- docs/source/reference/bundledict.rst | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 8e58df9a..7b6f8134 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -111,7 +111,7 @@ def end(b_sls, bundle_def, preproc_imap, **kwargs): def length(b_sls, bundle_def, preproc_imap, **kwargs): - accept_idx = b_sls.initiate_selection("length") + b_sls.initiate_selection("length") min_len = bundle_def["length"].get("min_len", 0) / preproc_imap["vox_dim"] max_len = bundle_def["length"].get("max_len", np.inf) / preproc_imap["vox_dim"] diff --git a/docs/source/reference/bundledict.rst b/docs/source/reference/bundledict.rst index 43d73d19..581e662b 100644 --- a/docs/source/reference/bundledict.rst +++ b/docs/source/reference/bundledict.rst @@ -94,12 +94,14 @@ be it start, end, include, exclude, or probability map, you can in fact input a dictionary instead. This dictionary should have two keys: - 'roi' : path to the ROI Nifti file - 'space' : either 'template' or 'subject', describing the space the ROI + is currently in. Then, for the whole bundle, set "space" to 'mixed'. This allows you to specify some ROIs in template space and some in subject space for the same bundle. For example: .. code-block:: python + import os.path as op import AFQ.api.bundle_dict as abd import AFQ.data.fetch as afd From 16b214555ae31756112df221ef87f8e504f812be Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 17:55:22 +0900 Subject: [PATCH 44/86] small BF --- AFQ/recognition/recognize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 0d2c9711..8fb02fec 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -216,7 +216,7 @@ def recognize( ) # Weight by distance to ROI - valid_dists = bundle_roi_dists > 0 + valid_dists = bundle_roi_dists >= -0.5 # i.e., not -1 has_any_valid_roi = np.any(valid_dists, axis=2) if np.any(has_any_valid_roi): dist_sums = np.sum(np.where(valid_dists, bundle_roi_dists, 0), axis=2) From 3b8b33a4fee61a3e77901f7e0dc5a07a5a731646 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 23:20:15 +0900 Subject: [PATCH 45/86] put this back --- AFQ/recognition/tests/test_recognition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index 95736aa2..66c3133a 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -194,7 +194,7 @@ def test_segment_reco(): # This condition should still hold npt.assert_equal(len(fiber_groups), 2) - npt.assert_(len(fiber_groups["CST_L"]) > 0) + npt.assert_(len(fiber_groups["CST_R"]) > 0) def test_exclusion_ROI(): From 00ee0382196153411ca1dc739a26c67cc4a4685c Mon Sep 17 00:00:00 2001 From: 36000 Date: Sat, 14 Feb 2026 12:53:46 +0900 Subject: [PATCH 46/86] implement moving streamlines with new mapping system --- AFQ/api/group.py | 23 ++++++++++----- AFQ/recognition/criteria.py | 5 ++-- AFQ/recognition/utils.py | 41 -------------------------- AFQ/utils/streamlines.py | 58 ++++++++++++++++++++++++++++++++++++- 4 files changed, 76 insertions(+), 51 deletions(-) diff --git a/AFQ/api/group.py b/AFQ/api/group.py index 617fc6ac..57d5f08a 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -554,9 +554,12 @@ def load_next_subject(): this_bundles_file = self.export("bundles", collapse=False)[sub][ses] this_mapping = self.export("mapping", collapse=False)[sub][ses] this_img = self.export("dwi", collapse=False)[sub][ses] + this_reg_template = self.export("reg_template", collapse=False)[sub][ + ses + ] seg_sft = aus.SegmentedSFT.fromfile(this_bundles_file, this_img) seg_sft.sft.to_rasmm() - subses_info.append((seg_sft, this_mapping)) + subses_info.append((seg_sft, this_mapping, this_img, this_reg_template)) bundle_dict = self.export("bundle_dict", collapse=False)[ self.valid_sub_list[0] @@ -566,7 +569,7 @@ def load_next_subject(): load_next_subject() # load first subject for b in bundle_dict.bundle_names: for i in range(len(self.valid_sub_list)): - seg_sft, mapping = subses_info[i] + seg_sft, mapping, img, reg_template = subses_info[i] idx = seg_sft.bundle_idxs[b] # use the first subses that works # otherwise try each successive subses @@ -582,9 +585,11 @@ def load_next_subject(): idx = np.random.choice(idx, size=100, replace=False) these_sls = seg_sft.sft.streamlines[idx] these_sls = dps.set_number_of_points(these_sls, 100) - tg = StatefulTractogram(these_sls, seg_sft.sft, Space.RASMM) - moved_sl = mapping.transform_points_inverse(tg.streamlines) - moved_sl = np.asarray(moved_sl) + tg = StatefulTractogram(these_sls, img, Space.RASMM) + moved_sl = aus.move_streamlines( + tg, "template", mapping, reg_template + ) + moved_sl = np.asarray(moved_sl.streamlines) median_sl = np.median(moved_sl, axis=0) sls_dict[b] = {"coreFiber": median_sl.tolist()} for ii, sl_idx in enumerate(idx): @@ -1015,11 +1020,15 @@ def combine_bundle(self, bundle_name): this_sub = self.valid_sub_list[ii] this_ses = self.valid_ses_list[ii] seg_sft = aus.SegmentedSFT.fromfile(bundles_dict[this_sub][this_ses]) - sls = seg_sft.get_bundle(bundle_name).streamlines + sls = seg_sft.get_bundle(bundle_name) mapping = mapping_dict[this_sub][this_ses] if len(sls) > 0: - sls_mni.extend(mapping.transform_points_inverse(sls)) + sls_mni.extend( + aus.move_streamlines( + sls, "template", mapping, reg_template + ).streamlines + ) moved_sft = StatefulTractogram(sls_mni, reg_template, Space.RASMM) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 7b6f8134..c931dcbf 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -20,6 +20,7 @@ import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import apply_to_roi_dict from AFQ.utils.stats import chunk_indices +from AFQ.utils.streamlines import move_streamlines criteria_order_pre_other_bundles = [ "prob_map", @@ -218,7 +219,7 @@ def curvature(b_sls, bundle_def, mapping, img, save_intermediates, **kwargs): ref_sl = load_tractogram( bundle_def["curvature"]["path"], "same", bbox_valid_check=False ) - moved_ref_sl = abu.move_streamlines( + moved_ref_sl = move_streamlines( ref_sl, "subject", mapping, img, save_intermediates=save_intermediates ) moved_ref_sl.to_vox() @@ -264,7 +265,7 @@ def recobundles( **kwargs, ): b_sls.initiate_selection("Recobundles") - moved_sl = abu.move_streamlines( + moved_sl = move_streamlines( StatefulTractogram(b_sls.get_selected_sls(), img, Space.VOX), "template", mapping, diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 5c751d1b..0a60552d 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -9,8 +9,6 @@ from dipy.io.streamline import save_tractogram from dipy.tracking.distances import bundles_distances_mdf -from AFQ.definitions.mapping import ConformedFnirtMapping - logger = logging.getLogger("AFQ") @@ -108,45 +106,6 @@ def orient_by_streamline(sls, template_sl): return DM[:, 0] > DM[:, 1] -def move_streamlines(tg, to, mapping, img, save_intermediates=None): - """Move streamlines to or from template space. - - to : str - Either "template" or "subject". - mapping : ConformedMapping - Mapping to use to move streamlines. - img : Nifti1Image - Space to move streamlines to. - """ - tg_og_space = tg.space - if isinstance(mapping, ConformedFnirtMapping): - if to != "subject": - raise ValueError( - "Attempted to transform streamlines to template using " - "unsupported mapping. " - "Use something other than Fnirt." - ) - tg.to_vox() - moved_sl = [] - for sl in tg.streamlines: - moved_sl.append(mapping.transform_pts(sl)) - else: - tg.to_rasmm() - if to == "template": - moved_sl = mapping.transform_points_inverse(tg.streamlines) - else: - moved_sl = mapping.transform_points(tg.streamlines) - moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) - if save_intermediates is not None: - save_tractogram( - moved_sft, - op.join(save_intermediates, f"sls_in_{to}.trk"), - bbox_valid_check=False, - ) - tg.to_space(tg_og_space) - return moved_sft - - def resample_tg(tg, n_points): # reformat for dipy's set_number_of_points if isinstance(tg, np.ndarray): diff --git a/AFQ/utils/streamlines.py b/AFQ/utils/streamlines.py index 571534a1..52c7051c 100644 --- a/AFQ/utils/streamlines.py +++ b/AFQ/utils/streamlines.py @@ -2,7 +2,7 @@ import numpy as np from dipy.io.stateful_tractogram import Space, StatefulTractogram -from dipy.io.streamline import load_tractogram +from dipy.io.streamline import load_tractogram, save_tractogram try: from trx.io import load as load_trx @@ -11,6 +11,7 @@ except ModuleNotFoundError: has_trx = False +from AFQ.definitions.mapping import ConformedFnirtMapping from AFQ.utils.path import drop_extension, read_json @@ -137,3 +138,58 @@ def split_streamline(streamlines, sl_to_split, split_idx): ) return streamlines + + +def move_streamlines(tg, to, mapping, img, to_space=None, save_intermediates=None): + """Move streamlines to or from template space. + + to : str + Either "template" or "subject". This determines + whether we will use the forward or backwards displacement field. + mapping : DIPY or pyAFQ mapping + Mapping to use to move streamlines. + img : Nifti1Image + Image defining reference for where the streamlines move to. + to_space : Space or None + If not None, space to move streamlines to after moving them to the + template or subject space. If None, streamlines will be moved back to + their original space. + Default: None. + save_intermediates : str or None + If not None, path to save intermediate tractogram after moving to template + or subject space. + Default: None. + """ + tg_og_space = tg.space + if isinstance(mapping, ConformedFnirtMapping): + if to != "subject": + raise ValueError( + "Attempted to transform streamlines to template using " + "unsupported mapping. " + "Use something other than Fnirt." + ) + tg.to_vox() + moved_sl = [] + for sl in tg.streamlines: + moved_sl.append(mapping.transform_pts(sl)) + moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) + else: + tg.to_vox() + if to == "template": + moved_sl = mapping.transform_points(tg.streamlines) + else: + moved_sl = mapping.transform_points_inverse(tg.streamlines) + moved_sft = StatefulTractogram(moved_sl, img, Space.VOX) + moved_sft.to_rasmm() + + if save_intermediates is not None: + save_tractogram( + moved_sft, + op.join(save_intermediates, f"sls_in_{to}.trk"), + bbox_valid_check=False, + ) + if to_space is None: + tg.to_space(tg_og_space) + else: + tg.to_space(to_space) + return moved_sft From 3e385d9200c8fc95a2b1027bf04992d0a3301c06 Mon Sep 17 00:00:00 2001 From: 36000 Date: Sat, 14 Feb 2026 23:41:16 +0900 Subject: [PATCH 47/86] add ORG VOF subclusters --- .codespellrc | 2 +- AFQ/api/bundle_dict.py | 211 ++++++++++++++++-- AFQ/definitions/mapping.py | 3 + AFQ/recognition/cleaning.py | 3 + AFQ/recognition/clustering.py | 195 ++++++++++++++++ AFQ/recognition/criteria.py | 135 +++++++---- AFQ/recognition/preprocess.py | 12 +- AFQ/recognition/recognize.py | 98 ++++---- AFQ/recognition/sparse_decisions.py | 116 ++++++++++ AFQ/recognition/utils.py | 25 +++ AFQ/tasks/viz.py | 3 + docs/source/references.bib | 10 + .../plot_001_group_afq_api.py | 4 +- 13 files changed, 695 insertions(+), 122 deletions(-) create mode 100644 AFQ/recognition/clustering.py create mode 100644 AFQ/recognition/sparse_decisions.py diff --git a/.codespellrc b/.codespellrc index 7035fe17..fdf6f1d6 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ [codespell] skip = [setup.cfg] -ignore-words-list = Reson, DNE, ACI, FPT, sagital, saggital, abd, Joo, Mapp, Commun, vor, Claus \ No newline at end of file +ignore-words-list = Reson, DNE, ACI, FPT, sagital, saggital, abd, Joo, Mapp, Commun, vor, Claus, coo \ No newline at end of file diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 2c7497bb..51f4a0c6 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -297,12 +297,75 @@ def default_bd(): "project": "L/R", "entire_core": "Anterior", }, - # "Left Inferior Longitudinal": {"core": "Right"}, - "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25, "max_len": 60}, - "isolation_forest": {}, + "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", "primary_axis_percentage": 40, + "ORG_spectral_subbundles": SpectralSubbundleDict( + { + "Left Vertical Occipital I": { + "cluster_ID": 89, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Left Vertical Occipital II": { + "cluster_ID": 82, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Left Vertical Occipital III": { + "cluster_ID": 83, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Left Vertical Occipital IV": { + "cluster_ID": 21, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Left Vertical Occipital V": { + "cluster_ID": 454, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + }, + remove_cluster_IDs=[ + 27, + 100, + 4, + 6, + 13, + 17, + 22, + 23, + 38, + 48, + 50, + 53, + 64, + 65, + 66, + 84, + 87, + 88, + 98, + ], + ), }, "Right Vertical Occipital": { "cross_midline": False, @@ -314,15 +377,84 @@ def default_bd(): "project": "L/R", "entire_core": "Anterior", }, - # "Right Inferior Longitudinal": {"core": "Left"}, - "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25, "max_len": 60}, - "isolation_forest": {}, + "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", "primary_axis_percentage": 40, + "ORG_spectral_subbundles": SpectralSubbundleDict( + { + "Right Vertical Occipital I": { + "cluster_ID": 89, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Right Vertical Occipital II": { + "cluster_ID": 82, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Right Vertical Occipital III": { + "cluster_ID": 83, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Right Vertical Occipital IV": { + "cluster_ID": 21, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Right Vertical Occipital V": { + "cluster_ID": 454, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + }, + remove_cluster_IDs=[ + 27, + 100, + 4, + 6, + 13, + 17, + 22, + 23, + 38, + 48, + 50, + 53, + 64, + 65, + 66, + 84, + 87, + 88, + 98, + ], + ), }, }, - citations={"Yeatman2012", "takemura2017occipital", "Tzourio-Mazoyer2002"}, + citations={ + "Yeatman2012", + "takemura2017occipital", + "Tzourio-Mazoyer2002", + "zhang2018anatomically", + "Hua2008", + }, ) @@ -334,31 +466,49 @@ def slf_bd(): "include": [templates["SFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, + "Left Cingulum Cingulate": { + "node_thresh": 20, + }, }, "Left Superior Longitudinal II": { "include": [templates["MFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, + "Left Cingulum Cingulate": { + "node_thresh": 20, + }, }, "Left Superior Longitudinal III": { "include": [templates["PrgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, + "Left Cingulum Cingulate": { + "node_thresh": 20, + }, }, "Right Superior Longitudinal I": { "include": [templates["SFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, + "Right Cingulum Cingulate": { + "node_thresh": 20, + }, }, "Right Superior Longitudinal II": { "include": [templates["MFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, + "Right Cingulum Cingulate": { + "node_thresh": 20, + }, }, "Right Superior Longitudinal III": { "include": [templates["PrgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, + "Right Cingulum Cingulate": { + "node_thresh": 20, + }, }, }, citations={"Sagi2024"}, @@ -1043,7 +1193,6 @@ def __init__( self.resample_to = resample_to self.resample_subject_to = resample_subject_to self.keep_in_memory = keep_in_memory - self.max_includes = 3 self.citations = citations if self.citations is None: self.citations = set() @@ -1097,10 +1246,6 @@ def __init__( def __print__(self): print(self._dict) - def update_max_includes(self, new_max): - if new_max > self.max_includes: - self.max_includes = new_max - def _use_bids_info(self, roi_or_sl, bids_layout, bids_path, subject, session): if isinstance(roi_or_sl, dict) and "roi" not in roi_or_sl: suffix = roi_or_sl.get("suffix", "dwi") @@ -1203,8 +1348,6 @@ def __getitem__(self, key): def __setitem__(self, key, item): self._dict[key] = item - if hasattr(item, "get"): - self.update_max_includes(len(item.get("include", []))) if key not in self.bundle_names: self.bundle_names.append(key) @@ -1464,6 +1607,46 @@ def __add__(self, other): ) +class SpectralSubbundleDict(BundleDict): + """ + A BundleDict where each bundle is defined as a spectral subbundle of a + larger bundle. See `Defining Custom Bundle Dictionaries` in the documentation + for details. + """ + + def __init__( + self, + bundle_info, + resample_to=None, + resample_subject_to=False, + keep_in_memory=False, + citations=None, + remove_cluster_IDs=None, + ): + super().__init__( + bundle_info, resample_to, resample_subject_to, keep_in_memory, citations + ) + if remove_cluster_IDs is None: + remove_cluster_IDs = [] + self.remove_cluster_IDs = remove_cluster_IDs + self.cluster_IDs = [] + self.id_to_name = {} + for b_name, b_info in bundle_info.items(): + if "cluster_ID" not in b_info: + raise ValueError( + ( + f"Bundle {b_name} does not have a cluster_ID. " + "All bundles in a SpectralSubbundleDict must have a cluster_ID." + ) + ) + self.cluster_IDs.append(b_info["cluster_ID"]) + self.id_to_name[b_info["cluster_ID"]] = b_name + self.all_cluster_IDs = self.remove_cluster_IDs + self.cluster_IDs + + def get_subbundle_name(self, cluster_id): + return self.id_to_name.get(cluster_id, None) + + def apply_to_roi_dict( dict_, func, diff --git a/AFQ/definitions/mapping.py b/AFQ/definitions/mapping.py index ce9ee6ec..ed5ca043 100644 --- a/AFQ/definitions/mapping.py +++ b/AFQ/definitions/mapping.py @@ -212,6 +212,7 @@ def get_fnames(self, extension, base_fname, sub_name, tmpl_name): return mapping_file, meta_fname def prealign(self, reg_subject, reg_template): + logger.info("Calculating affine pre-alignment...") _, aff = affine_registration(reg_subject, reg_template, **self.affine_kwargs) return aff @@ -338,6 +339,8 @@ def get_for_subses( reg_prealign = self.prealign(reg_subject, reg_template) else: reg_prealign = None + + logger.info("Calculating SyN registration...") _, mapping = syn_registration( reg_subject.get_fdata(), reg_template.get_fdata(), diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index db741fe0..59d3687f 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -101,6 +101,7 @@ def clean_by_orientation_mahalanobis( if np.sum(idx_dist) < min_sl: # need to sort and return exactly min_sl: idx = idx[np.argsort(np.sum(m_dist, axis=-1))[:min_sl].astype(int)] + idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) @@ -231,6 +232,7 @@ def clean_bundle( if np.sum(idx_belong) < min_sl: # need to sort and return exactly min_sl: idx = idx[np.argsort(np.sum(m_dist, axis=-1))[:min_sl].astype(int)] + idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) @@ -360,6 +362,7 @@ def clean_by_isolation_forest( if np.sum(idx_belong) < min_sl: # need to sort and return exactly min_sl: idx = idx[np.argsort(-sl_outliers)[:min_sl].astype(int)] + idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) diff --git a/AFQ/recognition/clustering.py b/AFQ/recognition/clustering.py new file mode 100644 index 00000000..2d33b6c3 --- /dev/null +++ b/AFQ/recognition/clustering.py @@ -0,0 +1,195 @@ +# Original source: github.com/SlicerDMRI/whitematteranalysis +# Copyright 2026 BWH and 3D Slicer contributors +# Licensed under 3D Slicer license (BSD style; https://github.com/SlicerDMRI/whitematteranalysis/blob/master/License.txt) # noqa +# Modified by John Kruper for pyAFQ +# Modifications: +# 1. Only mean distance included, and mean distance replaced with numba version. +# 2. Uses atlas data from dictionary and numpy files rather than pickled files, +# to avoid additional dependencies. +# 3. Added function to move template streamlines +# to subject space to calculate distances. + +import numpy as np +import scipy +from dipy.io.stateful_tractogram import Space +from numba import njit, prange + +import AFQ.data.fetch as afd +import AFQ.recognition.utils as abu +import AFQ.utils.streamlines as aus + + +@njit(parallel=True) +def _compute_mean_euclidean_matrix(group_n, group_m): + len_n = group_n.shape[0] + len_m = group_m.shape[0] + num_points = group_n.shape[1] + + dist_matrix = np.empty((len_n, len_m), dtype=np.float64) + + for i in prange(len_n): + for j in range(len_m): + sum_dist = 0.0 + sum_dist_ref = 0.0 + + for k in range(num_points): + dx = group_n[i, k, 0] - group_m[j, k, 0] + dx_ref = group_n[i, k, 0] + group_m[j, k, 0] + dy = group_n[i, k, 1] - group_m[j, k, 1] + dz = group_n[i, k, 2] - group_m[j, k, 2] + + sum_dist += np.sqrt(dx * dx + dy * dy + dz * dz) + sum_dist_ref += np.sqrt(dx_ref * dx_ref + dy * dy + dz * dz) + + mean_d = sum_dist / num_points + mean_d_ref = sum_dist_ref / num_points + + final_d = min(mean_d, mean_d_ref) + dist_matrix[i, j] = final_d * final_d + + return dist_matrix.T + + +def _distance_to_similarity(distance, sigmasq): + similarities = np.exp(-distance / (sigmasq)) + + return similarities + + +def _rectangular_similarity_matrix(fgarray_sub, fgarray_atlas, sigma): + distances = _compute_mean_euclidean_matrix(fgarray_sub, fgarray_atlas) + + sigmasq = sigma * sigma + similarity_matrix = _distance_to_similarity(distances, sigmasq) + + return similarity_matrix + + +def spectral_atlas_label( + sub_fgarray, + atlas_fgarray, + atlas_data=None, + sigma_multiplier=1.0, + cluster_indices=None, +): + """ + Use an existing atlas to label a new streamlines. + + Parameters + ---------- + sub_fgarray : ndarray + Resampled fiber group to be labeled. + atlas_fgarray : ndarray + Resampled atlas to use for labelling. + atlas_data : dict, optional + Precomputed atlas data formatted as a dictionary of arrays and floats. + See `afd.read_org800_templates` as a reference. + sigma_multiplier : float, optional + Multiplier for the sigma value used in computing the similarity + matrix. Default is 1.0. + cluster_indices : list of int, optional + If provided, only these cluster indices from the atlas will be used + for labeling. Default is None, which uses all clusters. + + Returns + ------- + tuple of (ndarray, ndarray) + Cluster indices for all the fibers and their embedding + """ + if atlas_data is None: + atlas_data = afd.read_org800_templates(load_trx=False) + + number_fibers = sub_fgarray.shape[0] + sz = atlas_fgarray.shape[0] + + # Compute fiber similarities. + B = _rectangular_similarity_matrix( + sub_fgarray, atlas_fgarray, sigma=atlas_data["sigma"] * sigma_multiplier + ) + + # Do Normalized Cuts transform of similarity matrix. + # row sum estimate for current B part of the matrix + row_sum_2 = np.sum(B, axis=0) + np.dot(atlas_data["row_sum_matrix"], B) + + # This happens plenty in our cases. Why? + # Maybe a probabilistic vs UKF thing? + # In practice, this is not an issue since we just set to a small value. + if any(row_sum_2 <= 0): + row_sum_2[row_sum_2 < 0] = 1e-4 + + # Normalized cuts normalization + row_sum = np.concatenate((atlas_data["row_sum_1"], row_sum_2)) + dhat = np.sqrt(np.divide(1, row_sum)) + B = np.multiply(B, np.outer(dhat[0:sz], dhat[sz:].T)) + + # Compute embedding using eigenvectors + V = np.dot( + np.dot(B.T, atlas_data["e_vec"]), np.diag(np.divide(1.0, atlas_data["e_val"])) + ) + V = np.divide(V, atlas_data["e_vec_norm"]) + n_eigen = int(atlas_data["number_of_eigenvectors"]) + embed = np.zeros((number_fibers, n_eigen)) + for i in range(0, n_eigen): + embed[:, i] = np.divide(V[:, -(i + 2)], V[:, -1]) + + # Label streamlines using centroids from atlas + if cluster_indices is not None: + centroids = atlas_data["centroids"][cluster_indices, :] + cluster_idx, _ = scipy.cluster.vq.vq(embed, centroids) + cluster_idx = np.array([cluster_indices[i] for i in cluster_idx]) + else: + cluster_idx, _ = scipy.cluster.vq.vq(embed, atlas_data["centroids"]) + + return cluster_idx, embed + + +def subcluster_by_atlas( + sub_trk, mapping, dwi_ref, cluster_indices, atlas_data=None, n_points=20 +): + """ + Use an existing atlas to label a new set of streamlines, and return the + cluster indices for each streamline. + + Parameters + ---------- + sub_fgarray : ndarray + Resampled fiber group in VOX to be labeled. + mapping : DIPY or pyAFQ mapping + Mapping to use to move streamlines. + dwi_ref : Nifti1Image + Image defining reference for where the atlas streamlines move to. + cluster_indices : list of int + Cluster indices from the atlas to use for labeling. + atlas_data : dict, optional + Precomputed atlas data formatted as a dictionary of arrays and floats. + See `afd.read_org800_templates` as a reference. + n_points : int, optional + Number of points to resample streamlines to for labeling. Default is 20. + """ + + if atlas_data is None: + atlas_data = afd.read_org800_templates() + atlas_sft = atlas_data["tracks_reoriented"] + + moved_atlas_sft = aus.move_streamlines( + atlas_sft, "subject", mapping, dwi_ref, to_space=Space.RASMM + ) + atlas_fgarray = np.array(abu.resample_tg(moved_atlas_sft.streamlines, n_points)) + + # Note: if we need more efficiency, + # we could modify the code to consider: + # voxel size, midline axis, and midline location + # then we should be able to do these calculations in + # voxel space without having to move the subject streamlines + # to rasmm (but this is not a bottleneck right now) + sub_trk.to_rasmm() + sub_fgarray = np.array(abu.resample_tg(sub_trk.streamlines, n_points)) + + cluster_idxs, _ = spectral_atlas_label( + sub_fgarray, + atlas_fgarray, + atlas_data=atlas_data, + cluster_indices=cluster_indices, + ) + + return cluster_idxs diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index c931dcbf..3f454dcc 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -19,6 +19,7 @@ import AFQ.recognition.roi as abr import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import apply_to_roi_dict +from AFQ.recognition.clustering import subcluster_by_atlas from AFQ.utils.stats import chunk_indices from AFQ.utils.streamlines import move_streamlines @@ -45,6 +46,8 @@ "primary_axis_percentage", "inc_addtol", "exc_addtol", + "ORG_spectral_subbundles", + "cluster_ID", ] @@ -138,7 +141,7 @@ def primary_axis(b_sls, bundle_def, img, **kwargs): b_sls.select(accept_idx, "orientation") -def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): +def include(b_sls, bundle_def, preproc_imap, n_cpus, **kwargs): accept_idx = b_sls.initiate_selection("include") flip_using_include = len(bundle_def["include"]) > 1 and not b_sls.oriented_yet @@ -176,17 +179,18 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): b_sls.get_selected_sls(), bundle_def["include"], include_roi_tols ) - roi_closest = -np.ones((max_includes, len(b_sls)), dtype=np.int32) - roi_dists = -np.ones((max_includes, len(b_sls)), dtype=np.float32) + n_inc = len(bundle_def["include"]) + roi_closest = np.zeros((n_inc, len(b_sls)), dtype=np.int32) + roi_dists = np.zeros((n_inc, len(b_sls)), dtype=np.float32) if flip_using_include: to_flip = np.ones_like(accept_idx, dtype=np.bool_) for sl_idx, inc_result in enumerate(inc_results): sl_accepted, sl_closest, sl_dists = inc_result if sl_accepted: + roi_closest[:, sl_idx] = sl_closest + roi_dists[:, sl_idx] = sl_dists if len(sl_closest) > 1: - roi_closest[: len(sl_closest), sl_idx] = sl_closest - roi_dists[: len(sl_dists), sl_idx] = sl_dists # Only accept SLs that, when cut, are meaningful if (len(sl_closest) < 2) or abs(sl_closest[0] - sl_closest[-1]) > 1: # Flip sl if it is close to second ROI @@ -194,8 +198,8 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): if flip_using_include: to_flip[sl_idx] = sl_closest[0] > sl_closest[-1] if to_flip[sl_idx]: - roi_closest[: len(sl_closest), sl_idx] = np.flip(sl_closest) - roi_dists[: len(sl_dists), sl_idx] = np.flip(sl_dists) + roi_closest[:, sl_idx] = np.flip(sl_closest) + roi_dists[:, sl_idx] = np.flip(sl_dists) accept_idx[sl_idx] = 1 else: accept_idx[sl_idx] = 1 @@ -301,14 +305,15 @@ def qb_thresh(b_sls, bundle_def, preproc_imap, clip_edges, **kwargs): def clean_by_other_bundle( - b_sls, bundle_def, img, preproc_imap, other_bundle_name, other_bundle_sls, **kwargs + b_sls, bundle_def, img, other_bundle_name, other_bundle_sls, **kwargs ): cleaned_idx = b_sls.initiate_selection(other_bundle_name) cleaned_idx = 1 + flipped_sls = b_sls.get_selected_sls(flip=True) if "overlap" in bundle_def[other_bundle_name]: cleaned_idx_overlap = abo.clean_by_overlap( - b_sls.get_selected_sls(), + flipped_sls, other_bundle_sls, bundle_def[other_bundle_name]["overlap"], img, @@ -319,7 +324,7 @@ def clean_by_other_bundle( if "node_thresh" in bundle_def[other_bundle_name]: cleaned_idx_node_thresh = abo.clean_by_overlap( - b_sls.get_selected_sls(), + flipped_sls, other_bundle_sls, bundle_def[other_bundle_name]["node_thresh"], img, @@ -331,8 +336,7 @@ def clean_by_other_bundle( if "core" in bundle_def[other_bundle_name]: cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["core"].lower(), - preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - # the extra specificity of 100 points is needed + np.array(abu.resample_tg(flipped_sls, 100)), np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, False, @@ -342,8 +346,8 @@ def clean_by_other_bundle( if "entire_core" in bundle_def[other_bundle_name]: cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["entire_core"].lower(), - preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - np.array(abu.resample_tg(other_bundle_sls, 20)), + np.array(abu.resample_tg(flipped_sls, 100)), + np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, True, ) @@ -392,11 +396,7 @@ def run_bundle_rec_plan( reg_template, preproc_imap, bundle_name, - bundle_idx, - bundle_to_flip, - bundle_roi_closest, - bundle_roi_dists, - bundle_decisions, + recognized_bundles_dict, **segmentation_params, ): # Warp ROIs @@ -430,20 +430,25 @@ def check_space(roi): ) logger.info(f"Time to prep ROIs: {time() - start_time}s") - b_sls = abu.SlsBeingRecognized( - tg.streamlines, - logger, - segmentation_params["save_intermediates"], - bundle_name, - img, - len(bundle_def.get("include", [])), - ) + if isinstance(tg, abu.SlsBeingRecognized): + # This only occurs when your inside a subbundle, + # in which case we want to keep the same SlsBeingRecognized object so that + # we can keep track of the same streamlines and their orientations + b_sls = tg + else: + b_sls = abu.SlsBeingRecognized( + tg.streamlines, + logger, + segmentation_params["save_intermediates"], + bundle_name, + img, + len(bundle_def.get("include", [])), + ) inputs = {} inputs["b_sls"] = b_sls inputs["preproc_imap"] = preproc_imap inputs["bundle_def"] = bundle_def - inputs["max_includes"] = bundle_dict.max_includes inputs["mapping"] = mapping inputs["img"] = img inputs["reg_template"] = reg_template @@ -454,7 +459,7 @@ def check_space(roi): if ( (potential_criterion not in criteria_order_post_other_bundles) and (potential_criterion not in criteria_order_pre_other_bundles) - and (potential_criterion not in bundle_dict.bundle_names) + and (potential_criterion not in recognized_bundles_dict.keys()) and (potential_criterion not in valid_noncriterion) ): raise ValueError( @@ -464,7 +469,7 @@ def check_space(roi): "Valid criteria are:\n" f"{criteria_order_pre_other_bundles}\n" f"{criteria_order_post_other_bundles}\n" - f"{bundle_dict.bundle_names}\n" + f"{recognized_bundles_dict.keys()}\n" f"{valid_noncriterion}\n" ) ) @@ -473,13 +478,14 @@ def check_space(roi): if b_sls and criterion in bundle_def: inputs[criterion] = globals()[criterion](**inputs) if b_sls: - for ii, bundle_name in enumerate(bundle_dict.bundle_names): - if bundle_name in bundle_def.keys(): - idx = np.where(bundle_decisions[:, ii])[0] + for o_bundle_name in recognized_bundles_dict.keys(): + if o_bundle_name in bundle_def.keys(): clean_by_other_bundle( **inputs, - other_bundle_name=bundle_name, - other_bundle_sls=tg.streamlines[idx], + other_bundle_name=o_bundle_name, + other_bundle_sls=recognized_bundles_dict[ + o_bundle_name + ].get_selected_sls(flip=True), ) for criterion in criteria_order_post_other_bundles: if b_sls and criterion in bundle_def: @@ -490,6 +496,22 @@ def check_space(roi): ): mahalanobis(**inputs) + # If you don't cross the midline, we remove streamliens + # entirely on the wrong side of the midline here after filtering + if b_sls and "cross_midline" in bundle_def and not bundle_def["cross_midline"]: + b_sls.initiate_selection("Wrong side of mid.") + zero_coord = preproc_imap["zero_coord"] + lr_axis = preproc_imap["lr_axis"] + avg_side = np.sign( + np.mean( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs, :, lr_axis] + - zero_coord, + axis=1, + ) + ) + majority_side = np.sign(np.sum(avg_side)) + b_sls.select(avg_side == majority_side, "Wrong side of mid.") + if b_sls and not b_sls.oriented_yet: raise ValueError( "pyAFQ was unable to consistently orient streamlines " @@ -500,13 +522,40 @@ def check_space(roi): ) if b_sls: - bundle_to_flip[b_sls.selected_fiber_idxs, bundle_idx] = b_sls.sls_flipped.copy() - bundle_decisions[b_sls.selected_fiber_idxs, bundle_idx] = 1 - if hasattr(b_sls, "roi_closest"): - bundle_roi_closest[b_sls.selected_fiber_idxs, bundle_idx, :] = ( - b_sls.roi_closest.copy() + if "ORG_spectral_subbundles" in bundle_def: + subdict = bundle_def["ORG_spectral_subbundles"] + c_ids = subdict.cluster_IDs + b_sls.initiate_selection( + (f"ORG spectral clustering, {len(c_ids)} subbundles being recognized") ) - if hasattr(b_sls, "roi_dists"): - bundle_roi_dists[b_sls.selected_fiber_idxs, bundle_idx, :] = ( - b_sls.roi_dists.copy() + + sub_sft = StatefulTractogram( + b_sls.get_selected_sls(flip=True), img, Space.VOX + ) + cluster_labels = subcluster_by_atlas( + sub_sft, mapping, img, subdict.all_cluster_IDs, n_points=40 ) + clusters_being_recognized = [] + for c_id in c_ids: + bundle_name = subdict.get_subbundle_name(c_id) + n_roi = len(subdict[bundle_name].get("include", [])) + cluster_b_sls = b_sls.copy(bundle_name, n_roi) + cluster_b_sls.select(cluster_labels == c_id, f"Cluster {c_id}") + clusters_being_recognized.append(cluster_b_sls) + + for ii, c_id in enumerate(c_ids): + bundle_name = subdict.get_subbundle_name(c_id) + run_bundle_rec_plan( + bundle_def["ORG_spectral_subbundles"], + clusters_being_recognized[ii], + mapping, + img, + reg_template, + preproc_imap, + bundle_name, + recognized_bundles_dict, + **segmentation_params, + ) + else: + b_sls.bundle_def = bundle_def + recognized_bundles_dict[bundle_name] = b_sls diff --git a/AFQ/recognition/preprocess.py b/AFQ/recognition/preprocess.py index 8b3ce657..52531045 100644 --- a/AFQ/recognition/preprocess.py +++ b/AFQ/recognition/preprocess.py @@ -27,7 +27,7 @@ def fgarray(tg): return fg_array -@immlib.calc("crosses") +@immlib.calc("crosses", "lr_axis", "zero_coord") def crosses(fgarray, img): """ Classify the streamlines by whether they cross the midline. @@ -45,9 +45,13 @@ def crosses(fgarray, img): lr_axis = idx break - return np.logical_and( - np.any(fgarray[:, :, lr_axis] > zero_coord[lr_axis], axis=1), - np.any(fgarray[:, :, lr_axis] < zero_coord[lr_axis], axis=1), + return ( + np.logical_and( + np.any(fgarray[:, :, lr_axis] > zero_coord[lr_axis], axis=1), + np.any(fgarray[:, :, lr_axis] < zero_coord[lr_axis], axis=1), + ), + lr_axis, + zero_coord[lr_axis], ) diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 8fb02fec..73602909 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -6,10 +6,12 @@ import numpy as np from dipy.io.stateful_tractogram import Space, StatefulTractogram +import AFQ.recognition.sparse_decisions as ars import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import BundleDict from AFQ.recognition.criteria import run_bundle_rec_plan from AFQ.recognition.preprocess import get_preproc_plan +from AFQ.utils.path import write_json logger = logging.getLogger("AFQ") @@ -155,14 +157,7 @@ def recognize( tg.to_vox() n_streamlines = len(tg) - bundle_decisions = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.float32) - bundle_to_flip = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.bool_) - bundle_roi_closest = -np.ones( - (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.int32 - ) - bundle_roi_dists = -np.ones( - (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.float32 - ) + recognized_bundles_dict = {} fiber_groups = {} meta = {} @@ -170,7 +165,7 @@ def recognize( preproc_imap = get_preproc_plan(img, tg, dist_to_waypoint, dist_to_atlas) logger.info("Assigning Streamlines to Bundles") - for bundle_idx, bundle_name in enumerate(bundle_dict.bundle_names): + for bundle_name in bundle_dict.bundle_names: logger.info(f"Finding Streamlines for {bundle_name}") run_bundle_rec_plan( bundle_dict, @@ -180,11 +175,7 @@ def recognize( reg_template, preproc_imap, bundle_name, - bundle_idx, - bundle_to_flip, - bundle_roi_closest, - bundle_roi_dists, - bundle_decisions, + recognized_bundles_dict, clip_edges=clip_edges, n_cpus=n_cpus, rb_recognize_params=rb_recognize_params, @@ -199,10 +190,18 @@ def recognize( if save_intermediates is not None: os.makedirs(save_intermediates, exist_ok=True) - bc_path = op.join(save_intermediates, "sls_bundle_decisions.npy") - np.save(bc_path, bundle_decisions) + bc_path = op.join(save_intermediates, "sls_bundle_decisions.json") + write_json( + bc_path, + { + b_name: b_sls.selected_fiber_idxs.tolist() + for b_name, b_sls in recognized_bundles_dict.items() + }, + ) + + sparse_dists = ars.compute_sparse_decisions(recognized_bundles_dict, n_streamlines) - conflicts = np.sum(np.sum(bundle_decisions, axis=1) > 1) + conflicts = ars.get_conflict_count(sparse_dists) if conflicts > 0: logger.info( ( @@ -215,63 +214,38 @@ def recognize( ) ) - # Weight by distance to ROI - valid_dists = bundle_roi_dists >= -0.5 # i.e., not -1 - has_any_valid_roi = np.any(valid_dists, axis=2) - if np.any(has_any_valid_roi): - dist_sums = np.sum(np.where(valid_dists, bundle_roi_dists, 0), axis=2) - max_roi_dist_sum = float(dist_sums[has_any_valid_roi].max() + 1) - final_mask = (bundle_decisions > 0) & has_any_valid_roi - bundle_decisions[final_mask] = 2 - (dist_sums[final_mask] / max_roi_dist_sum) - - bundle_decisions = np.concatenate( - (bundle_decisions, np.ones((n_streamlines, 1))), axis=1 - ) - bundle_decisions = np.argmax(bundle_decisions, -1) + ars.remove_conflicts(sparse_dists, recognized_bundles_dict) # We do another round through, so that we can: # 1. Clip streamlines according to ROIs # 2. Re-orient streamlines logger.info("Re-orienting streamlines to consistent directions") - for bundle_idx, bundle in enumerate(bundle_dict.bundle_names): - logger.info(f"Processing {bundle}") + for b_name, r_bd in recognized_bundles_dict.items(): + logger.info(f"Processing {b_name}") - select_idx = np.where(bundle_decisions == bundle_idx)[0] - - if len(select_idx) == 0: + if len(r_bd.selected_fiber_idxs) == 0: # There's nothing here, set and move to the next bundle: - if "bundlesection" in bundle_dict.get_b_info(bundle): - for sb_name in bundle_dict.get_b_info(bundle)["bundlesection"]: + if "bundlesection" in bundle_dict.get_b_info(b_name): + for sb_name in bundle_dict.get_b_info(b_name)["bundlesection"]: _return_empty(sb_name, return_idx, fiber_groups, img) else: - _return_empty(bundle, return_idx, fiber_groups, img) + _return_empty(b_name, return_idx, fiber_groups, img) continue - # Use a list here, because ArraySequence doesn't support item - # assignment: - select_sl = list(tg.streamlines[select_idx]) - roi_closest = bundle_roi_closest[select_idx, bundle_idx, :] - n_includes = len(bundle_dict.get_b_info(bundle).get("include", [])) - if clip_edges and n_includes > 1: - logger.info("Clipping Streamlines by ROI") - select_sl = abu.cut_sls_by_closest( - select_sl, roi_closest, (0, n_includes - 1), in_place=True - ) - - to_flip = bundle_to_flip[select_idx, bundle_idx] - b_def = dict(bundle_dict.get_b_info(bundle_name)) + b_def = r_bd.bundle_def if "bundlesection" in b_def: - for sb_name, sb_include_cuts in bundle_dict.get_b_info(bundle)[ - "bundlesection" - ].items(): + for sb_name, sb_include_cuts in b_def["bundlesection"].items(): bundlesection_select_sl = abu.cut_sls_by_closest( - select_sl, roi_closest, sb_include_cuts, in_place=False + r_bd.get_selected_sls(), + r_bd.roi_closest, + sb_include_cuts, + in_place=False, ) _add_bundle_to_fiber_group( sb_name, bundlesection_select_sl, - select_idx, - to_flip, + r_bd.selected_fiber_idxs, + r_bd.sls_flipped, return_idx, fiber_groups, img, @@ -279,9 +253,15 @@ def recognize( _add_bundle_to_meta(sb_name, b_def, meta) else: _add_bundle_to_fiber_group( - bundle, select_sl, select_idx, to_flip, return_idx, fiber_groups, img + b_name, + r_bd.get_selected_sls(cut=clip_edges), + r_bd.selected_fiber_idxs, + r_bd.sls_flipped, + return_idx, + fiber_groups, + img, ) - _add_bundle_to_meta(bundle, b_def, meta) + _add_bundle_to_meta(b_name, b_def, meta) return fiber_groups, meta diff --git a/AFQ/recognition/sparse_decisions.py b/AFQ/recognition/sparse_decisions.py new file mode 100644 index 00000000..114cc8ee --- /dev/null +++ b/AFQ/recognition/sparse_decisions.py @@ -0,0 +1,116 @@ +import numpy as np +from scipy.sparse import csr_matrix + + +def compute_sparse_decisions(bundles_being_recognized, n_streamlines): + """ + Compute a sparse matrix of distances to ROIs for the streamlines that are + currently being recognized. This can be used to weight decisions by distance + to ROIs, without having to create a dense matrix of distances for all + streamlines and all bundles. + + Parameters + ---------- + bundles_being_recognized : dict + A dictionary of SlsBeingRecognized objects, keyed by bundle name. + n_streamlines : int + The total number of streamlines in the original tractogram. + + Returns + ------- + csr_matrix + A sparse matrix of shape (number of bundles being recognized, n_streamlines), + where the entry (i, j) is a score: + bundles with ROIs result in weights [2.0 to 3.0] with higher scores + for streamlines closer to ROIs + Non-ROI bundles result in weight 1.0 + Everything else is 0.0 (implicit in sparse matrices) + """ + rows, cols, data = [], [], [] + epsilon = 1e-6 + + global_max_dist = 0.0 + for b in bundles_being_recognized.values(): + if hasattr(b, "roi_dists"): + global_max_dist = max(global_max_dist, np.sum(b.roi_dists, axis=-1).max()) + + norm_factor = global_max_dist + 1.0 + + for b_idx, name in enumerate(bundles_being_recognized.keys()): + bundle = bundles_being_recognized[name] + indices = bundle.selected_fiber_idxs + + if hasattr(bundle, "roi_dists"): + dists = np.sum(bundle.roi_dists, axis=-1) + dists = np.maximum(dists, epsilon) + bundle_weights = dists / norm_factor + else: + bundle_weights = np.full(len(indices), 2.0, dtype=np.float32) + + rows.extend([b_idx] * len(indices)) + cols.extend(indices) + data.extend(bundle_weights) + + sparse_scores = csr_matrix( + (data, (rows, cols)), shape=(len(bundles_being_recognized), n_streamlines) + ) + + # Final Decision: 3.0 - Score + # ROI bundles result in weights [2.0 to 3.0] + # No-ROI bundles result in weight 1.0 + sparse_scores.data = 3.0 - sparse_scores.data + + return sparse_scores + + +def get_conflict_count(sparse_scores): + """ + Count how many streamlines are being considered for more than one bundle + """ + sorted_indices = np.sort(sparse_scores.indices) + is_duplicate = np.diff(sorted_indices) == 0 + num_conflicts = np.sum(is_duplicate) + + return num_conflicts + + +def remove_conflicts(sparse_scores, bundles_being_recognized): + """ + Returns a dictionary of {bundle_name: np.array(accepted_indices)} + """ + coo = sparse_scores.tocoo() + + order = np.lexsort((-coo.data, coo.col)) + + mask = np.concatenate(([True], np.diff(coo.col[order]) != 0)) + winner_rows = coo.row[order][mask] + winner_cols = coo.col[order][mask] + + row_sort = np.argsort(winner_rows) + winner_rows = winner_rows[row_sort] + winner_cols = winner_cols[row_sort] + + num_bundles = len(bundles_being_recognized) + split_indices = np.searchsorted(winner_rows, np.arange(num_bundles + 1)) + + for i, b_name in enumerate(bundles_being_recognized.keys()): + b_sls = bundles_being_recognized[b_name] + if np.any(b_sls.selected_fiber_idxs[:-1] > b_sls.selected_fiber_idxs[1:]): + raise NotImplementedError( + f"Bundle '{b_name}' has unsorted selected_fiber_idxs. " + "The searchsorted optimization requires sorted indices." + "This is a bug in the implementation of the bundle " + "recognition procedure, please report it to the developers." + ) + + accept_idx = b_sls.initiate_selection(f"{b_name} conflicts") + start, end = split_indices[i], split_indices[i + 1] + bundle_winners = winner_cols[start:end] + + if len(bundle_winners) > 0: + local_positions = np.searchsorted(b_sls.selected_fiber_idxs, bundle_winners) + accept_idx[local_positions] = True + b_sls.select(local_positions, "conflicts") + else: + b_sls.select(accept_idx, "conflicts") + bundles_being_recognized.pop(b_name) diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 0a60552d..678c4eed 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -1,3 +1,4 @@ +import copy import logging import os.path as op from time import time @@ -188,3 +189,27 @@ def __bool__(self): def __len__(self): return len(self.selected_fiber_idxs) + + def copy(self, new_name, n_roi): + new_copy = copy.copy(self) + new_copy.b_name = new_name + if n_roi > 0: + if self.n_roi > 0: + raise NotImplementedError( + ( + "You cannot have includes in the original bundle and" + " subbundles; only one or the other." + ) + ) + else: + new_copy.n_roi = n_roi + + new_copy.selected_fiber_idxs = self.selected_fiber_idxs.copy() + new_copy.sls_flipped = self.sls_flipped.copy() + + if hasattr(self, "roi_closest"): + new_copy.roi_closest = self.roi_closest.copy() + if hasattr(self, "roi_dists"): + new_copy.roi_dists = self.roi_dists.copy() + + return new_copy diff --git a/AFQ/tasks/viz.py b/AFQ/tasks/viz.py index 0fae71a2..cf1a8909 100644 --- a/AFQ/tasks/viz.py +++ b/AFQ/tasks/viz.py @@ -203,6 +203,9 @@ def viz_indivBundle( if "bundlesection" in b_info: for sb_name in b_info["bundlesection"]: segmented_bname_to_roi_bname[sb_name] = b_name + elif "ORG_spectral_subbundles" in b_info: + for sb_name in b_info["ORG_spectral_subbundles"]: + segmented_bname_to_roi_bname[sb_name] = b_name else: segmented_bname_to_roi_bname[b_name] = b_name diff --git a/docs/source/references.bib b/docs/source/references.bib index ec7b1eb7..0588b036 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -538,6 +538,16 @@ @InProceedings{Kruper2023 isbn="978-3-031-47292-3" } +@article{zhang2018anatomically, + title={An anatomically curated fiber clustering white matter atlas for consistent white matter tract parcellation across the lifespan}, + author={Zhang, Fan and Wu, Ye and Norton, Isaiah and Rigolo, Laura and Rathi, Yogesh and Makris, Nikos and O'Donnell, Lauren J}, + journal={Neuroimage}, + volume={179}, + pages={429--447}, + year={2018}, + publisher={Elsevier} +} + @ARTICLE{Hua2008, title = "Tract probability maps in stereotaxic spaces: analyses of white matter anatomy and tract-specific quantification", diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index 0b8f0980..749e470d 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -46,7 +46,7 @@ bids_path = afd.fetch_hbn_preproc( ["NDARAA948VFH"], - clear_previous_afq="all")[1] + clear_previous_afq="recog")[1] ########################################################################## # Set tractography parameters (optional) @@ -272,6 +272,8 @@ threshold = 3000 elif "Fronto-occipital" in ind: threshold = 10 + elif "Vertical Occipital" in ind: + threshold = 5 else: threshold = 15 if bundle_counts["n_streamlines"][ind] < threshold: From 184bcad4a02f7b51166596af0850363efa8b992f Mon Sep 17 00:00:00 2001 From: 36000 Date: Sat, 14 Feb 2026 23:55:58 +0900 Subject: [PATCH 48/86] put this back --- examples/tutorial_examples/plot_001_group_afq_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index 749e470d..3f2dbd00 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -46,7 +46,7 @@ bids_path = afd.fetch_hbn_preproc( ["NDARAA948VFH"], - clear_previous_afq="recog")[1] + clear_previous_afq="all")[1] ########################################################################## # Set tractography parameters (optional) From 4aea96044c95cf4e0ba047195a9d23882787b13b Mon Sep 17 00:00:00 2001 From: 36000 Date: Sun, 15 Feb 2026 11:26:17 +0900 Subject: [PATCH 49/86] fix up reco --- AFQ/recognition/criteria.py | 1 + AFQ/recognition/tests/test_recognition.py | 6 +++--- AFQ/utils/streamlines.py | 5 ++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 3f454dcc..8f227e85 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -274,6 +274,7 @@ def recobundles( "template", mapping, reg_template, + to_space=Space.RASMM, save_intermediates=save_intermediates, ).streamlines moved_sl_resampled = abu.resample_tg(moved_sl, 100) diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index 66c3133a..6e4812de 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -176,7 +176,7 @@ def test_segment_clip_edges_api(): def test_segment_reco(): # get bundles for reco method bundles_reco = afd.read_hcp_atlas(16) - bundle_names = ["CST_R", "CST_L"] + bundle_names = ["MCP"] for key in list(bundles_reco): if key not in bundle_names: bundles_reco.pop(key, None) @@ -193,8 +193,8 @@ def test_segment_reco(): ) # This condition should still hold - npt.assert_equal(len(fiber_groups), 2) - npt.assert_(len(fiber_groups["CST_R"]) > 0) + npt.assert_equal(len(fiber_groups), 1) + npt.assert_(len(fiber_groups["MCP"]) > 0) def test_exclusion_ROI(): diff --git a/AFQ/utils/streamlines.py b/AFQ/utils/streamlines.py index 52c7051c..6d64003f 100644 --- a/AFQ/utils/streamlines.py +++ b/AFQ/utils/streamlines.py @@ -180,7 +180,6 @@ def move_streamlines(tg, to, mapping, img, to_space=None, save_intermediates=Non else: moved_sl = mapping.transform_points_inverse(tg.streamlines) moved_sft = StatefulTractogram(moved_sl, img, Space.VOX) - moved_sft.to_rasmm() if save_intermediates is not None: save_tractogram( @@ -189,7 +188,7 @@ def move_streamlines(tg, to, mapping, img, to_space=None, save_intermediates=Non bbox_valid_check=False, ) if to_space is None: - tg.to_space(tg_og_space) + moved_sft.to_space(tg_og_space) else: - tg.to_space(to_space) + moved_sft.to_space(to_space) return moved_sft From 3b884c852dd75cae36f3624407b7f2adec8b24fb Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 16 Feb 2026 11:01:35 +0900 Subject: [PATCH 50/86] color new bundles --- AFQ/viz/plotly_backend.py | 19 +++++++------ AFQ/viz/utils.py | 60 ++++++++++++++++++++++++++++++++++----- 2 files changed, 64 insertions(+), 15 deletions(-) diff --git a/AFQ/viz/plotly_backend.py b/AFQ/viz/plotly_backend.py index 07d467e0..8b2f8745 100644 --- a/AFQ/viz/plotly_backend.py +++ b/AFQ/viz/plotly_backend.py @@ -40,8 +40,6 @@ def _inline_interact(figure, show, show_inline): def _to_color_range(num): - if num < 0: - num = 0 if num >= 0.999: num = 0.999 if num <= 0.001: @@ -232,9 +230,10 @@ def _draw_streamlines( def _plot_profiles(profiles, bundle_name, color, fig, scalar): if isinstance(profiles, pd.DataFrame): - sc_max = np.max(profiles[scalar].to_numpy()) - sc_90 = np.percentile(profiles[scalar].to_numpy(), 10) - sc_1 = np.percentile(profiles[scalar].to_numpy(), 99) + all_tp = profiles[scalar].to_numpy() + all_tp = np.max(all_tp) - all_tp + lim_0 = np.percentile(all_tp, 1) + lim_1 = np.percentile(all_tp, 90) profiles = profiles[profiles.tractID == bundle_name] x = profiles["nodeID"] @@ -242,10 +241,14 @@ def _plot_profiles(profiles, bundle_name, color, fig, scalar): line_color = [] for scalar_val in profiles[scalar].to_numpy(): - xformed_scalar = np.minimum( - (sc_max - scalar_val) / (sc_1 - sc_90) + sc_90 + 0.1, 0.999 + brightness = np.minimum( + np.maximum( + scalar_val - lim_0, + 0, + ), + lim_1, ) - line_color.append(_color_arr2str(xformed_scalar * color)) + line_color.append(_color_arr2str(brightness * color)) else: x = np.arange(len(profiles)) y = profiles diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index 1b82a306..7b4bd2d6 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -1,3 +1,4 @@ +import colorsys import logging import os.path as op from collections import OrderedDict @@ -17,6 +18,23 @@ __all__ = ["Viz"] + +def get_distinct_shades(base_rgb, n_steps, hue_shift): + """ + Creates distinct shades by shifting Hue + """ + hh, ll, ss = colorsys.rgb_to_hls(*base_rgb) + shades = [] + + for i in range(n_steps): + offset = i - (n_steps - 1) / 2 + + new_h = (hh + (offset * hue_shift)) % 1.0 + + shades.append(colorsys.hls_to_rgb(new_h, ll, ss)) + return shades + + viz_logger = logging.getLogger("AFQ") tableau_20 = [ (0.12156862745098039, 0.4666666666666667, 0.7058823529411765), @@ -51,6 +69,18 @@ small_font = 20 marker_size = 200 +slf_l_base = tableau_extension[0] +slf_r_base = tableau_extension[1] + +vof_l_base = tableau_20[6] +vof_r_base = tableau_20[7] + +slf_l_shades = get_distinct_shades(slf_l_base, 3, hue_shift=0.1) +slf_r_shades = get_distinct_shades(slf_r_base, 3, hue_shift=0.1) + +vof_l_shades = get_distinct_shades(vof_l_base, 5, hue_shift=0.12) +vof_r_shades = get_distinct_shades(vof_r_base, 5, hue_shift=0.12) + COLOR_DICT = OrderedDict( { "Left Anterior Thalamic": tableau_20[0], @@ -75,8 +105,14 @@ "F_L": tableau_20[12], "Right Inferior Longitudinal": tableau_20[13], "F_R": tableau_20[13], - "Left Superior Longitudinal": tableau_20[14], - "Right Superior Longitudinal": tableau_20[15], + "Left Superior Longitudinal": slf_l_base, + "Right Superior Longitudinal": slf_r_base, + "Left Superior Longitudinal I": slf_l_shades[0], + "Left Superior Longitudinal II": slf_l_shades[1], + "Left Superior Longitudinal III": slf_l_shades[2], + "Right Superior Longitudinal I": slf_r_shades[0], + "Right Superior Longitudinal II": slf_r_shades[1], + "Right Superior Longitudinal III": slf_r_shades[2], "Left Uncinate": tableau_20[16], "UF_L": tableau_20[16], "Right Uncinate": tableau_20[17], @@ -85,10 +121,20 @@ "AF_L": tableau_20[18], "Right Arcuate": tableau_20[19], "AF_R": tableau_20[19], - "Left Posterior Arcuate": tableau_20[6], - "Right Posterior Arcuate": tableau_20[7], - "Left Vertical Occipital": tableau_extension[0], - "Right Vertical Occipital": tableau_extension[1], + "Left Posterior Arcuate": tableau_20[14], + "Right Posterior Arcuate": tableau_20[15], + "Left Vertical Occipital": vof_l_base, + "Right Vertical Occipital": vof_r_base, + "Left Vertical Occipital I": vof_l_shades[0], + "Left Vertical Occipital II": vof_l_shades[1], + "Left Vertical Occipital III": vof_l_shades[2], + "Left Vertical Occipital IV": vof_l_shades[3], + "Left Vertical Occipital V": vof_l_shades[4], + "Right Vertical Occipital I": vof_r_shades[0], + "Right Vertical Occipital II": vof_r_shades[1], + "Right Vertical Occipital III": vof_r_shades[2], + "Right Vertical Occipital IV": vof_r_shades[3], + "Right Vertical Occipital V": vof_r_shades[4], "median": tableau_20[6], # Paul Tol's palette for callosal bundles "Callosum Orbital": (0.2, 0.13, 0.53), @@ -510,7 +556,7 @@ def tract_generator( else: if bundle is None: # No selection: visualize all of them: - for bundle_name in seg_sft.bundle_names: + for bundle_name in sorted(seg_sft.bundle_names): idx = seg_sft.bundle_idxs[bundle_name] if len(idx) == 0: continue From 95feb984f640656027597e29004aaa3406f70f47 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 18 Feb 2026 21:34:36 +0900 Subject: [PATCH 51/86] VOF refinements --- AFQ/api/bundle_dict.py | 159 ++++++++++++++++++++----------- AFQ/data/fetch.py | 6 ++ AFQ/recognition/cleaning.py | 33 +++++-- AFQ/recognition/criteria.py | 18 +++- AFQ/tractography/tractography.py | 4 +- AFQ/viz/utils.py | 8 +- 6 files changed, 148 insertions(+), 80 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 51f4a0c6..ae7b8d59 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -268,7 +268,7 @@ def default_bd(): "Left Arcuate": {"overlap": 30}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "primary_axis_percentage": 50, }, "Right Posterior Arcuate": { "cross_midline": False, @@ -285,68 +285,77 @@ def default_bd(): "Right Arcuate": {"overlap": 30}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "primary_axis_percentage": 50, }, "Left Vertical Occipital": { "cross_midline": False, "space": "template", "end": templates["VOF_L_end"], + "exclude": [ + templates["Cerebellar_Hemi_L"], + ], "Left Arcuate": {"node_thresh": 20, "project": "L/R"}, "Left Posterior Arcuate": { "node_thresh": 20, "project": "L/R", "entire_core": "Anterior", }, - "length": {"min_len": 25, "max_len": 60}, + "length": {"min_len": 30, "max_len": 70}, "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "primary_axis_percentage": 60, "ORG_spectral_subbundles": SpectralSubbundleDict( { "Left Vertical Occipital I": { - "cluster_ID": 89, - "isolation_forest": {}, + "cluster_ID": 82, "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, + }, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, }, }, "Left Vertical Occipital II": { - "cluster_ID": 82, - "isolation_forest": {}, + "cluster_ID": 75, "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, }, - }, - "Left Vertical Occipital III": { - "cluster_ID": 83, - "isolation_forest": {}, - "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, }, }, - "Left Vertical Occipital IV": { + "Left Vertical Occipital III": { "cluster_ID": 21, - "isolation_forest": {}, "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, }, - }, - "Left Vertical Occipital V": { - "cluster_ID": 454, - "isolation_forest": {}, - "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, }, }, }, remove_cluster_IDs=[ + 89, + 93, 27, 100, + 102, + 454, + 27, + 555, + 118, 4, 6, 13, @@ -371,62 +380,71 @@ def default_bd(): "cross_midline": False, "space": "template", "end": templates["VOF_R_end"], + "exclude": [ + templates["Cerebellar_Hemi_R"], + ], "Right Arcuate": {"node_thresh": 20, "project": "L/R"}, "Right Posterior Arcuate": { "node_thresh": 20, "project": "L/R", "entire_core": "Anterior", }, - "length": {"min_len": 25, "max_len": 60}, + "length": {"min_len": 30, "max_len": 70}, "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "primary_axis_percentage": 60, "ORG_spectral_subbundles": SpectralSubbundleDict( { "Right Vertical Occipital I": { - "cluster_ID": 89, - "isolation_forest": {}, + "cluster_ID": 82, "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, + }, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, }, }, "Right Vertical Occipital II": { - "cluster_ID": 82, - "isolation_forest": {}, + "cluster_ID": 75, "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, }, - }, - "Right Vertical Occipital III": { - "cluster_ID": 83, - "isolation_forest": {}, - "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, }, }, - "Right Vertical Occipital IV": { + "Right Vertical Occipital III": { "cluster_ID": 21, - "isolation_forest": {}, "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, }, - }, - "Right Vertical Occipital V": { - "cluster_ID": 454, - "isolation_forest": {}, - "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, }, }, }, remove_cluster_IDs=[ + 89, + 93, 27, 100, + 102, + 454, + 27, + 555, + 118, 4, 6, 13, @@ -1310,6 +1328,31 @@ def _cond_load(self, roi_or_sl, resample_to): def get_b_info(self, b_name): return self._dict[b_name] + def relax_cleaning(self, delta_distance=1, delta_length=1): + """ + This can be useful for PTT + """ + cleaner_keys = ["mahal", "isolation_forest", "orient_mahal"] + + for b_name in self.bundle_names: + bundle_data = self._dict[b_name] + + for key in cleaner_keys: + if key in bundle_data: + target = bundle_data[key] + if ( + "distance_threshold" in target + and target["distance_threshold"] != 0 + ): + target["distance_threshold"] += delta_distance + if "length_threshold" in target and target["length_threshold"] != 0: + target["length_threshold"] += delta_length + + if "ORG_spectral_subbundles" in bundle_data: + bundle_data["ORG_spectral_subbundles"].relax_cleaning( + delta_distance, delta_length + ) + def __getitem__(self, key): if isinstance(key, tuple) or isinstance(key, list): # Generates a copy of this BundleDict with only the bundle names diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index b40f34ad..9cb70d23 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -761,6 +761,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "ATR_L_start.nii.gz", "pARC_xroi1_L.nii.gz", "pARC_xroi1_R.nii.gz", + "Cerebellar_Hemi_L.nii.gz", + "Cerebellar_Hemi_R.nii.gz", ] @@ -865,6 +867,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "40944080", "61737616", "61737619", + "61970155", + "61970158", ] @@ -970,6 +974,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "1c0b570bb2d622718b01ee2c429a5d15", "51c8a6b5fbb0834b03986093b9ee4fa3", "7cf5800a4efa6bac7e70d84095bc259b", + "f65b3f9133820921d023517a68d4ea41", + "4476935f5aadfcdd633b9a23779625ef", ] fetch_templates = _make_reusable_fetcher( diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index 59d3687f..eaaf0cf2 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -50,7 +50,7 @@ def clean_by_orientation(streamlines, primary_axis, affine, tol=None): along_accepted_idx = orientation_along == primary_axis if tol is not None: percentage_primary = ( - 100 * axis_diff[:, primary_axis] / np.sum(axis_diff, axis=1) + 100 * endpoint_diff[:, primary_axis] / np.sum(endpoint_diff, axis=1) ) logger.debug( (f"Maximum primary percentage found: {np.max(percentage_primary)}") @@ -73,33 +73,42 @@ def clean_by_orientation_mahalanobis( core_only=0.6, min_sl=20, distance_threshold=3, + length_threshold=4, clean_rounds=5, ): + if length_threshold == 0: + length_threshold = np.inf fgarray = np.array(abu.resample_tg(streamlines, n_points)) if core_only != 0: crop_edge = (1.0 - core_only) / 2 fgarray = fgarray[ :, int(n_points * crop_edge) : int(n_points * (1 - crop_edge)), : - ] # Crop to middle 60% + ] fgarray_dists = fgarray[:, 1:, :] - fgarray[:, :-1, :] + lengths = np.array([sl.shape[0] for sl in streamlines]) idx = np.arange(len(fgarray)) rounds_elapsed = 0 while rounds_elapsed < clean_rounds: - # This calculates the Mahalanobis for each streamline/node: m_dist = gaussian_weights( fgarray_dists, return_mahalnobis=True, n_points=None, stat=np.mean ) + length_z = zscore(lengths) + logger.debug(f"Shape of fgarray: {np.asarray(fgarray_dists).shape}") logger.debug((f"Maximum m_dist for each fiber: {np.max(m_dist, axis=1)}")) - if not (np.any(m_dist >= distance_threshold)): + if not ( + np.any(m_dist >= distance_threshold) or np.any(length_z >= length_threshold) + ): break + idx_dist = np.all(m_dist < distance_threshold, axis=-1) + idx_len = length_z < length_threshold + idx_belong = np.logical_and(idx_dist, idx_len) - if np.sum(idx_dist) < min_sl: - # need to sort and return exactly min_sl: + if np.sum(idx_belong) < min_sl: idx = idx[np.argsort(np.sum(m_dist, axis=-1))[:min_sl].astype(int)] idx = np.sort(idx) logger.debug( @@ -107,9 +116,9 @@ def clean_by_orientation_mahalanobis( ) break else: - # Update by selection: - idx = idx[idx_dist] - fgarray_dists = fgarray_dists[idx_dist] + idx = idx[idx_belong] + fgarray_dists = fgarray_dists[idx_belong] + lengths = lengths[idx_belong] rounds_elapsed += 1 logger.debug((f"Rounds elapsed: {rounds_elapsed}, num kept: {len(idx)}")) logger.debug(f"Kept indices: {idx}") @@ -190,6 +199,9 @@ def clean_bundle( else: return tg + if length_threshold == 0: + length_threshold = np.inf + # Resample once up-front: fgarray = np.asarray(abu.resample_tg(streamlines, n_points)) if core_only != 0: @@ -320,6 +332,9 @@ def clean_by_isolation_forest( ) return np.ones(len(streamlines), dtype=bool) + if length_threshold == 0: + length_threshold = np.inf + # Resample once up-front: fgarray = np.asarray(abu.resample_tg(streamlines, n_points)) fgarray_dists = np.zeros_like(fgarray) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 8f227e85..89bffa7f 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -1,10 +1,10 @@ import logging from time import time -import dipy.tracking.streamline as dts import nibabel as nib import numpy as np import ray +from dipy.core.interpolation import interpolate_scalar_3d from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.io.streamline import load_tractogram from dipy.segment.bundles import RecoBundles @@ -56,10 +56,9 @@ def prob_map(b_sls, bundle_def, preproc_imap, prob_threshold, **kwargs): b_sls.initiate_selection("Prob. Map") - # using entire fgarray here only because it is the first step - fiber_probabilities = dts.values_from_volume( - bundle_def["prob_map"].get_fdata(), preproc_imap["fgarray"], np.eye(4) - ) + fiber_probabilities = interpolate_scalar_3d( + bundle_def["prob_map"].get_fdata(), preproc_imap["fgarray"].reshape(-1, 3) + )[0].reshape(-1, 20) fiber_probabilities = np.mean(fiber_probabilities, -1) b_sls.select(fiber_probabilities > prob_threshold, "Prob. Map") @@ -154,6 +153,15 @@ def include(b_sls, bundle_def, preproc_imap, n_cpus, **kwargs): else: include_roi_tols = [preproc_imap["tol"] ** 2] * len(bundle_def["include"]) + # For now I am turning ray parallelization here off. + # It is never worthwhile considering other changes we + # have made to speed up this step, + # so spinning up ray and transferring data back + # and forth is not worth it. + # In the future, I think we should redo this with numba and + # use multithreading + n_cpus = 1 + # with parallel segmentation, the first for loop will # only collect streamlines and does not need tqdm if n_cpus > 1: diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 1624986f..de99c63d 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -30,7 +30,7 @@ def track( seed_mask=None, seed_threshold=0.5, thresholds_as_percentages=False, - n_seeds=2000000, + n_seeds=5000000, random_seeds=True, rng_seed=None, step_size=0.5, @@ -75,7 +75,7 @@ def track( voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D array, these are the coordinates of the seeds. Unless random_seeds is set to True, in which case this is the total number of random seeds - to generate within the mask. Default: 2000000 + to generate within the mask. Default: 5000000 random_seeds : bool Whether to generate a total of n_seeds random seeds in the mask. Default: True diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index 7b4bd2d6..e3d2c6c9 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -78,8 +78,8 @@ def get_distinct_shades(base_rgb, n_steps, hue_shift): slf_l_shades = get_distinct_shades(slf_l_base, 3, hue_shift=0.1) slf_r_shades = get_distinct_shades(slf_r_base, 3, hue_shift=0.1) -vof_l_shades = get_distinct_shades(vof_l_base, 5, hue_shift=0.12) -vof_r_shades = get_distinct_shades(vof_r_base, 5, hue_shift=0.12) +vof_l_shades = get_distinct_shades(vof_l_base, 3, hue_shift=0.15) +vof_r_shades = get_distinct_shades(vof_r_base, 3, hue_shift=0.15) COLOR_DICT = OrderedDict( { @@ -128,13 +128,9 @@ def get_distinct_shades(base_rgb, n_steps, hue_shift): "Left Vertical Occipital I": vof_l_shades[0], "Left Vertical Occipital II": vof_l_shades[1], "Left Vertical Occipital III": vof_l_shades[2], - "Left Vertical Occipital IV": vof_l_shades[3], - "Left Vertical Occipital V": vof_l_shades[4], "Right Vertical Occipital I": vof_r_shades[0], "Right Vertical Occipital II": vof_r_shades[1], "Right Vertical Occipital III": vof_r_shades[2], - "Right Vertical Occipital IV": vof_r_shades[3], - "Right Vertical Occipital V": vof_r_shades[4], "median": tableau_20[6], # Paul Tol's palette for callosal bundles "Callosum Orbital": (0.2, 0.13, 0.53), From a110d04de52b0508845863fe45c8adca7e58ab99 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 19 Feb 2026 11:45:53 +0900 Subject: [PATCH 52/86] push model centroids closer together --- AFQ/api/bundle_dict.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index ae7b8d59..093fd996 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -298,7 +298,6 @@ def default_bd(): "Left Posterior Arcuate": { "node_thresh": 20, "project": "L/R", - "entire_core": "Anterior", }, "length": {"min_len": 30, "max_len": 70}, "mahal": {"clean_rounds": 0}, @@ -307,7 +306,7 @@ def default_bd(): "ORG_spectral_subbundles": SpectralSubbundleDict( { "Left Vertical Occipital I": { - "cluster_ID": 82, + "cluster_ID": 61, "orient_mahal": { "distance_threshold": 2, "length_threshold": 5, @@ -333,7 +332,7 @@ def default_bd(): }, }, "Left Vertical Occipital III": { - "cluster_ID": 21, + "cluster_ID": 25, "orient_mahal": { "distance_threshold": 2, "length_threshold": 5, @@ -387,7 +386,6 @@ def default_bd(): "Right Posterior Arcuate": { "node_thresh": 20, "project": "L/R", - "entire_core": "Anterior", }, "length": {"min_len": 30, "max_len": 70}, "mahal": {"clean_rounds": 0}, @@ -396,7 +394,7 @@ def default_bd(): "ORG_spectral_subbundles": SpectralSubbundleDict( { "Right Vertical Occipital I": { - "cluster_ID": 82, + "cluster_ID": 61, "orient_mahal": { "distance_threshold": 2, "length_threshold": 5, @@ -422,7 +420,7 @@ def default_bd(): }, }, "Right Vertical Occipital III": { - "cluster_ID": 21, + "cluster_ID": 25, "orient_mahal": { "distance_threshold": 2, "length_threshold": 5, From e21f9207023141ff39c75febb33e2b77bdeb0424 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 19 Feb 2026 14:19:01 +0900 Subject: [PATCH 53/86] make the VOF 3 relative to IFOF again --- AFQ/api/bundle_dict.py | 2 ++ AFQ/recognition/utils.py | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 093fd996..3a90dbe7 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -332,6 +332,7 @@ def default_bd(): }, }, "Left Vertical Occipital III": { + "Left Inferior Fronto-occipital": {"core": "Right"}, "cluster_ID": 25, "orient_mahal": { "distance_threshold": 2, @@ -420,6 +421,7 @@ def default_bd(): }, }, "Right Vertical Occipital III": { + "Right Inferior Fronto-occipital": {"core": "Left"}, "cluster_ID": 25, "orient_mahal": { "distance_threshold": 2, diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 678c4eed..4acd2a5e 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -151,7 +151,11 @@ def select(self, idx, clean_name, cut=False): f"After filtering by {clean_name} (time: {time_taken}s), " f"{len(self)} streamlines remain." ) - if self.save_intermediates is not None: + + # Only save intermediates after the 90% of the + # streamlines have been filtered out, + # otherwise its impractical + if self.save_intermediates is not None and len(self) < 0.1 * len(self.ref_sls): save_tractogram( StatefulTractogram(self.get_selected_sls(cut=cut), self.ref, Space.VOX), op.join( From aa47b55825a45e20f42ddfea07c11d889bc2294d Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 19 Feb 2026 15:05:39 +0900 Subject: [PATCH 54/86] binary dilation deprecated --- AFQ/nn/brainchop.py | 6 ++++-- AFQ/utils/volume.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/AFQ/nn/brainchop.py b/AFQ/nn/brainchop.py index ea2a6e15..1cb1b8f0 100644 --- a/AFQ/nn/brainchop.py +++ b/AFQ/nn/brainchop.py @@ -4,7 +4,8 @@ import nibabel as nib import nibabel.processing as nbp import numpy as np -from scipy.ndimage import binary_dilation, gaussian_filter +from scipy.ndimage import gaussian_filter +from skimage.morphology import dilation from skimage.segmentation import find_boundaries from AFQ.data.fetch import afq_home, fetch_brainchop_models @@ -66,7 +67,8 @@ def run_brainchop(ort, t1_img, model_name): # Mindgrab can be tight sometimes, # better to include a bit more, # than to miss some - output = binary_dilation(output, iterations=2) + for _ in range(2): + output = dilation(output) output_img = nbp.resample_from_to( nib.Nifti1Image(output.astype(np.uint8), t1_img_conformed.affine), t1_img diff --git a/AFQ/utils/volume.py b/AFQ/utils/volume.py index a8269811..572ce1f1 100644 --- a/AFQ/utils/volume.py +++ b/AFQ/utils/volume.py @@ -7,7 +7,7 @@ from dipy.io.utils import create_nifti_header, get_reference_info from dipy.tracking.streamline import select_random_set_of_streamlines from scipy.spatial.distance import dice -from skimage.morphology import binary_dilation +from skimage.morphology import dilation logger = logging.getLogger("AFQ") @@ -46,13 +46,13 @@ def transform_roi(roi, mapping, bundle_name="ROI"): np.asarray(mapping.codomain_shape) / np.asarray(mapping.domain_shape) ) for _ in range(max(np.ceil(scale_factor) - 1, 0).astype(int)): - roi = binary_dilation(roi) + roi = dilation(roi) _roi = mapping.transform((roi.astype(float)), interpolation="linear") if np.sum(_roi) == 0: - logger.warning(f"Lost ROI {bundle_name}, performing automatic binary dilation") - _roi = binary_dilation(roi) + logger.warning(f"Lost ROI {bundle_name}, performing automatic dilation") + _roi = dilation(roi) _roi = mapping.transform(_roi.astype(float), interpolation="linear") _roi = patch_up_roi(_roi > 0, bundle_name=bundle_name).astype(np.int32) From 05a51eb734dc6beb1fe63a2efe6b7e90e5da21a9 Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 20 Feb 2026 18:03:47 +0900 Subject: [PATCH 55/86] further modifications to VOF subbundling, formalization --- AFQ/api/bundle_dict.py | 55 ++++++++++++++++++------------------- AFQ/data/fetch.py | 24 ++++++++-------- AFQ/recognition/criteria.py | 26 ++++++++++-------- 3 files changed, 53 insertions(+), 52 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 3a90dbe7..a1d5d528 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -185,7 +185,7 @@ def default_bd(): "prob_map": templates["IFO_L_prob_map"], "end": templates["IFO_L_start"], "start": templates["IFO_L_end"], - "length": {"min_len": 100, "max_len": 250}, + "length": {"min_len": 100}, }, "Right Inferior Fronto-occipital": { "cross_midline": False, @@ -195,7 +195,7 @@ def default_bd(): "prob_map": templates["IFO_R_prob_map"], "end": templates["IFO_R_start"], "start": templates["IFO_R_end"], - "length": {"min_len": 100, "max_len": 250}, + "length": {"min_len": 100}, }, "Left Inferior Longitudinal": { "cross_midline": False, @@ -223,7 +223,7 @@ def default_bd(): "prob_map": templates["ARC_L_prob_map"], "start": templates["ARC_L_start"], "end": templates["ARC_L_end"], - "length": {"min_len": 50, "max_len": 250}, + "length": {"min_len": 50}, }, "Right Arcuate": { "cross_midline": False, @@ -233,7 +233,7 @@ def default_bd(): "prob_map": templates["ARC_R_prob_map"], "start": templates["ARC_R_start"], "end": templates["ARC_R_end"], - "length": {"min_len": 50, "max_len": 250}, + "length": {"min_len": 50}, }, "Left Uncinate": { "cross_midline": False, @@ -266,7 +266,7 @@ def default_bd(): "start": templates["pARC_L_start"], "end": templates["VOF_L_end"], "Left Arcuate": {"overlap": 30}, - "length": {"min_len": 30, "max_len": 120}, + "length": {"min_len": 30}, "primary_axis": "I/S", "primary_axis_percentage": 50, }, @@ -283,13 +283,14 @@ def default_bd(): "start": templates["pARC_R_start"], "end": templates["VOF_R_end"], "Right Arcuate": {"overlap": 30}, - "length": {"min_len": 30, "max_len": 120}, + "length": {"min_len": 30}, "primary_axis": "I/S", "primary_axis_percentage": 50, }, "Left Vertical Occipital": { "cross_midline": False, "space": "template", + "prob_map": templates["VOF_L_prob_map"], "end": templates["VOF_L_end"], "exclude": [ templates["Cerebellar_Hemi_L"], @@ -298,15 +299,16 @@ def default_bd(): "Left Posterior Arcuate": { "node_thresh": 20, "project": "L/R", + "core": "Anterior", }, - "length": {"min_len": 30, "max_len": 70}, + "length": {"min_len": 30}, "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", "primary_axis_percentage": 60, "ORG_spectral_subbundles": SpectralSubbundleDict( { "Left Vertical Occipital I": { - "cluster_ID": 61, + "cluster_IDs": [74, 92], "orient_mahal": { "distance_threshold": 2, "length_threshold": 5, @@ -319,21 +321,21 @@ def default_bd(): }, }, "Left Vertical Occipital II": { - "cluster_ID": 75, + "cluster_IDs": [1, 72, 81], "orient_mahal": { - "distance_threshold": 2, + "distance_threshold": 3, "length_threshold": 5, "clean_rounds": 1, }, "mahal": { - "distance_threshold": 3, + "distance_threshold": 4, "length_threshold": 0, "clean_rounds": 5, }, }, "Left Vertical Occipital III": { "Left Inferior Fronto-occipital": {"core": "Right"}, - "cluster_ID": 25, + "cluster_IDs": [2, 7, 18, 25], "orient_mahal": { "distance_threshold": 2, "length_threshold": 5, @@ -349,7 +351,6 @@ def default_bd(): remove_cluster_IDs=[ 89, 93, - 27, 100, 102, 454, @@ -379,6 +380,7 @@ def default_bd(): "Right Vertical Occipital": { "cross_midline": False, "space": "template", + "prob_map": templates["VOF_R_prob_map"], "end": templates["VOF_R_end"], "exclude": [ templates["Cerebellar_Hemi_R"], @@ -387,15 +389,16 @@ def default_bd(): "Right Posterior Arcuate": { "node_thresh": 20, "project": "L/R", + "core": "Anterior", }, - "length": {"min_len": 30, "max_len": 70}, + "length": {"min_len": 30}, "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", "primary_axis_percentage": 60, "ORG_spectral_subbundles": SpectralSubbundleDict( { "Right Vertical Occipital I": { - "cluster_ID": 61, + "cluster_IDs": [74, 92], "orient_mahal": { "distance_threshold": 2, "length_threshold": 5, @@ -408,21 +411,21 @@ def default_bd(): }, }, "Right Vertical Occipital II": { - "cluster_ID": 75, + "cluster_IDs": [1, 72, 81], "orient_mahal": { - "distance_threshold": 2, + "distance_threshold": 3, "length_threshold": 5, "clean_rounds": 1, }, "mahal": { - "distance_threshold": 3, + "distance_threshold": 4, "length_threshold": 0, "clean_rounds": 5, }, }, "Right Vertical Occipital III": { "Right Inferior Fronto-occipital": {"core": "Left"}, - "cluster_ID": 25, + "cluster_IDs": [2, 7, 18, 25], "orient_mahal": { "distance_threshold": 2, "length_threshold": 5, @@ -438,7 +441,6 @@ def default_bd(): remove_cluster_IDs=[ 89, 93, - 27, 100, 102, 454, @@ -1673,22 +1675,17 @@ def __init__( remove_cluster_IDs = [] self.remove_cluster_IDs = remove_cluster_IDs self.cluster_IDs = [] - self.id_to_name = {} for b_name, b_info in bundle_info.items(): - if "cluster_ID" not in b_info: + if "cluster_IDs" not in b_info: raise ValueError( ( - f"Bundle {b_name} does not have a cluster_ID. " - "All bundles in a SpectralSubbundleDict must have a cluster_ID." + f"Bundle {b_name} does not have cluster_IDs. " + "All bundles in a SpectralSubbundleDict must have cluster_IDs." ) ) - self.cluster_IDs.append(b_info["cluster_ID"]) - self.id_to_name[b_info["cluster_ID"]] = b_name + self.cluster_IDs.extend(b_info["cluster_IDs"]) self.all_cluster_IDs = self.remove_cluster_IDs + self.cluster_IDs - def get_subbundle_name(self, cluster_id): - return self.id_to_name.get(cluster_id, None) - def apply_to_roi_dict( dict_, diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index 9cb70d23..61c89d71 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -719,10 +719,6 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "UNC_R_prob_map.nii.gz", "ARC_L_prob_map.nii.gz", "ARC_R_prob_map.nii.gz", - "VOF_R_end.nii.gz", - "VOF_R_start.nii.gz", - "VOF_L_end.nii.gz", - "VOF_L_start.nii.gz", "pARC_R_start.nii.gz", "pARC_L_start.nii.gz", "ARC_R_end.nii.gz", @@ -763,6 +759,10 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "pARC_xroi1_R.nii.gz", "Cerebellar_Hemi_L.nii.gz", "Cerebellar_Hemi_R.nii.gz", + "VOF_L_end.nii.gz", + "VOF_R_end.nii.gz", + "VOF_L_prob_map.nii.gz", + "VOF_R_prob_map.nii.gz", ] @@ -825,10 +825,6 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "11458229", "11458232", "11458235", - "40943957", - "40943960", - "40943966", - "40943969", "40943972", "40943975", "40943978", @@ -869,6 +865,10 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "61737619", "61970155", "61970158", + "62031448", + "62031439", + "62031442", + "62031445", ] @@ -932,10 +932,6 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "19590c712f1776da1fdba64d4eb7f1f6", "04d5af0feb2c1b5b52a87ccbbf148e4b", "53c277be990d00f7de04f2ea35e74d73", - "d37d815fd1bdaaf3a9d2dcfc3ccb1345", - "95ed3189d8ac152945e6be1eb24381a3", - "a9007e6f2d6ae13ef182f65057c06573", - "c6eb9ee33b7caf691749e266f89e8ec4", "a06b2e2e52c09a601f683dc39859a7f1", "bee876a34fdb03e69a418b791f90975a", "680749c9e4565bc02492019d57d8e7d7", @@ -976,6 +972,10 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "7cf5800a4efa6bac7e70d84095bc259b", "f65b3f9133820921d023517a68d4ea41", "4476935f5aadfcdd633b9a23779625ef", + "318ea89a04caf8d6f6afa34c8d173142", + "27fe6a73aec3a0d90dae07327c93393e", + "db5bd2d1e810e366f5ef67a9cce205c2", + "6891cfc038ce7db21e0cc307ae2b1b37", ] fetch_templates = _make_reusable_fetcher( diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 89bffa7f..1511a3bd 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -47,7 +47,7 @@ "inc_addtol", "exc_addtol", "ORG_spectral_subbundles", - "cluster_ID", + "cluster_IDs", ] @@ -533,9 +533,11 @@ def check_space(roi): if b_sls: if "ORG_spectral_subbundles" in bundle_def: subdict = bundle_def["ORG_spectral_subbundles"] - c_ids = subdict.cluster_IDs b_sls.initiate_selection( - (f"ORG spectral clustering, {len(c_ids)} subbundles being recognized") + ( + f"ORG spectral clustering, {len(subdict.bundle_names)} " + "subbundles being recognized" + ) ) sub_sft = StatefulTractogram( @@ -545,15 +547,17 @@ def check_space(roi): sub_sft, mapping, img, subdict.all_cluster_IDs, n_points=40 ) clusters_being_recognized = [] - for c_id in c_ids: - bundle_name = subdict.get_subbundle_name(c_id) - n_roi = len(subdict[bundle_name].get("include", [])) - cluster_b_sls = b_sls.copy(bundle_name, n_roi) - cluster_b_sls.select(cluster_labels == c_id, f"Cluster {c_id}") + for sub_b_name in subdict.bundle_names: + c_ids = subdict._dict[sub_b_name]["cluster_IDs"] + n_roi = len(subdict._dict[sub_b_name].get("include", [])) + cluster_b_sls = b_sls.copy(sub_b_name, n_roi) + selected = np.zeros(len(b_sls), dtype=bool) + for c_id in c_ids: + selected = np.logical_or(selected, cluster_labels == c_id) + cluster_b_sls.select(selected, f"Clusters {c_ids}") clusters_being_recognized.append(cluster_b_sls) - for ii, c_id in enumerate(c_ids): - bundle_name = subdict.get_subbundle_name(c_id) + for ii, sub_b_name in enumerate(subdict.bundle_names): run_bundle_rec_plan( bundle_def["ORG_spectral_subbundles"], clusters_being_recognized[ii], @@ -561,7 +565,7 @@ def check_space(roi): img, reg_template, preproc_imap, - bundle_name, + sub_b_name, recognized_bundles_dict, **segmentation_params, ) From fcf5d59a2658876847094188d49bcb29c505dcc9 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 23 Feb 2026 16:21:57 +0900 Subject: [PATCH 56/86] Refine VOF definition --- AFQ/api/bundle_dict.py | 116 ++++++++++++++++++++++-------------- AFQ/recognition/cleaning.py | 2 +- 2 files changed, 73 insertions(+), 45 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index a1d5d528..34b4f5bc 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -266,9 +266,9 @@ def default_bd(): "start": templates["pARC_L_start"], "end": templates["VOF_L_end"], "Left Arcuate": {"overlap": 30}, + "Left Inferior Fronto-occipital": {"core": "Right"}, "length": {"min_len": 30}, "primary_axis": "I/S", - "primary_axis_percentage": 50, }, "Right Posterior Arcuate": { "cross_midline": False, @@ -283,9 +283,9 @@ def default_bd(): "start": templates["pARC_R_start"], "end": templates["VOF_R_end"], "Right Arcuate": {"overlap": 30}, + "Right Inferior Fronto-occipital": {"core": "Left"}, "length": {"min_len": 30}, "primary_axis": "I/S", - "primary_axis_percentage": 50, }, "Left Vertical Occipital": { "cross_midline": False, @@ -304,76 +304,90 @@ def default_bd(): "length": {"min_len": 30}, "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", - "primary_axis_percentage": 60, "ORG_spectral_subbundles": SpectralSubbundleDict( { "Left Vertical Occipital I": { - "cluster_IDs": [74, 92], + "cluster_IDs": [63], "orient_mahal": { "distance_threshold": 2, - "length_threshold": 5, + "length_threshold": 0, "clean_rounds": 1, }, "mahal": { - "distance_threshold": 3, - "length_threshold": 0, + "distance_threshold": 4, + "length_threshold": 4, "clean_rounds": 5, }, }, "Left Vertical Occipital II": { - "cluster_IDs": [1, 72, 81], + "cluster_IDs": [1, 81], "orient_mahal": { - "distance_threshold": 3, - "length_threshold": 5, + "distance_threshold": 2, + "length_threshold": 0, "clean_rounds": 1, }, "mahal": { "distance_threshold": 4, - "length_threshold": 0, + "length_threshold": 4, "clean_rounds": 5, }, }, "Left Vertical Occipital III": { "Left Inferior Fronto-occipital": {"core": "Right"}, - "cluster_IDs": [2, 7, 18, 25], + "cluster_IDs": [2, 7, 18], + "exclude": [templates["pARC_xroi1_L"]], "orient_mahal": { "distance_threshold": 2, - "length_threshold": 5, + "length_threshold": 0, "clean_rounds": 1, }, "mahal": { - "distance_threshold": 3, - "length_threshold": 0, + "distance_threshold": 4, + "length_threshold": 4, "clean_rounds": 5, }, }, }, remove_cluster_IDs=[ - 89, - 93, - 100, - 102, - 454, - 27, - 555, - 118, 4, 6, + 9, + 10, 13, 17, 22, 23, + 27, + 28, + 29, + 30, + 34, + 35, 38, 48, 50, 53, + 54, 64, 65, 66, 84, 87, 88, + 89, + 92, + 93, + 94, + 97, 98, + 100, + 102, + 118, + 129, + 422, + 439, + 454, + 555, ], ), }, @@ -394,76 +408,90 @@ def default_bd(): "length": {"min_len": 30}, "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", - "primary_axis_percentage": 60, "ORG_spectral_subbundles": SpectralSubbundleDict( { "Right Vertical Occipital I": { - "cluster_IDs": [74, 92], + "cluster_IDs": [63], "orient_mahal": { "distance_threshold": 2, - "length_threshold": 5, + "length_threshold": 0, "clean_rounds": 1, }, "mahal": { - "distance_threshold": 3, - "length_threshold": 0, + "distance_threshold": 4, + "length_threshold": 4, "clean_rounds": 5, }, }, "Right Vertical Occipital II": { - "cluster_IDs": [1, 72, 81], + "cluster_IDs": [1, 81], "orient_mahal": { - "distance_threshold": 3, - "length_threshold": 5, + "distance_threshold": 2, + "length_threshold": 0, "clean_rounds": 1, }, "mahal": { "distance_threshold": 4, - "length_threshold": 0, + "length_threshold": 4, "clean_rounds": 5, }, }, "Right Vertical Occipital III": { "Right Inferior Fronto-occipital": {"core": "Left"}, - "cluster_IDs": [2, 7, 18, 25], + "cluster_IDs": [2, 7, 18], + "exclude": [templates["pARC_xroi1_R"]], "orient_mahal": { "distance_threshold": 2, - "length_threshold": 5, + "length_threshold": 0, "clean_rounds": 1, }, "mahal": { - "distance_threshold": 3, - "length_threshold": 0, + "distance_threshold": 4, + "length_threshold": 4, "clean_rounds": 5, }, }, }, remove_cluster_IDs=[ - 89, - 93, - 100, - 102, - 454, - 27, - 555, - 118, 4, 6, + 9, + 10, 13, 17, 22, 23, + 27, + 28, + 29, + 30, + 34, + 35, 38, 48, 50, 53, + 54, 64, 65, 66, 84, 87, 88, + 89, + 92, + 93, + 94, + 97, 98, + 100, + 102, + 118, + 129, + 422, + 439, + 454, + 555, ], ), }, diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index eaaf0cf2..fded1040 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -50,7 +50,7 @@ def clean_by_orientation(streamlines, primary_axis, affine, tol=None): along_accepted_idx = orientation_along == primary_axis if tol is not None: percentage_primary = ( - 100 * endpoint_diff[:, primary_axis] / np.sum(endpoint_diff, axis=1) + 100 * axis_diff[:, primary_axis] / np.sum(axis_diff, axis=1) ) logger.debug( (f"Maximum primary percentage found: {np.max(percentage_primary)}") From f02cf7be771ca6d8bc5bc58a21e72d16e096a696 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 23 Feb 2026 17:08:47 +0900 Subject: [PATCH 57/86] bundle montage updates --- AFQ/api/group.py | 4 ++-- AFQ/api/participant.py | 17 +++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/AFQ/api/group.py b/AFQ/api/group.py index 57d5f08a..ea517031 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -787,7 +787,7 @@ def cmd_outputs( clobber = cmd_outputs # alias for default of cmd_outputs - def make_all_participant_montages(self, images_per_row=2): + def make_all_participant_montages(self, images_per_row=3): """ Generate montage of all bundles for a all subjects. @@ -795,7 +795,7 @@ def make_all_participant_montages(self, images_per_row=2): ---------- images_per_row : int Number of bundle images per row in output file. - Default: 2 + Default: 3 Returns ------- diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index e647f834..8ad8d89c 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -2,14 +2,15 @@ import math import os.path as op import tempfile +from math import radians from time import time import nibabel as nib -from math import radians from dipy.align import resample from PIL import Image, ImageDraw, ImageFont from tqdm import tqdm +import AFQ.utils.streamlines as aus from AFQ.api.utils import ( AFQclass_doc, check_attribute, @@ -284,9 +285,9 @@ def participant_montage(self, images_per_row=3, anatomy=True, bundle_names=None) tdir = tempfile.gettempdir() all_fnames = [] + seg_sft = aus.SegmentedSFT.fromfile(self.export("bundles")) if bundle_names is None: - bundle_dict = self.export("bundle_dict") - bundle_names = list(bundle_dict.keys()) + bundle_names = list(seg_sft.bundle_names) self.logger.info("Generating Montage...") viz_backend = self.export("viz_backend") t1 = nib.load(self.export("t1_masked")) @@ -305,7 +306,7 @@ def participant_montage(self, images_per_row=3, anatomy=True, bundle_names=None) else: figure = None figure = viz_backend.visualize_bundles( - self.export("bundles"), + seg_sft, img=t1, shade_by_volume=best_scalar.get_fdata(), color_by_direction=True, @@ -336,6 +337,10 @@ def participant_montage(self, images_per_row=3, anatomy=True, bundle_names=None) showlegend=False, ) figure.write_image(this_fname, scale=4) + # temporary fix for memory leak + import plotly.io as pio + + pio.kaleido.scope._shutdown_kaleido() else: from fury import window @@ -344,9 +349,9 @@ def participant_montage(self, images_per_row=3, anatomy=True, bundle_names=None) ) window.update_camera(show_m.screens[0].camera, None, figure) if view == "Coronal": - show_m.screens[0].controller.rotate((0, radians(-eye["y"] * 90)), None) + show_m.screens[0].controller.rotate((0, radians(-90)), None) elif view == "Axial": - show_m.screens[0].controller.rotate((radians(eye["z"] * 90), 0, 0), None) + show_m.screens[0].controller.rotate((radians(90), 0), None) elif view == "Sagittal": pass show_m.render() From 070df1dc33fbff34d26a933eea2b9bc2adeb0bc0 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 23 Feb 2026 21:26:04 +0900 Subject: [PATCH 58/86] further refined VOF --- AFQ/api/bundle_dict.py | 16 ++++++++++------ AFQ/data/fetch.py | 8 ++++---- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 34b4f5bc..bedad155 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -307,7 +307,7 @@ def default_bd(): "ORG_spectral_subbundles": SpectralSubbundleDict( { "Left Vertical Occipital I": { - "cluster_IDs": [63], + "cluster_IDs": [61, 63, 77, 82], "orient_mahal": { "distance_threshold": 2, "length_threshold": 0, @@ -320,7 +320,7 @@ def default_bd(): }, }, "Left Vertical Occipital II": { - "cluster_IDs": [1, 81], + "cluster_IDs": [1, 72, 75, 81, 83], "orient_mahal": { "distance_threshold": 2, "length_threshold": 0, @@ -334,7 +334,7 @@ def default_bd(): }, "Left Vertical Occipital III": { "Left Inferior Fronto-occipital": {"core": "Right"}, - "cluster_IDs": [2, 7, 18], + "cluster_IDs": [2, 7, 18, 21, 25, 51], "exclude": [templates["pARC_xroi1_L"]], "orient_mahal": { "distance_threshold": 2, @@ -371,6 +371,8 @@ def default_bd(): 64, 65, 66, + 74, + 78, 84, 87, 88, @@ -411,7 +413,7 @@ def default_bd(): "ORG_spectral_subbundles": SpectralSubbundleDict( { "Right Vertical Occipital I": { - "cluster_IDs": [63], + "cluster_IDs": [61, 63, 77, 82], "orient_mahal": { "distance_threshold": 2, "length_threshold": 0, @@ -424,7 +426,7 @@ def default_bd(): }, }, "Right Vertical Occipital II": { - "cluster_IDs": [1, 81], + "cluster_IDs": [1, 72, 75, 81, 83], "orient_mahal": { "distance_threshold": 2, "length_threshold": 0, @@ -438,7 +440,7 @@ def default_bd(): }, "Right Vertical Occipital III": { "Right Inferior Fronto-occipital": {"core": "Left"}, - "cluster_IDs": [2, 7, 18], + "cluster_IDs": [2, 7, 18, 21, 25, 51], "exclude": [templates["pARC_xroi1_R"]], "orient_mahal": { "distance_threshold": 2, @@ -475,6 +477,8 @@ def default_bd(): 64, 65, 66, + 74, + 78, 84, 87, 88, diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index 61c89d71..b1a2eb84 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -865,8 +865,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "61737619", "61970155", "61970158", - "62031448", - "62031439", + "62084713", + "62084716", "62031442", "62031445", ] @@ -972,8 +972,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "7cf5800a4efa6bac7e70d84095bc259b", "f65b3f9133820921d023517a68d4ea41", "4476935f5aadfcdd633b9a23779625ef", - "318ea89a04caf8d6f6afa34c8d173142", - "27fe6a73aec3a0d90dae07327c93393e", + "bac80a77df083c12a01982c0386f94be", + "dddd1923e87fb2880091615bc7f8a9a4", "db5bd2d1e810e366f5ef67a9cce205c2", "6891cfc038ce7db21e0cc307ae2b1b37", ] From 21cbb9e6fa74d7699143922c63cc6e0053c343d3 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 23 Feb 2026 22:58:25 +0900 Subject: [PATCH 59/86] more tweaks --- AFQ/api/bundle_dict.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index bedad155..b66d6b83 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -314,7 +314,7 @@ def default_bd(): "clean_rounds": 1, }, "mahal": { - "distance_threshold": 4, + "distance_threshold": 3, "length_threshold": 4, "clean_rounds": 5, }, @@ -327,7 +327,7 @@ def default_bd(): "clean_rounds": 1, }, "mahal": { - "distance_threshold": 4, + "distance_threshold": 3, "length_threshold": 4, "clean_rounds": 5, }, @@ -342,7 +342,7 @@ def default_bd(): "clean_rounds": 1, }, "mahal": { - "distance_threshold": 4, + "distance_threshold": 3, "length_threshold": 4, "clean_rounds": 5, }, @@ -373,6 +373,7 @@ def default_bd(): 66, 74, 78, + 80, 84, 87, 88, @@ -380,6 +381,7 @@ def default_bd(): 92, 93, 94, + 95, 97, 98, 100, @@ -420,7 +422,7 @@ def default_bd(): "clean_rounds": 1, }, "mahal": { - "distance_threshold": 4, + "distance_threshold": 3, "length_threshold": 4, "clean_rounds": 5, }, @@ -433,7 +435,7 @@ def default_bd(): "clean_rounds": 1, }, "mahal": { - "distance_threshold": 4, + "distance_threshold": 3, "length_threshold": 4, "clean_rounds": 5, }, @@ -448,7 +450,7 @@ def default_bd(): "clean_rounds": 1, }, "mahal": { - "distance_threshold": 4, + "distance_threshold": 3, "length_threshold": 4, "clean_rounds": 5, }, @@ -479,6 +481,7 @@ def default_bd(): 66, 74, 78, + 80, 84, 87, 88, @@ -486,6 +489,7 @@ def default_bd(): 92, 93, 94, + 95, 97, 98, 100, From ee5e0522ba389678f80ce7170347ca7dd51f8665 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 24 Feb 2026 11:25:04 +0900 Subject: [PATCH 60/86] reduce default chunk size on GPU --- AFQ/tasks/tractography.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 58e43619..7c3d418f 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -253,7 +253,7 @@ def gpu_tractography( seed, tissue_imap, tractography_ngpus=0, - chunk_size=100000, + chunk_size=25000, ): """ full path to the complete, unsegmented tractography file @@ -269,7 +269,7 @@ def gpu_tractography( Default: 0 chunk_size : int, optional Chunk size for GPU tracking. - Default: 100000 + Default: 25000 """ start_time = time() if tracking_params["directions"] == "boot": From 99c130d9514c792248de8c13bebc609794c2ec81 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 24 Feb 2026 15:45:02 +0900 Subject: [PATCH 61/86] better gif maker --- AFQ/_fixes.py | 52 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/AFQ/_fixes.py b/AFQ/_fixes.py index ab610cbf..3c9a5b6b 100644 --- a/AFQ/_fixes.py +++ b/AFQ/_fixes.py @@ -314,7 +314,7 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, stat=np.mean return weights / np.sum(weights, 0) -def make_gif(show_m, out_path, n_frames=36, az_ang=-10): +def make_gif(show_m, out_path, n_frames=36, az_ang=-10, duration=150): """ Make a video from a Fury Show Manager. @@ -334,6 +334,10 @@ def make_gif(show_m, out_path, n_frames=36, az_ang=-10): The angle to rotate the camera around the z-axis for each frame, in degrees. Default: -10 + + duration : int + The duration of each frame in the output GIF, in milliseconds. + Default: 150 """ video = [] @@ -341,15 +345,49 @@ def make_gif(show_m, out_path, n_frames=36, az_ang=-10): show_m.window.draw() with tempfile.TemporaryDirectory() as tmp_dir: - for ii in tqdm(range(n_frames), desc="Generating GIF"): + for ii in tqdm(range(n_frames), desc="Generating GIF", leave=False): frame_fname = f"{tmp_dir}/{ii}.png" show_m.screens[0].controller.rotate((radians(az_ang), 0), None) show_m.render() show_m.window.draw() show_m.snapshot(frame_fname) - video.append(frame_fname) - - video = [Image.open(frame) for frame in video] - video[0].save( - out_path, save_all=True, append_images=video[1:], duration=300, loop=1 + video.append(Image.open(frame_fname).convert("RGB")) + + all_left, all_upper = float("inf"), float("inf") + all_right, all_lower = 0, 0 + + for img in video: + arr = np.array(img) + bg_color = arr[0, 0] + + mask = np.any(arr != bg_color, axis=-1) + + if np.any(mask): + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + ymin, ymax = np.where(rows)[0][[0, -1]] + xmin, xmax = np.where(cols)[0][[0, -1]] + + all_left = min(all_left, xmin) + all_upper = min(all_upper, ymin) + all_right = max(all_right, xmax) + all_lower = max(all_lower, ymax) + + if all_left < all_right: + crop_box = ( + max(0, all_left), + max(0, all_upper), + min(video[0].width, all_right), + min(video[0].height, all_lower), + ) + cropped_video = [img.crop(crop_box) for img in video] + else: + cropped_video = video + + cropped_video[0].save( + out_path, + save_all=True, + append_images=cropped_video[1:], + duration=duration, + loop=1, ) From 5706c0ee55e9df2e91c18dd05f67b555a5bb1d7b Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 24 Feb 2026 17:31:20 +0900 Subject: [PATCH 62/86] recog in rasmm --- AFQ/recognition/cleaning.py | 8 ++--- AFQ/recognition/clustering.py | 10 ++---- AFQ/recognition/criteria.py | 43 +++++++++++------------ AFQ/recognition/other_bundles.py | 34 +++++------------- AFQ/recognition/preprocess.py | 41 ++++++++++----------- AFQ/recognition/recognize.py | 8 ++--- AFQ/recognition/roi.py | 22 +++++++----- AFQ/recognition/tests/test_recognition.py | 1 - AFQ/recognition/tests/test_utils.py | 1 - AFQ/recognition/utils.py | 17 ++++++++- AFQ/tasks/mapping.py | 2 +- AFQ/tasks/segmentation.py | 12 +++---- AFQ/tasks/tractography.py | 2 +- AFQ/tractography/gputractography.py | 2 +- 14 files changed, 95 insertions(+), 108 deletions(-) diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index fded1040..0d3fab86 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -1,7 +1,6 @@ import logging import dipy.tracking.streamline as dts -import nibabel as nib import numpy as np from dipy.io.stateful_tractogram import StatefulTractogram from scipy.stats import zscore @@ -32,11 +31,8 @@ def clean_by_orientation(streamlines, primary_axis, affine, tol=None): raise ValueError( f"Primary axis must be one of {axes_names}, got {primary_axis}" ) - orientation = nib.orientations.aff2axcodes(affine) - for idx, axis_label in enumerate(orientation): - if axis_label in primary_axis: - primary_axis = idx - break + + primary_axis = abu.axes_dict[primary_axis] axis_diff = np.zeros((len(streamlines), 3)) endpoint_diff = np.zeros((len(streamlines), 3)) diff --git a/AFQ/recognition/clustering.py b/AFQ/recognition/clustering.py index 2d33b6c3..335f68b4 100644 --- a/AFQ/recognition/clustering.py +++ b/AFQ/recognition/clustering.py @@ -152,8 +152,8 @@ def subcluster_by_atlas( Parameters ---------- - sub_fgarray : ndarray - Resampled fiber group in VOX to be labeled. + sub_trk : StatefulTractogram + streamlines to be labeled. mapping : DIPY or pyAFQ mapping Mapping to use to move streamlines. dwi_ref : Nifti1Image @@ -176,12 +176,6 @@ def subcluster_by_atlas( ) atlas_fgarray = np.array(abu.resample_tg(moved_atlas_sft.streamlines, n_points)) - # Note: if we need more efficiency, - # we could modify the code to consider: - # voxel size, midline axis, and midline location - # then we should be able to do these calculations in - # voxel space without having to move the subject streamlines - # to rasmm (but this is not a bottleneck right now) sub_trk.to_rasmm() sub_fgarray = np.array(abu.resample_tg(sub_trk.streamlines, n_points)) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 1511a3bd..80034031 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -1,10 +1,10 @@ import logging from time import time +import dipy.tracking.streamline as dts import nibabel as nib import numpy as np import ray -from dipy.core.interpolation import interpolate_scalar_3d from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.io.streamline import load_tractogram from dipy.segment.bundles import RecoBundles @@ -24,9 +24,9 @@ from AFQ.utils.streamlines import move_streamlines criteria_order_pre_other_bundles = [ - "prob_map", - "cross_midline", "length", + "cross_midline", + "prob_map", "start", "end", "primary_axis", @@ -48,17 +48,21 @@ "exc_addtol", "ORG_spectral_subbundles", "cluster_IDs", + "startpoint_location", + "endpoint_location", ] logger = logging.getLogger("AFQ") -def prob_map(b_sls, bundle_def, preproc_imap, prob_threshold, **kwargs): +def prob_map(b_sls, bundle_def, preproc_imap, prob_threshold, img, **kwargs): b_sls.initiate_selection("Prob. Map") - fiber_probabilities = interpolate_scalar_3d( - bundle_def["prob_map"].get_fdata(), preproc_imap["fgarray"].reshape(-1, 3) - )[0].reshape(-1, 20) + fiber_probabilities = dts.values_from_volume( + bundle_def["prob_map"].get_fdata(), + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], + img.affine, + ) fiber_probabilities = np.mean(fiber_probabilities, -1) b_sls.select(fiber_probabilities > prob_threshold, "Prob. Map") @@ -115,15 +119,12 @@ def end(b_sls, bundle_def, preproc_imap, **kwargs): def length(b_sls, bundle_def, preproc_imap, **kwargs): b_sls.initiate_selection("length") - min_len = bundle_def["length"].get("min_len", 0) / preproc_imap["vox_dim"] - max_len = bundle_def["length"].get("max_len", np.inf) / preproc_imap["vox_dim"] + min_len = bundle_def["length"].get("min_len", 0) + max_len = bundle_def["length"].get("max_len", np.inf) - # Using resampled fgarray biases lengths to be lower. However, - # this is not meant to be a precise selection requirement, and - # is more meant for efficiency. - segments = np.diff(preproc_imap["fgarray"][b_sls.selected_fiber_idxs], axis=1) - segment_lengths = np.sqrt(np.sum(segments**2, axis=2)) - sl_lens = np.sum(segment_lengths, axis=1) + # No need to use b_sls.selected_fiber_idxs + # because this is first step + sl_lens = preproc_imap["lengths"] accept_idx = (sl_lens >= min_len) & (sl_lens <= max_len) b_sls.select(accept_idx, "length") @@ -234,7 +235,6 @@ def curvature(b_sls, bundle_def, mapping, img, save_intermediates, **kwargs): moved_ref_sl = move_streamlines( ref_sl, "subject", mapping, img, save_intermediates=save_intermediates ) - moved_ref_sl.to_vox() moved_ref_sl = moved_ref_sl.streamlines[0] moved_ref_curve = abv.sl_curve(moved_ref_sl, len(moved_ref_sl)) ref_curve_threshold = np.radians(bundle_def["curvature"].get("thresh", 10)) @@ -278,7 +278,7 @@ def recobundles( ): b_sls.initiate_selection("Recobundles") moved_sl = move_streamlines( - StatefulTractogram(b_sls.get_selected_sls(), img, Space.VOX), + StatefulTractogram(b_sls.get_selected_sls(), img, Space.RASMM), "template", mapping, reg_template, @@ -305,7 +305,7 @@ def qb_thresh(b_sls, bundle_def, preproc_imap, clip_edges, **kwargs): b_sls.initiate_selection("qb_thresh") cut = clip_edges or ("bundlesection" in bundle_def) qbx = QuickBundles( - bundle_def["qb_thresh"] / preproc_imap["vox_dim"], + bundle_def["qb_thresh"], AveragePointwiseEuclideanMetric(ResampleFeature(nb_points=12)), ) clusters = qbx.cluster(b_sls.get_selected_sls(cut=cut, flip=True)) @@ -509,12 +509,9 @@ def check_space(roi): # entirely on the wrong side of the midline here after filtering if b_sls and "cross_midline" in bundle_def and not bundle_def["cross_midline"]: b_sls.initiate_selection("Wrong side of mid.") - zero_coord = preproc_imap["zero_coord"] - lr_axis = preproc_imap["lr_axis"] avg_side = np.sign( np.mean( - preproc_imap["fgarray"][b_sls.selected_fiber_idxs, :, lr_axis] - - zero_coord, + preproc_imap["fgarray"][b_sls.selected_fiber_idxs, :, 0], axis=1, ) ) @@ -541,7 +538,7 @@ def check_space(roi): ) sub_sft = StatefulTractogram( - b_sls.get_selected_sls(flip=True), img, Space.VOX + b_sls.get_selected_sls(flip=True), img, Space.RASMM ) cluster_labels = subcluster_by_atlas( sub_sft, mapping, img, subdict.all_cluster_IDs, n_points=40 diff --git a/AFQ/recognition/other_bundles.py b/AFQ/recognition/other_bundles.py index 618daf45..bff432c2 100644 --- a/AFQ/recognition/other_bundles.py +++ b/AFQ/recognition/other_bundles.py @@ -6,6 +6,8 @@ import numpy as np from scipy.spatial.distance import cdist +import AFQ.recognition.utils as abu + logger = logging.getLogger("AFQ") @@ -26,6 +28,7 @@ def clean_by_overlap( ---------- this_bundle_sls : array-like A list or array of streamlines to be cleaned. + Assumed to be in RASMM space. other_bundle_sls : array-like A reference list or array of streamlines to determine overlapping regions. overlap : int @@ -72,7 +75,7 @@ def clean_by_overlap( >>> cleaned_bundle = [s for i, s in enumerate(bundle1) if clean_idx[i]] """ other_bundle_density_map = dtu.density_map( - other_bundle_sls, np.eye(4), img.shape[:3] + other_bundle_sls, img.affine, img.shape[:3] ) if remove: @@ -99,7 +102,7 @@ def clean_by_overlap( ) fiber_probabilities = dts.values_from_volume( - other_bundle_density_map, this_bundle_sls, np.eye(4) + other_bundle_density_map, this_bundle_sls, img.affine ) cleaned_idx = np.zeros(len(this_bundle_sls), dtype=np.bool_) for ii, fp in enumerate(fiber_probabilities): @@ -125,8 +128,10 @@ def clean_relative_to_other_core( retained. this_fgarray : ndarray An array of streamlines to be cleaned. + Assumed to be in RASMM space. other_fgarray : ndarray An array of reference streamlines to define the core. + Assumed to be in RASMM space. affine : ndarray The affine transformation matrix. entire : bool, optional @@ -164,17 +169,7 @@ def clean_relative_to_other_core( return np.ones(this_fgarray.shape[0], dtype=np.bool_) # find dimension of core axis - orientation = nib.orientations.aff2axcodes(affine) - core_axis = None - core_upper = core[0].upper() - axis_groups = { - "L": ("L", "R"), - "R": ("L", "R"), - "P": ("P", "A"), - "A": ("P", "A"), - "I": ("I", "S"), - "S": ("I", "S"), - } + core_axis = abu.axes_dict[core[0].upper()] direction_signs = { "L": 1, @@ -185,18 +180,7 @@ def clean_relative_to_other_core( "S": -1, } - core_axis = None - for idx, axis_label in enumerate(orientation): - if core_upper in axis_groups[axis_label]: - core_axis = idx - core_direc = direction_signs[core_upper] - break - - if affine[core_axis, core_axis] < 0: - core_direc = -core_direc - - if core_axis is None: - raise ValueError(f"Invalid core axis: {core}") + core_direc = direction_signs[core[0].upper()] core_bundle = np.median(other_fgarray, axis=0) cleaned_idx_core = np.zeros(this_fgarray.shape[0], dtype=np.bool_) diff --git a/AFQ/recognition/preprocess.py b/AFQ/recognition/preprocess.py index 52531045..ec47c7a4 100644 --- a/AFQ/recognition/preprocess.py +++ b/AFQ/recognition/preprocess.py @@ -2,7 +2,6 @@ from time import time import immlib -import nibabel as nib import numpy as np import AFQ.recognition.utils as abu @@ -27,39 +26,41 @@ def fgarray(tg): return fg_array -@immlib.calc("crosses", "lr_axis", "zero_coord") -def crosses(fgarray, img): +@immlib.calc("crosses") +def crosses(fgarray): """ Classify the streamlines by whether they cross the midline. Creates a crosses attribute which is an array of booleans. Each boolean corresponds to a streamline, and is whether or not that streamline crosses the midline. """ - # What is the x,y,z coordinate of 0,0,0 in the template space? - zero_coord = np.dot(np.linalg.inv(img.affine), np.array([0, 0, 0, 1])) + return np.logical_and( + np.any(fgarray[:, :, 0] > 0, axis=1), + np.any(fgarray[:, :, 0] < 0, axis=1), + ) - orientation = nib.orientations.aff2axcodes(img.affine) - lr_axis = 0 - for idx, axis_label in enumerate(orientation): - if axis_label in ["L", "R"]: - lr_axis = idx - break - return ( - np.logical_and( - np.any(fgarray[:, :, lr_axis] > zero_coord[lr_axis], axis=1), - np.any(fgarray[:, :, lr_axis] < zero_coord[lr_axis], axis=1), - ), - lr_axis, - zero_coord[lr_axis], - ) +@immlib.calc("lengths") +def lengths(fgarray): + """ + Calculate the lengths of the streamlines. + Using resampled fgarray biases lengths to be lower. However, + this is not meant to be a precise selection requirement, and + is more meant for efficiency. + """ + segments = np.diff(fgarray, axis=1) + segment_lengths = np.sqrt(np.sum(segments**2, axis=2)) + return np.sum(segment_lengths, axis=1) # Things that can be calculated for multiple bundles at once # (i.e., for a whole tractogram) go here def get_preproc_plan(img, tg, dist_to_waypoint, dist_to_atlas): preproc_plan = immlib.plan( - tolerance_mm_to_vox=tolerance_mm_to_vox, fgarray=fgarray, crosses=crosses + tolerance_mm_to_vox=tolerance_mm_to_vox, + fgarray=fgarray, + crosses=crosses, + lengths=lengths, ) return preproc_plan( img=img, diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 73602909..51861deb 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -155,7 +155,7 @@ def recognize( if not isinstance(bundle_dict, BundleDict): bundle_dict = BundleDict(bundle_dict) - tg.to_vox() + tg.to_rasmm() n_streamlines = len(tg) recognized_bundles_dict = {} @@ -273,10 +273,10 @@ def _return_empty(bundle_name, return_idx, fiber_groups, img): """ if return_idx: fiber_groups[bundle_name] = {} - fiber_groups[bundle_name]["sl"] = StatefulTractogram([], img, Space.VOX) + fiber_groups[bundle_name]["sl"] = StatefulTractogram([], img, Space.RASMM) fiber_groups[bundle_name]["idx"] = np.array([]) else: - fiber_groups[bundle_name] = StatefulTractogram([], img, Space.VOX) + fiber_groups[bundle_name] = StatefulTractogram([], img, Space.RASMM) def _add_bundle_to_fiber_group(b_name, sl, idx, to_flip, return_idx, fiber_groups, img): @@ -285,7 +285,7 @@ def _add_bundle_to_fiber_group(b_name, sl, idx, to_flip, return_idx, fiber_group """ sl = abu.flip_sls(sl, to_flip, in_place=False) - sl = StatefulTractogram(sl, img, Space.VOX) + sl = StatefulTractogram(sl, img, Space.RASMM) if return_idx: fiber_groups[b_name] = {} diff --git a/AFQ/recognition/roi.py b/AFQ/recognition/roi.py index 4ae02310..3ca51318 100644 --- a/AFQ/recognition/roi.py +++ b/AFQ/recognition/roi.py @@ -2,20 +2,22 @@ from dipy.core.interpolation import interpolate_scalar_3d -def _interp3d(roi, sl): - return interpolate_scalar_3d(roi.get_fdata(), np.asarray(sl))[0] +def _interp_arr_with_affine(roi, arr, affine): + inv_affine = np.linalg.inv(affine) + arr_xformd = np.dot(arr, inv_affine[:3, :3]) + inv_affine[:3, 3] + + return np.array(interpolate_scalar_3d(roi, arr_xformd)[0]) def check_sls_with_inclusion(sls, include_rois, include_roi_tols): inc_results = np.zeros(len(sls), dtype=tuple) - include_rois = [roi_.get_fdata().copy() for roi_ in include_rois] for jj, sl in enumerate(sls): closest = np.zeros(len(include_rois), dtype=np.int32) dists = np.zeros(len(include_rois), dtype=np.float32) sl = np.asarray(sl) valid = True for ii, roi in enumerate(include_rois): - dist = interpolate_scalar_3d(roi, sl)[0] + dist = _interp_arr_with_affine(roi.get_fdata(), sl, roi.affine) closest[ii] = np.argmin(dist) dists[ii] = dist[closest[ii]] @@ -38,7 +40,10 @@ def check_sl_with_exclusion(sl, exclude_rois, exclude_roi_tols): for ii, roi in enumerate(exclude_rois): # if any part of the streamline is near any exclusion ROI, # return False - if np.any(_interp3d(roi, sl) <= exclude_roi_tols[ii]): + if np.any( + _interp_arr_with_affine(roi.get_fdata(), sl, roi.affine) + <= exclude_roi_tols[ii] + ): return False # Either there are no exclusion ROIs, or you are not close to any: return True @@ -53,6 +58,7 @@ def clean_by_endpoints(fgarray, target, target_idx, tol=0, flip_sls=None): ---------- fgarray : ndarray of shape (N, M, 3) Where N is number of streamlines, M is number of nodes. + Assumed to be in RASMM space. target: Nifti1Image Nifti1Image containing a distance transform of the ROI. target_idx: int. @@ -87,8 +93,8 @@ def clean_by_endpoints(fgarray, target, target_idx, tol=0, flip_sls=None): flipped_indices = n_nodes - 1 - effective_idx indices = np.where(flip_sls.astype(bool), flipped_indices, indices) - distances = interpolate_scalar_3d( - target.get_fdata(), fgarray[np.arange(n_sls), indices] - )[0] + distances = _interp_arr_with_affine( + target.get_fdata(), fgarray[np.arange(n_sls), indices], target.affine + ) return distances <= tol diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index 6e4812de..6c5c4a73 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -26,7 +26,6 @@ mapping = reg.read_old_mapping(file_dict["mapping.nii.gz"], hardi_img, reg_template) streamlines = file_dict["tractography_subsampled.trk"] tg = StatefulTractogram(streamlines, hardi_img, Space.RASMM) -tg.to_vox() streamlines = tg.streamlines templates = afd.read_templates() cst_r_curve_ref = StatefulTractogram( diff --git a/AFQ/recognition/tests/test_utils.py b/AFQ/recognition/tests/test_utils.py index 59930cf0..4f320de6 100644 --- a/AFQ/recognition/tests/test_utils.py +++ b/AFQ/recognition/tests/test_utils.py @@ -18,7 +18,6 @@ file_dict = afd.read_stanford_hardi_tractography() streamlines = file_dict["tractography_subsampled.trk"] tg = StatefulTractogram(streamlines, hardi_img, Space.RASMM) -tg.to_vox() streamlines = tg.streamlines diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 4acd2a5e..350adb4b 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -13,6 +13,19 @@ logger = logging.getLogger("AFQ") +axes_dict = { + "L/R": 0, + "L": 0, + "R": 0, + "P/A": 1, + "P": 1, + "A": 1, + "I/S": 2, + "I": 2, + "S": 2, +} + + def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): # We need to calculate the size of a voxel, so we can transform # from mm to voxel units: @@ -157,7 +170,9 @@ def select(self, idx, clean_name, cut=False): # otherwise its impractical if self.save_intermediates is not None and len(self) < 0.1 * len(self.ref_sls): save_tractogram( - StatefulTractogram(self.get_selected_sls(cut=cut), self.ref, Space.VOX), + StatefulTractogram( + self.get_selected_sls(cut=cut), self.ref, Space.RASMM + ), op.join( self.save_intermediates, f"sls_after_{clean_name}_for_{self.b_name}.trk", diff --git a/AFQ/tasks/mapping.py b/AFQ/tasks/mapping.py index 724711ea..4dc62973 100644 --- a/AFQ/tasks/mapping.py +++ b/AFQ/tasks/mapping.py @@ -173,7 +173,7 @@ def sls_mapping( ) streamlines_file = tractography_imap["streamlines"] tg = load_tractogram( - streamlines_file, reg_subject, Space.VOX, bbox_valid_check=False + streamlines_file, reg_subject, Space.RASMM, bbox_valid_check=False ) tg.to_rasmm() diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index 79cd64cb..465a5fee 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -59,9 +59,7 @@ def segment(data_imap, mapping_imap, tractography_imap, segmentation_params): or streamlines.endswith(".tck") or streamlines.endswith(".vtk") ): - tg = load_tractogram( - streamlines, data_imap["dwi"], Space.VOX, bbox_valid_check=False - ) + tg = load_tractogram(streamlines, data_imap["dwi"], bbox_valid_check=False) is_trx = False elif streamlines.endswith(".trx"): is_trx = True @@ -85,9 +83,7 @@ def segment(data_imap, mapping_imap, tractography_imap, segmentation_params): with open(temp_tck, "wb") as f_out: shutil.copyfileobj(f_in, f_out) # initialize stateful tractogram from tck file: - tg = load_tractogram( - temp_tck, data_imap["dwi"], Space.VOX, bbox_valid_check=False - ) + tg = load_tractogram(temp_tck, data_imap["dwi"], bbox_valid_check=False) is_trx = False if len(tg.streamlines) == 0: raise ValueError( @@ -119,7 +115,7 @@ def segment(data_imap, mapping_imap, tractography_imap, segmentation_params): raise ValueError("Fatal: No bundles recognized.") if is_trx: - seg_sft.sft.dtype_dict = {"positions": np.float16, "offsets": np.uint32} + seg_sft.sft.dtype_dict = {"positions": np.float32, "offsets": np.uint32} tgram = TrxFile.from_sft(seg_sft.sft) tgram.groups = seg_sft.bundle_idxs @@ -175,7 +171,7 @@ def export_bundles(base_fname, output_dir, bundles, tracking_params): logger.info(f"Saving {fname}") if is_trx: seg_sft.sft.dtype_dict = { - "positions": np.float16, + "positions": np.float32, "offsets": np.uint32, } trxfile = TrxFile.from_sft(bundle_sft) diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 7c3d418f..b666c3d8 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -88,7 +88,7 @@ def streamlines(data_imap, seed, tissue_imap, fodf, citations, tracking_params): if is_trx: start_time = time() - dtype_dict = {"positions": np.float16, "offsets": np.uint32} + dtype_dict = {"positions": np.float32, "offsets": np.uint32} if num_chunks and num_chunks > 1: if not has_ray: raise ImportError( diff --git a/AFQ/tractography/gputractography.py b/AFQ/tractography/gputractography.py index c8087d22..1824683f 100644 --- a/AFQ/tractography/gputractography.py +++ b/AFQ/tractography/gputractography.py @@ -89,7 +89,7 @@ def gpu_track( directions = directions.lower() # Roughly handle ACT/CMC for now - wm_threshold = 0.01 + wm_threshold = 0.5 pve_img = nib.load(pve_path) From b2fdb54c521d80a05cc5571762395e65e8757f25 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 25 Feb 2026 14:03:58 +0900 Subject: [PATCH 63/86] more refinements to VOF, more refinements to RASMM segmenting --- AFQ/api/bundle_dict.py | 97 ++++++++++++----------------- AFQ/api/utils.py | 2 +- AFQ/data/fetch.py | 8 +-- AFQ/recognition/criteria.py | 19 +++++- AFQ/recognition/utils.py | 10 +++ AFQ/tasks/segmentation.py | 29 ++++++--- AFQ/tasks/tractography.py | 2 + AFQ/tractography/gputractography.py | 21 ++++++- AFQ/viz/utils.py | 12 ++-- 9 files changed, 118 insertions(+), 82 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index b66d6b83..0da36241 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -135,6 +135,7 @@ def default_bd(): "prob_map": templates["ATR_L_prob_map"], "start": templates["ATR_L_start"], "end": templates["ATR_L_end"], + "length": {"min_len": 30}, }, "Right Anterior Thalamic": { "cross_midline": False, @@ -144,6 +145,7 @@ def default_bd(): "prob_map": templates["ATR_R_prob_map"], "start": templates["ATR_R_start"], "end": templates["ATR_R_end"], + "length": {"min_len": 30}, }, "Left Cingulum Cingulate": { "cross_midline": False, @@ -152,6 +154,7 @@ def default_bd(): "space": "template", "prob_map": templates["CGC_L_prob_map"], "end": templates["CGC_L_start"], + "length": {"min_len": 30}, }, "Right Cingulum Cingulate": { "cross_midline": False, @@ -160,6 +163,7 @@ def default_bd(): "space": "template", "prob_map": templates["CGC_R_prob_map"], "end": templates["CGC_R_start"], + "length": {"min_len": 30}, }, "Left Corticospinal": { "cross_midline": False, @@ -168,6 +172,7 @@ def default_bd(): "space": "template", "prob_map": templates["CST_L_prob_map"], "end": templates["CST_L_start"], + "length": {"min_len": 40}, }, "Right Corticospinal": { "cross_midline": False, @@ -176,26 +181,27 @@ def default_bd(): "space": "template", "prob_map": templates["CST_R_prob_map"], "end": templates["CST_R_start"], + "length": {"min_len": 40}, }, "Left Inferior Fronto-occipital": { "cross_midline": False, "include": [templates["IFO_roi2_L"], templates["IFO_roi1_L"]], - "exclude": [templates["ARC_roi1_L"]], + "exclude": [templates["ARC_roi1_L"], templates["CGC_roi1_L"]], "space": "template", "prob_map": templates["IFO_L_prob_map"], "end": templates["IFO_L_start"], "start": templates["IFO_L_end"], - "length": {"min_len": 100}, + "length": {"min_len": 80}, }, "Right Inferior Fronto-occipital": { "cross_midline": False, "include": [templates["IFO_roi2_R"], templates["IFO_roi1_R"]], - "exclude": [templates["ARC_roi1_R"]], + "exclude": [templates["ARC_roi1_R"], templates["CGC_roi1_R"]], "space": "template", "prob_map": templates["IFO_R_prob_map"], "end": templates["IFO_R_start"], "start": templates["IFO_R_end"], - "length": {"min_len": 100}, + "length": {"min_len": 80}, }, "Left Inferior Longitudinal": { "cross_midline": False, @@ -205,6 +211,7 @@ def default_bd(): "prob_map": templates["ILF_L_prob_map"], "start": templates["ILF_L_end"], "end": templates["ILF_L_start"], + "length": {"min_len": 40}, }, "Right Inferior Longitudinal": { "cross_midline": False, @@ -214,6 +221,7 @@ def default_bd(): "prob_map": templates["ILF_R_prob_map"], "start": templates["ILF_R_end"], "end": templates["ILF_R_start"], + "length": {"min_len": 40}, }, "Left Arcuate": { "cross_midline": False, @@ -223,7 +231,7 @@ def default_bd(): "prob_map": templates["ARC_L_prob_map"], "start": templates["ARC_L_start"], "end": templates["ARC_L_end"], - "length": {"min_len": 50}, + "length": {"min_len": 40}, }, "Right Arcuate": { "cross_midline": False, @@ -233,7 +241,7 @@ def default_bd(): "prob_map": templates["ARC_R_prob_map"], "start": templates["ARC_R_start"], "end": templates["ARC_R_end"], - "length": {"min_len": 50}, + "length": {"min_len": 40}, }, "Left Uncinate": { "cross_midline": False, @@ -306,25 +314,12 @@ def default_bd(): "primary_axis": "I/S", "ORG_spectral_subbundles": SpectralSubbundleDict( { - "Left Vertical Occipital I": { + "Left V1V3": { "cluster_IDs": [61, 63, 77, 82], "orient_mahal": { - "distance_threshold": 2, - "length_threshold": 0, - "clean_rounds": 1, - }, - "mahal": { "distance_threshold": 3, "length_threshold": 4, - "clean_rounds": 5, - }, - }, - "Left Vertical Occipital II": { - "cluster_IDs": [1, 72, 75, 81, 83], - "orient_mahal": { - "distance_threshold": 2, - "length_threshold": 0, - "clean_rounds": 1, + "clean_rounds": 2, }, "mahal": { "distance_threshold": 3, @@ -332,20 +327,15 @@ def default_bd(): "clean_rounds": 5, }, }, - "Left Vertical Occipital III": { + "Left Posterior Vertical Occipital": { + "cluster_IDs": [1, 72, 75, 81, 83], + "isolation_forest": {}, + }, + "Left Anterior Vertical Occipital": { "Left Inferior Fronto-occipital": {"core": "Right"}, "cluster_IDs": [2, 7, 18, 21, 25, 51], "exclude": [templates["pARC_xroi1_L"]], - "orient_mahal": { - "distance_threshold": 2, - "length_threshold": 0, - "clean_rounds": 1, - }, - "mahal": { - "distance_threshold": 3, - "length_threshold": 4, - "clean_rounds": 5, - }, + "isolation_forest": {}, }, }, remove_cluster_IDs=[ @@ -414,25 +404,12 @@ def default_bd(): "primary_axis": "I/S", "ORG_spectral_subbundles": SpectralSubbundleDict( { - "Right Vertical Occipital I": { + "Right V1V3": { "cluster_IDs": [61, 63, 77, 82], "orient_mahal": { - "distance_threshold": 2, - "length_threshold": 0, - "clean_rounds": 1, - }, - "mahal": { "distance_threshold": 3, "length_threshold": 4, - "clean_rounds": 5, - }, - }, - "Right Vertical Occipital II": { - "cluster_IDs": [1, 72, 75, 81, 83], - "orient_mahal": { - "distance_threshold": 2, - "length_threshold": 0, - "clean_rounds": 1, + "clean_rounds": 2, }, "mahal": { "distance_threshold": 3, @@ -440,20 +417,15 @@ def default_bd(): "clean_rounds": 5, }, }, - "Right Vertical Occipital III": { + "Right Posterior Vertical Occipital": { + "cluster_IDs": [1, 72, 75, 81, 83], + "isolation_forest": {}, + }, + "Right Anterior Vertical Occipital": { "Right Inferior Fronto-occipital": {"core": "Left"}, "cluster_IDs": [2, 7, 18, 21, 25, 51], "exclude": [templates["pARC_xroi1_R"]], - "orient_mahal": { - "distance_threshold": 2, - "length_threshold": 0, - "clean_rounds": 1, - }, - "mahal": { - "distance_threshold": 3, - "length_threshold": 4, - "clean_rounds": 5, - }, + "isolation_forest": {}, }, }, remove_cluster_IDs=[ @@ -516,11 +488,15 @@ def default_bd(): def slf_bd(): templates = afd.read_slf_templates(as_img=False) + templates_afq = afd.read_templates(as_img=False) + templates["Frontal_Lobe_L"] = templates_afq["ATR_L_start"] + templates["Frontal_Lobe_R"] = templates_afq["ATR_R_start"] return BundleDict( { "Left Superior Longitudinal I": { "include": [templates["SFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], + "start": templates["Frontal_Lobe_L"], "cross_midline": False, "Left Cingulum Cingulate": { "node_thresh": 20, @@ -529,6 +505,7 @@ def slf_bd(): "Left Superior Longitudinal II": { "include": [templates["MFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], + "start": templates["Frontal_Lobe_L"], "cross_midline": False, "Left Cingulum Cingulate": { "node_thresh": 20, @@ -537,6 +514,7 @@ def slf_bd(): "Left Superior Longitudinal III": { "include": [templates["PrgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], + "start": templates["Frontal_Lobe_L"], "cross_midline": False, "Left Cingulum Cingulate": { "node_thresh": 20, @@ -545,6 +523,7 @@ def slf_bd(): "Right Superior Longitudinal I": { "include": [templates["SFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], + "start": templates["Frontal_Lobe_R"], "cross_midline": False, "Right Cingulum Cingulate": { "node_thresh": 20, @@ -553,6 +532,7 @@ def slf_bd(): "Right Superior Longitudinal II": { "include": [templates["MFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], + "start": templates["Frontal_Lobe_R"], "cross_midline": False, "Right Cingulum Cingulate": { "node_thresh": 20, @@ -561,6 +541,7 @@ def slf_bd(): "Right Superior Longitudinal III": { "include": [templates["PrgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], + "start": templates["Frontal_Lobe_R"], "cross_midline": False, "Right Cingulum Cingulate": { "node_thresh": 20, diff --git a/AFQ/api/utils.py b/AFQ/api/utils.py index 0c015b49..4dd83930 100644 --- a/AFQ/api/utils.py +++ b/AFQ/api/utils.py @@ -172,7 +172,7 @@ def export_all_helper(api_afq_object, xforms, indiv, viz): api_afq_object.export("indiv_bundles") api_afq_object.export("rois") api_afq_object.export("sl_counts") - api_afq_object.export("median_bundle_lengths") + api_afq_object.export("bundle_lengths") api_afq_object.export("profiles") if viz: diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index b1a2eb84..f4fa089f 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -865,8 +865,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "61737619", "61970155", "61970158", - "62084713", - "62084716", + "62134578", + "62134581", "62031442", "62031445", ] @@ -972,8 +972,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "7cf5800a4efa6bac7e70d84095bc259b", "f65b3f9133820921d023517a68d4ea41", "4476935f5aadfcdd633b9a23779625ef", - "bac80a77df083c12a01982c0386f94be", - "dddd1923e87fb2880091615bc7f8a9a4", + "11ba79ff1f9a01c6b064428323d01013", + "84df5abfefbed5e3e310f2db0b62fcea", "db5bd2d1e810e366f5ef67a9cce205c2", "6891cfc038ce7db21e0cc307ae2b1b37", ] diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 80034031..c2b281ff 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -26,9 +26,9 @@ criteria_order_pre_other_bundles = [ "length", "cross_midline", - "prob_map", "start", "end", + "prob_map", "primary_axis", "include", "exclude", @@ -91,8 +91,15 @@ def start(b_sls, bundle_def, preproc_imap, **kwargs): -1, tol=preproc_imap["dist_to_atlas"], ) + new_accept_idx = np.logical_or(accepted_idx_flipped, accept_idx) + special_idx = np.logical_and(accept_idx, accepted_idx_flipped) + special_idx_to_flip = abu.manual_orient_sls( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs][special_idx] + ) + accepted_idx_flipped[special_idx] = special_idx_to_flip b_sls.reorient(accepted_idx_flipped) - accept_idx = np.logical_xor(accepted_idx_flipped, accept_idx) + accept_idx = new_accept_idx + b_sls.select(accept_idx, "Startpoint") @@ -112,8 +119,14 @@ def end(b_sls, bundle_def, preproc_imap, **kwargs): 0, tol=preproc_imap["dist_to_atlas"], ) + new_accept_idx = np.logical_or(accepted_idx_flipped, accept_idx) + special_idx = np.logical_and(accept_idx, accepted_idx_flipped) + special_idx_to_flip = abu.manual_orient_sls( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs][special_idx] + ) + accepted_idx_flipped[special_idx] = special_idx_to_flip b_sls.reorient(accepted_idx_flipped) - accept_idx = np.logical_xor(accepted_idx_flipped, accept_idx) + accept_idx = new_accept_idx b_sls.select(accept_idx, "endpoint") diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 350adb4b..217ae415 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -26,6 +26,16 @@ } +def manual_orient_sls(fgarray): + """ + Helper function to manually orient streamlines by their endpoints, + according to LPI+ pyAFQ standard assuming streamlines are in RASMM + """ + endpoint_diff = fgarray[:, 0, :] - fgarray[:, -1, :] + primary_axis = np.argmax(np.abs(endpoint_diff), axis=1) + return endpoint_diff[np.arange(len(fgarray)), primary_axis] < 0 + + def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): # We need to calculate the size of a voxel, so we can transform # from mm to voxel units: diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index 465a5fee..9da73635 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -32,6 +32,7 @@ from dipy.io.streamline import load_tractogram, save_tractogram from dipy.stats.analysis import afq_profile from dipy.tracking.streamline import set_number_of_points, values_from_volume +from dipy.tracking.utils import length from nibabel.affines import voxel_sizes from nibabel.orientations import aff2axcodes @@ -206,26 +207,38 @@ def export_sl_counts(bundles): return counts_df, dict(source=bundles) -@immlib.calc("median_bundle_lengths") +@immlib.calc("bundle_lengths") @as_file("_desc-medianBundleLengths_tractography.csv", subfolder="stats") def export_bundle_lengths(bundles): """ full path to a JSON file containing median bundle lengths """ - med_len_counts = [] + len_data = {} seg_sft = aus.SegmentedSFT.fromfile(bundles) for bundle in seg_sft.bundle_names: - these_lengths = seg_sft.get_bundle(bundle)._tractogram._streamlines._lengths + these_lengths = list(length(seg_sft.get_bundle(bundle).streamlines)) if len(these_lengths) > 0: - med_len_counts.append(np.median(these_lengths)) + len_data[f"{bundle} Median"] = np.median(these_lengths) + len_data[f"{bundle} Min"] = np.min(these_lengths) + len_data[f"{bundle} Max"] = np.max(these_lengths) else: - med_len_counts.append(0) - med_len_counts.append(np.median(seg_sft.sft._tractogram._streamlines._lengths)) + len_data[f"{bundle} Median"] = 0 + len_data[f"{bundle} Min"] = 0 + len_data[f"{bundle} Max"] = 0 + len_data["Total Recognized Median"] = np.median( + seg_sft.sft._tractogram._streamlines._lengths + ) + len_data["Total Recognized Min"] = np.min( + seg_sft.sft._tractogram._streamlines._lengths + ) + len_data["Total Recognized Max"] = np.max( + seg_sft.sft._tractogram._streamlines._lengths + ) counts_df = pd.DataFrame( - data=dict(median_len=med_len_counts), - index=seg_sft.bundle_names + ["Total Recognized"], + data=len_data, + index=[0], ) return counts_df, dict(source=bundles) diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index b666c3d8..91502793 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -297,6 +297,8 @@ def gpu_tractography( tracking_params["thresholds_as_percentages"], tracking_params["max_angle"], tracking_params["step_size"], + tracking_params["minlen"], + tracking_params["maxlen"], tracking_params["n_seeds"], tracking_params["random_seeds"], tracking_params["rng_seed"], diff --git a/AFQ/tractography/gputractography.py b/AFQ/tractography/gputractography.py index 1824683f..f60059e5 100644 --- a/AFQ/tractography/gputractography.py +++ b/AFQ/tractography/gputractography.py @@ -30,6 +30,8 @@ def gpu_track( thresholds_as_percentages, max_angle, step_size, + minlen, + maxlen, n_seeds, random_seeds, rng_seed, @@ -70,6 +72,10 @@ def gpu_track( array, these are the coordinates of the seeds. Unless random_seeds is set to True, in which case this is the total number of random seeds to generate within the mask. Default: 1 + minlen: int, optional + The minimal length (mm) in a streamline + maxlen: int, optional + The minimal length (mm) in a streamline random_seeds : bool If True, n_seeds is total number of random seeds to generate. If False, n_seeds encodes the density of seeds to generate. @@ -88,6 +94,13 @@ def gpu_track( seed_img = nib.load(seed_path) directions = directions.lower() + minlen = int(minlen / step_size) + maxlen = int(maxlen / step_size) + + R = seed_img.affine[0:3, 0:3] + vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) + step_size = step_size / vox_dim + # Roughly handle ACT/CMC for now wm_threshold = 0.5 @@ -170,6 +183,10 @@ def gpu_track( chunk_size=chunk_size, ) as gpu_tracker: if use_trx: - return gpu_tracker.generate_trx(seeds, seed_img) + return gpu_tracker.generate_trx( + seeds, seed_img, minlen=minlen, maxlen=maxlen + ) else: - return gpu_tracker.generate_sft(seeds, seed_img) + return gpu_tracker.generate_sft( + seeds, seed_img, minlen=minlen, maxlen=maxlen + ) diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index e3d2c6c9..ec3bbd76 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -125,12 +125,12 @@ def get_distinct_shades(base_rgb, n_steps, hue_shift): "Right Posterior Arcuate": tableau_20[15], "Left Vertical Occipital": vof_l_base, "Right Vertical Occipital": vof_r_base, - "Left Vertical Occipital I": vof_l_shades[0], - "Left Vertical Occipital II": vof_l_shades[1], - "Left Vertical Occipital III": vof_l_shades[2], - "Right Vertical Occipital I": vof_r_shades[0], - "Right Vertical Occipital II": vof_r_shades[1], - "Right Vertical Occipital III": vof_r_shades[2], + "Left V1V3": vof_l_shades[0], + "Left Posterior Vertical Occipital": vof_l_shades[1], + "Left Anterior Vertical Occipital": vof_l_shades[2], + "Right V1V3": vof_r_shades[0], + "Right Posterior Vertical Occipital": vof_r_shades[1], + "Right Anterior Vertical Occipital": vof_r_shades[2], "median": tableau_20[6], # Paul Tol's palette for callosal bundles "Callosum Orbital": (0.2, 0.13, 0.53), From 7cb61b6a15c73624f457c87d54655573c22a4465 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 25 Feb 2026 14:22:47 +0900 Subject: [PATCH 64/86] new gpu streamlines minmaxlen api --- AFQ/tractography/gputractography.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/AFQ/tractography/gputractography.py b/AFQ/tractography/gputractography.py index f60059e5..bade6345 100644 --- a/AFQ/tractography/gputractography.py +++ b/AFQ/tractography/gputractography.py @@ -178,15 +178,13 @@ def gpu_track( sphere.edges, max_angle=radians(max_angle), step_size=step_size, + min_pts=minlen, + max_pts=maxlen, ngpus=ngpus, rng_seed=rng_seed, chunk_size=chunk_size, ) as gpu_tracker: if use_trx: - return gpu_tracker.generate_trx( - seeds, seed_img, minlen=minlen, maxlen=maxlen - ) + return gpu_tracker.generate_trx(seeds, seed_img) else: - return gpu_tracker.generate_sft( - seeds, seed_img, minlen=minlen, maxlen=maxlen - ) + return gpu_tracker.generate_sft(seeds, seed_img) From 9d3fa01989f9c59537cf52aabc42d03ed4087b3e Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 25 Feb 2026 15:02:50 +0900 Subject: [PATCH 65/86] BFs --- AFQ/recognition/roi.py | 4 ++- AFQ/recognition/tests/test_other_bundles.py | 3 +- AFQ/recognition/tests/test_recognition.py | 34 ++++++++------------- AFQ/tractography/gputractography.py | 2 +- AFQ/tractography/tractography.py | 2 +- 5 files changed, 20 insertions(+), 25 deletions(-) diff --git a/AFQ/recognition/roi.py b/AFQ/recognition/roi.py index 3ca51318..4d26dbfc 100644 --- a/AFQ/recognition/roi.py +++ b/AFQ/recognition/roi.py @@ -4,7 +4,9 @@ def _interp_arr_with_affine(roi, arr, affine): inv_affine = np.linalg.inv(affine) - arr_xformd = np.dot(arr, inv_affine[:3, :3]) + inv_affine[:3, 3] + lin_T = inv_affine[:3, :3].T + offset = inv_affine[:3, 3] + arr_xformd = np.dot(arr, lin_T) + offset return np.array(interpolate_scalar_3d(roi, arr_xformd)[0]) diff --git a/AFQ/recognition/tests/test_other_bundles.py b/AFQ/recognition/tests/test_other_bundles.py index 0724bf3e..f1a83b19 100644 --- a/AFQ/recognition/tests/test_other_bundles.py +++ b/AFQ/recognition/tests/test_other_bundles.py @@ -1,3 +1,4 @@ +import nibabel as nib import numpy as np import AFQ.recognition.other_bundles as abo @@ -9,7 +10,7 @@ other_bundle_sls_sample = np.array( [[[0, 1, 2], [1, 2, 3], [2, 2, 2]], [[1, 1, 1], [2, 2, 2], [3, 3, 3]]] ) -img_sample = np.zeros((5, 5, 5)) +img_sample = nib.Nifti1Image(np.zeros((5, 5, 5)), np.eye(4)) node_thresh_sample = 1 diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index 6c5c4a73..e60fd90c 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -63,9 +63,7 @@ def test_segment(): - fiber_groups, _ = recognize( - tg, nib.load(hardi_fdata), mapping, bundles, reg_template, 2 - ) + fiber_groups, _ = recognize(tg, hardi_img, mapping, bundles, reg_template, 2) # We asked for 2 fiber groups: npt.assert_equal(len(fiber_groups), 2) @@ -75,7 +73,7 @@ def test_segment(): npt.assert_(len(CST_R_sl) > 0) # Calculate the tract profile for a volume of all-ones: tract_profile = afq_profile( - np.ones(nib.load(hardi_fdata).shape[:3]), CST_R_sl.streamlines, np.eye(4) + np.ones(hardi_img.shape[:3]), CST_R_sl.streamlines, hardi_img.affine ) npt.assert_almost_equal(tract_profile, np.ones(100)) @@ -104,13 +102,13 @@ def test_segment_mixed_roi(): ), ): fiber_groups, _ = recognize( - tg, nib.load(hardi_fdata), mapping, bundle_info, reg_template, 2 + tg, hardi_img, mapping, bundle_info, reg_template, 2 ) bundle_info = abd.BundleDict(bundle_info, resample_subject_to=hardi_fdata) fiber_groups, _ = recognize( tg, - nib.load(hardi_fdata), + hardi_img, mapping, bundle_info, reg_template, @@ -121,7 +119,7 @@ def test_segment_mixed_roi(): # We asked for 2 fiber groups: npt.assert_equal(len(fiber_groups), 1) OR_LV1_sl = fiber_groups["OR LV1"] - npt.assert_(len(OR_LV1_sl) == 2) + npt.assert_(len(OR_LV1_sl) == 6) @pytest.mark.nightly @@ -139,7 +137,7 @@ def test_segment_no_prob(): } fiber_groups, _ = recognize( - tg, nib.load(hardi_fdata), mapping, bundles_no_prob, reg_template, 1 + tg, hardi_img, mapping, bundles_no_prob, reg_template, 1 ) # This condition should still hold @@ -150,7 +148,7 @@ def test_segment_no_prob(): def test_segment_return_idx(): # Test with the return_idx kwarg set to True: fiber_groups, _ = recognize( - tg, nib.load(hardi_fdata), mapping, bundles, reg_template, 1, return_idx=True + tg, hardi_img, mapping, bundles, reg_template, 1, return_idx=True ) npt.assert_equal(len(fiber_groups), 2) @@ -166,7 +164,7 @@ def test_segment_return_idx(): def test_segment_clip_edges_api(): # Test with the clip_edges kwarg set to True: fiber_groups, _ = recognize( - tg, nib.load(hardi_fdata), mapping, bundles, reg_template, 1, clip_edges=True + tg, hardi_img, mapping, bundles, reg_template, 1, clip_edges=True ) npt.assert_equal(len(fiber_groups), 2) npt.assert_(len(fiber_groups["Right Corticospinal"]) > 0) @@ -183,7 +181,7 @@ def test_segment_reco(): # Try recobundles method fiber_groups, _ = recognize( tg, - nib.load(hardi_fdata), + hardi_img, mapping, bundles_reco, reg_template, @@ -216,25 +214,19 @@ def test_exclusion_ROI(): hardi_img, Space.VOX, ) - fiber_groups, _ = recognize( - slf_tg, nib.load(hardi_fdata), mapping, slf_bundle, reg_template, 1 - ) + fiber_groups, _ = recognize(slf_tg, hardi_img, mapping, slf_bundle, reg_template, 1) npt.assert_equal(len(fiber_groups["Left Superior Longitudinal"]), 2) slf_bundle["Left Superior Longitudinal"]["exclude"] = [templates["SLFt_roi2_L"]] - fiber_groups, _ = recognize( - slf_tg, nib.load(hardi_fdata), mapping, slf_bundle, reg_template, 1 - ) + fiber_groups, _ = recognize(slf_tg, hardi_img, mapping, slf_bundle, reg_template, 1) npt.assert_equal(len(fiber_groups["Left Superior Longitudinal"]), 1) def test_segment_sampled_streamlines(): - fiber_groups, _ = recognize( - tg, nib.load(hardi_fdata), mapping, bundles, reg_template, 1 - ) + fiber_groups, _ = recognize(tg, hardi_img, mapping, bundles, reg_template, 1) # Already using a subsampled tck # the Right Corticospinal has two streamlines and @@ -247,7 +239,7 @@ def test_segment_sampled_streamlines(): # sample and segment streamlines sampled_fiber_groups, _ = recognize( tg, - nib.load(hardi_fdata), + hardi_img, mapping, bundles, reg_template, diff --git a/AFQ/tractography/gputractography.py b/AFQ/tractography/gputractography.py index bade6345..d390eda2 100644 --- a/AFQ/tractography/gputractography.py +++ b/AFQ/tractography/gputractography.py @@ -75,7 +75,7 @@ def gpu_track( minlen: int, optional The minimal length (mm) in a streamline maxlen: int, optional - The minimal length (mm) in a streamline + The maximum length (mm) in a streamline random_seeds : bool If True, n_seeds is total number of random seeds to generate. If False, n_seeds encodes the density of seeds to generate. diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index de99c63d..cea4f4f5 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -93,7 +93,7 @@ def track( minlen: int, optional The minimal length (mm) in a streamline. Default: 20 maxlen: int, optional - The minimal length (mm) in a streamline. Default: 250 + The maximum length (mm) in a streamline. Default: 250 odf_model : str or Definition, optional Can be either a string or Definition. If a string, it must be one of {"DTI", "CSD", "DKI", "GQ", "RUMBA", "MSMT_AODF", "CSD_AODF", "MSMTCSD"}. From c53362bee1ba1549afe15ae816a5c393cbeb79c1 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 25 Feb 2026 15:20:15 +0900 Subject: [PATCH 66/86] loosen cleaning --- AFQ/recognition/cleaning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index 0d3fab86..2b1e707a 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -126,7 +126,7 @@ def clean_bundle( tg, n_points=100, clean_rounds=5, - distance_threshold=3, + distance_threshold=4, length_threshold=4, min_sl=20, stat=np.mean, @@ -149,7 +149,7 @@ def clean_bundle( the mean of extracted bundles. Default: 5 distance_threshold : float, optional. Threshold of cleaning based on the Mahalanobis distance (the units are - standard deviations). Default: 3. + standard deviations). Default: 4. length_threshold: float, optional Threshold for cleaning based on length (in standard deviations). Length of any streamline should not be *more* than this number of stdevs from From 81b40888d63b34a726b76714693e89b8c3bc035b Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 25 Feb 2026 15:39:07 +0900 Subject: [PATCH 67/86] add color fa, update docs --- AFQ/tasks/data.py | 12 ++++++ docs/source/reference/kwargs.rst | 5 +-- docs/source/reference/methods.rst | 62 +++++++++++++++++-------------- 3 files changed, 48 insertions(+), 31 deletions(-) diff --git a/AFQ/tasks/data.py b/AFQ/tasks/data.py index 66dfbfc4..6ce839e1 100644 --- a/AFQ/tasks/data.py +++ b/AFQ/tasks/data.py @@ -1132,6 +1132,17 @@ def dki_lt(dki_tf, dwi_affine): return dki_lt_dict +@immlib.calc("dki_cfa") +@as_file(suffix="_model-kurtosis_param-cfa_dwimap.nii.gz", subfolder="models") +@as_fit_deriv("DKI") +def dki_cfa(dki_tf): + """ + full path to a nifti file containing + the DKI color fractional anisotropy + """ + return dki_tf.color_fa, {"Description": "Color Fractional Anisotropy"} + + @immlib.calc("dki_fa") @as_file("_model-kurtosis_param-fa_dwimap.nii.gz", subfolder="models") @as_fit_deriv("DKI") @@ -1438,6 +1449,7 @@ def get_data_plan(kwargs): dki_awf, dki_mk, dki_kfa, + dki_cfa, dki_ga, dki_rd, dti_ga, diff --git a/docs/source/reference/kwargs.rst b/docs/source/reference/kwargs.rst index 914c1297..8709543c 100644 --- a/docs/source/reference/kwargs.rst +++ b/docs/source/reference/kwargs.rst @@ -149,7 +149,7 @@ tractography_ngpus: int Number of GPUs to use in tractography. If non-0, this algorithm is used for tractography, https://github.com/dipy/GPUStreamlines PTT, Prob can be used with any SHM model. Bootstrapped can be done with CSA/OPDT. Default: 0 chunk_size: int - Chunk size for GPU tracking. Default: 100000 + Chunk size for GPU tracking. Default: 25000 ========================================================== @@ -173,9 +173,6 @@ volume_opacity_indiv: float n_points_indiv: int or None n_points to resample streamlines to before plotting. If None, no resampling is done. Default: 40 -virtual_frame_buffer: bool - Whether to use a virtual frame buffer. This is if generating GIFs in a headless environment. Default: False - viz_backend_spec: str Which visualization backend to use. See Visualization Backends page in documentation for details https://tractometry.org/pyAFQ/reference/viz_backend.html One of {"fury", "plotly", "plotly_no_gif"}. Default: "plotly_no_gif" diff --git a/docs/source/reference/methods.rst b/docs/source/reference/methods.rst index 50c94406..bde820f3 100644 --- a/docs/source/reference/methods.rst +++ b/docs/source/reference/methods.rst @@ -111,6 +111,10 @@ dti_params: full path to a nifti file containing parameters for the DTI fit +dti_s0: + s0 values of DTI fit + + fwdti_tf: Free-water DTI TensorFit object @@ -127,6 +131,10 @@ dki_params: full path to a nifti file containing parameters for the DKI fit +dki_s0: + s0 values of DKI fit + + msdki_tf: Mean Signal DKI DiffusionKurtosisFit object @@ -135,6 +143,10 @@ msdki_params: full path to a nifti file containing parameters for the Mean Signal DKI fit +msdki_s0: + s0 values of Mean Signal DKI fit + + msdki_msd: full path to a nifti file containing the MSDKI mean signal diffusivity @@ -168,47 +180,27 @@ csd_ai: gq_params: - full path to a nifti file containing parameters for the Generalized Q-Sampling shm_coeff + full path to a nifti file containing ODF for the Generalized Q-Sampling gq_iso: full path to a nifti file containing isotropic diffusion component -gq_aso: - full path to a nifti file containing anisotropic diffusion component - - -gq_pmap: - full path to a nifti file containing the anisotropic power map from GQ - - -gq_ai: - full path to a nifti file containing the anisotropic index from GQ - - -rumba_model: - fit for RUMBA-SD model as documented on dipy reconstruction options - - rumba_params: - Takes the fitted RUMBA-SD model as input and returns the spherical harmonics coefficients (SHM). - - -rumba_fit: - RUMBA FIT + ODF for the RUMBA-SD model rumba_f_csf: - full path to a nifti file containing the CSF volume fraction for each voxel. + full path to a nifti file containing the CSF volume fraction for each voxel rumba_f_gm: - full path to a nifti file containing the GM volume fraction for each voxel. + full path to a nifti file containing the GM volume fraction for each voxel rumba_f_wm: - full path to a nifti file containing the white matter volume fraction for each voxel. + full path to a nifti file containing the white matter volume fraction for each voxel opdt_params: @@ -391,6 +383,10 @@ dki_lt5: Image of sixth element in the DTI tensor from DKI +dki_cfa: + full path to a nifti file containing the DKI color fractional anisotropy + + dki_fa: full path to a nifti file containing the DKI fractional anisotropy @@ -468,7 +464,15 @@ pve_internal: msmtcsd_params: - full path to a nifti file containing parameters for the MSMT CSD fit + full path to a nifti file containing parameters for the MSMT CSD white matter fit + + +msmtcsd_gm: + full path to a nifti file containing parameters for the MSMT CSD gray matter fit + + +msmtcsd_csf: + full path to a nifti file containing parameters for the MSMT CSD cerebrospinal fluid fit msmt_apm: @@ -527,7 +531,7 @@ sl_counts: full path to a JSON file containing streamline counts -median_bundle_lengths: +bundle_lengths: full path to a JSON file containing median bundle lengths @@ -565,3 +569,7 @@ tract_profile_plots: viz_backend: An instance of the `AFQ.viz.utils.viz_backend` class. + + +citations: + Export Bibtex citation file for methods used by pyAFQ. From c44f8c05122f7d9901ded7edca52d90ef252be16 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 25 Feb 2026 16:07:40 +0900 Subject: [PATCH 68/86] the vof is lateral to ifof --- AFQ/api/bundle_dict.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 0da36241..d134ee57 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -328,6 +328,7 @@ def default_bd(): }, }, "Left Posterior Vertical Occipital": { + "Left Inferior Fronto-occipital": {"core": "Right"}, "cluster_IDs": [1, 72, 75, 81, 83], "isolation_forest": {}, }, @@ -418,6 +419,7 @@ def default_bd(): }, }, "Right Posterior Vertical Occipital": { + "Right Inferior Fronto-occipital": {"core": "Left"}, "cluster_IDs": [1, 72, 75, 81, 83], "isolation_forest": {}, }, From c0ee04e777cdffe979aa8aa2cb767dbec6d4ff74 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 25 Feb 2026 17:32:34 +0900 Subject: [PATCH 69/86] tiny tweaks --- AFQ/api/bundle_dict.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index d134ee57..3bda7334 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -382,7 +382,10 @@ def default_bd(): 422, 439, 454, + 552, 555, + 556, + 725, ], ), }, @@ -473,7 +476,10 @@ def default_bd(): 422, 439, 454, + 552, 555, + 556, + 725, ], ), }, From 62c43f106ac45ad51f9272a291c5a2e63ac69435 Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 27 Feb 2026 13:29:31 +0900 Subject: [PATCH 70/86] more onnx option --- AFQ/nn/brainchop.py | 4 +- AFQ/nn/multiaxial.py | 21 +++++--- AFQ/nn/synthseg.py | 4 +- AFQ/tasks/data.py | 44 ++------------- AFQ/tasks/segmentation.py | 6 ++- AFQ/tasks/structural.py | 109 ++++++++++++++++++++++++++++++++++---- AFQ/tasks/tissue.py | 17 ++++-- AFQ/tasks/tractography.py | 6 ++- AFQ/tasks/utils.py | 6 ++- setup.cfg | 1 + 10 files changed, 148 insertions(+), 70 deletions(-) diff --git a/AFQ/nn/brainchop.py b/AFQ/nn/brainchop.py index 1cb1b8f0..13bfc968 100644 --- a/AFQ/nn/brainchop.py +++ b/AFQ/nn/brainchop.py @@ -30,7 +30,7 @@ def _get_model(model_name): return model_fname -def run_brainchop(ort, t1_img, model_name): +def run_brainchop(ort, t1_img, model_name, onnx_kwargs): """ Run the Brainchop command line interface with the provided arguments. @@ -56,7 +56,7 @@ def run_brainchop(ort, t1_img, model_name): image = t1_data.astype(np.float32)[None, None, ...] logger.info(f"Running {model_name}...") - sess = ort.InferenceSession(model) + sess = ort.InferenceSession(model, **onnx_kwargs) input_name = sess.get_inputs()[0].name output_name = sess.get_outputs()[0].name output_channels = sess.run([output_name], {input_name: image})[0] diff --git a/AFQ/nn/multiaxial.py b/AFQ/nn/multiaxial.py index d5f5a12a..f0d72307 100644 --- a/AFQ/nn/multiaxial.py +++ b/AFQ/nn/multiaxial.py @@ -14,7 +14,9 @@ __all__ = ["run_multiaxial", "extract_brain_mask", "multiaxial"] -def multiaxial(ort, img, model_sagittal, model_axial, model_coronal, consensus_model): +def multiaxial( + ort, img, model_sagittal, model_axial, model_coronal, consensus_model, onnx_kwargs +): """ Perform multiaxial segmentation using three ONNX models and a consensus model [1]. @@ -31,6 +33,8 @@ def multiaxial(ort, img, model_sagittal, model_axial, model_coronal, consensus_m Path to coronal ONNX model. consensus_model : str Path to consensus ONNX model. + onnx_kwargs : dict + ONNX kwargs to use for inference. Returns ------- @@ -48,25 +52,25 @@ def multiaxial(ort, img, model_sagittal, model_axial, model_coronal, consensus_m pbar = tqdm(total=4) input_ = img[..., None] - sagittal_results = _run_onnx_model(ort, model_sagittal, input_, coords) + sagittal_results = _run_onnx_model(ort, model_sagittal, input_, coords, onnx_kwargs) pbar.update(1) input_ = np.swapaxes(img, 0, 1)[..., None] coronal_results = np.swapaxes( - _run_onnx_model(ort, model_coronal, input_, coords), 0, 1 + _run_onnx_model(ort, model_coronal, input_, coords, onnx_kwargs), 0, 1 ) pbar.update(1) input_ = np.transpose(img, (2, 0, 1))[..., None] axial_results = np.transpose( - _run_onnx_model(ort, model_axial, input_, coords), (1, 2, 0, 3) + _run_onnx_model(ort, model_axial, input_, coords, onnx_kwargs), (1, 2, 0, 3) ) pbar.update(1) X = np.concatenate( [img[..., None], sagittal_results, coronal_results, axial_results], -1 ) - sess = ort.InferenceSession(consensus_model, providers=["CPUExecutionProvider"]) + sess = ort.InferenceSession(consensus_model, **onnx_kwargs) input_name = sess.get_inputs()[0].name output_name = sess.get_outputs()[0].name yhat = sess.run([output_name], {input_name: X[None, ...]})[0] @@ -77,8 +81,8 @@ def multiaxial(ort, img, model_sagittal, model_axial, model_coronal, consensus_m return pred -def _run_onnx_model(ort, model, input_, coords): - sess = ort.InferenceSession(model, providers=["CPUExecutionProvider"]) +def _run_onnx_model(ort, model, input_, coords, onnx_kwargs): + sess = ort.InferenceSession(model, **onnx_kwargs) input_name = sess.get_inputs()[0].name coord_name = sess.get_inputs()[1].name output_name = sess.get_outputs()[0].name @@ -131,7 +135,7 @@ def _get_multiaxial_model(): return model_dict -def run_multiaxial(ort, t1_img): +def run_multiaxial(ort, t1_img, onnx_kwargs): """ Run the multiaxial model. """ @@ -155,6 +159,7 @@ def run_multiaxial(ort, t1_img): model_dict["axial_model"], model_dict["coronal_model"], model_dict["consensus_model"], + onnx_kwargs, ) output_img = nbp.resample_from_to( diff --git a/AFQ/nn/synthseg.py b/AFQ/nn/synthseg.py index 6ea9d0c2..e67cfdf0 100644 --- a/AFQ/nn/synthseg.py +++ b/AFQ/nn/synthseg.py @@ -28,7 +28,7 @@ def _get_model(model_name): return model_fname -def run_synthseg(ort, t1_img, model_name): +def run_synthseg(ort, t1_img, model_name, onnx_kwargs): """ Run the Synthseg Model @@ -57,7 +57,7 @@ def run_synthseg(ort, t1_img, model_name): image = t1_data.astype(np.float32)[None, ..., None] logger.info(f"Running {model_name}...") - sess = ort.InferenceSession(model) + sess = ort.InferenceSession(model, **onnx_kwargs) input_name = sess.get_inputs()[0].name output_name = sess.get_outputs()[0].name output_channels = sess.run([output_name], {input_name: image})[0] diff --git a/AFQ/tasks/data.py b/AFQ/tasks/data.py index 6ce839e1..1f2e1e35 100644 --- a/AFQ/tasks/data.py +++ b/AFQ/tasks/data.py @@ -1,5 +1,4 @@ import logging -import multiprocessing import dipy.core.gradients as dpg import dipy.reconst.dki as dpy_dki @@ -16,7 +15,6 @@ from dipy.reconst.dki_micro import axonal_water_fraction from dipy.reconst.gqi import GeneralizedQSamplingModel from dipy.reconst.rumba import RumbaSDModel -from numba import get_num_threads import AFQ.api.bundle_dict as abd import AFQ.data.fetch as afd @@ -100,40 +98,6 @@ def get_data_gtab( return data, gtab, img, img.affine -@immlib.calc("n_cpus", "n_threads", "low_mem") -def configure_ncpus_nthreads(ray_n_cpus=None, numba_n_threads=None, low_memory=False): - """ - Configure the number of CPUs to use for parallel processing with Ray, - the number of threads to use for Numba, - and whether to use low-memory versions of algorithms - where available - - Parameters - ---------- - ray_n_cpus : int, optional - The number of CPUs to use for parallel processing with Ray. - If None, uses the number of available CPUs minus one. - Tractography, Recognition, and MSMT use Ray. - Default: None - numba_n_threads : int, optional - The number of threads to use for Numba. - If None, uses the number of available CPUs minus one, - but with a maximum of 16. - ASYM fit uses Numba. - Default: None - low_memory : bool, optional - Whether to use low-memory versions of algorithms - where available. - Default: False - """ - if ray_n_cpus is None: - ray_n_cpus = max(multiprocessing.cpu_count() - 1, 1) - if numba_n_threads is None: - numba_n_threads = min(max(get_num_threads() - 1, 1), 16) - - return ray_n_cpus, numba_n_threads, low_memory - - @immlib.calc("b0") @as_file("_b0ref.nii.gz") @as_img @@ -519,7 +483,7 @@ def csd_params( @immlib.calc("csd_aodf_params") @as_file(suffix="_model-csd_param-aodf_dwimap.nii.gz", subfolder="models") @as_img -def csd_aodf(csd_params, n_threads, low_mem, citations): +def csd_aodf(structural_imap, csd_params, citations): """ full path to a nifti file containing SSST CSD ODFs filtered by unified filtering [1] @@ -536,7 +500,10 @@ def csd_aodf(csd_params, n_threads, low_mem, citations): logger.info("Applying unified filtering to generate asymmetric CSD ODFs...") aodf = unified_filtering( - sh_coeff, get_sphere(name="repulsion724"), n_threads=n_threads, low_mem=low_mem + sh_coeff, + get_sphere(name="repulsion724"), + n_threads=structural_imap["n_threads"], + low_mem=structural_imap["low_mem"], ) return aodf, dict( @@ -1414,7 +1381,6 @@ def get_data_plan(kwargs): b0, b0_mask, brain_mask, - configure_ncpus_nthreads, dti_fit, dki_fit, fwdti_fit, diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index 9da73635..5f0c4e86 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -41,7 +41,9 @@ @immlib.calc("bundles") @as_file("_desc-bundles_tractography") -def segment(data_imap, mapping_imap, tractography_imap, segmentation_params): +def segment( + structural_imap, data_imap, mapping_imap, tractography_imap, segmentation_params +): """ full path to a trk/trx file containing containing segmented streamlines, labeled by bundle @@ -106,7 +108,7 @@ def segment(data_imap, mapping_imap, tractography_imap, segmentation_params): mapping_imap["mapping"], bundle_dict, reg_template, - data_imap["n_cpus"], + structural_imap["n_cpus"], **segmentation_params, ) diff --git a/AFQ/tasks/structural.py b/AFQ/tasks/structural.py index 302d3984..4230c6f4 100644 --- a/AFQ/tasks/structural.py +++ b/AFQ/tasks/structural.py @@ -1,7 +1,9 @@ import logging +import multiprocessing import immlib import nibabel as nib +from numba import get_num_threads from AFQ.definitions.utils import Definition from AFQ.nn.brainchop import run_brainchop @@ -13,9 +15,90 @@ logger = logging.getLogger("AFQ") +@immlib.calc("n_cpus", "n_threads", "low_mem") +def configure_ncpus_nthreads(ray_n_cpus=None, numba_n_threads=None, low_memory=False): + """ + Configure the number of CPUs to use for parallel processing with Ray, + the number of threads to use for Numba, + and whether to use low-memory versions of algorithms + where available + + Parameters + ---------- + ray_n_cpus : int, optional + The number of CPUs to use for parallel processing with Ray. + If None, uses the number of available CPUs minus one. + Tractography, Recognition, and MSMT use Ray. + Default: None + numba_n_threads : int, optional + The number of threads to use for Numba. + If None, uses the number of available CPUs minus one, + but with a maximum of 16. + ASYM fit uses Numba. + Default: None + low_memory : bool, optional + Whether to use low-memory versions of algorithms + where available. + Default: False + """ + if ray_n_cpus is None: + ray_n_cpus = max(multiprocessing.cpu_count() - 1, 1) + if numba_n_threads is None: + numba_n_threads = min(max(get_num_threads() - 1, 1), 16) + + return ray_n_cpus, numba_n_threads, low_memory + + +@immlib.calc("onnx_kwargs") +def onnx_kwargs(low_mem, onnx_execution_provider="CPUExecutionProvider"): + """ + The execution provider to use for onnx models + + Parameters + ---------- + onnx_execution_provider : str, optional + The execution provider to use for onnx models. + By default this is set to CPUExecutionProvider + which should work on all systems. If you have a + compatible GPU and the appropriate onnxruntime installed + you can set this to "CUDAExecutionProvider" or + "OpenVINOExecutionProvider" for potentially faster + inference. + Default: "CPUExecutionProvider" + + Returns + ------- + str + The ONNX execution provider to use for onnx models. + """ + try: + import onnxruntime as ort + except ImportError: + # In this case, we can throw a more informative error + # when the user tries to run a model + # that requires onnxruntime + return onnx_execution_provider + if onnx_execution_provider not in ort.get_available_providers(): + logger.warning( + f"{onnx_execution_provider} is not available. " + f"Available providers are: {ort.get_available_providers()}. " + "Falling back to CPUExecutionProvider." + ) + onnx_execution_provider = "CPUExecutionProvider" + options = ort.SessionOptions() + if low_mem: + options.add_session_config_entry("session.use_mem_arena", "0") + options.enable_mem_pattern = False + options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + + onnx_kwargs = {"providers": [onnx_execution_provider], "options": options} + + return {"onnx_kwargs": onnx_kwargs} + + @immlib.calc("synthseg_model") @as_file(suffix="_model-synthseg2_probseg.nii.gz", subfolder="nn") -def synthseg_model(t1_masked, citations): +def synthseg_model(t1_masked, citations, onnx_kwargs): """ full path to the synthseg2 model segmentations @@ -36,13 +119,13 @@ def synthseg_model(t1_masked, citations): "Or, provide your own segmentations using PVEImage or PVEImages.", ) t1_img = nib.load(t1_masked) - predictions = run_synthseg(ort, t1_img, "synthseg2") + predictions = run_synthseg(ort, t1_img, "synthseg2", onnx_kwargs) return predictions, dict(T1w=t1_masked) @immlib.calc("mx_model") @as_file(suffix="_model-multiaxial_probseg.nii.gz", subfolder="nn") -def mx_model(t1_file, t1w_brain_mask, citations): +def mx_model(t1_file, t1w_brain_mask, citations, onnx_kwargs): """ full path to the multi-axial model for brain extraction outputs @@ -59,7 +142,7 @@ def mx_model(t1_file, t1w_brain_mask, citations): ) t1_img = nib.load(t1_file) t1_mask = nib.load(t1w_brain_mask) - predictions = run_multiaxial(ort, t1_img) + predictions = run_multiaxial(ort, t1_img, onnx_kwargs) predictions = nib.Nifti1Image( predictions.get_fdata() * t1_mask.get_fdata(), t1_img.affine ) @@ -68,7 +151,7 @@ def mx_model(t1_file, t1w_brain_mask, citations): @immlib.calc("t1w_brain_mask") @as_file(suffix="_desc-T1w_mask.nii.gz") -def t1w_brain_mask(t1_file, citations, brain_mask_definition=None): +def t1w_brain_mask(t1_file, citations, onnx_kwargs, brain_mask_definition=None): """ full path to a nifti file containing brain mask from T1w image @@ -98,7 +181,7 @@ def t1w_brain_mask(t1_file, citations, brain_mask_definition=None): ort = check_onnxruntime( "Mindgrab", "Or, provide your own brain mask using brain_mask_definition." ) - return run_brainchop(ort, nib.load(t1_file), "mindgrab"), dict( + return run_brainchop(ort, nib.load(t1_file), "mindgrab", onnx_kwargs), dict( T1w=t1_file, model="mindgrab" ) @@ -119,7 +202,7 @@ def t1_masked(t1_file, t1w_brain_mask): @immlib.calc("t1_subcortex") @as_file(suffix="_desc-subcortex_probseg.nii.gz", subfolder="nn") -def t1_subcortex(t1_masked, citations): +def t1_subcortex(t1_masked, citations, onnx_kwargs): """ full path to a nifti file containing segmentation of subcortical structures from T1w image using Brainchop @@ -140,7 +223,7 @@ def t1_subcortex(t1_masked, citations): t1_img_masked = nib.load(t1_masked) - subcortical_img = run_brainchop(ort, t1_img_masked, "subcortical") + subcortical_img = run_brainchop(ort, t1_img_masked, "subcortical", onnx_kwargs) meta = dict( T1w=t1_masked, @@ -172,7 +255,15 @@ def t1_subcortex(t1_masked, citations): def get_structural_plan(kwargs): structural_tasks = with_name( - [mx_model, synthseg_model, t1w_brain_mask, t1_subcortex, t1_masked] + [ + mx_model, + synthseg_model, + t1w_brain_mask, + t1_subcortex, + t1_masked, + onnx_kwargs, + configure_ncpus_nthreads, + ] ) bm_def = kwargs.get("brain_mask_definition", None) diff --git a/AFQ/tasks/tissue.py b/AFQ/tasks/tissue.py index c890fa7d..3b3b5bd3 100644 --- a/AFQ/tasks/tissue.py +++ b/AFQ/tasks/tissue.py @@ -127,7 +127,14 @@ def pve_internal(structural_imap, pve="synthseg"): subfolder="models", ) @as_img -def msmt_params(data_imap, pve_internal, citations, msmt_sh_order=8, msmt_fa_thr=0.7): +def msmt_params( + structural_imap, + data_imap, + pve_internal, + citations, + msmt_sh_order=8, + msmt_fa_thr=0.7, +): """ full path to a nifti file containing parameters for the MSMT CSD white matter fit, @@ -201,7 +208,7 @@ def msmt_params(data_imap, pve_internal, citations, msmt_sh_order=8, msmt_fa_thr mcsd_model = MultiShellDeconvModel(data_imap["gtab"], response_mcsd) logger.info("Fitting Multi-Shell CSD model...") - mcsd_fit = mcsd_model.fit(data_imap["data"], mask, n_cpus=data_imap["n_cpus"]) + mcsd_fit = mcsd_model.fit(data_imap["data"], mask, n_cpus=structural_imap["n_cpus"]) def _get_meta(desc, sh_order, response): return dict( @@ -268,7 +275,7 @@ def msmt_apm(msmtcsd_params): @immlib.calc("msmt_aodf_params") @as_file(suffix="_model-msmtcsd_param-aodf_dwimap.nii.gz", subfolder="models") @as_img -def msmt_aodf(msmtcsd_params, data_imap, citations): +def msmt_aodf(msmtcsd_params, structural_imap, citations): """ full path to a nifti file containing MSMT CSD ODFs filtered by unified filtering [1] @@ -288,8 +295,8 @@ def msmt_aodf(msmtcsd_params, data_imap, citations): aodf = unified_filtering( sh_coeff, get_sphere(name="repulsion724"), - n_threads=data_imap["n_threads"], - low_mem=data_imap["low_mem"], + n_threads=structural_imap["n_threads"], + low_mem=structural_imap["low_mem"], ) return aodf, dict( diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 91502793..3093cdc1 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -60,7 +60,9 @@ def _meta_from_tracking_params(tracking_params, start_time, n_streamlines, seed, @immlib.calc("streamlines") @as_file("_tractography", subfolder="tractography") -def streamlines(data_imap, seed, tissue_imap, fodf, citations, tracking_params): +def streamlines( + structural_imap, data_imap, seed, tissue_imap, fodf, citations, tracking_params +): """ full path to the complete, unsegmented tractography file @@ -84,7 +86,7 @@ def streamlines(data_imap, seed, tissue_imap, fodf, citations, tracking_params): is_trx = this_tracking_params.get("trx", False) - num_chunks = data_imap["n_cpus"] + num_chunks = structural_imap["n_cpus"] if is_trx: start_time = time() diff --git a/AFQ/tasks/utils.py b/AFQ/tasks/utils.py index e57eb393..3cafc325 100644 --- a/AFQ/tasks/utils.py +++ b/AFQ/tasks/utils.py @@ -60,7 +60,11 @@ def check_onnxruntime(model_name, alternative_text): f"When we tried to import onnxruntime, we got the " f"following error:\n{e}\n" f"Please install onnxruntime to use this feature, " - f"by doing `pip install onnxruntime` or `pip install pyAFQ[nn]`. " + f"by doing `pip install onnxruntime` or " + f"`pip install onnxruntime-gpu` if you have a compatible GPU or " + f"`pip install pyAFQ[nn]` or " + "`pip install pyAFQ[gpu]` if you have a compatible GPU " + "and want GPUStreamlines. " f"{alternative_text}\n" "If there are still issues, post an issue on " "https://github.com/tractometry/pyAFQ/issues" diff --git a/setup.cfg b/setup.cfg index 6d66c246..eb0dd268 100644 --- a/setup.cfg +++ b/setup.cfg @@ -95,6 +95,7 @@ nn = onnxruntime gpu = cuslines==2.0.0 + onnxruntime-gpu all = %(dev)s %(fury)s From e1a0a44a7c155821b8e89afd5fd63c2086064e32 Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 27 Feb 2026 14:20:05 +0900 Subject: [PATCH 71/86] bf --- AFQ/tasks/structural.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/AFQ/tasks/structural.py b/AFQ/tasks/structural.py index 4230c6f4..c74cdb7c 100644 --- a/AFQ/tasks/structural.py +++ b/AFQ/tasks/structural.py @@ -87,11 +87,12 @@ def onnx_kwargs(low_mem, onnx_execution_provider="CPUExecutionProvider"): onnx_execution_provider = "CPUExecutionProvider" options = ort.SessionOptions() if low_mem: - options.add_session_config_entry("session.use_mem_arena", "0") + options.enable_cpu_mem_arena = False options.enable_mem_pattern = False options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + options.inter_op_num_threads = 1 - onnx_kwargs = {"providers": [onnx_execution_provider], "options": options} + onnx_kwargs = {"providers": [onnx_execution_provider], "sess_options": options} return {"onnx_kwargs": onnx_kwargs} From 73fa4a26e7fba2e30ddde6f7b62ea87b0a195933 Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 27 Feb 2026 14:56:03 +0900 Subject: [PATCH 72/86] dont force dwi load --- AFQ/api/participant.py | 3 --- AFQ/tasks/decorators.py | 24 ++++++++++++++++++++---- AFQ/tasks/segmentation.py | 10 +++++++--- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 8ad8d89c..dc0fb525 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -165,9 +165,6 @@ def make_workflow(self): plan_kwargs[key] = self.kwargs[key] elif key in previous_plans: plan_kwargs[key] = previous_plans[key] - elif name not in ["data", "structural"] and key == "dwi_affine": - # simplifies syntax to access commonly used dwi_affine - plan_kwargs[key] = previous_plans["data_imap"][key] else: raise NotImplementedError( f"Missing required parameter {key} for {name} plan" diff --git a/AFQ/tasks/decorators.py b/AFQ/tasks/decorators.py index 3c041287..689e1426 100644 --- a/AFQ/tasks/decorators.py +++ b/AFQ/tasks/decorators.py @@ -163,13 +163,21 @@ def as_fit_deriv(tf_name): """ def _as_fit_deriv(func): + module_name = func.__module__ + is_data_module = "data" in module_name + dependency = "dwi_affine" if is_data_module else "data_imap" + new_signature, new_params = get_new_signature( - func, ["dwi_affine", f"{tf_name.lower()}_params"] + func, [dependency, f"{tf_name.lower()}_params"] ) @functools.wraps(func) def wrapper_as_fit_deriv(*args, **kwargs): - dwi_affine = get_param(kwargs, new_params, "dwi_affine") + if is_data_module: + dwi_affine = get_param(kwargs, new_params, "dwi_affine") + else: + data_imap = get_param(kwargs, new_params, "data_imap") + dwi_affine = data_imap["dwi_affine"] params = get_param(kwargs, new_params, f"{tf_name.lower()}_params") params_meta = read_json(drop_extension(params) + ".json") img_meta = {} @@ -202,11 +210,19 @@ def as_img(func): Decorator to convert function output (ndarray, meta) into (Nifti1Image, meta). Supports functions returning a single tuple or a list of tuples. """ - new_signature, new_params = get_new_signature(func, ["dwi_affine"]) + module_name = func.__module__ + is_data_module = "data" in module_name + dependency = "dwi_affine" if is_data_module else "data_imap" + + new_signature, new_params = get_new_signature(func, [dependency]) @functools.wraps(func) def wrapper_as_img(*args, **kwargs): - dwi_affine = get_param(kwargs, new_params, "dwi_affine") + if is_data_module: + dwi_affine = get_param(kwargs, new_params, "dwi_affine") + else: + data_imap = get_param(kwargs, new_params, "data_imap") + dwi_affine = data_imap["dwi_affine"] start_time = time() results = func(*args, **kwargs) diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index 5f0c4e86..5e577082 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -269,7 +269,7 @@ def export_density_maps(bundles, data_imap): @immlib.calc("profiles") @as_file("_desc-profiles_tractography.csv") def tract_profiles( - bundles, scalar_dict, dwi_affine, profile_weights="gauss", n_points_profile=100 + bundles, scalar_dict, data_imap, profile_weights="gauss", n_points_profile=100 ): """ full path to a CSV file containing tract profiles @@ -338,7 +338,11 @@ def tract_profiles( def _median_weight(bundle): fgarray = set_number_of_points(bundle, n_points_profile) values = np.array( - values_from_volume(scalar_data, fgarray, dwi_affine) # noqa B023 + values_from_volume( + scalar_data, # noqa B023 + fgarray, + data_imap["dwi_affine"], + ) ) weights = np.zeros(values.shape) for ii, jj in enumerate( @@ -366,7 +370,7 @@ def _median_weight(bundle): this_profile[ii] = afq_profile( scalar_data, this_sl, - dwi_affine, + data_imap["dwi_affine"], weights=this_prof_weights, n_points=n_points_profile, ) From 9e78c55a35e8cdd32068c51abd4d00258e2b53b5 Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 27 Feb 2026 15:02:20 +0900 Subject: [PATCH 73/86] odf calc is helper function not part of plan --- AFQ/tasks/tractography.py | 44 +++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 3093cdc1..681bf966 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -58,10 +58,26 @@ def _meta_from_tracking_params(tracking_params, start_time, n_streamlines, seed, return meta +def _fiber_odf(data_imap, tissue_imap, tracking_params): + odf_model = tracking_params["odf_model"] + if isinstance(odf_model, str): + calc_name = f"{odf_model.lower()}_params" + if calc_name in data_imap: + params_file = data_imap[calc_name] + elif calc_name in tissue_imap: + params_file = tissue_imap[calc_name] + else: + raise ValueError((f"Could not find {odf_model}")) + else: + raise TypeError(("odf_model must be a string or Definition")) + + return params_file + + @immlib.calc("streamlines") @as_file("_tractography", subfolder="tractography") def streamlines( - structural_imap, data_imap, seed, tissue_imap, fodf, citations, tracking_params + structural_imap, data_imap, seed, tissue_imap, citations, tracking_params ): """ full path to the complete, unsegmented tractography file @@ -79,6 +95,7 @@ def streamlines( citations.add("smith2012anatomically") this_tracking_params = tracking_params.copy() + fodf = _fiber_odf(data_imap, tissue_imap, tracking_params) # get masks this_tracking_params["seed_mask"] = nib.load(seed).get_fdata() @@ -208,26 +225,6 @@ def delete_lazyt(self, id): ) -@immlib.calc("fodf") -def fiber_odf(data_imap, tissue_imap, tracking_params): - """ - Nifti Image containing the fiber orientation distribution function - """ - odf_model = tracking_params["odf_model"] - if isinstance(odf_model, str): - calc_name = f"{odf_model.lower()}_params" - if calc_name in data_imap: - params_file = data_imap[calc_name] - elif calc_name in tissue_imap: - params_file = tissue_imap[calc_name] - else: - raise ValueError((f"Could not find {odf_model}")) - else: - raise TypeError(("odf_model must be a string or Definition")) - - return params_file - - @immlib.calc("streamlines") def custom_tractography(import_tract=None): """ @@ -251,7 +248,6 @@ def custom_tractography(import_tract=None): def gpu_tractography( data_imap, tracking_params, - fodf, seed, tissue_imap, tractography_ngpus=0, @@ -274,6 +270,8 @@ def gpu_tractography( Default: 25000 """ start_time = time() + fodf = _fiber_odf(data_imap, tissue_imap, tracking_params) + if tracking_params["directions"] == "boot": data = data_imap["data"] else: @@ -316,7 +314,7 @@ def get_tractography_plan(kwargs): if "tracking_params" in kwargs and not isinstance(kwargs["tracking_params"], dict): raise TypeError("tracking_params a dict") - tractography_tasks = with_name([streamlines, fiber_odf]) + tractography_tasks = with_name([streamlines]) # use GPU accelerated tractography if asked for if "tractography_ngpus" in kwargs and kwargs["tractography_ngpus"] != 0: From 7454a4e57ea60768f604dcb20bc3e0f378d0ab9c Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 27 Feb 2026 15:27:36 +0900 Subject: [PATCH 74/86] make this default --- AFQ/tasks/structural.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/AFQ/tasks/structural.py b/AFQ/tasks/structural.py index c74cdb7c..3a563555 100644 --- a/AFQ/tasks/structural.py +++ b/AFQ/tasks/structural.py @@ -28,7 +28,7 @@ def configure_ncpus_nthreads(ray_n_cpus=None, numba_n_threads=None, low_memory=F ray_n_cpus : int, optional The number of CPUs to use for parallel processing with Ray. If None, uses the number of available CPUs minus one. - Tractography, Recognition, and MSMT use Ray. + Tractography and MSMT use Ray. Default: None numba_n_threads : int, optional The number of threads to use for Numba. @@ -50,7 +50,9 @@ def configure_ncpus_nthreads(ray_n_cpus=None, numba_n_threads=None, low_memory=F @immlib.calc("onnx_kwargs") -def onnx_kwargs(low_mem, onnx_execution_provider="CPUExecutionProvider"): +def onnx_kwargs( + low_mem, onnx_execution_provider="CPUExecutionProvider", onnx_inter_threads=1 +): """ The execution provider to use for onnx models @@ -65,11 +67,11 @@ def onnx_kwargs(low_mem, onnx_execution_provider="CPUExecutionProvider"): "OpenVINOExecutionProvider" for potentially faster inference. Default: "CPUExecutionProvider" + onnx_inter_threads : int, optional + The number of inter threads to use for onnx models. + Increasing will increase memory usage significantly. + Default: 1 - Returns - ------- - str - The ONNX execution provider to use for onnx models. """ try: import onnxruntime as ort @@ -86,11 +88,11 @@ def onnx_kwargs(low_mem, onnx_execution_provider="CPUExecutionProvider"): ) onnx_execution_provider = "CPUExecutionProvider" options = ort.SessionOptions() + options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + options.inter_op_num_threads = onnx_inter_threads if low_mem: options.enable_cpu_mem_arena = False options.enable_mem_pattern = False - options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - options.inter_op_num_threads = 1 onnx_kwargs = {"providers": [onnx_execution_provider], "sess_options": options} From 2a0799892e06b471c8224e79fb5d917f4f374c77 Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 27 Feb 2026 15:35:44 +0900 Subject: [PATCH 75/86] test update --- AFQ/tests/test_nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/tests/test_nn.py b/AFQ/tests/test_nn.py index 0d999194..460782b8 100644 --- a/AFQ/tests/test_nn.py +++ b/AFQ/tests/test_nn.py @@ -30,7 +30,7 @@ def test_run_brainchop(): "sub-01/ses-01/anat/sub-01_ses-01_T1w.nii.gz" ), ) - chopped_brain = run_brainchop(ort, nib.load(t1_path), "mindgrab") + chopped_brain = run_brainchop(ort, nib.load(t1_path), "mindgrab", {}) npt.assert_(chopped_brain.get_fdata().sum() > 200000) @@ -47,6 +47,6 @@ def test_run_multiaxial(): "sub-01/ses-01/anat/sub-01_ses-01_T1w.nii.gz" ), ) - chopped_brain = run_multiaxial(ort, nib.load(t1_path)) + chopped_brain = run_multiaxial(ort, nib.load(t1_path), {}) npt.assert_(chopped_brain.get_fdata().sum() > 200000) From 6f044b3c8c999279cefae6e4179bf98fa83c6b17 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 2 Mar 2026 15:16:45 +0900 Subject: [PATCH 76/86] refine paf, begin mdlf --- AFQ/api/bundle_dict.py | 140 +++++++++++++++++++--------- AFQ/data/fetch.py | 30 ++++++ AFQ/recognition/cleaning.py | 57 ++++++----- AFQ/recognition/criteria.py | 24 +++-- AFQ/recognition/tests/test_utils.py | 21 +++-- docs/source/references.bib | 22 +++++ 6 files changed, 214 insertions(+), 80 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 3bda7334..e8ddbf27 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -119,13 +119,45 @@ def append_l_r(bundle_list, no_lr_list): DIPY_GH = "https://github.com/dipy/dipy/blob/master/dipy/" +def OR_bd(): + or_rois = afd.read_or_templates() + + return BundleDict( + { + "Left Optic Radiation": { + "include": [or_rois["left_OR_1"], or_rois["left_OR_2"]], + "exclude": [ + or_rois["left_OP_MNI"], + or_rois["left_TP_MNI"], + or_rois["left_pos_thal_MNI"], + ], + "start": or_rois["left_thal_MNI"], + "end": or_rois["left_V1_MNI"], + "cross_midline": False, + }, + "Right Optic Radiation": { + "include": [or_rois["right_OR_1"], or_rois["right_OR_2"]], + "exclude": [ + or_rois["right_OP_MNI"], + or_rois["right_TP_MNI"], + or_rois["right_pos_thal_MNI"], + ], + "start": or_rois["right_thal_MNI"], + "end": or_rois["right_V1_MNI"], + "cross_midline": False, + }, + }, + citations={"Caffarra2021"}, + ) + + def default_bd(): templates = afd.read_templates(as_img=False) templates["ARC_roi1_L"] = templates["SLF_roi1_L"] templates["ARC_roi1_R"] = templates["SLF_roi1_R"] templates["ARC_roi2_L"] = templates["SLFt_roi2_L"] templates["ARC_roi2_R"] = templates["SLFt_roi2_R"] - return BundleDict( + return OR_bd() + BundleDict( { "Left Anterior Thalamic": { "cross_midline": False, @@ -270,11 +302,12 @@ def default_bd(): templates["pARC_xroi1_L"], ], "space": "template", - "prob_map": templates["ARC_L_prob_map"], "start": templates["pARC_L_start"], - "end": templates["VOF_L_end"], + "end": templates["pARC_L_end"], "Left Arcuate": {"overlap": 30}, "Left Inferior Fronto-occipital": {"core": "Right"}, + "Left Optic Radiation": {"core": "Right"}, + "endpoints_exact": True, "length": {"min_len": 30}, "primary_axis": "I/S", }, @@ -287,11 +320,12 @@ def default_bd(): templates["pARC_xroi1_R"], ], "space": "template", - "prob_map": templates["ARC_R_prob_map"], "start": templates["pARC_R_start"], - "end": templates["VOF_R_end"], + "end": templates["pARC_R_end"], "Right Arcuate": {"overlap": 30}, "Right Inferior Fronto-occipital": {"core": "Left"}, + "Right Optic Radiation": {"core": "Left"}, + "endpoints_exact": True, "length": {"min_len": 30}, "primary_axis": "I/S", }, @@ -300,6 +334,7 @@ def default_bd(): "space": "template", "prob_map": templates["VOF_L_prob_map"], "end": templates["VOF_L_end"], + "include": [templates["VOF_roi1_L"], templates["VOF_roi2_L"]], "exclude": [ templates["Cerebellar_Hemi_L"], ], @@ -329,14 +364,24 @@ def default_bd(): }, "Left Posterior Vertical Occipital": { "Left Inferior Fronto-occipital": {"core": "Right"}, + "Left Optic Radiation": {"core": "Right"}, "cluster_IDs": [1, 72, 75, 81, 83], - "isolation_forest": {}, + "mahal": { + "distance_threshold": 5, + "length_threshold": 4, + "clean_rounds": 5, + }, }, "Left Anterior Vertical Occipital": { "Left Inferior Fronto-occipital": {"core": "Right"}, + "Left Optic Radiation": {"core": "Right"}, "cluster_IDs": [2, 7, 18, 21, 25, 51], "exclude": [templates["pARC_xroi1_L"]], - "isolation_forest": {}, + "mahal": { + "distance_threshold": 5, + "length_threshold": 4, + "clean_rounds": 5, + }, }, }, remove_cluster_IDs=[ @@ -394,6 +439,7 @@ def default_bd(): "space": "template", "prob_map": templates["VOF_R_prob_map"], "end": templates["VOF_R_end"], + "include": [templates["VOF_roi1_R"], templates["VOF_roi2_R"]], "exclude": [ templates["Cerebellar_Hemi_R"], ], @@ -423,14 +469,24 @@ def default_bd(): }, "Right Posterior Vertical Occipital": { "Right Inferior Fronto-occipital": {"core": "Left"}, + "Right Optic Radiation": {"core": "Left"}, "cluster_IDs": [1, 72, 75, 81, 83], - "isolation_forest": {}, + "mahal": { + "distance_threshold": 5, + "length_threshold": 4, + "clean_rounds": 5, + }, }, "Right Anterior Vertical Occipital": { "Right Inferior Fronto-occipital": {"core": "Left"}, + "Right Optic Radiation": {"core": "Left"}, "cluster_IDs": [2, 7, 18, 21, 25, 51], "exclude": [templates["pARC_xroi1_R"]], - "isolation_forest": {}, + "mahal": { + "distance_threshold": 5, + "length_threshold": 4, + "clean_rounds": 5, + }, }, }, remove_cluster_IDs=[ @@ -486,7 +542,7 @@ def default_bd(): }, citations={ "Yeatman2012", - "takemura2017occipital", + "takemura2016major", "Tzourio-Mazoyer2002", "zhang2018anatomically", "Hua2008", @@ -494,6 +550,38 @@ def default_bd(): ) +def mdlf_bd(): + """ + Work in Progress. + """ + templates = afd.read_templates(as_img=False) + return default_bd() + BundleDict( + { + "Left Middle Longitudinal": { + "cross_midline": False, + "start": templates["Temporal_Sup_L"], + "end": templates["MdLF_L_end"], + "exclude": [templates["SLF_roi1_L"]], + "space": "template", + "Left Inferior Longitudinal": {"node_thresh": 20}, + "length": {"min_len": 50}, + }, + "Right Middle Longitudinal": { + "cross_midline": False, + "start": templates["Temporal_Sup_R"], + "end": templates["MdLF_R_end"], + "exclude": [templates["SLF_roi1_R"]], + "space": "template", + "Right Inferior Longitudinal": {"node_thresh": 20}, + "length": {"min_len": 50}, + }, + }, + citations={ + "wang2013rethinking", + }, + ) + + def slf_bd(): templates = afd.read_slf_templates(as_img=False) templates_afq = afd.read_templates(as_img=False) @@ -1080,38 +1168,6 @@ def cerebellar_bd(): ) -def OR_bd(): - or_rois = afd.read_or_templates() - - return BundleDict( - { - "Left Optic Radiation": { - "include": [or_rois["left_OR_1"], or_rois["left_OR_2"]], - "exclude": [ - or_rois["left_OP_MNI"], - or_rois["left_TP_MNI"], - or_rois["left_pos_thal_MNI"], - ], - "start": or_rois["left_thal_MNI"], - "end": or_rois["left_V1_MNI"], - "cross_midline": False, - }, - "Right Optic Radiation": { - "include": [or_rois["right_OR_1"], or_rois["right_OR_2"]], - "exclude": [ - or_rois["right_OP_MNI"], - or_rois["right_TP_MNI"], - or_rois["right_pos_thal_MNI"], - ], - "start": or_rois["right_thal_MNI"], - "end": or_rois["right_V1_MNI"], - "cross_midline": False, - }, - }, - citations={"Caffarra2021"}, - ) - - class _BundleEntry(Mapping): """Describes how to recognize a single bundle, immutable""" diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index f4fa089f..7fe6d433 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -763,6 +763,16 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "VOF_R_end.nii.gz", "VOF_L_prob_map.nii.gz", "VOF_R_prob_map.nii.gz", + "VOF_roi1_L.nii.gz", + "VOF_roi1_R.nii.gz", + "VOF_roi2_L.nii.gz", + "VOF_roi2_R.nii.gz", + "Temporal_Sup_L.nii.gz", + "Temporal_Sup_R.nii.gz", + "pARC_L_end.nii.gz", + "pARC_R_end.nii.gz", + "MdLF_L_end.nii.gz", + "MdLF_R_end.nii.gz", ] @@ -869,6 +879,16 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "62134581", "62031442", "62031445", + "62213933", + "62213936", + "62213939", + "62213942", + "62282968", + "62282971", + "62283235", + "62283238", + "62283226", + "62283229", ] @@ -976,6 +996,16 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "84df5abfefbed5e3e310f2db0b62fcea", "db5bd2d1e810e366f5ef67a9cce205c2", "6891cfc038ce7db21e0cc307ae2b1b37", + "ad5407fa6c058c9317a5ba51e5e188bf", + "5c52e20d74608da784ee874d23322385", + "76baf26294c8430afabcdc9a6d756b12", + "eff69ba30619bbf2ff500cd714f07894", + "088e850d38fed2d62fc768071d42a43d", + "a3d249535ca0452ce9f59c754e13695a", + "2cba992046cd80d1b9bdc1de3e101214", + "2929851e9611dba80e8c17d2bd767b1d", + "827b21f9069cd0192ede3f7153f1ca80", + "e2912622d36db6723c48a0a1de887807", ] fetch_templates = _make_reusable_fetcher( diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index 2b1e707a..0f3d68f3 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -12,15 +12,24 @@ logger = logging.getLogger("AFQ") -def clean_by_orientation(streamlines, primary_axis, affine, tol=None): +def clean_by_orientation(streamlines, primary_axis, core_only=0.6): """ - Compute the cardinal orientation of each streamline + Retain streamlines whose core is oriented along the primary axis + and have endpoints that are also oriented along the primary axis + and have a majority of their steps along the primary axis. Parameters ---------- streamlines : sequence of N by 3 arrays Where N is number of nodes in the array, the collection of streamlines to filter down to. + core_only : float, optional + If non-zero, only the core of the bundle is used for cleaning. + The core is defined as the middle 60% of each streamline, + thus our default is 0.6. This means streamlines are allowed to + deviate in the starting and ending 20% of the bundle. This is useful + for allowing more diverse endpoints. + Default: 0.6 Returns ------- @@ -34,31 +43,35 @@ def clean_by_orientation(streamlines, primary_axis, affine, tol=None): primary_axis = abu.axes_dict[primary_axis] - axis_diff = np.zeros((len(streamlines), 3)) + core_accepted_idx = np.zeros(len(streamlines), dtype=bool) + if core_only != 0: + crop_edge = (1.0 - core_only) / 2 + for ii, sl in enumerate(streamlines): + n_points = len(sl) + along_diff = np.abs( + np.diff( + sl[int(n_points * crop_edge) : int(n_points * (1 - crop_edge))], + axis=0, + ) + ) + # The majority of steps must be in the primary axis direction: + core_accepted_idx[ii] = np.sum( + np.argmax(along_diff, axis=1) == primary_axis + ) > (len(along_diff) / 2) + endpoint_diff = np.zeros((len(streamlines), 3)) + along_diff = np.zeros((len(streamlines), 3)) for ii, sl in enumerate(streamlines): - # endpoint diff is between first and last endpoint_diff[ii, :] = np.abs(sl[0, :] - sl[-1, :]) - # axis diff is difference between the nodes, along - axis_diff[ii, :] = np.sum(np.abs(np.diff(sl, axis=0)), axis=0) - - orientation_along = np.argmax(axis_diff, axis=1) - along_accepted_idx = orientation_along == primary_axis - if tol is not None: - percentage_primary = ( - 100 * axis_diff[:, primary_axis] / np.sum(axis_diff, axis=1) - ) - logger.debug( - (f"Maximum primary percentage found: {np.max(percentage_primary)}") - ) - along_accepted_idx = np.logical_and( - along_accepted_idx, percentage_primary > tol - ) - + along_diff[ii, :] = np.sum(np.abs(np.diff(sl, axis=0)), axis=0) orientation_end = np.argmax(endpoint_diff, axis=1) + orientation_along = np.argmax(along_diff, axis=1) end_accepted_idx = orientation_end == primary_axis + along_accepted_idx = orientation_along == primary_axis - cleaned_idx = np.logical_and(along_accepted_idx, end_accepted_idx) + cleaned_idx = np.logical_and( + along_accepted_idx, np.logical_and(end_accepted_idx, core_accepted_idx) + ) return cleaned_idx @@ -162,7 +175,7 @@ def clean_bundle( calculated. Default: `np.mean` (but can also use median, etc.) core_only : float, optional If non-zero, only the core of the bundle is used for cleaning. - The core is commonly defined as the middle 60% of each streamline, + The core is defined as the middle 60% of each streamline, thus our default is 0.6. This means streamlines are allowed to deviate in the starting and ending 20% of the bundle. This is useful for allowing more diverse endpoints. diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index c2b281ff..e250f5e8 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -46,6 +46,7 @@ "primary_axis_percentage", "inc_addtol", "exc_addtol", + "endpoints_exact", "ORG_spectral_subbundles", "cluster_IDs", "startpoint_location", @@ -77,11 +78,17 @@ def cross_midline(b_sls, bundle_def, preproc_imap, **kwargs): def start(b_sls, bundle_def, preproc_imap, **kwargs): b_sls.initiate_selection("Startpoint") + endpoints_exact = bundle_def.get("endpoints_exact", False) + if endpoints_exact: + tol = 0 + else: + tol = preproc_imap["dist_to_atlas"] + accept_idx = abr.clean_by_endpoints( preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["start"], 0, - tol=preproc_imap["dist_to_atlas"], + tol=tol, flip_sls=b_sls.sls_flipped, ) if not b_sls.oriented_yet: @@ -89,7 +96,7 @@ def start(b_sls, bundle_def, preproc_imap, **kwargs): preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["start"], -1, - tol=preproc_imap["dist_to_atlas"], + tol=tol, ) new_accept_idx = np.logical_or(accepted_idx_flipped, accept_idx) special_idx = np.logical_and(accept_idx, accepted_idx_flipped) @@ -105,11 +112,17 @@ def start(b_sls, bundle_def, preproc_imap, **kwargs): def end(b_sls, bundle_def, preproc_imap, **kwargs): b_sls.initiate_selection("endpoint") + endpoints_exact = bundle_def.get("endpoints_exact", False) + if endpoints_exact: + tol = 0 + else: + tol = preproc_imap["dist_to_atlas"] + accept_idx = abr.clean_by_endpoints( preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["end"], -1, - tol=preproc_imap["dist_to_atlas"], + tol=tol, flip_sls=b_sls.sls_flipped, ) if not b_sls.oriented_yet: @@ -117,7 +130,7 @@ def end(b_sls, bundle_def, preproc_imap, **kwargs): preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["end"], 0, - tol=preproc_imap["dist_to_atlas"], + tol=tol, ) new_accept_idx = np.logical_or(accepted_idx_flipped, accept_idx) special_idx = np.logical_and(accept_idx, accepted_idx_flipped) @@ -148,8 +161,7 @@ def primary_axis(b_sls, bundle_def, img, **kwargs): accept_idx = abc.clean_by_orientation( b_sls.get_selected_sls(), bundle_def["primary_axis"], - img.affine, - bundle_def.get("primary_axis_percentage", None), + bundle_def.get("core_only", 0.6), ) b_sls.select(accept_idx, "orientation") diff --git a/AFQ/recognition/tests/test_utils.py b/AFQ/recognition/tests/test_utils.py index 4f320de6..dec1d4a0 100644 --- a/AFQ/recognition/tests/test_utils.py +++ b/AFQ/recognition/tests/test_utils.py @@ -82,21 +82,22 @@ def test_segment_clip_edges(): def test_segment_orientation(): cleaned_idx = abc.clean_by_orientation( - streamlines, primary_axis="P/A", affine=np.eye(4) + streamlines, + primary_axis="P/A", ) - npt.assert_equal(np.sum(cleaned_idx), 93) - cleaned_idx_tol = abc.clean_by_orientation( - streamlines, primary_axis="P/A", affine=np.eye(4), tol=50 - ) - npt.assert_(np.sum(cleaned_idx_tol) < np.sum(cleaned_idx)) + npt.assert_equal(np.sum(cleaned_idx), 46) cleaned_idx = abc.clean_by_orientation( - streamlines, primary_axis="I/S", affine=np.eye(4) + streamlines, + primary_axis="I/S", ) - cleaned_idx_tol = abc.clean_by_orientation( - streamlines, primary_axis="I/S", affine=np.eye(4), tol=33 + npt.assert_equal(np.sum(cleaned_idx), 38) + + cleaned_idx = abc.clean_by_orientation( + streamlines, + primary_axis="L/R", ) - npt.assert_array_equal(cleaned_idx_tol, cleaned_idx) + npt.assert_equal(np.sum(cleaned_idx), 438) def test_clean_isolation_forest_basic(): diff --git a/docs/source/references.bib b/docs/source/references.bib index 0588b036..58a778bb 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -222,6 +222,28 @@ @article{takemura2017occipital publisher={Oxford University Press} } +@article{takemura2016major, + title={A major human white matter pathway between dorsal and ventral visual cortex}, + author={Takemura, Hiromasa and Rokem, Ariel and Winawer, Jonathan and Yeatman, Jason D and Wandell, Brian A and Pestilli, Franco}, + journal={Cerebral cortex}, + volume={26}, + number={5}, + pages={2205--2214}, + year={2016}, + publisher={Oxford University Press} +} + +@article{wang2013rethinking, + title={Rethinking the role of the middle longitudinal fascicle in language and auditory pathways}, + author={Wang, Yibao and Fern{\'a}ndez-Miranda, Juan C and Verstynen, Timothy and Pathak, Sudhir and Schneider, Walter and Yeh, Fang-Cheng}, + journal={Cerebral cortex}, + volume={23}, + number={10}, + pages={2347--2356}, + year={2013}, + publisher={Oxford University Press} +} + @ARTICLE{Dougherty2007, title = "Temporal-callosal pathway diffusivity predicts phonological skills in children", From 9384d52b0deb9cf5b2ae8e55ce34c3b177ba8de5 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 2 Mar 2026 17:16:42 +0900 Subject: [PATCH 77/86] update tests --- AFQ/data/fetch.py | 26 ++++++++++++++--------- AFQ/recognition/criteria.py | 1 + AFQ/recognition/sparse_decisions.py | 2 +- AFQ/recognition/tests/test_recognition.py | 9 ++++---- AFQ/recognition/tests/test_utils.py | 8 +++---- AFQ/tests/test_api.py | 23 +++++++++++++++----- AFQ/tests/test_bundle_dict.py | 2 +- AFQ/tests/test_fixes.py | 2 +- AFQ/tests/test_nn.py | 1 + AFQ/tests/test_registration.py | 2 +- AFQ/utils/tests/test_volume.py | 8 +++---- 11 files changed, 51 insertions(+), 33 deletions(-) diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index 7fe6d433..fa3dcdba 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -24,7 +24,9 @@ from dipy.segment.metric import AveragePointwiseEuclideanMetric from tqdm import tqdm +from AFQ._fixes import get_simplified_transform from AFQ.data.utils import aws_import_msg_error +from AFQ.registration import read_old_mapping from AFQ.utils.path import apply_cmd_to_afq_derivs, drop_extension # capture templateflow resource warning and log @@ -1667,22 +1669,26 @@ def read_stanford_hardi_tractography(): Reads a minimal tractography from the Stanford dataset. """ files, folder = fetch_stanford_hardi_tractography() - files_dict = {} - files_dict["mapping.nii.gz"] = nib.load( - op.join(afq_home, "stanford_hardi_tractography", "mapping.nii.gz") - ) # We need the original data as reference dwi_img, gtab = dpd.read_stanford_hardi() + reg_template = read_mni_template() + + files_dict = {} + files_dict["dwi"] = dwi_img + + mapping_file = op.join(afq_home, "stanford_hardi_tractography", "mapping.nii.gz") + old_mapping = read_old_mapping(mapping_file, dwi_img, reg_template) + files_dict["mapping"] = get_simplified_transform(old_mapping) - files_dict["tractography_subsampled.trk"] = load_tractogram( + files_dict["tractography_subsampled"] = load_tractogram( op.join(afq_home, "stanford_hardi_tractography", "tractography_subsampled.trk"), dwi_img, bbox_valid_check=False, trk_header_check=False, ).streamlines - files_dict["full_segmented_cleaned_tractography.trk"] = load_tractogram( + files_dict["full_segmented_cleaned_tractography"] = load_tractogram( op.join( afq_home, "stanford_hardi_tractography", @@ -1957,10 +1963,10 @@ def read_hcp_atlas(n_bundles=16, as_file=False): bundle_dict[bundle]["recobundles"]["centroid"] = centroid_file # For some reason, this file-name has a 0 in it, instead of an O: - bundle_dict["IFOF_R"] = bundle_dict["IF0F_R"] - # In the 80-bundle case, there are two files, and both have identical - # content, so this is fine: - del bundle_dict["IF0F_R"] + if "IF0F_R" in bundle_dict: + bundle_dict["IFOF_R"] = bundle_dict["IF0F_R"] + del bundle_dict["IF0F_R"] + return bundle_dict diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index e250f5e8..015a87c4 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -323,6 +323,7 @@ def recobundles( standard_sl = next(iter(bundle_def["recobundles"]["centroid"])) oriented_idx = abu.orient_by_streamline(moved_sl[rec_labels], standard_sl) b_sls.reorient(rec_labels[oriented_idx]) + rec_labels = sorted(rec_labels) b_sls.select(rec_labels, "Recobundles") diff --git a/AFQ/recognition/sparse_decisions.py b/AFQ/recognition/sparse_decisions.py index 114cc8ee..586721d5 100644 --- a/AFQ/recognition/sparse_decisions.py +++ b/AFQ/recognition/sparse_decisions.py @@ -93,7 +93,7 @@ def remove_conflicts(sparse_scores, bundles_being_recognized): num_bundles = len(bundles_being_recognized) split_indices = np.searchsorted(winner_rows, np.arange(num_bundles + 1)) - for i, b_name in enumerate(bundles_being_recognized.keys()): + for i, b_name in enumerate(list(bundles_being_recognized.keys())): b_sls = bundles_being_recognized[b_name] if np.any(b_sls.selected_fiber_idxs[:-1] > b_sls.selected_fiber_idxs[1:]): raise NotImplementedError( diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index e60fd90c..5b24d470 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -12,7 +12,6 @@ import AFQ.api.bundle_dict as abd import AFQ.data.fetch as afd import AFQ.recognition.cleaning as abc -import AFQ.registration as reg from AFQ.recognition.recognize import recognize dpd.fetch_stanford_hardi() @@ -23,8 +22,8 @@ hardi_fbvec = op.join(hardi_dir, "HARDI150.bvec") file_dict = afd.read_stanford_hardi_tractography() reg_template = afd.read_mni_template() -mapping = reg.read_old_mapping(file_dict["mapping.nii.gz"], hardi_img, reg_template) -streamlines = file_dict["tractography_subsampled.trk"] +mapping = file_dict["mapping"] +streamlines = file_dict["tractography_subsampled"] tg = StatefulTractogram(streamlines, hardi_img, Space.RASMM) streamlines = tg.streamlines templates = afd.read_templates() @@ -173,7 +172,7 @@ def test_segment_clip_edges_api(): def test_segment_reco(): # get bundles for reco method bundles_reco = afd.read_hcp_atlas(16) - bundle_names = ["MCP"] + bundle_names = ["CCMid"] for key in list(bundles_reco): if key not in bundle_names: bundles_reco.pop(key, None) @@ -191,7 +190,7 @@ def test_segment_reco(): # This condition should still hold npt.assert_equal(len(fiber_groups), 1) - npt.assert_(len(fiber_groups["MCP"]) > 0) + npt.assert_(len(fiber_groups["CCMid"]) > 0) def test_exclusion_ROI(): diff --git a/AFQ/recognition/tests/test_utils.py b/AFQ/recognition/tests/test_utils.py index dec1d4a0..c765ba27 100644 --- a/AFQ/recognition/tests/test_utils.py +++ b/AFQ/recognition/tests/test_utils.py @@ -16,7 +16,7 @@ hardi_fdata = op.join(hardi_dir, "HARDI150.nii.gz") hardi_img = nib.load(hardi_fdata) file_dict = afd.read_stanford_hardi_tractography() -streamlines = file_dict["tractography_subsampled.trk"] +streamlines = file_dict["tractography_subsampled"] tg = StatefulTractogram(streamlines, hardi_img, Space.RASMM) streamlines = tg.streamlines @@ -85,19 +85,19 @@ def test_segment_orientation(): streamlines, primary_axis="P/A", ) - npt.assert_equal(np.sum(cleaned_idx), 46) + npt.assert_equal(np.sum(cleaned_idx), 79) cleaned_idx = abc.clean_by_orientation( streamlines, primary_axis="I/S", ) - npt.assert_equal(np.sum(cleaned_idx), 38) + npt.assert_equal(np.sum(cleaned_idx), 58) cleaned_idx = abc.clean_by_orientation( streamlines, primary_axis="L/R", ) - npt.assert_equal(np.sum(cleaned_idx), 438) + npt.assert_equal(np.sum(cleaned_idx), 57) def test_clean_isolation_forest_basic(): diff --git a/AFQ/tests/test_api.py b/AFQ/tests/test_api.py index 379b7788..8d552c24 100644 --- a/AFQ/tests/test_api.py +++ b/AFQ/tests/test_api.py @@ -697,7 +697,8 @@ def test_AFQ_data_waypoint(): tmpdir, bids_path, _ = get_temp_hardi() t1_path = op.join(tmpdir, "T1.nii.gz") t1_path_other = op.join(tmpdir, "T1-untransformed.nii.gz") - nib.save(afd.read_mni_template(mask=True, weight="T1w"), t1_path) + reg_template = afd.read_mni_template(mask=True, weight="T1w") + nib.save(reg_template, t1_path) shutil.copy(t1_path, t1_path_other) vista_folder = op.join(bids_path, "derivatives/vistasoft/sub-01/ses-01/dwi") @@ -775,8 +776,8 @@ def test_AFQ_data_waypoint(): # Replace the mapping and streamlines with precomputed: file_dict = afd.read_stanford_hardi_tractography() - mapping = file_dict["mapping.nii.gz"] - streamlines = file_dict["tractography_subsampled.trk"] + mapping = file_dict["mapping"] + streamlines = file_dict["tractography_subsampled"] dwi_affine = myafq.export("dwi_affine") streamlines = dts.Streamlines( dtu.transform_tracking_output( @@ -784,11 +785,23 @@ def test_AFQ_data_waypoint(): ) ) - mapping_file = op.join( + mapping_file_forward = op.join( myafq.export("output_dir"), "sub-01_ses-01_desc-mapping_from-subject_to-mni_xform.nii.gz", ) - nib.save(mapping, mapping_file) + nib.save( + nib.Nifti1Image(mapping.forward, dwi_affine), + mapping_file_forward, + ) + + mapping_file_backward = op.join( + myafq.export("output_dir"), + "sub-01_ses-01_desc-mapping_from-mni_to-subject_xform.nii.gz", + ) + nib.save( + nib.Nifti1Image(mapping.backward, reg_template.affine), + mapping_file_backward, + ) # Test ROI exporting: myafq.export("rois") diff --git a/AFQ/tests/test_bundle_dict.py b/AFQ/tests/test_bundle_dict.py index 5d9d57c5..acff27ce 100644 --- a/AFQ/tests/test_bundle_dict.py +++ b/AFQ/tests/test_bundle_dict.py @@ -20,7 +20,7 @@ def test_BundleDict(): # test defaults afq_bundles = abd.default_bd() - assert len(afq_bundles) == 18 + assert len(afq_bundles) == 20 # Arcuate Fasciculus afq_bundles = abd.default_bd()["Left Arcuate", "Right Arcuate"] diff --git a/AFQ/tests/test_fixes.py b/AFQ/tests/test_fixes.py index 3e9dc0d2..437cb8ef 100644 --- a/AFQ/tests/test_fixes.py +++ b/AFQ/tests/test_fixes.py @@ -34,7 +34,7 @@ def test_GQI_fix(): def test_gaussian_weights(): file_dict = afd.read_stanford_hardi_tractography() - streamlines = file_dict["tractography_subsampled.trk"] + streamlines = file_dict["tractography_subsampled"] assert not np.any(np.isnan(gaussian_weights(streamlines[76:92]))) diff --git a/AFQ/tests/test_nn.py b/AFQ/tests/test_nn.py index 460782b8..a410af51 100644 --- a/AFQ/tests/test_nn.py +++ b/AFQ/tests/test_nn.py @@ -36,6 +36,7 @@ def test_run_brainchop(): @pytest.mark.skipif(not has_onnx, reason="onnxruntime is not installed") +@pytest.mark.nightly def test_run_multiaxial(): tmpdir = tempfile.mkdtemp() afd.organize_stanford_data(path=tmpdir) diff --git a/AFQ/tests/test_registration.py b/AFQ/tests/test_registration.py index d1b3144c..b2b77d9c 100644 --- a/AFQ/tests/test_registration.py +++ b/AFQ/tests/test_registration.py @@ -32,7 +32,7 @@ def test_slr_registration(): # have to import subject sls file_dict = afd.read_stanford_hardi_tractography() - streamlines = file_dict["tractography_subsampled.trk"] + streamlines = file_dict["tractography_subsampled"] # have to import sls atlas afd.fetch_hcp_atlas_16_bundles() diff --git a/AFQ/utils/tests/test_volume.py b/AFQ/utils/tests/test_volume.py index e8be2b38..c5c437a7 100644 --- a/AFQ/utils/tests/test_volume.py +++ b/AFQ/utils/tests/test_volume.py @@ -22,12 +22,10 @@ def test_density_map(): file_dict = afd.read_stanford_hardi_tractography() # subsample even more - subsampled_tractography = file_dict["tractography_subsampled.trk"][441:444] - sft = StatefulTractogram( - subsampled_tractography, file_dict["mapping.nii.gz"], Space.VOX - ) + subsampled_tractography = file_dict["tractography_subsampled"][441:444] + sft = StatefulTractogram(subsampled_tractography, file_dict["dwi"], Space.RASMM) density_map = afv.density_map(sft) - npt.assert_equal(int(np.sum(density_map.get_fdata())), 69) + npt.assert_equal(int(np.sum(density_map.get_fdata())), 36) density_map = afv.density_map(sft, normalize=True) npt.assert_equal(density_map.get_fdata().max(), 1) From 1af58b16e95bf4956e14d365f29e58232c8ff6c0 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 3 Mar 2026 11:11:34 +0900 Subject: [PATCH 78/86] tweaking definitions of paf/vof --- AFQ/api/bundle_dict.py | 4 ++-- AFQ/data/fetch.py | 14 ++++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index e8ddbf27..f6eb6a69 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -307,7 +307,6 @@ def default_bd(): "Left Arcuate": {"overlap": 30}, "Left Inferior Fronto-occipital": {"core": "Right"}, "Left Optic Radiation": {"core": "Right"}, - "endpoints_exact": True, "length": {"min_len": 30}, "primary_axis": "I/S", }, @@ -325,7 +324,6 @@ def default_bd(): "Right Arcuate": {"overlap": 30}, "Right Inferior Fronto-occipital": {"core": "Left"}, "Right Optic Radiation": {"core": "Left"}, - "endpoints_exact": True, "length": {"min_len": 30}, "primary_axis": "I/S", }, @@ -365,6 +363,7 @@ def default_bd(): "Left Posterior Vertical Occipital": { "Left Inferior Fronto-occipital": {"core": "Right"}, "Left Optic Radiation": {"core": "Right"}, + "exclude": [templates["pVOF_xroi_1_L"]], "cluster_IDs": [1, 72, 75, 81, 83], "mahal": { "distance_threshold": 5, @@ -470,6 +469,7 @@ def default_bd(): "Right Posterior Vertical Occipital": { "Right Inferior Fronto-occipital": {"core": "Left"}, "Right Optic Radiation": {"core": "Left"}, + "exclude": [templates["pVOF_xroi_1_R"]], "cluster_IDs": [1, 72, 75, 81, 83], "mahal": { "distance_threshold": 5, diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index fa3dcdba..033b7d89 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -775,6 +775,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "pARC_R_end.nii.gz", "MdLF_L_end.nii.gz", "MdLF_R_end.nii.gz", + "pVOF_xroi_1_L.nii.gz", + "pVOF_xroi_1_R.nii.gz", ] @@ -887,10 +889,12 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "62213942", "62282968", "62282971", - "62283235", - "62283238", + "62316400", + "62316403", "62283226", "62283229", + "62316817", + "62316820", ] @@ -1004,10 +1008,12 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "eff69ba30619bbf2ff500cd714f07894", "088e850d38fed2d62fc768071d42a43d", "a3d249535ca0452ce9f59c754e13695a", - "2cba992046cd80d1b9bdc1de3e101214", - "2929851e9611dba80e8c17d2bd767b1d", + "df8a7480c507e91976c5a82d5826d521", + "243ddae33bf84e7b24da4d1d9f90a121", "827b21f9069cd0192ede3f7153f1ca80", "e2912622d36db6723c48a0a1de887807", + "9ebb676cf213605063934981c7c8fe9e", + "cc30b09f5e4c084675380654fc691b95", ] fetch_templates = _make_reusable_fetcher( From c103695762e4cd8203f58b69df44f1ed3eca1d34 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 3 Mar 2026 11:34:52 +0900 Subject: [PATCH 79/86] bf --- AFQ/api/bundle_dict.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index f6eb6a69..7af8468b 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -363,7 +363,6 @@ def default_bd(): "Left Posterior Vertical Occipital": { "Left Inferior Fronto-occipital": {"core": "Right"}, "Left Optic Radiation": {"core": "Right"}, - "exclude": [templates["pVOF_xroi_1_L"]], "cluster_IDs": [1, 72, 75, 81, 83], "mahal": { "distance_threshold": 5, @@ -375,7 +374,10 @@ def default_bd(): "Left Inferior Fronto-occipital": {"core": "Right"}, "Left Optic Radiation": {"core": "Right"}, "cluster_IDs": [2, 7, 18, 21, 25, 51], - "exclude": [templates["pARC_xroi1_L"]], + "exclude": [ + templates["pARC_xroi1_L"], + templates["pVOF_xroi_1_L"], + ], "mahal": { "distance_threshold": 5, "length_threshold": 4, @@ -469,7 +471,6 @@ def default_bd(): "Right Posterior Vertical Occipital": { "Right Inferior Fronto-occipital": {"core": "Left"}, "Right Optic Radiation": {"core": "Left"}, - "exclude": [templates["pVOF_xroi_1_R"]], "cluster_IDs": [1, 72, 75, 81, 83], "mahal": { "distance_threshold": 5, @@ -481,7 +482,10 @@ def default_bd(): "Right Inferior Fronto-occipital": {"core": "Left"}, "Right Optic Radiation": {"core": "Left"}, "cluster_IDs": [2, 7, 18, 21, 25, 51], - "exclude": [templates["pARC_xroi1_R"]], + "exclude": [ + templates["pARC_xroi1_R"], + templates["pVOF_xroi_1_R"], + ], "mahal": { "distance_threshold": 5, "length_threshold": 4, From 193bb4f060c5c10777dc04adb85f7b274bc08be1 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 3 Mar 2026 12:59:10 +0900 Subject: [PATCH 80/86] bf --- AFQ/api/bundle_dict.py | 4 ++-- AFQ/data/fetch.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 7af8468b..b1259525 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -376,7 +376,7 @@ def default_bd(): "cluster_IDs": [2, 7, 18, 21, 25, 51], "exclude": [ templates["pARC_xroi1_L"], - templates["pVOF_xroi_1_L"], + templates["pVOF_xroi1_L"], ], "mahal": { "distance_threshold": 5, @@ -484,7 +484,7 @@ def default_bd(): "cluster_IDs": [2, 7, 18, 21, 25, 51], "exclude": [ templates["pARC_xroi1_R"], - templates["pVOF_xroi_1_R"], + templates["pVOF_xroi1_R"], ], "mahal": { "distance_threshold": 5, diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index 033b7d89..90f71fcc 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -775,8 +775,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "pARC_R_end.nii.gz", "MdLF_L_end.nii.gz", "MdLF_R_end.nii.gz", - "pVOF_xroi_1_L.nii.gz", - "pVOF_xroi_1_R.nii.gz", + "pVOF_xroi1_R.nii.gz", + "pVOF_xroi1_L.nii.gz", ] From 6ed7a88433205911a4d98420280a1ce615e7388d Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 3 Mar 2026 14:16:06 +0900 Subject: [PATCH 81/86] for large tractograms, trx is much much better --- AFQ/tractography/tractography.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index cea4f4f5..25474ec2 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -40,7 +40,7 @@ def track( basis_type="descoteaux07", legacy=True, tracker="pft", - trx=False, + trx=True, ): """ Tractography @@ -113,7 +113,7 @@ def track( trx : bool, optional Whether to return the streamlines compatible with input to TRX file (i.e., as a LazyTractogram class instance). - Default: False + Default: True Returns ------- From 671c380198f8a206d35515af0d7844104e6e36fd Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 3 Mar 2026 14:55:43 +0900 Subject: [PATCH 82/86] better name for this --- AFQ/recognition/criteria.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 015a87c4..b054dda6 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -46,7 +46,7 @@ "primary_axis_percentage", "inc_addtol", "exc_addtol", - "endpoints_exact", + "exact_endpoints", "ORG_spectral_subbundles", "cluster_IDs", "startpoint_location", @@ -78,8 +78,8 @@ def cross_midline(b_sls, bundle_def, preproc_imap, **kwargs): def start(b_sls, bundle_def, preproc_imap, **kwargs): b_sls.initiate_selection("Startpoint") - endpoints_exact = bundle_def.get("endpoints_exact", False) - if endpoints_exact: + exact_endpoints = bundle_def.get("exact_endpoints", False) + if exact_endpoints: tol = 0 else: tol = preproc_imap["dist_to_atlas"] @@ -112,8 +112,8 @@ def start(b_sls, bundle_def, preproc_imap, **kwargs): def end(b_sls, bundle_def, preproc_imap, **kwargs): b_sls.initiate_selection("endpoint") - endpoints_exact = bundle_def.get("endpoints_exact", False) - if endpoints_exact: + exact_endpoints = bundle_def.get("exact_endpoints", False) + if exact_endpoints: tol = 0 else: tol = preproc_imap["dist_to_atlas"] From c693df9897d8c453743ce5b79cf00b687410ff94 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 3 Mar 2026 15:38:23 +0900 Subject: [PATCH 83/86] update test to trx --- AFQ/tests/test_api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/AFQ/tests/test_api.py b/AFQ/tests/test_api.py index 8d552c24..e6b6577f 100644 --- a/AFQ/tests/test_api.py +++ b/AFQ/tests/test_api.py @@ -822,7 +822,7 @@ def test_AFQ_data_waypoint(): op.join( myafq.export("output_dir"), "bundles", - "sub-01_ses-01_desc-RightInferiorLongitudinal_tractography.trk", + "sub-01_ses-01_desc-RightInferiorLongitudinal_tractography.trx", ) ) # noqa @@ -839,7 +839,7 @@ def test_AFQ_data_waypoint(): all_sl = load_tractogram( op.join( - myafq.export("output_dir"), "tractography", "sub-01_ses-01_tractography.trk" + myafq.export("output_dir"), "tractography", "sub-01_ses-01_tractography.trx" ), reference="same", ).streamlines @@ -852,7 +852,7 @@ def test_AFQ_data_waypoint(): op.join( myafq.export("output_dir"), "bundles", - "sub-01_ses-01_desc-LeftArcuate_tractography.trk", + "sub-01_ses-01_desc-LeftArcuate_tractography.trx", ), reference="same", ).streamlines @@ -980,6 +980,6 @@ def test_AFQ_data_waypoint(): op.join( output_dir, "bundles", - "sub-01_ses-01_desc-RightInferiorLongitudinal_tractography.trk", + "sub-01_ses-01_desc-RightInferiorLongitudinal_tractography.trx", ) ) # noqa From 44d564bb6352a4d711d0e0b3b36c78b3a8fc7540 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 3 Mar 2026 16:01:17 +0900 Subject: [PATCH 84/86] settings for length threshold removal --- AFQ/api/bundle_dict.py | 18 ++++++++++++------ AFQ/recognition/cleaning.py | 20 +++++++++++++++++++- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index b1259525..fbacc1fb 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -356,8 +356,9 @@ def default_bd(): }, "mahal": { "distance_threshold": 3, - "length_threshold": 4, + "length_threshold": 2, "clean_rounds": 5, + "remove_lengths": "short", }, }, "Left Posterior Vertical Occipital": { @@ -366,8 +367,9 @@ def default_bd(): "cluster_IDs": [1, 72, 75, 81, 83], "mahal": { "distance_threshold": 5, - "length_threshold": 4, + "length_threshold": 2, "clean_rounds": 5, + "remove_lengths": "short", }, }, "Left Anterior Vertical Occipital": { @@ -380,8 +382,9 @@ def default_bd(): ], "mahal": { "distance_threshold": 5, - "length_threshold": 4, + "length_threshold": 2, "clean_rounds": 5, + "remove_lengths": "short", }, }, }, @@ -464,8 +467,9 @@ def default_bd(): }, "mahal": { "distance_threshold": 3, - "length_threshold": 4, + "length_threshold": 2, "clean_rounds": 5, + "remove_lengths": "short", }, }, "Right Posterior Vertical Occipital": { @@ -474,8 +478,9 @@ def default_bd(): "cluster_IDs": [1, 72, 75, 81, 83], "mahal": { "distance_threshold": 5, - "length_threshold": 4, + "length_threshold": 2, "clean_rounds": 5, + "remove_lengths": "short", }, }, "Right Anterior Vertical Occipital": { @@ -488,8 +493,9 @@ def default_bd(): ], "mahal": { "distance_threshold": 5, - "length_threshold": 4, + "length_threshold": 2, "clean_rounds": 5, + "remove_lengths": "short", }, }, }, diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index 0f3d68f3..9b375ecd 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -145,6 +145,7 @@ def clean_bundle( stat=np.mean, core_only=0.6, return_idx=False, + remove_lengths="long", ): """ Clean a segmented fiber group based on the Mahalnobis distance of @@ -183,6 +184,13 @@ def clean_bundle( return_idx : bool Whether to return indices in the original streamlines. Default: False. + remove_lengths : str + Specifies which streamlines to remove based on their length. + Options are "long" (remove long streamlines), "short" + (remove short streamlines), or "both" + (remove both long and short streamlines). + Default: "long" + Returns ------- A StatefulTractogram class instance containing only the streamlines @@ -247,7 +255,17 @@ def clean_bundle( # Select the fibers that have Mahalanobis smaller than the # threshold for all their nodes: idx_dist = np.all(m_dist < distance_threshold, axis=-1) - idx_len = length_z < length_threshold + if remove_lengths == "long": + idx_len = length_z < length_threshold + elif remove_lengths == "short": + idx_len = length_z > -length_threshold + elif remove_lengths == "both": + idx_len = np.abs(length_z) < length_threshold + else: + raise ValueError( + f"Invalid value for remove_lengths: {remove_lengths}. " + "Expected 'long', 'short', or 'both'." + ) idx_belong = np.logical_and(idx_dist, idx_len) if np.sum(idx_belong) < min_sl: From 09f962b00437cfb22ce722bb077afc478153e5ff Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 4 Mar 2026 12:59:42 +0900 Subject: [PATCH 85/86] add webgpu compat --- AFQ/tasks/tractography.py | 6 +++ AFQ/tractography/gputractography.py | 57 ++++++++++++++++++++++++++--- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 681bf966..21e9c64f 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -251,6 +251,7 @@ def gpu_tractography( seed, tissue_imap, tractography_ngpus=0, + gpu_backend="auto", chunk_size=25000, ): """ @@ -265,6 +266,10 @@ def gpu_tractography( PTT, Prob can be used with any SHM model. Bootstrapped can be done with CSA/OPDT. Default: 0 + gpu_backend : str, optional + GPU backend to use for tractography. + One of {"auto", "cuda", "metal", "webgpu"}. + Default: "auto" chunk_size : int, optional Chunk size for GPU tracking. Default: 25000 @@ -305,6 +310,7 @@ def gpu_tractography( tracking_params["trx"], tractography_ngpus, chunk_size, + gpu_backend, ) return sft, _meta_from_tracking_params(tracking_params, start_time, sft, seed, pve) diff --git a/AFQ/tractography/gputractography.py b/AFQ/tractography/gputractography.py index d390eda2..2ebd2608 100644 --- a/AFQ/tractography/gputractography.py +++ b/AFQ/tractography/gputractography.py @@ -3,12 +3,6 @@ import nibabel as nib import numpy as np -from cuslines import ( - BootDirectionGetter, - GPUTracker, - ProbDirectionGetter, - PttDirectionGetter, -) from dipy.align import resample from dipy.reconst import shm @@ -38,6 +32,7 @@ def gpu_track( use_trx, ngpus, chunk_size, + gpu_backend, ): """ Perform GPU tractography on DWI data. @@ -88,9 +83,59 @@ def gpu_track( Number of GPUs to use. chunk_size : int Chunk size for GPU tracking. + gpu_backend : str, optional + GPU backend to use for tractography. + One of {"auto", "cuda", "metal", "webgpu"}. Returns ------- """ + gpu_backend = gpu_backend.lower() + if gpu_backend == "auto": + from cuslines import ( + BootDirectionGetter, + GPUTracker, + ProbDirectionGetter, + PttDirectionGetter, + ) + elif gpu_backend == "cuda": + from cuslines.cuda_python import ( + BootDirectionGetter, + GPUTracker, + ProbDirectionGetter, + PttDirectionGetter, + ) + elif gpu_backend == "metal": + from cuslines.metal import ( + MetalBootDirectionGetter as BootDirectionGetter, + ) + from cuslines.metal import ( + MetalGPUTracker as GPUTracker, + ) + from cuslines.metal import ( + MetalProbDirectionGetter as ProbDirectionGetter, + ) + from cuslines.metal import ( + MetalPttDirectionGetter as PttDirectionGetter, + ) + elif gpu_backend == "webgpu": + from cuslines.webgpu import ( + WebGPUBootDirectionGetter as BootDirectionGetter, + ) + from cuslines.webgpu import ( + WebGPUProbDirectionGetter as ProbDirectionGetter, + ) + from cuslines.webgpu import ( + WebGPUPttDirectionGetter as PttDirectionGetter, + ) + from cuslines.webgpu import ( + WebGPUTracker as GPUTracker, + ) + else: + raise ValueError( + "gpu_backend must be one of 'auto', 'cuda', " + "'metal', or 'webgpu', not {gpu_backend}" + ) + seed_img = nib.load(seed_path) directions = directions.lower() From 44711fc531daddbf361c3573c92610f653af96d4 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 4 Mar 2026 16:09:36 +0900 Subject: [PATCH 86/86] buan cleaning then relatively strict vof params --- AFQ/_fixes.py | 58 +++++++++++++++++++++++-------------- AFQ/api/bundle_dict.py | 24 +++++++++++++-- AFQ/recognition/cleaning.py | 17 +++++++++-- 3 files changed, 72 insertions(+), 27 deletions(-) diff --git a/AFQ/_fixes.py b/AFQ/_fixes.py index 3c9a5b6b..78cd6b0e 100644 --- a/AFQ/_fixes.py +++ b/AFQ/_fixes.py @@ -223,7 +223,9 @@ def tensor_odf(evals, evecs, sphere, num_batches=100): return odf -def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, stat=np.mean): +def gaussian_weights( + bundle, assignment_idxs=None, n_points=100, return_mahalnobis=False, stat=np.mean +): """ Calculate weights for each streamline/node in a bundle, based on a Mahalanobis distance from the core the bundle, at that node (mean, per @@ -233,6 +235,8 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, stat=np.mean ---------- bundle : Streamlines The streamlines to weight. + assignment_idxs : array of shape (n_streamlines, n_points), optional + BUAN assignments, optional. n_points : int or None, optional The number of points to resample to. If this is None, we assume bundle is already resampled, and do not do any resampling. Default: 100. @@ -280,38 +284,48 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, stat=np.mean return weights / np.sum(weights, 0) else: weights = np.zeros((n_sls, n_nodes)) - diff = stat(sls, axis=0) - sls - for i in range(n_nodes): - # This should come back as a 3D covariance matrix with the spatial - # variance covariance of this node across the different streamlines, - # converted to a positive semi-definite matrix if necessary - cov = np.cov(sls[:, i, :].T, ddof=0) + + if assignment_idxs is None: + working_groups = np.tile(np.arange(n_nodes), (n_sls, 1)) + else: + working_groups = np.asarray(assignment_idxs) + + flat_coords = sls.reshape(-1, 3) + flat_groups = working_groups.reshape(-1) + unique_ids = np.unique(flat_groups) + + for gid in unique_ids: + mask = flat_groups == gid + group_data = flat_coords[mask] + + if len(group_data) < 15: + continue + + mu = stat(group_data, axis=0) + diff = group_data - mu + + cov = np.cov(group_data.T, ddof=0) + + # Ensure positive semi-definite if np.any(np.linalg.eigvals(cov) < 0): eigenvalues, eigenvectors = np.linalg.eigh((cov + cov.T) / 2) eigenvalues[eigenvalues < 0] = 0 cov = eigenvectors @ np.diag(eigenvalues) @ eigenvectors.T - # calculate Mahalanobis for node in every fiber if np.any(cov > 0): - weights[:, i] = np.sqrt( - np.einsum("ij,jk,ik->i", diff[:, i, :], pinvh(cov), diff[:, i, :]) + weights.ravel()[mask] = np.sqrt( + np.einsum("ij,jk,ik->i", diff, pinvh(cov), diff) ) - # In the special case where all the streamlines have the exact same - # coordinate in this node, the covariance matrix is all zeros, so - # we can't calculate the Mahalanobis distance, we will instead give - # each streamline an identical weight, equal to the number of - # streamlines: - else: - weights[:, i] = 0 if return_mahalnobis: return weights - # weighting is inverse to the distance (the further you are, the less you - # should be weighted) - weights = 1 / weights - # Normalize before returning, so that the weights in each node sum to 1: - return weights / np.sum(weights, 0) + with np.errstate(divide="ignore"): + w_inv = 1.0 / weights + w_inv[np.isinf(w_inv)] = 0 + + denom = np.sum(w_inv, axis=0) + return np.divide(w_inv, denom, out=np.zeros_like(w_inv), where=denom != 0) def make_gif(show_m, out_path, n_frames=36, az_ang=-10, duration=150): diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index fbacc1fb..2179c541 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -355,7 +355,7 @@ def default_bd(): "clean_rounds": 2, }, "mahal": { - "distance_threshold": 3, + "distance_threshold": 5, "length_threshold": 2, "clean_rounds": 5, "remove_lengths": "short", @@ -365,6 +365,11 @@ def default_bd(): "Left Inferior Fronto-occipital": {"core": "Right"}, "Left Optic Radiation": {"core": "Right"}, "cluster_IDs": [1, 72, 75, 81, 83], + "orient_mahal": { + "distance_threshold": 3, + "length_threshold": 4, + "clean_rounds": 2, + }, "mahal": { "distance_threshold": 5, "length_threshold": 2, @@ -380,6 +385,11 @@ def default_bd(): templates["pARC_xroi1_L"], templates["pVOF_xroi1_L"], ], + "orient_mahal": { + "distance_threshold": 3, + "length_threshold": 4, + "clean_rounds": 2, + }, "mahal": { "distance_threshold": 5, "length_threshold": 2, @@ -466,7 +476,7 @@ def default_bd(): "clean_rounds": 2, }, "mahal": { - "distance_threshold": 3, + "distance_threshold": 5, "length_threshold": 2, "clean_rounds": 5, "remove_lengths": "short", @@ -476,6 +486,11 @@ def default_bd(): "Right Inferior Fronto-occipital": {"core": "Left"}, "Right Optic Radiation": {"core": "Left"}, "cluster_IDs": [1, 72, 75, 81, 83], + "orient_mahal": { + "distance_threshold": 3, + "length_threshold": 4, + "clean_rounds": 2, + }, "mahal": { "distance_threshold": 5, "length_threshold": 2, @@ -491,6 +506,11 @@ def default_bd(): templates["pARC_xroi1_R"], templates["pVOF_xroi1_R"], ], + "orient_mahal": { + "distance_threshold": 3, + "length_threshold": 4, + "clean_rounds": 2, + }, "mahal": { "distance_threshold": 5, "length_threshold": 2, diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index 9b375ecd..206b1772 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -3,6 +3,7 @@ import dipy.tracking.streamline as dts import numpy as np from dipy.io.stateful_tractogram import StatefulTractogram +from dipy.stats.analysis import assignment_map from scipy.stats import zscore from sklearn.ensemble import IsolationForest @@ -79,7 +80,7 @@ def clean_by_orientation(streamlines, primary_axis, core_only=0.6): def clean_by_orientation_mahalanobis( streamlines, n_points=100, - core_only=0.6, + core_only=0, min_sl=20, distance_threshold=3, length_threshold=4, @@ -87,7 +88,11 @@ def clean_by_orientation_mahalanobis( ): if length_threshold == 0: length_threshold = np.inf - fgarray = np.array(abu.resample_tg(streamlines, n_points)) + fgarray = abu.resample_tg(streamlines, n_points) + + assignment_idxs = np.asarray(assignment_map(fgarray, fgarray, 100)) + assignment_idxs = assignment_idxs.reshape((len(fgarray), n_points)) + fgarray = np.asarray(fgarray) if core_only != 0: crop_edge = (1.0 - core_only) / 2 @@ -96,12 +101,17 @@ def clean_by_orientation_mahalanobis( ] fgarray_dists = fgarray[:, 1:, :] - fgarray[:, :-1, :] + assignment_idxs = assignment_idxs[:, 1:] lengths = np.array([sl.shape[0] for sl in streamlines]) idx = np.arange(len(fgarray)) rounds_elapsed = 0 while rounds_elapsed < clean_rounds: m_dist = gaussian_weights( - fgarray_dists, return_mahalnobis=True, n_points=None, stat=np.mean + fgarray_dists, + assignment_idxs=assignment_idxs, + return_mahalnobis=True, + n_points=None, + stat=np.mean, ) length_z = zscore(lengths) @@ -128,6 +138,7 @@ def clean_by_orientation_mahalanobis( idx = idx[idx_belong] fgarray_dists = fgarray_dists[idx_belong] lengths = lengths[idx_belong] + assignment_idxs = assignment_idxs[idx_belong] rounds_elapsed += 1 logger.debug((f"Rounds elapsed: {rounds_elapsed}, num kept: {len(idx)}")) logger.debug(f"Kept indices: {idx}")