Skip to content
Open

Test #30

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ Training/DeepSurv*.ckpt
*.txt
wandb/
Training/*.xml
Configs/Classification/*.ini
Configs/Regression/*.ini
*.png
Training/test*.py
Empty file.
177 changes: 56 additions & 121 deletions DataGenerator/DataGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@


class DataGenerator(torch.utils.data.Dataset):
def __init__(self, SubjectList,config=None, keys=['CT'], transform=None, inference=False,
def __init__(self, SubjectList, config=None, keys=['CT'], transform=None, inference=False,
clinical_cols=None, session=None, **kwargs):
super().__init__()
self.config = config
self.session = session
self.SubjectList = SubjectList
self.keys = keys
self.transform = transform
Expand All @@ -44,133 +43,83 @@ def __getitem__(self, i):
meta = {}
subject_id = self.SubjectList.loc[i, 'subjectid']
slabel = self.SubjectList.loc[i, 'subject_label']
data['slabel'] = slabel
## Load CT
if 'CT' in self.keys:
CTPath = self.SubjectList.loc[i, 'CT_Path']
if self.config['DATA']['Nifty']:
CTPath = Path(CTPath, 'ct.nii.gz')
data['CT'], meta['CT'] = LoadImage()(CTPath)
else:
data['CT'], meta['CT'] = LoadImage()(CTPath)
CTSession = ReadDicom(CTPath)
CTArray = sitk.GetArrayFromImage(CTSession)
if not(CTArray.shape == data['CT'].shape):
CTArray = CTArray.transpose((2, 1, 0))
CTArray = np.flip(CTArray, axis=2)
mCT = MetaTensor(CTArray.copy(), meta=meta['CT'])
data['CT'] = mCT
CTPath = Path(CTPath, 'CT.nii.gz')
data['CT'], meta['CT'] = LoadImage()(CTPath)

## Load Dose
if 'Dose' in self.keys:
DosePath = self.SubjectList.loc[i, 'Dose_Path']
if self.config['DATA']['Nifty']:
DosePath = Path(DosePath, 'dose.nii.gz')
DosePath = Path(DosePath, 'Dose.nii.gz')
data['Dose'], meta['Dose'] = LoadImage()(DosePath)
data['Dose'] = data['Dose']/67
if not self.config['DATA']['Nifty']:
data['Dose'] = data['Dose'] * np.double(meta['Dose']['3004|000e'])/67
data['Dose'] = data['Dose'] / 67 ## Probably need to make it a variable

## Load PET
if 'PET' in self.keys:
PETPath = self.SubjectList.loc[i, 'PET_Path']
if self.config['DATA']['Nifty']:
PETPath = Path(PETPath, 'dose.nii.gz')
PETPath = Path(PETPath, 'pet.nii.gz')
data['PET'], meta['PET'] = LoadImage()(PETPath)

## Load Mask
if 'Structs' in self.keys:
RSPath = self.SubjectList.loc[i, 'Structs_Path']
if self.config['DATA']['Nifty']:
#for roi in self.config['DATA']['Structs']:
# data['Struct_' + roi], meta['Struct_' + roi] = LoadImage()(Path(RSPath,roi+'.nii.gz'))
# dt = distance_transform_edt(data['Struct_' + roi])
# data['Struct_' + roi] = MetaTensor(dt, meta = meta['CT'])
masks_img = np.zeros_like(data['CT'])
masks_img = get_nii_masks(slabel, masks_img, RSPath, self.config['DATA']['Structs'])
masks_img = MetaTensor(masks_img.copy(), meta=meta['CT'])
data['Structs'] = masks_img
else:
## mask in multichannel
RS = RTStructBuilder.create_from(dicom_series_path=CTPath, rt_struct_path=RSPath)
#roi_names = RS.get_roi_names()
#for roi in self.config['DATA']['Structs']:
# if roi in roi_names:
# mask_img = RS.get_roi_mask_by_name(roi)
# mask_img = distance_transform_edt(mask_img)
# else:
# message = "No ROI of name " + self.targetROI + " found in RTStruct"
# raise ValueError(message)
# mask_img = np.rot90(mask_img)
# mask_img = np.flip(mask_img, 2)
# mask_img = np.flip(mask_img, 0)
# mask = MetaTensor(mask_img.copy(), meta = meta['CT'])
# data['Struct_' + roi] = mask

### masks images
masks_img = np.zeros_like(data['CT'])
masks_img = get_RS_masks(slabel, CTPath, masks_img, RSPath, self.config['DATA']['Structs'])
masks_img = np.rot90(masks_img)
masks_img = np.flip(masks_img, 0)
masks_img = MetaTensor(masks_img.copy(), meta = meta['CT'])
data['Structs'] = masks_img
else:
data['Structs'] = np.ones_like(data['CT']) ## No ROI target defined
data['Structs'], meta['Structs'] = LoadImage()(Path(RSPath, self.config['DATA']['Structs']))

## Apply transforms on all
if self.transform: data = self.transform(data)

# mask_imgs = np.zeros_like(CTArray)
# for key in data.keys():
# if 'Mask' in key:
# mask_imgs = mask_imgs + data[key]

#for key in data.keys():
# data[key] = get_masked_img_voxel(data[key], data['Mask'])
# Decide between multi-branch single-channel/multi-channel single-branch
if self.config['DATA']['Multichannel']:
old_keys = list(data.keys())
data['Image'] = np.concatenate([data[key] for key in data.keys()], axis=0)
old_keys = list(self.keys)
data['Image'] = np.concatenate([data[key] for key in old_keys], axis=0)
for key in old_keys: data.pop(key)
else:
data.pop('Structs') ## No need for mask in single-channel multi-branch

#data = ResizeWithPadOrCropd(keys=data.keys(), spatial_size=self.config['DATA']['dim'])(data)
if 'Structs' in data.keys():
data.pop('Structs') ## No need for mask in single-channel multi-branch

## Add clinical record at the end
if 'Records' in self.config.keys(): data['Records'] = torch.tensor(self.SubjectList.loc[i, self.clinical_cols],
dtype=torch.float32)
dtype=torch.float32)
if self.inference:
return data
else:
label = torch.tensor(np.float(self.SubjectList.loc[i, "xnat_subjectdata_field_map_" + self.config['DATA']['target']]))
if self.config['DATA']['threshold'] is not None: label = torch.where(
label > self.config['DATA']['threshold'], 1, 0)
label = torch.as_tensor(label, dtype=torch.float32)
return data, label
else: ##Training
label = torch.tensor(
np.float(self.SubjectList.loc[i, "xnat_subjectdata_field_map_" + self.config['DATA']['target']]))
censor_status = not (np.int8(
self.SubjectList.loc[i, 'xnat_subjectdata_field_map_' + self.config['DATA']['censor_label']]).astype(
'bool'))
if 'threshold' in self.config['DATA'].keys(): ## Classification
label = torch.where(label > self.config['DATA']['threshold'], 1, 0)
label = torch.as_tensor(label, dtype=torch.float32)
return data, censor_status, label


### DataLoader
class DataModule(LightningDataModule):
def __init__(self, SubjectList, config=None, train_transform=None, val_transform=None, train_size=0.7,
val_size=0.2, test_size=0.1, num_workers=10, **kwargs):
def __init__(self, SubjectList, config=None, train_transform=None, val_transform=None, train_size=0.85,
num_workers=10, **kwargs):
super().__init__()
self.batch_size = config['MODEL']['batch_size']
self.num_workers = num_workers
data_trans = class_stratify(SubjectList, config)
## Split Test with fixed seed
train_val_list, test_list = train_test_split(SubjectList, test_size=0.15, random_state=42, stratify=data_trans)
train_val_list, test_list = train_test_split(SubjectList, train_size=train_size, random_state=42,
stratify=data_trans)

data_trans = class_stratify(train_val_list, config)
## Split train-val with random seed
train_list, val_list = train_test_split(train_val_list, test_size=0.15, random_state=np.random.randint(10000),
train_list, val_list = train_test_split(train_val_list, train_size=train_size,
random_state=np.random.randint(10000),
stratify=data_trans)

train_list = train_list.reset_index(drop=True)
val_list = val_list.reset_index(drop=True)
test_list = test_list.reset_index(drop=True)

self.train_data = DataGenerator(train_list, config=config, transform=train_transform, **kwargs)
self.val_data = DataGenerator(val_list,config=config, transform=val_transform, **kwargs)
self.val_data = DataGenerator(val_list, config=config, transform=val_transform, **kwargs)
self.test_data = DataGenerator(test_list, config=config, transform=val_transform, **kwargs)

def train_dataloader(self): return DataLoader(self.train_data, batch_size=self.batch_size,
Expand All @@ -193,6 +142,11 @@ def QuerySubjectList(config, session):
XML.Add_search_field(
{"element_name": "xnat:subjectData", "field_ID": "XNAT_SUBJECTDATA_FIELD_MAP=" + str(config['DATA']['target']),
"sequence": "1", "type": "int"})
if 'censor_label' in config['DATA'].keys():
XML.Add_search_field(
{"element_name": "xnat:subjectData",
"field_ID": "XNAT_SUBJECTDATA_FIELD_MAP=" + str(config['DATA']['censor_label']),
"sequence": "1", "type": "int"})
## Label
XML.Add_search_field(
{"element_name": "xnat:subjectData", "field_ID": "SUBJECT_LABEL", "sequence": "1", "type": "string"})
Expand Down Expand Up @@ -237,6 +191,14 @@ def QuerySubjectList(config, session):
return SubjectList


def class_stratify(SubjectList, config):
ptarget = SubjectList['xnat_subjectdata_field_map_' + config['DATA']['target']]
kbins = KBinsDiscretizer(n_bins=15, encode='ordinal', strategy='uniform')
ptarget = np.array(ptarget).reshape((len(ptarget), 1))
data_trans = kbins.fit_transform(ptarget)
return data_trans


def SynchronizeData(config, SubjectList):
session = xnat.connect(config['SERVER']['Address'], user=config['SERVER']['User'],
password=config['SERVER']['Password'])
Expand All @@ -253,46 +215,15 @@ def get_subject_info(config, session, subjectid):
return data


def QuerySubjectInfo(config, SubjectList, session):
if config['DATA']['Nifty']:
for i in range(len(SubjectList)):
subject_label = SubjectList.loc[i,'subject_label']
for key in config['MODALITY'].keys():
def QuerySubjectInfo(config, SubjectList):
for i in range(len(SubjectList)):
subject_label = SubjectList.loc[i, 'subject_label']
for key in config['MODALITY'].keys():
if key == 'Structs':
SubjectList.loc[i, key + '_Path'] = Path(config['DATA']['DataFolder'], subject_label, 'struct_TS')
else:
SubjectList.loc[i, key + '_Path'] = Path(config['DATA']['DataFolder'], subject_label)
else:
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_url = {executor.submit(get_subject_info, config, session, subjectid) for subjectid in
SubjectList['subjectid']}
executor.shutdown(wait=True)
for future in concurrent.futures.as_completed(future_to_url):
subjectdata = future.result()
subjectid = subjectdata["xnat:Subject"][0]["@ID"]
for key in config['MODALITY'].keys():
path = GeneratePath(subjectdata, Modality=key, config=config)
if key == 'CT':
SubjectList.loc[SubjectList.subjectid == subjectid, key + '_Path'] = path
else:
spath = glob.glob(path + '/*dcm')
SubjectList.loc[SubjectList.subjectid == subjectid, key + '_Path'] = spath[0]

def GeneratePath(subjectdata, Modality, config):
subject = subjectdata['xnat:Subject'][0]
subject_label = subject['@label']
experiments = subject['xnat:experiments'][0]['xnat:experiment']

## Won't work with many experiments yet
for experiment in experiments:
experiment_label = experiment['@label']
scans = experiment['xnat:scans'][0]['xnat:scan']
for scan in scans:
if (scan['@type'] in Modality):
scan_label = scan['@ID'] + '-' + scan['@type']
resources_label = scan['xnat:file'][0]['@label']
if resources_label == 'SNAPSHOTS':
resources_label = scan['xnat:file'][1]['@label']
path = os.path.join(config['DATA']['DataFolder'], subject_label, experiment_label, 'scans',
scan_label, 'resources', resources_label, 'files')
return path


def LoadClinicalData(config, PatientList):
category_cols = []
Expand All @@ -313,7 +244,7 @@ def LoadClinicalData(config, PatientList):
yc = X[category_cols].astype('float32')
X[category_cols] = yc.fillna(yc.mean().astype('int'))
yn = X[numerical_cols].astype('float32')
X[numerical_cols] = yn.fillna(yn.mean()) #X.loc[:, numerical_cols] = yn.fillna(yn.mean())
X[numerical_cols] = yn.fillna(yn.mean()) # X.loc[:, numerical_cols] = yn.fillna(yn.mean())
X_trans = ct.fit_transform(X)
if not isinstance(X_trans, (np.ndarray, np.generic)): X_trans = X_trans.toarray()

Expand All @@ -322,4 +253,8 @@ def LoadClinicalData(config, PatientList):
df_trans['xnat_subjectdata_field_map_' + target] = PatientList.loc[:, 'xnat_subjectdata_field_map_' + target]
df_trans['subject_label'] = PatientList.loc[:, 'subject_label']
df_trans['subjectid'] = PatientList.loc[:, 'subjectid']
if 'censor_label' in config['DATA'].keys():
df_trans['xnat_subjectdata_field_map_' + config['DATA']['censor_label']] = PatientList.loc[:,
'xnat_subjectdata_field_map_' +
config['DATA']['censor_label']]
return df_trans, clinical_col
65 changes: 0 additions & 65 deletions DefaultConfiguration.ini

This file was deleted.

Loading