Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions CorpusCallosum/data/fsaverage_data.json
Comment thread
dkuegler marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,18 @@
[
0.0,
0.0,
10000000000.0
-1.0
],
[
0.0,
-10000000000.0,
1.0,
0.0
]
],
"Pxyz_c": [
128.0,
-128.0,
128.0
0.0,
0.0,
0.0
]
}
}
163 changes: 107 additions & 56 deletions CorpusCallosum/fastsurfer_cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
calc_mapping_to_standard_space,
map_softlabels_to_orig,
)
from CorpusCallosum.utils.types import CCMeasuresDict, SliceSelection, SubdivisionMethod
from CorpusCallosum.utils.types import SliceSelection, SubdivisionMethod
from FastSurferCNN.data_loader.conform import conform, is_conform
from FastSurferCNN.segstats import HelpFormatter
from FastSurferCNN.utils import (
Expand Down Expand Up @@ -737,10 +737,14 @@ def main(
_aseg_fut = thread_executor().submit(nib.load, sd.filename_by_attribute("aseg_name"))
orig = cast(nibabelImage, nib.load(sd.conf_name))

# check that the image is conformed, i.e. isotropic 1mm voxels, 256^3 size, LIA orientation
# check that the image is conformed, the affine should not change (under no circumstance)
_orig_affine = orig.affine
if not is_conform(orig, vox_size=None, img_size=None, orientation=None):
logger.info("Internally conforming orig to soft-LIA.")
logger.info("Robust rescaling of input intensities.")
orig = conform(orig, vox_size=None, img_size=None, orientation=None)
if not np.allclose(_orig_affine, orig.affine):
logger.error("Conforming the image should not change the affine, but it did!")
sys.exit(1)

# 5 mm around the midplane (guaranteed to be aligned RAS by as_closest_canonical)
vox_size_ras: tuple[float, float, float] = nib.as_closest_canonical(orig).header.get_zooms()
Expand Down Expand Up @@ -838,54 +842,85 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4:
# Create a temporary segmentation image with proper affine for enhanced postprocessing
# Process slices based on selection mode

logger.info(f"Processing slices with selection mode: {slice_selection}")
slice_results, slice_io_futures, cc_contours = recon_cc_surf_measures_multi(
segmentation=cc_fn_seg_labels,
slice_selection=slice_selection,
upright_header=fsavg_header,
fsavg2midslab_vox2vox=fsavg2midslab_vox2vox,
fsavg_vox2ras=fsavg_vox2ras,
orig2fsavg_vox2vox=orig2fsavg_vox2vox,
midslices=midslices,
ac_coords_vox=ac_coords_vox,
pc_coords_vox=pc_coords_vox,
num_thickness_points=num_thickness_points,
subdivisions=subdivisions,
subdivision_method=cast(SubdivisionMethod, subdivision_method),
contour_smoothing=contour_smoothing,
subject_dir=sd,
)
io_futures.extend(slice_io_futures)

outer_contours = [slice_result["split_contours"][0] for slice_result in slice_results]

if len(outer_contours) > 1 and not check_area_changes(outer_contours):
logger.warning(
"Large area changes detected between consecutive slices, this is likely due to a segmentation error."
)

# Get middle slice result
middle_slice_result: CCMeasuresDict = slice_results[len(slice_results) // 2]


# save segmentation labels, this
# save segmentation labels
if sd.has_attribute("cc_segmentation"):
sd.filename_by_attribute("cc_segmentation").parent.mkdir(exist_ok=True, parents=True)
_cc_seg_path = sd.filename_by_attribute("cc_segmentation")
_cc_seg_path.parent.mkdir(exist_ok=True, parents=True)
logger.info(f"Saving CC segmentation to {_cc_seg_path}")
io_futures.append(thread_executor().submit(
nib.save,
nib.MGHImage(cc_fn_seg_labels, fsaverage_midslab_vox2ras, orig.header),
sd.filename_by_attribute("cc_segmentation"),
_cc_seg_path,
))

logger.info(f"Processing slices with selection mode: {slice_selection}")
try:
slice_results, slice_io_futures, cc_contours, num_failed_slices = recon_cc_surf_measures_multi(
segmentation=cc_fn_seg_labels,
slice_selection=slice_selection,
upright_header=fsavg_header,
fsavg2midslab_vox2vox=fsavg2midslab_vox2vox,
fsavg_vox2ras=fsavg_vox2ras,
orig2fsavg_vox2vox=orig2fsavg_vox2vox,
midslices=midslices,
ac_coords_vox=ac_coords_vox,
pc_coords_vox=pc_coords_vox,
num_thickness_points=num_thickness_points,
subdivisions=subdivisions,
subdivision_method=cast(SubdivisionMethod, subdivision_method),
contour_smoothing=contour_smoothing,
subject_dir=sd,
)
io_futures.extend(slice_io_futures)
except Exception as e:
logger.error(f"CC morphometry analysis failed: {e}")
logger.exception(e)
# We continue to save what we have
slice_results = []
slice_io_futures = []
cc_contours = []
num_failed_slices = cc_fn_seg_labels.shape[0] if slice_selection == "all" else 1

# Filter out None results for further processing
valid_slice_results = [r for r in slice_results if r is not None]
valid_cc_contours = [c for c in cc_contours if c is not None]
num_failed_slices += len(cc_contours) - len(valid_cc_contours)
Comment thread
dkuegler marked this conversation as resolved.
outer_contours = []
cc_volume_contour = None

if not valid_slice_results:
logger.error("No valid CC morphometry results found for any slice.")
else:
if num_failed_slices > 0:
logger.warning(
f"QC flag: CC morphometry analysis failed for {num_failed_slices} of "
f"{num_failed_slices + len(valid_cc_contours)} slices. Results will only be included/saved for "
f"successful slices!"
)
outer_contours = [slice_result["split_contours"][0] for slice_result in valid_slice_results]

if len(outer_contours) > 1 and not check_area_changes(outer_contours):
logger.warning(
"Large area changes detected between consecutive slices, this is likely due to a segmentation error."
)

# Get middle slice result if available
middle_slice_idx = len(slice_results) // 2
middle_slice_result = slice_results[middle_slice_idx] if middle_slice_idx < len(slice_results) else None

# map soft labels to original space (in parallel because this takes a while, and we only do it to save the labels)
if sd.has_attribute("cc_orig_segfile"):
if len(middle_slice_result["split_contours"]) <= 5:
if middle_slice_result is not None and len(middle_slice_result["split_contours"]) <= 5:
cc_subseg_midslice = make_subdivision_mask(
(cc_fn_seg_labels.shape[1], cc_fn_seg_labels.shape[2]),
middle_slice_result["subdivision_lines"],
vox2ras=fsavg_vox2ras @ np.linalg.inv(fsavg2midslice_vox2vox)
)
else:
logger.warning("Too many subsegments for lookup table, skipping sub-division of output segmentation.")
if middle_slice_result is None:
logger.warning("No valid middle slice result found, skipping sub-division of output segmentation.")
else:
logger.warning("Too many subsegments for lookup table, skipping sub-division of output segmentation.")
cc_subseg_midslice = None
# if num_threads is not large enough (>1), this might be blocking ; serial_executor runs the function in submit
executor = thread_executor() if get_num_threads() > 2 else serial_executor()
Expand All @@ -902,29 +937,43 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4:
metrics: tuple[CCMeasures] = get_args(CCMeasures)

# Record key metrics for middle slice
output_metrics_middle_slice = {metric: middle_slice_result[metric] for metric in metrics}
if middle_slice_result is not None:
output_metrics_middle_slice = {metric: middle_slice_result[metric] for metric in metrics}
else:
logger.warning("Middle slice morphometry failed, no middle slice metrics available.")
output_metrics_middle_slice = {}

# Create enhanced output dictionary with all slice results
per_slice_output_dict = {
"slices": [convert_numpy_to_json_serializable({metric: result[metric] for metric in metrics})
for result in slice_results],
if result else None for result in slice_results],
}

########## Save outputs ##########
additional_metrics = {}

cc_volume_voxel = segmentation_postprocessing.get_cc_volume_voxel(
cc_num_voxel = segmentation_postprocessing.get_cc_num_voxel(
desired_width_mm=5,
cc_mask=np.equal(cc_fn_seg_labels, CC_LABEL),
voxel_size=vox_size, # in LIA order
)
additional_metrics["cc_5mm_volume"] = cc_volume_voxel

if len(outer_contours) > 1:
logger.info(f"CC volume voxel: {cc_volume_voxel}")
cc_volume_contour = calculate_cc_volume_contour(cc_contours, width=5.0)
additional_metrics["cc_num_voxel"] = cc_num_voxel
voxel_volume = np.prod(vox_size)
additional_metrics["voxel_volume"] = voxel_volume

if len(valid_cc_contours) > 1:
logger.info(
f"CC voxel count: {cc_num_voxel} at {voxel_volume:.2f} mm^3 voxel volume => "
f"{cc_num_voxel * voxel_volume:.2f} mm^3"
)
cc_volume_contour = calculate_cc_volume_contour(valid_cc_contours, width=5.0)
Comment thread
dkuegler marked this conversation as resolved.
logger.info(f"CC volume contour: {cc_volume_contour}")
additional_metrics["cc_5mm_volume_pv_corrected"] = cc_volume_contour
else:
cc_volume_contour = None

# surface-based volume estimate (not PV-corrected); only valid if all slices processed successfully
additional_metrics["cc_volume"] = cc_volume_contour
additional_metrics["cc_num_failed_slices"] = num_failed_slices

# get ac and pc in all spaces
ac_coords_vox_3d, pc_coords_vox_3d = [np.hstack((0, c)) for c in (ac_coords_vox, pc_coords_vox)]
Expand All @@ -933,7 +982,7 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4:
(ac_coords_vox_3d, pc_coords_vox_3d),
)
standardized2orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig = (
calc_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig2fsavg_vox2vox)
calc_mapping_to_standard_space(fsavg_header["dims"][:3], ac_coords_3d, pc_coords_3d, orig2fsavg_vox2vox)
)

# write output dict as csv
Expand All @@ -952,13 +1001,14 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4:
additional_metrics["slice_selection"] = slice_selection

# QC checks
if len(outer_contours) > 1:
if len(outer_contours) > 1 and cc_volume_contour is not None:
cc_volume_voxel = cc_num_voxel * voxel_volume
max_vol = max(cc_volume_voxel, cc_volume_contour)
if max_vol > 0 and abs(cc_volume_voxel - cc_volume_contour) / max_vol > 0.2:
logger.warning(
f"QC flag: CC volume estimates differ by more than 20% "
f"(voxel: {cc_volume_voxel:.2f}, contour: {cc_volume_contour:.2f})",
"this can happen if contour creation failed for some slices"
f"(segmentation: {cc_volume_voxel:.2f} mm³, contour: {cc_volume_contour:.2f} mm³); "
"this can happen if contour creation failed for some slices."
)

cc_index = output_metrics_middle_slice.get("cc_index")
Expand All @@ -974,7 +1024,7 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4:
"incorrectly detected or contour creation may have failed"
)

