Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions niCHART/core/model/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class DataModel(QObject):
"""This class holds the data model."""

data_changed = QtCore.pyqtSignal()
model_changed = QtCore.pyqtSignal()

def __init__(self):
QObject.__init__(self)
Expand Down Expand Up @@ -90,7 +91,8 @@ def SetData(self,d):
def SetHarmonizationModel(self,m):
"""Setter for neuroHarmonize model"""
self.harmonization_model = m
logger.info('neuroHarmonize model set.')
logger.info('neuroHarmonize model set. Signal emmitted')
self.model_changed.emit()


def SetSPAREModel(self,BrainAgeModel, ADModel):
Expand Down Expand Up @@ -118,16 +120,20 @@ def GetMinAgeOfMUSEHarmonizationModel(self):
return self.harmonization_model['smooth_model']['bsplines_constructor'].knot_kwds[0]['lower_bound']


def GetNormativeRange(self,roi):
def GetNormativeRange(self,roi,sex='M'):
"""Return normative range"""
if sex == 'M':
sex_bin = 1
else:
sex_bin = 0

# Constructig the visualization of the normative range based on GAM
# model
covariates = pd.DataFrame(np.linspace(self.GetMinAgeOfMUSEHarmonizationModel(), self.GetMaxAgeOfMUSEHarmonizationModel(), 200), columns=['Age'])
# Fix ICV roughly to population average
covariates['ICV'] = 1450000 # mean ICV
# Fix Sex variable
covariates['Sex'] = 0
covariates['Sex'] = sex_bin
# No need to specify site, but column with name `SITE` must exist
covariates['SITE'] = 'None'
# Set the ROI to be predicted
Expand Down Expand Up @@ -190,6 +196,8 @@ def GetColumnHeaderNames(self):
"""Returns all header names for all columns in the dataset."""
if self.data is not None:
k = self.data.keys()
elif self.harmonization_model is not None:
k = self.harmonization_model['ROIs']
else:
k = []

Expand Down
69 changes: 64 additions & 5 deletions niCHART/plugins/agetrends/agetrends.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import pandas as pd
from niCHART.plugins.loadsave.dataio import DataIO
from niCHART.core.plotcanvas import PlotCanvas
from niCHART.core.gui.SearchableQComboBox import SearchableQComboBox

Expand All @@ -33,13 +35,24 @@ def getUI(self):

def SetupConnections(self):
self.datamodel.data_changed.connect(lambda: self.OnDataChanged())
self.datamodel.model_changed.connect(lambda: self.OnModelChanged())
self.ui.comboBoxROI.currentIndexChanged.connect(self.UpdatePlot)
self.ui.comboBoxHue.currentIndexChanged.connect(self.UpdatePlot)

def OnDataChanged(self):
self.PopulateROI()
self.PopulateHue()

def OnModelChanged(self):
self.GetMUSEROIDict()
self.PopulateROI()

def GetMUSEROIDict(self):
dio = DataIO()
#also read MUSE dictionary
MUSEDictNAMEtoID, MUSEDictIDtoNAME, MUSEDictDataFrame = dio.ReadMUSEDictionary()
self.datamodel.SetMUSEDictionaries(MUSEDictNAMEtoID, MUSEDictIDtoNAME,MUSEDictDataFrame)

def PopulateROI(self):
#get data column header names
datakeys = self.datamodel.GetColumnHeaderNames()
Expand All @@ -58,7 +71,6 @@ def PopulateROI(self):
if invalid_ROI in roiList:
roiList.remove(invalid_ROI)


_, MUSEDictIDtoNAME = self.datamodel.GetMUSEDictionaries()
roiList = list(set(roiList).intersection(set(datakeys)))
roiList.sort()
Expand Down Expand Up @@ -117,7 +129,11 @@ def UpdatePlot(self):
plotOptions['HUE'] = currentHue

#Plot data
self.PlotAgeTrends(plotOptions)
if self.datamodel.data is not None:
self.PlotAgeTrends(plotOptions)

if (self.datamodel.harmonization_model is not None) and (self.datamodel.data is None):
self.PlotModelTrends(plotOptions)

def PlotAgeTrends(self,plotOptions):
"""Plot Age Trends"""
Expand All @@ -138,9 +154,8 @@ def PlotAgeTrends(self,plotOptions):
sns.despine(fig=self.plotCanvas.axes.get_figure(), trim=True)
self.plotCanvas.axes.get_figure().set_tight_layout(True)

# Plot normative range if according GAM model is available
if (self.datamodel.harmonization_model is not None) and (currentROI in ['H_' + x for x in self.datamodel.harmonization_model['ROIs']]):
x,y,z = self.datamodel.GetNormativeRange(currentROI[2:])
if (self.datamodel.harmonization_model is not None) and (currentROI in [x for x in self.datamodel.harmonization_model['ROIs']]):
x,y,z = self.datamodel.GetNormativeRange(currentROI)
#print('Pooled variance: %f' % (z))
# Plot three lines as expected mean and +/- 2 times standard deviation
sns.lineplot(x=x, y=y, ax=self.plotCanvas.axes, linestyle='-', markers=False, color='k')
Expand All @@ -167,3 +182,47 @@ def PlotAgeTrends(self,plotOptions):

# refresh canvas
self.plotCanvas.draw()

def PlotModelTrends(self,plotOptions):
"""Plot Age Trends"""
currentROI = plotOptions['ROI']

# clear plot
self.plotCanvas.axes.clear()

if (self.datamodel.harmonization_model is not None) and (currentROI in [x for x in self.datamodel.harmonization_model['ROIs']]):
x,y,z = self.datamodel.GetNormativeRange(currentROI,sex='M')
u,v,w = self.datamodel.GetNormativeRange(currentROI,sex='F')
#print('Pooled variance: %f' % (z))
# Plot three lines as expected mean and +/- 2 times standard deviation
sns.lineplot(x=x, y=y, ax=self.plotCanvas.axes, linestyle='-', markers=False, color='blue')
sns.lineplot(x=x, y=y+z, ax=self.plotCanvas.axes, linestyle=':', markers=False, color='blue')
sns.lineplot(x=x, y=y-z, ax=self.plotCanvas.axes, linestyle=':', markers=False, color='blue')
sns.lineplot(x=u, y=v, ax=self.plotCanvas.axes, linestyle='-', markers=False, color='orange')
sns.lineplot(x=u, y=v+w, ax=self.plotCanvas.axes, linestyle=':', markers=False, color='orange')
sns.lineplot(x=u, y=v-w, ax=self.plotCanvas.axes, linestyle=':', markers=False, color='orange')
custom_lines = [Line2D([0], [0], color='orange', lw=4),
Line2D([0], [0], color='blue', lw=4)]
self.plotCanvas.axes.legend(custom_lines,['Female ','Male'],loc='upper left',title='Sex')


# Set ROI name as y-label if applicable
_, MUSEDictIDtoNAME = self.datamodel.GetMUSEDictionaries()
ylabel = currentROI
#print(currentROI)
if ylabel.startswith('MUSE_'):
ylabel = '(MUSE) ' + list(map(MUSEDictIDtoNAME.get, [currentROI]))[0]

if ylabel.startswith('WMLS_'):
ylabel = '(WMLS) ' + list(map(MUSEDictIDtoNAME.get, [currentROI.replace('WMLS_', 'MUSE_')]))[0]

if ylabel.startswith('H_MUSE_'):
ylabel = '(Harmonized MUSE) ' + list(map(MUSEDictIDtoNAME.get, [currentROI.replace('H_', '')]))[0]

if ylabel.startswith('RES_MUSE_'):
ylabel = '(Residuals MUSE) ' + list(map(MUSEDictIDtoNAME.get, [currentROI.replace('RES_', '')]))[0]

self.plotCanvas.axes.set(ylabel=ylabel)

# refresh canvas
self.plotCanvas.draw()
11 changes: 7 additions & 4 deletions niCHART/plugins/harmonizationplugin/harmonization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def DoHarmonization(self):

covars = self.datamodel.data[['SITE','Age','Sex','DLICV_baseline']].reset_index(drop=True).copy()
covars.loc[:,'Sex'] = covars['Sex'].map({'M':1,'F':0})
covars.loc[covars.Age>100, 'Age']=100
covars.loc[covars.Age>102, 'Age']=102
covars.loc[covars.Age<20, 'Age']=20

