Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions MultiverSeg/SegmentEditorMultiverSegLib/ContextLogic.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,12 @@ def saveNewExample(self, volume: vtkMRMLVolumeNode, view, segmentID, segmentatio
imageArray = slicer.util.arrayFromVolume(volume).copy()
maskArray = slicer.util.arrayFromSegmentBinaryLabelmap(segmentationNode, segmentID, volume)

IJKToRAS = np.zeros((3, 3))
volume.GetIJKToRASDirections(IJKToRAS)
KJIToRAS = IJKToRAS.copy()
KJIToRAS[:, 0] = IJKToRAS[:, 2]
KJIToRAS[:, 2] = IJKToRAS[:, 0]
sliceNodeID = f"vtkMRMLSliceNode{view}"
sliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID)
axis = segLogic.computeSliceAxis(volume, sliceNode)

imageArray = segLogic.reorderAxisToRAS(imageArray, KJIToRAS)
maskArray = segLogic.reorderAxisToRAS(maskArray, KJIToRAS)

imageTensor = torch.from_numpy(segLogic.extractSlice(imageArray, k, view))
maskTensor = torch.from_numpy(segLogic.extractSlice(maskArray, k, view))
imageTensor = torch.from_numpy(segLogic.extractSlice(imageArray, k, axis))
maskTensor = torch.from_numpy(segLogic.extractSlice(maskArray, k, axis))

imageTensor = segLogic.preprocessSlice(imageTensor[None])
maskTensor = segLogic.preprocessSlice(maskTensor[None])
Expand Down
184 changes: 96 additions & 88 deletions MultiverSeg/SegmentEditorMultiverSegLib/SegmentationLogic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import slicer
import vtkAddon

from MRMLCorePython import vtkMRMLSegmentationNode, vtkMRMLScalarVolumeNode, vtkMRMLSliceNode
from numpy.ma.core import maximum
Expand Down Expand Up @@ -37,6 +38,9 @@ def __init__(self, scriptedEffect):
self.sliceOffsetRange = (0., 0.)

def initSegments(self):
"""
Initialize the segments by creating the positive and negative segments.
"""
# Get the current segment
self.segmentationNode: vtkMRMLSegmentationNode = self.scriptedEffect.parameterSetNode().GetSegmentationNode()
segmentation: vtkSegmentation = self.segmentationNode.GetSegmentation()
Expand All @@ -58,6 +62,10 @@ def initSegments(self):
segmentation.AddSegment(self.negSegment)

def initModel(self):
"""
Verify the dependencies and initialize the model.
:return: True if the initialization was successful, False otherwise.
"""
from .InstallLogic import InstallLogic, DependenciesLogic

progress = slicer.util.createProgressDialog(maximum=10, labelText="Verifying dependencies")
Expand Down Expand Up @@ -96,6 +104,9 @@ def initModel(self):
return True

def reset(self):
"""
Remove the pos and neg segments and reset the internal state of the logic.
"""
if self.segmentationNode is None:
return

Expand All @@ -112,6 +123,9 @@ def setOffsetRange(self, min: float, max: float):
self.sliceOffsetRange = (min, max)

def predict(self):
"""
Launch a 2D prediction for the current slice and view.
"""
# Get the slice number
import torchvision.transforms.v2 as torchviz

Expand All @@ -134,18 +148,26 @@ def predict(self):
KJIToRAS[:, 0] = IJKToRAS[:, 2]
KJIToRAS[:, 2] = IJKToRAS[:, 0]

resultSegment = self.reorderAxisToRAS(resultSegment, KJIToRAS)
resultSegment = self.updateSlice(resultSegment, y, k)
resultSegment = self.invertAxisReordering(resultSegment, KJIToRAS)
sliceNodeID = f"vtkMRMLSliceNode{self.workingView}"
sliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID)
axis = self.computeSliceAxis(volumeNode, sliceNode)

resultSegment = self.updateSlice(resultSegment, y, k, axis)

slicer.util.updateSegmentBinaryLabelmapFromArray(resultSegment, segNode, segmentId, volumeNode)

def thresholdPrediction(self, prediction: "torch.Tensor", threshold=0.5):
"""
Apply a threshold to the prediction.
"""
prediction[prediction < threshold] = 0
prediction[prediction >= threshold] = 1
return prediction

def rawPredictForSlice(self, sliceNumber: int) -> tuple["torch.Tensor", "torch.Size"]:
"""
Make a prediction for a 2D slice without post-processing
"""
# return the raw prediction and the original dimension of the slice (for resizing)
import torch
# Load the context
Expand All @@ -161,12 +183,9 @@ def rawPredictForSlice(self, sliceNumber: int) -> tuple["torch.Tensor", "torch.S
posSegId = segNode.GetSegmentation().GetSegmentIdBySegment(self.posSegment)
negSegId = segNode.GetSegmentation().GetSegmentIdBySegment(self.negSegment)

# Create the convertion matrix needed to handle slice selection correctly
IJKToRAS = np.zeros((3, 3))
volumeNode.GetIJKToRASDirections(IJKToRAS)
KJIToRAS = IJKToRAS.copy()
KJIToRAS[:, 0] = IJKToRAS[:, 2]
KJIToRAS[:, 2] = IJKToRAS[:, 0]
sliceNodeID = f"vtkMRMLSliceNode{self.workingView}"
sliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID)
axis = self.computeSliceAxis(volumeNode, sliceNode)

# Getting the different arrays
# Array from slicer.util are K-J-I indexed
Expand All @@ -175,17 +194,11 @@ def rawPredictForSlice(self, sliceNumber: int) -> tuple["torch.Tensor", "torch.S
posArray = slicer.util.arrayFromSegmentBinaryLabelmap(segNode, posSegId, volumeNode)
negArray = slicer.util.arrayFromSegmentBinaryLabelmap(segNode, negSegId, volumeNode)

# Reorder axis to be R-A-S indexed
imageArray = self.reorderAxisToRAS(imageArray, KJIToRAS)
resultSegment = self.reorderAxisToRAS(resultSegment, KJIToRAS)
posArray = self.reorderAxisToRAS(posArray, KJIToRAS)
negArray = self.reorderAxisToRAS(negArray, KJIToRAS)

# Extract the slice corresponding to the current view
imageSlice = self.extractSlice(imageArray, sliceNumber)
prevPredSlice = self.extractSlice(resultSegment, sliceNumber)
posSlice = self.extractSlice(posArray, sliceNumber)
negSlice = self.extractSlice(negArray, sliceNumber)
imageSlice = self.extractSlice(imageArray, sliceNumber, axis)
prevPredSlice = self.extractSlice(resultSegment, sliceNumber, axis)
posSlice = self.extractSlice(posArray, sliceNumber, axis)
negSlice = self.extractSlice(negArray, sliceNumber, axis)

# Convertion to tensors
imageTensor = torch.from_numpy(imageSlice)
Expand All @@ -203,7 +216,6 @@ def rawPredictForSlice(self, sliceNumber: int) -> tuple["torch.Tensor", "torch.S

scribbles = torch.cat((posTensor, negTensor), dim=0)

# print("Starting prediction")
y = self.model.predict(imageTensor[None],
scribbles=scribbles[None],
mask_input=prevPredTensor[None],
Expand All @@ -213,6 +225,9 @@ def rawPredictForSlice(self, sliceNumber: int) -> tuple["torch.Tensor", "torch.S
return y, originalDim

def predict3d(self):
"""
Make a 3D prediction
"""

sliceNodeID = f"vtkMRMLSliceNode{self.workingView}"
sliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID)
Expand All @@ -235,12 +250,7 @@ def predict3d(self):
posSegId = segNode.GetSegmentation().GetSegmentIdBySegment(self.posSegment)
negSegId = segNode.GetSegmentation().GetSegmentIdBySegment(self.negSegment)

# Create the convertion matrix needed to handle slice selection correctly
IJKToRAS = np.zeros((3, 3))
volumeNode.GetIJKToRASDirections(IJKToRAS)
KJIToRAS = IJKToRAS.copy()
KJIToRAS[:, 0] = IJKToRAS[:, 2]
KJIToRAS[:, 2] = IJKToRAS[:, 0]
axis = self.computeSliceAxis(volumeNode, sliceNode)

# Getting the different arrays
# Array from slicer.util are K-J-I indexed
Expand All @@ -249,13 +259,7 @@ def predict3d(self):
posArray = slicer.util.arrayFromSegmentBinaryLabelmap(segNode, posSegId, volumeNode)
negArray = slicer.util.arrayFromSegmentBinaryLabelmap(segNode, negSegId, volumeNode)

# Reorder axis to be R-A-S indexed
imageArray = self.reorderAxisToRAS(imageArray, KJIToRAS)
resultSegment = self.reorderAxisToRAS(resultSegment, KJIToRAS)
posArray = self.reorderAxisToRAS(posArray, KJIToRAS)
negArray = self.reorderAxisToRAS(negArray, KJIToRAS)

imageSlice = self.extractSlice(imageArray, 0)
imageSlice = self.extractSlice(imageArray, 0, axis)
originalDim = imageSlice.shape

import torch
Expand All @@ -267,10 +271,10 @@ def predict3d(self):
negTensor = torch.from_numpy(negArray)

# Pre process
imageTensor = self.preprocessVolume(imageTensor[None])[0]
posTensor = self.preprocessVolume(posTensor[None], isSegmentation=True)[0]
negTensor = self.preprocessVolume(negTensor[None], isSegmentation=True)[0]
prevPredTensor = self.preprocessVolume(prevPredTensor[None], isSegmentation=True)[0]
imageTensor = self.preprocessVolume(imageTensor[None], axis)[0]
posTensor = self.preprocessVolume(posTensor[None], axis, isSegmentation=True)[0]
negTensor = self.preprocessVolume(negTensor[None], axis, isSegmentation=True)[0]
prevPredTensor = self.preprocessVolume(prevPredTensor[None], axis, isSegmentation=True)[0]

progressDialog = slicer.util.createProgressDialog(value=startSlice - 1,
minimum=startSlice - 1,
Expand All @@ -289,10 +293,10 @@ def predict3d(self):
sliceLogic.SetSliceOffset(sliceOffset)

# Extract the slice corresponding to the current view
imageSlice = self.extractSlice(imageTensor, sliceNumber)[None]
prevPredSlice = self.extractSlice(prevPredTensor, sliceNumber)[None]
posSlice = self.extractSlice(posTensor, sliceNumber)[None]
negSlice = self.extractSlice(negTensor, sliceNumber)[None]
imageSlice = self.extractSlice(imageTensor, sliceNumber, axis)[None]
prevPredSlice = self.extractSlice(prevPredTensor, sliceNumber, axis)[None]
posSlice = self.extractSlice(posTensor, sliceNumber, axis)[None]
negSlice = self.extractSlice(negTensor, sliceNumber, axis)[None]

scribbles = torch.cat((posSlice, negSlice), dim=0)

Expand All @@ -305,18 +309,20 @@ def predict3d(self):
y = torchviz.functional.resize(y[0], originalDim)[0]
y = self.thresholdPrediction(y)

resultSegment = self.updateSlice(resultSegment, y, sliceNumber)
resultSegment = self.updateSlice(resultSegment, y, sliceNumber, axis)
progressDialog.setValue(sliceNumber)

if progressDialog.wasCanceled:
progressDialog.close()
break
slicer.app.processEvents()

resultSegment = self.invertAxisReordering(resultSegment, KJIToRAS)
slicer.util.updateSegmentBinaryLabelmapFromArray(resultSegment, segNode, segmentId, volumeNode)

def getCurrentSliceIndex(self, sliceColor):
"""
Get the index of the current slice for the view sliceColor based on the offset value.
"""
sliceNodeID = f"vtkMRMLSliceNode{sliceColor}"

sliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID)
Expand All @@ -325,47 +331,55 @@ def getCurrentSliceIndex(self, sliceColor):
sliceOffset = sliceLogic.GetSliceOffset()
return sliceLogic.GetSliceIndexFromOffset(sliceOffset) - 1 # slice is 1-indexed

def reorderAxisToRAS(self, array: np.ndarray, directionMatrix: np.ndarray):
perm_order = np.argmax(np.abs(directionMatrix), axis=0)
return np.transpose(array, axes=perm_order)

def invertAxisReordering(self, permutedArray: np.ndarray, directionMatrix: np.ndarray):
perm_order = np.argmax(np.abs(directionMatrix), axis=0)
inverse_order = np.argsort(perm_order) # Compute the inverse permutation
return np.transpose(permutedArray, axes=inverse_order)

def extractSlice(self, array: np.ndarray, sliceNumber: int, sliceColor=None):
sliceNodeID = f"vtkMRMLSliceNode{self.workingView if sliceColor is None else sliceColor}"
sliceNode: vtkMRMLSliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID)

orientation = sliceNode.GetOrientation()
if orientation == "Axial":
orientationAx = 2
elif orientation == "Sagittal":
orientationAx = 0
elif orientation == "Coronal":
orientationAx = 1
else:
raise ValueError(f"Orientation {orientation} is not supported")

return np.take(array, sliceNumber, axis=orientationAx)
def computeSliceAxis(self, volumeNode: vtkMRMLScalarVolumeNode, sliceNode: vtkMRMLSliceNode):
"""
Given the volume node and the slice node, find the axis of the volume which correspond to the stepping direction in the selected view.
"""
# Get the slice normal vector in RAS
sliceToRAS = sliceNode.GetSliceToRAS()
sliceNormal = np.zeros(4)
vtkAddon.vtkAddonMathUtilities.GetOrientationMatrixColumn(sliceToRAS, 2, sliceNormal)

def updateSlice(self, array: np.ndarray, newSlice: np.ndarray, sliceNumber: int):
sliceNodeID = f"vtkMRMLSliceNode{self.workingView}"
sliceNode: vtkMRMLSliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID)
# Get the KIJ to RAS matrix
IJKToRAS = np.zeros((3, 3))
volumeNode.GetIJKToRASDirections(IJKToRAS)
KJIToRAS = IJKToRAS.copy()
KJIToRAS[:, 0] = IJKToRAS[:, 2]
KJIToRAS[:, 2] = IJKToRAS[:, 0]

orientation = sliceNode.GetOrientation()
if orientation == "Axial":
array[:, :, sliceNumber] = newSlice
elif orientation == "Sagittal":
res = KJIToRAS.T @ sliceNormal[:3]
res = np.abs(res)

if np.allclose(res, [1, 0, 0], atol=0.01):
return 0
if np.allclose(res, [0, 1, 0], atol=0.01):
return 1
if np.allclose(res, [0, 0, 1], atol=0.01):
return 2
raise ValueError(f"View {self.workingView} is not axis aligned with the volume geometry")

def extractSlice(self, array: np.ndarray, sliceNumber: int, axis: int):
"""Extract the slice sliceNumber from the array given an axis"""
return np.take(array, sliceNumber, axis=axis)

def updateSlice(self, array: np.ndarray, newSlice: np.ndarray, sliceNumber: int, axis: int):
"""
Replace the slice in array by the newSlice. sliceNumber and axis are for positional information.
"""
if axis == 0:
array[sliceNumber] = newSlice
elif orientation == "Coronal":
elif axis == 1:
array[:, sliceNumber] = newSlice
elif axis == 2:
array[:, :, sliceNumber] = newSlice
else:
raise ValueError(f"Orientation {orientation} is not supported")
slicer.util.errorDisplay(f"Error during segmentation update, axis {axis} was given")
return array

def preprocessSlice(self, slice: "torch.Tensor", isSegmentation=False):
"""
Preprocess a 2d slice for the model. If isSegmentation, the resulting Tensor in of type bool
"""
# Slice of dim of shape 1*W*H
import torch
import torchvision.transforms.v2 as torchviz
Expand All @@ -384,27 +398,21 @@ def preprocessSlice(self, slice: "torch.Tensor", isSegmentation=False):

return result # 1*W*H

def preprocessVolume(self, volume: "torch.Tensor", isSegmentation=False):
def preprocessVolume(self, volume: "torch.Tensor", axis: int, isSegmentation=False):
"""
Apply the preprocessing pipeline on a full volume given an axis. The direction of the axis is not rescaled to allow stepping through each slice.
"""
# volume indexed RAS of shape 1*X*Y*Z
import torch
if isSegmentation:
targetDtype = torch.bool
else:
targetDtype = torch.float16

sliceNodeID = f"vtkMRMLSliceNode{self.workingView}"
sliceNode: vtkMRMLSliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID)
orientation = sliceNode.GetOrientation()
originalSize = volume.shape

if orientation == "Axial":
targetSize = [128, 128, originalSize[3]]
elif orientation == "Sagittal":
targetSize = [originalSize[1], 128, 128]
elif orientation == "Coronal":
targetSize = [128, originalSize[2], 128]
else:
raise ValueError(f"Orientation {orientation} is not supported")
targetSize = [128, 128, 128]
targetSize[axis] = originalSize[axis + 1]

# Resizing
result = torch.nn.functional.interpolate(volume[None].to(torch.float), targetSize, mode='trilinear').to(
Expand Down
Loading