if sd.has_attribute("cc_mid_measures"):
if sd.has_attribute("cc_mid_measures") and middle_slice_result is not None:
sd.filename_by_attribute('cc_mid_measures').parent.mkdir(exist_ok=True, parents=True)
io_futures.append(thread_executor().submit(
save_cc_measures_json,
Expand Down Expand Up @@ -1008,16 +1058,17 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4:
if sd.has_attribute("cc_orient_volume_lta"):
sd.filename_by_attribute("cc_orient_volume_lta").parent.mkdir(exist_ok=True, parents=True)
# save lta to standardized space (fsaverage + nodding + ac to center)
orig2standardized_ras2ras = orig.affine @ np.linalg.inv(standardized2orig_vox2vox) @ np.linalg.inv(orig.affine)
fsavg2standardized_ras2ras = fsavg_vox2ras @ \
np.linalg.inv(standardized2orig_vox2vox) @ np.linalg.inv(orig.affine)
logger.info(f"Saving LTA to standardized space: {sd.filename_by_attribute('cc_orient_volume_lta')}")
io_futures.append(thread_executor().submit(
write_lta,
sd.filename_by_attribute("cc_orient_volume_lta"),
orig2standardized_ras2ras,
sd.conf_name,
orig.header,
fsavg2standardized_ras2ras,
sd.conf_name,
orig.header,
"standardized",
fsavg_header,
))

# this waits for all io to finish
Expand Down
26 changes: 12 additions & 14 deletions CorpusCallosum/segmentation/segmentation_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,16 +291,15 @@ def connect_nearby_components(seg_arr: ArrayType, max_connection_distance: float
return connected_seg


def get_cc_volume_voxel(
def get_cc_num_voxel(
desired_width_mm: int,
cc_mask: Mask3d,
voxel_size: tuple[float, float, float],
) -> float:
"""Calculate the volume of the corpus callosum in cubic millimeters.
"""Calculate the voxel count of the corpus callosum.

This function calculates the volume of the corpus callosum (CC) in cubic millimeters.
If the CC width is larger than desired_width_mm, the voxels on the edges are calculated as
partial volumes to achieve the desired width.
This function calculates the voxel count of the corpus callosum (CC). If the CC width is larger than
desired_width_mm, voxels on the edges only contribute partially to the count to achieve the desired width.

Parameters
----------
Expand All @@ -314,14 +313,14 @@ def get_cc_volume_voxel(
Returns
-------
float
Volume of the CC in cubic millimeters.
Voxel count of the CC (with partial voxels to achieve with `desired_width_mm`).

Raises
------
ValueError
If CC width is smaller than desired width
If CC width is smaller than desired width.
AssertionError
If CC mask doesn't have odd number of voxels in x dimension
If CC mask doesn't have odd number of voxels in x dimension.

Notes
-----
Expand All @@ -342,14 +341,13 @@ def get_cc_volume_voxel(
assert width_vox % 2 == 1, f"CC mask must have odd number of voxels in x dimension, but has {width_vox}"

# Calculate voxel volume
voxel_volume: float = np.prod(voxel_size, dtype=float)
lateral_voxel_size: float = voxel_size[0]

# we are in LIA, so 0 is L/R resolution
width_mm = width_vox * lateral_voxel_size

if width_mm == desired_width_mm:
return np.sum(cropped_mask) * voxel_volume
return np.sum(cropped_mask)
elif width_mm > desired_width_mm:
# remainder on the left/right side of the CC mask
desired_width_vox = desired_width_mm / lateral_voxel_size
Expand All @@ -363,10 +361,10 @@ def get_cc_volume_voxel(
f"desired_width_vox: {desired_width_vox}, width_vox: {width_vox}, "
f"desired_width_mm: {desired_width_mm}, voxel size (lateral): {lateral_voxel_size} mm")

left_partial_volume = np.sum(cropped_mask[0]) * voxel_volume * fraction_of_voxel_at_edge
right_partial_volume = np.sum(cropped_mask[-1]) * voxel_volume * fraction_of_voxel_at_edge
center_volume = np.sum(cropped_mask[1:-1]) * voxel_volume
return left_partial_volume + right_partial_volume + center_volume
left_partial_count = np.sum(cropped_mask[0]) * fraction_of_voxel_at_edge
right_partial_count = np.sum(cropped_mask[-1]) * fraction_of_voxel_at_edge
center_count = np.sum(cropped_mask[1:-1])
return left_partial_count + right_partial_count + center_count
else:
raise ValueError(f"Width of CC segmentation is smaller than desired width: {width_mm} < {desired_width_mm}")

Expand Down
Loading