# Parameter table for plotting
gamma_ROIs = ['gamma_'+ x for x in self.datamodel.harmonization_model['ROIs']]
Expand Down Expand Up @@ -73,9 +74,11 @@ def DoHarmonization(self):
model_delta = pd.DataFrame(self.datamodel.harmonization_model['delta_star'],columns=delta_ROIs,index=[x for x in self.datamodel.harmonization_model['SITE_labels']])
parameters = pd.concat([model_gamma,model_delta],axis=1).sort_index()
else:
oos_data = self.datamodel.data[self.datamodel.data['SITE'].isin(sites_to_harmonize)].dropna(subset=covariates)[[x for x in self.datamodel.harmonization_model['ROIs']]].values
oos_covars = self.datamodel.data[self.datamodel.data.SITE.isin(sites_to_harmonize)].dropna(subset=covariates)[covariates]
oos_data = self.datamodel.data[(self.datamodel.data['SITE'].isin(sites_to_harmonize))&(self.datamodel.data['UseForComBatGAMHarmonization'].notnull())].dropna(subset=covariates)[[x for x in self.datamodel.harmonization_model['ROIs']]].values
oos_covars = self.datamodel.data[(self.datamodel.data['SITE'].isin(sites_to_harmonize))&(self.datamodel.data['UseForComBatGAMHarmonization'].notnull())].dropna(subset=covariates)[covariates]
oos_covars.loc[:,'Sex'] = oos_covars['Sex'].map({'M':1,'F':0})
oos_covars.loc[oos_covars.Age>102, 'Age']=102
oos_covars.loc[oos_covars.Age<20, 'Age']=20
self.model, _ = nh.harmonizationLearn(oos_data, oos_covars,
smooth_terms=['Age'],
smooth_term_bounds=(np.floor(np.min(self.datamodel.data.Age)),np.ceil(np.max(self.datamodel.data.Age))),
Expand Down Expand Up @@ -122,7 +125,7 @@ def DoHarmonization(self):
MUSEDictDataFrame= self.datamodel.GetMUSEDictDataFrame()
muse_mappings = self.datamodel.GetDerivedMUSEMap()
for ROI in MUSEDictDataFrame[MUSEDictDataFrame['ROI_LEVEL']=='DERIVED']['ROI_INDEX']:
single_ROIs = muse_mappings.loc[ROI].replace('NaN',np.nan).dropna().astype(np.float)
single_ROIs = muse_mappings.loc[ROI].replace('NaN',np.nan).dropna().astype(np.float64)
single_ROIs = ['H_MUSE_Volume_%0d' % x for x in single_ROIs]
muse['H_MUSE_Volume_%d' % ROI] = muse[single_ROIs].sum(axis=1,skipna=False)
muse.drop(columns=['H_MUSE_Volume_702'], inplace=True)
Expand Down
14 changes: 7 additions & 7 deletions niCHART/plugins/harmonizationplugin/harmonizationplugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ def getUI(self):

def SetupConnections(self):
self.ui.load_harmonization_model_Btn.clicked.connect(lambda: self.OnLoadHarmonizationModelBtnClicked())
if self.datamodel.data is None:
self.ui.load_harmonization_model_Btn.setEnabled(False)
self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection.\nReturn to Load and Save Data tab to select data.')
self.ui.load_other_model_Btn.clicked.connect(lambda: self.OnLoadHarmonizationModelBtnClicked())
self.ui.show_data_Btn.clicked.connect(lambda: self.OnShowDataBtnClicked())
self.ui.apply_model_to_dataset_Btn.clicked.connect(lambda: self.OnApplyModelToDatasetBtnClicked())
Expand Down Expand Up @@ -98,13 +95,17 @@ def LoadHarmonizationModel(self, filename):
age_min = self.datamodel.harmonization_model['smooth_model']['bsplines_constructor'].knot_kwds[0]['lower_bound']
model_text4 = ('Valid Age Range: [' + str(age_min) + ', ' + str(age_max) + ']')
model_text1 += '\n'+model_text4
self.ui.Harmonized_Data_Information_Lbl.setText(model_text1)
if self.datamodel.data is None:
self.ui.apply_model_to_dataset_Btn.setEnabled(False)
self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection or application.\nReturn to Load and Save Data tab to select data.')
model_text5 = 'Data must be loaded before model application.\nReturn to Load and Save Data tab to select data.'
model_text1 += '\n'+model_text5
self.ui.Harmonized_Data_Information_Lbl.setText(model_text1)
else:
self.ui.Harmonized_Data_Information_Lbl.setText(model_text1)
self.ui.apply_model_to_dataset_Btn.setEnabled(True)
self.ui.apply_model_to_dataset_Btn.setStyleSheet("background-color: rgb(230,255,230); color: black")
self.datamodel.SetDataFilePath(filename)
self.datamodel.SetHarmonizationModel(self.datamodel.harmonization_model)
self.ui.stackedWidget.setCurrentIndex(0)

def OnLoadHarmonizationModelBtnClicked(self):
Expand Down Expand Up @@ -339,9 +340,8 @@ def OnDataChanged(self):
self.ui.show_data_Btn.setEnabled(False)

if self.datamodel.data is None:
self.ui.load_harmonization_model_Btn.setEnabled(False)
self.ui.apply_model_to_dataset_Btn.setEnabled(False)
self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection.\nReturn to Load and Save Data tab to select data.')
self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model application.\nReturn to Load and Save Data tab to select data.')
else:
self.ui.load_harmonization_model_Btn.setEnabled(True)
if self.datamodel.harmonization_model is None:
Expand Down