From 2b00fb0879cc6ba8756c7f5e77959a406e7c4dd8 Mon Sep 17 00:00:00 2001 From: melhemr Date: Wed, 3 May 2023 11:38:44 -0400 Subject: [PATCH] Add capability to plot model trends without data --- niCHART/core/model/datamodel.py | 14 +++- niCHART/plugins/agetrends/agetrends.py | 69 +++++++++++++++++-- .../harmonizationplugin/harmonization.py | 11 +-- .../harmonizationplugin.py | 14 ++-- 4 files changed, 89 insertions(+), 19 deletions(-) diff --git a/niCHART/core/model/datamodel.py b/niCHART/core/model/datamodel.py index 295900e91..ba50b4a32 100644 --- a/niCHART/core/model/datamodel.py +++ b/niCHART/core/model/datamodel.py @@ -21,6 +21,7 @@ class DataModel(QObject): """This class holds the data model.""" data_changed = QtCore.pyqtSignal() + model_changed = QtCore.pyqtSignal() def __init__(self): QObject.__init__(self) @@ -90,7 +91,8 @@ def SetData(self,d): def SetHarmonizationModel(self,m): """Setter for neuroHarmonize model""" self.harmonization_model = m - logger.info('neuroHarmonize model set.') + logger.info('neuroHarmonize model set. Signal emmitted') + self.model_changed.emit() def SetSPAREModel(self,BrainAgeModel, ADModel): @@ -118,8 +120,12 @@ def GetMinAgeOfMUSEHarmonizationModel(self): return self.harmonization_model['smooth_model']['bsplines_constructor'].knot_kwds[0]['lower_bound'] - def GetNormativeRange(self,roi): + def GetNormativeRange(self,roi,sex='M'): """Return normative range""" + if sex == 'M': + sex_bin = 1 + else: + sex_bin = 0 # Constructig the visualization of the normative range based on GAM # model @@ -127,7 +133,7 @@ def GetNormativeRange(self,roi): # Fix ICV roughly to population average covariates['ICV'] = 1450000 # mean ICV # Fix Sex variable - covariates['Sex'] = 0 + covariates['Sex'] = sex_bin # No need to specify site, but column with name `SITE` must exist covariates['SITE'] = 'None' # Set the ROI to be predicted @@ -190,6 +196,8 @@ def GetColumnHeaderNames(self): """Returns all header names for all columns in the dataset.""" if self.data is not None: k = self.data.keys() + elif self.harmonization_model is not None: + k = self.harmonization_model['ROIs'] else: k = [] diff --git a/niCHART/plugins/agetrends/agetrends.py b/niCHART/plugins/agetrends/agetrends.py index a41c0194c..75993f779 100644 --- a/niCHART/plugins/agetrends/agetrends.py +++ b/niCHART/plugins/agetrends/agetrends.py @@ -6,8 +6,10 @@ import seaborn as sns import matplotlib.pyplot as plt +from matplotlib.lines import Line2D import numpy as np import pandas as pd +from niCHART.plugins.loadsave.dataio import DataIO from niCHART.core.plotcanvas import PlotCanvas from niCHART.core.gui.SearchableQComboBox import SearchableQComboBox @@ -33,6 +35,7 @@ def getUI(self): def SetupConnections(self): self.datamodel.data_changed.connect(lambda: self.OnDataChanged()) + self.datamodel.model_changed.connect(lambda: self.OnModelChanged()) self.ui.comboBoxROI.currentIndexChanged.connect(self.UpdatePlot) self.ui.comboBoxHue.currentIndexChanged.connect(self.UpdatePlot) @@ -40,6 +43,16 @@ def OnDataChanged(self): self.PopulateROI() self.PopulateHue() + def OnModelChanged(self): + self.GetMUSEROIDict() + self.PopulateROI() + + def GetMUSEROIDict(self): + dio = DataIO() + #also read MUSE dictionary + MUSEDictNAMEtoID, MUSEDictIDtoNAME, MUSEDictDataFrame = dio.ReadMUSEDictionary() + self.datamodel.SetMUSEDictionaries(MUSEDictNAMEtoID, MUSEDictIDtoNAME,MUSEDictDataFrame) + def PopulateROI(self): #get data column header names datakeys = self.datamodel.GetColumnHeaderNames() @@ -58,7 +71,6 @@ def PopulateROI(self): if invalid_ROI in roiList: roiList.remove(invalid_ROI) - _, MUSEDictIDtoNAME = self.datamodel.GetMUSEDictionaries() roiList = list(set(roiList).intersection(set(datakeys))) roiList.sort() @@ -117,7 +129,11 @@ def UpdatePlot(self): plotOptions['HUE'] = currentHue #Plot data - self.PlotAgeTrends(plotOptions) + if self.datamodel.data is not None: + self.PlotAgeTrends(plotOptions) + + if (self.datamodel.harmonization_model is not None) and (self.datamodel.data is None): + self.PlotModelTrends(plotOptions) def PlotAgeTrends(self,plotOptions): """Plot Age Trends""" @@ -138,9 +154,8 @@ def PlotAgeTrends(self,plotOptions): sns.despine(fig=self.plotCanvas.axes.get_figure(), trim=True) self.plotCanvas.axes.get_figure().set_tight_layout(True) - # Plot normative range if according GAM model is available - if (self.datamodel.harmonization_model is not None) and (currentROI in ['H_' + x for x in self.datamodel.harmonization_model['ROIs']]): - x,y,z = self.datamodel.GetNormativeRange(currentROI[2:]) + if (self.datamodel.harmonization_model is not None) and (currentROI in [x for x in self.datamodel.harmonization_model['ROIs']]): + x,y,z = self.datamodel.GetNormativeRange(currentROI) #print('Pooled variance: %f' % (z)) # Plot three lines as expected mean and +/- 2 times standard deviation sns.lineplot(x=x, y=y, ax=self.plotCanvas.axes, linestyle='-', markers=False, color='k') @@ -167,3 +182,47 @@ def PlotAgeTrends(self,plotOptions): # refresh canvas self.plotCanvas.draw() + + def PlotModelTrends(self,plotOptions): + """Plot Age Trends""" + currentROI = plotOptions['ROI'] + + # clear plot + self.plotCanvas.axes.clear() + + if (self.datamodel.harmonization_model is not None) and (currentROI in [x for x in self.datamodel.harmonization_model['ROIs']]): + x,y,z = self.datamodel.GetNormativeRange(currentROI,sex='M') + u,v,w = self.datamodel.GetNormativeRange(currentROI,sex='F') + #print('Pooled variance: %f' % (z)) + # Plot three lines as expected mean and +/- 2 times standard deviation + sns.lineplot(x=x, y=y, ax=self.plotCanvas.axes, linestyle='-', markers=False, color='blue') + sns.lineplot(x=x, y=y+z, ax=self.plotCanvas.axes, linestyle=':', markers=False, color='blue') + sns.lineplot(x=x, y=y-z, ax=self.plotCanvas.axes, linestyle=':', markers=False, color='blue') + sns.lineplot(x=u, y=v, ax=self.plotCanvas.axes, linestyle='-', markers=False, color='orange') + sns.lineplot(x=u, y=v+w, ax=self.plotCanvas.axes, linestyle=':', markers=False, color='orange') + sns.lineplot(x=u, y=v-w, ax=self.plotCanvas.axes, linestyle=':', markers=False, color='orange') + custom_lines = [Line2D([0], [0], color='orange', lw=4), + Line2D([0], [0], color='blue', lw=4)] + self.plotCanvas.axes.legend(custom_lines,['Female ','Male'],loc='upper left',title='Sex') + + + # Set ROI name as y-label if applicable + _, MUSEDictIDtoNAME = self.datamodel.GetMUSEDictionaries() + ylabel = currentROI + #print(currentROI) + if ylabel.startswith('MUSE_'): + ylabel = '(MUSE) ' + list(map(MUSEDictIDtoNAME.get, [currentROI]))[0] + + if ylabel.startswith('WMLS_'): + ylabel = '(WMLS) ' + list(map(MUSEDictIDtoNAME.get, [currentROI.replace('WMLS_', 'MUSE_')]))[0] + + if ylabel.startswith('H_MUSE_'): + ylabel = '(Harmonized MUSE) ' + list(map(MUSEDictIDtoNAME.get, [currentROI.replace('H_', '')]))[0] + + if ylabel.startswith('RES_MUSE_'): + ylabel = '(Residuals MUSE) ' + list(map(MUSEDictIDtoNAME.get, [currentROI.replace('RES_', '')]))[0] + + self.plotCanvas.axes.set(ylabel=ylabel) + + # refresh canvas + self.plotCanvas.draw() \ No newline at end of file diff --git a/niCHART/plugins/harmonizationplugin/harmonization.py b/niCHART/plugins/harmonizationplugin/harmonization.py index f1d42926a..4df8bcb47 100644 --- a/niCHART/plugins/harmonizationplugin/harmonization.py +++ b/niCHART/plugins/harmonizationplugin/harmonization.py @@ -35,7 +35,8 @@ def DoHarmonization(self): covars = self.datamodel.data[['SITE','Age','Sex','DLICV_baseline']].reset_index(drop=True).copy() covars.loc[:,'Sex'] = covars['Sex'].map({'M':1,'F':0}) - covars.loc[covars.Age>100, 'Age']=100 + covars.loc[covars.Age>102, 'Age']=102 + covars.loc[covars.Age<20, 'Age']=20 # Parameter table for plotting gamma_ROIs = ['gamma_'+ x for x in self.datamodel.harmonization_model['ROIs']] @@ -73,9 +74,11 @@ def DoHarmonization(self): model_delta = pd.DataFrame(self.datamodel.harmonization_model['delta_star'],columns=delta_ROIs,index=[x for x in self.datamodel.harmonization_model['SITE_labels']]) parameters = pd.concat([model_gamma,model_delta],axis=1).sort_index() else: - oos_data = self.datamodel.data[self.datamodel.data['SITE'].isin(sites_to_harmonize)].dropna(subset=covariates)[[x for x in self.datamodel.harmonization_model['ROIs']]].values - oos_covars = self.datamodel.data[self.datamodel.data.SITE.isin(sites_to_harmonize)].dropna(subset=covariates)[covariates] + oos_data = self.datamodel.data[(self.datamodel.data['SITE'].isin(sites_to_harmonize))&(self.datamodel.data['UseForComBatGAMHarmonization'].notnull())].dropna(subset=covariates)[[x for x in self.datamodel.harmonization_model['ROIs']]].values + oos_covars = self.datamodel.data[(self.datamodel.data['SITE'].isin(sites_to_harmonize))&(self.datamodel.data['UseForComBatGAMHarmonization'].notnull())].dropna(subset=covariates)[covariates] oos_covars.loc[:,'Sex'] = oos_covars['Sex'].map({'M':1,'F':0}) + oos_covars.loc[oos_covars.Age>102, 'Age']=102 + oos_covars.loc[oos_covars.Age<20, 'Age']=20 self.model, _ = nh.harmonizationLearn(oos_data, oos_covars, smooth_terms=['Age'], smooth_term_bounds=(np.floor(np.min(self.datamodel.data.Age)),np.ceil(np.max(self.datamodel.data.Age))), @@ -122,7 +125,7 @@ def DoHarmonization(self): MUSEDictDataFrame= self.datamodel.GetMUSEDictDataFrame() muse_mappings = self.datamodel.GetDerivedMUSEMap() for ROI in MUSEDictDataFrame[MUSEDictDataFrame['ROI_LEVEL']=='DERIVED']['ROI_INDEX']: - single_ROIs = muse_mappings.loc[ROI].replace('NaN',np.nan).dropna().astype(np.float) + single_ROIs = muse_mappings.loc[ROI].replace('NaN',np.nan).dropna().astype(np.float64) single_ROIs = ['H_MUSE_Volume_%0d' % x for x in single_ROIs] muse['H_MUSE_Volume_%d' % ROI] = muse[single_ROIs].sum(axis=1,skipna=False) muse.drop(columns=['H_MUSE_Volume_702'], inplace=True) diff --git a/niCHART/plugins/harmonizationplugin/harmonizationplugin.py b/niCHART/plugins/harmonizationplugin/harmonizationplugin.py index 005664d41..9b1d7df3e 100644 --- a/niCHART/plugins/harmonizationplugin/harmonizationplugin.py +++ b/niCHART/plugins/harmonizationplugin/harmonizationplugin.py @@ -42,9 +42,6 @@ def getUI(self): def SetupConnections(self): self.ui.load_harmonization_model_Btn.clicked.connect(lambda: self.OnLoadHarmonizationModelBtnClicked()) - if self.datamodel.data is None: - self.ui.load_harmonization_model_Btn.setEnabled(False) - self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection.\nReturn to Load and Save Data tab to select data.') self.ui.load_other_model_Btn.clicked.connect(lambda: self.OnLoadHarmonizationModelBtnClicked()) self.ui.show_data_Btn.clicked.connect(lambda: self.OnShowDataBtnClicked()) self.ui.apply_model_to_dataset_Btn.clicked.connect(lambda: self.OnApplyModelToDatasetBtnClicked()) @@ -98,13 +95,17 @@ def LoadHarmonizationModel(self, filename): age_min = self.datamodel.harmonization_model['smooth_model']['bsplines_constructor'].knot_kwds[0]['lower_bound'] model_text4 = ('Valid Age Range: [' + str(age_min) + ', ' + str(age_max) + ']') model_text1 += '\n'+model_text4 - self.ui.Harmonized_Data_Information_Lbl.setText(model_text1) if self.datamodel.data is None: self.ui.apply_model_to_dataset_Btn.setEnabled(False) - self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection or application.\nReturn to Load and Save Data tab to select data.') + model_text5 = 'Data must be loaded before model application.\nReturn to Load and Save Data tab to select data.' + model_text1 += '\n'+model_text5 + self.ui.Harmonized_Data_Information_Lbl.setText(model_text1) else: + self.ui.Harmonized_Data_Information_Lbl.setText(model_text1) self.ui.apply_model_to_dataset_Btn.setEnabled(True) self.ui.apply_model_to_dataset_Btn.setStyleSheet("background-color: rgb(230,255,230); color: black") + self.datamodel.SetDataFilePath(filename) + self.datamodel.SetHarmonizationModel(self.datamodel.harmonization_model) self.ui.stackedWidget.setCurrentIndex(0) def OnLoadHarmonizationModelBtnClicked(self): @@ -339,9 +340,8 @@ def OnDataChanged(self): self.ui.show_data_Btn.setEnabled(False) if self.datamodel.data is None: - self.ui.load_harmonization_model_Btn.setEnabled(False) self.ui.apply_model_to_dataset_Btn.setEnabled(False) - self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection.\nReturn to Load and Save Data tab to select data.') + self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model application.\nReturn to Load and Save Data tab to select data.') else: self.ui.load_harmonization_model_Btn.setEnabled(True) if self.datamodel.harmonization_model is None: