diff --git a/.gitignore b/.gitignore index b889e93..ea45e5d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ *.ipynb* training/.flag* prediction*.csv +.idea diff --git a/config_submit.py b/config_submit.py index a8d8012..bdafa8d 100644 --- a/config_submit.py +++ b/config_submit.py @@ -1,13 +1,13 @@ -config = {'datapath':'/work/DataBowl3/stage2/stage2/', - 'preprocess_result_path':'./prep_result/', - 'outputfile':'prediction.csv', - - 'detector_model':'net_detector', - 'detector_param':'./model/detector.ckpt', - 'classifier_model':'net_classifier', - 'classifier_param':'./model/classifier.ckpt', - 'n_gpu':8, - 'n_worker_preprocessing':None, - 'use_exsiting_preprocessing':False, - 'skip_preprocessing':False, - 'skip_detect':False} +config = {'datapath': '/work/DataBowl3/stage2/stage2/', + 'preprocess_result_path': './prep_result/', + 'outputfile': 'prediction.csv', + + 'detector_model': 'net_detector', + 'detector_param': './model/detector.ckpt', + 'classifier_model': 'net_classifier', + 'classifier_param': './model/classifier.ckpt', + 'n_gpu': 8, + 'n_worker_preprocessing': None, + 'use_exsiting_preprocessing': False, + 'skip_preprocessing': False, + 'skip_detect': False} diff --git a/data_classifier.py b/data_classifier.py index 58b0e73..83e55e4 100644 --- a/data_classifier.py +++ b/data_classifier.py @@ -1,51 +1,49 @@ -import numpy as np -import torch -from torch.utils.data import Dataset import os import time -import collections -import random -from layers import iou -from scipy.ndimage import zoom import warnings + +import numpy as np +import torch +from scipy.ndimage import zoom from scipy.ndimage.interpolation import rotate -from layers import nms,iou -import pandas +from torch.utils.data import Dataset + +from layers import nms, iou + class DataBowl3Classifier(Dataset): - def __init__(self, split, config, phase = 'train'): - assert(phase == 'train' or phase == 'val' or phase == 'test') - + def __init__(self, split, config, phase='train'): + assert (phase == 'train' or phase == 'val' or phase == 'test') + self.random_sample = config['random_sample'] self.T = config['T'] self.topk = config['topk'] self.crop_size = config['crop_size'] self.stride = config['stride'] - self.augtype = config['augtype'] + self.augtype = config['augtype'] self.filling_value = config['filling_value'] - - #self.labels = np.array(pandas.read_csv(config['labelfile'])) - + + # self.labels = np.array(pandas.read_csv(config['labelfile'])) + datadir = config['datadir'] - bboxpath = config['bboxpath'] + bboxpath = config['bboxpath'] self.phase = phase self.candidate_box = [] self.pbb_label = [] - + idcs = split self.filenames = [os.path.join(datadir, '%s_clean.npy' % idx.split('-')[0]) for idx in idcs] - if self.phase!='test': - self.yset = 1-np.array([f.split('-')[1][2] for f in idcs]).astype('int') - - + if self.phase != 'test': + self.yset = 1 - np.array([f.split('-')[1][2] for f in idcs]).astype('int') + for idx in idcs: - pbb = np.load(os.path.join(bboxpath,idx+'_pbb.npy')) - pbb = pbb[pbb[:,0]>config['conf_th']] + pbb = np.load(os.path.join(bboxpath, idx + '_pbb.npy')) + pbb = pbb[pbb[:, 0] > config['conf_th']] pbb = nms(pbb, config['nms_th']) - - lbb = np.load(os.path.join(bboxpath,idx+'_lbb.npy')) + + lbb = np.load(os.path.join(bboxpath, idx + '_lbb.npy')) pbb_label = [] - + for p in pbb: isnod = False for l in lbb: @@ -54,166 +52,171 @@ def __init__(self, split, config, phase = 'train'): isnod = True break pbb_label.append(isnod) -# if idx.startswith() + # if idx.startswith() self.candidate_box.append(pbb) self.pbb_label.append(np.array(pbb_label)) - self.crop = simpleCrop(config,phase) - + self.crop = simpleCrop(config, phase) - def __getitem__(self, idx,split=None): + def __getitem__(self, idx, split=None): t = time.time() - np.random.seed(int(str(t%1)[2:7]))#seed according to time + np.random.seed(int(str(t % 1)[2:7])) # seed according to time pbb = self.candidate_box[idx] pbb_label = self.pbb_label[idx] - conf_list = pbb[:,0] + conf_list = pbb[:, 0] T = self.T topk = self.topk img = np.load(self.filenames[idx]) - if self.random_sample and self.phase=='train': - chosenid = sample(conf_list,topk,T=T) - #chosenid = conf_list.argsort()[::-1][:topk] + if self.random_sample and self.phase == 'train': + chosenid = sample(conf_list, topk, T=T) + # chosenid = conf_list.argsort()[::-1][:topk] else: chosenid = conf_list.argsort()[::-1][:topk] - croplist = np.zeros([topk,1,self.crop_size[0],self.crop_size[1],self.crop_size[2]]).astype('float32') - coordlist = np.zeros([topk,3,self.crop_size[0]/self.stride,self.crop_size[1]/self.stride,self.crop_size[2]/self.stride]).astype('float32') - padmask = np.concatenate([np.ones(len(chosenid)),np.zeros(self.topk-len(chosenid))]) + croplist = np.zeros([topk, 1, self.crop_size[0], self.crop_size[1], self.crop_size[2]]).astype('float32') + coordlist = np.zeros([topk, 3, self.crop_size[0] / self.stride, self.crop_size[1] / self.stride, + self.crop_size[2] / self.stride]).astype('float32') + padmask = np.concatenate([np.ones(len(chosenid)), np.zeros(self.topk - len(chosenid))]) isnodlist = np.zeros([topk]) - - for i,id in enumerate(chosenid): - target = pbb[id,1:] + for i, id in enumerate(chosenid): + target = pbb[id, 1:] isnod = pbb_label[id] - crop,coord = self.crop(img,target) - if self.phase=='train': - crop,coord = augment(crop,coord, - ifflip=self.augtype['flip'],ifrotate=self.augtype['rotate'], - ifswap = self.augtype['swap'],filling_value = self.filling_value) + crop, coord = self.crop(img, target) + if self.phase == 'train': + crop, coord = augment(crop, coord, + ifflip=self.augtype['flip'], ifrotate=self.augtype['rotate'], + ifswap=self.augtype['swap'], filling_value=self.filling_value) crop = crop.astype(np.float32) croplist[i] = crop coordlist[i] = coord isnodlist[i] = isnod - - if self.phase!='test': + + if self.phase != 'test': y = np.array([self.yset[idx]]) - return torch.from_numpy(croplist).float(), torch.from_numpy(coordlist).float(), torch.from_numpy(isnodlist).int(), torch.from_numpy(y) + return torch.from_numpy(croplist).float(), torch.from_numpy(coordlist).float(), torch.from_numpy( + isnodlist).int(), torch.from_numpy(y) else: return torch.from_numpy(croplist).float(), torch.from_numpy(coordlist).float() + def __len__(self): if self.phase != 'test': return len(self.candidate_box) else: return len(self.candidate_box) - - + class simpleCrop(): - def __init__(self,config,phase): + def __init__(self, config, phase): self.crop_size = config['crop_size'] self.scaleLim = config['scaleLim'] self.radiusLim = config['radiusLim'] self.jitter_range = config['jitter_range'] - self.isScale = config['augtype']['scale'] and phase=='train' + self.isScale = config['augtype']['scale'] and phase == 'train' self.stride = config['stride'] self.filling_value = config['filling_value'] self.phase = phase - - def __call__(self,imgs,target): + + def __call__(self, imgs, target): if self.isScale: radiusLim = self.radiusLim scaleLim = self.scaleLim - scaleRange = [np.min([np.max([(radiusLim[0]/target[3]),scaleLim[0]]),1]) - ,np.max([np.min([(radiusLim[1]/target[3]),scaleLim[1]]),1])] - scale = np.random.rand()*(scaleRange[1]-scaleRange[0])+scaleRange[0] - crop_size = (np.array(self.crop_size).astype('float')/scale).astype('int') + scaleRange = [np.min([np.max([(radiusLim[0] / target[3]), scaleLim[0]]), 1]) + , np.max([np.min([(radiusLim[1] / target[3]), scaleLim[1]]), 1])] + scale = np.random.rand() * (scaleRange[1] - scaleRange[0]) + scaleRange[0] + crop_size = (np.array(self.crop_size).astype('float') / scale).astype('int') else: crop_size = np.array(self.crop_size).astype('int') - if self.phase=='train': - jitter_range = target[3]*self.jitter_range - jitter = (np.random.rand(3)-0.5)*jitter_range + if self.phase == 'train': + jitter_range = target[3] * self.jitter_range + jitter = (np.random.rand(3) - 0.5) * jitter_range else: jitter = 0 - start = (target[:3]- crop_size/2 + jitter).astype('int') - pad = [[0,0]] + start = (target[:3] - crop_size / 2 + jitter).astype('int') + pad = [[0, 0]] for i in range(3): - if start[i]<0: + if start[i] < 0: leftpad = -start[i] start[i] = 0 else: leftpad = 0 - if start[i]+crop_size[i]>imgs.shape[i+1]: - rightpad = start[i]+crop_size[i]-imgs.shape[i+1] + if start[i] + crop_size[i] > imgs.shape[i + 1]: + rightpad = start[i] + crop_size[i] - imgs.shape[i + 1] else: rightpad = 0 - pad.append([leftpad,rightpad]) - imgs = np.pad(imgs,pad,'constant',constant_values =self.filling_value) - crop = imgs[:,start[0]:start[0]+crop_size[0],start[1]:start[1]+crop_size[1],start[2]:start[2]+crop_size[2]] - - normstart = np.array(start).astype('float32')/np.array(imgs.shape[1:])-0.5 - normsize = np.array(crop_size).astype('float32')/np.array(imgs.shape[1:]) - xx,yy,zz = np.meshgrid(np.linspace(normstart[0],normstart[0]+normsize[0],self.crop_size[0]/self.stride), - np.linspace(normstart[1],normstart[1]+normsize[1],self.crop_size[1]/self.stride), - np.linspace(normstart[2],normstart[2]+normsize[2],self.crop_size[2]/self.stride),indexing ='ij') - coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32') + pad.append([leftpad, rightpad]) + imgs = np.pad(imgs, pad, 'constant', constant_values=self.filling_value) + crop = imgs[:, start[0]:start[0] + crop_size[0], start[1]:start[1] + crop_size[1], + start[2]:start[2] + crop_size[2]] + + normstart = np.array(start).astype('float32') / np.array(imgs.shape[1:]) - 0.5 + normsize = np.array(crop_size).astype('float32') / np.array(imgs.shape[1:]) + xx, yy, zz = np.meshgrid(np.linspace(normstart[0], normstart[0] + normsize[0], self.crop_size[0] / self.stride), + np.linspace(normstart[1], normstart[1] + normsize[1], self.crop_size[1] / self.stride), + np.linspace(normstart[2], normstart[2] + normsize[2], self.crop_size[2] / self.stride), + indexing='ij') + coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32') if self.isScale: with warnings.catch_warnings(): warnings.simplefilter("ignore") - crop = zoom(crop,[1,scale,scale,scale],order=1) - newpad = self.crop_size[0]-crop.shape[1:][0] - if newpad<0: - crop = crop[:,:-newpad,:-newpad,:-newpad] - elif newpad>0: - pad2 = [[0,0],[0,newpad],[0,newpad],[0,newpad]] - crop = np.pad(crop,pad2,'constant',constant_values =self.filling_value) - - return crop,coord - -def sample(conf,N,T=1): - if len(conf)>N: + crop = zoom(crop, [1, scale, scale, scale], order=1) + newpad = self.crop_size[0] - crop.shape[1:][0] + if newpad < 0: + crop = crop[:, :-newpad, :-newpad, :-newpad] + elif newpad > 0: + pad2 = [[0, 0], [0, newpad], [0, newpad], [0, newpad]] + crop = np.pad(crop, pad2, 'constant', constant_values=self.filling_value) + + return crop, coord + + +def sample(conf, N, T=1): + if len(conf) > N: target = range(len(conf)) chosen_list = [] for i in range(N): - chosenidx = sampleone(target,conf,T) + chosenidx = sampleone(target, conf, T) chosen_list.append(target[chosenidx]) target.pop(chosenidx) conf = np.delete(conf, chosenidx) - return chosen_list else: return np.arange(len(conf)) -def sampleone(target,conf,T): - assert len(conf)>1 - p = softmax(conf/T) - p = np.max([np.ones_like(p)*0.00001,p],axis=0) - p = p/np.sum(p) - return np.random.choice(np.arange(len(target)),size=1,replace = False, p=p)[0] + +def sampleone(target, conf, T): + assert len(conf) > 1 + p = softmax(conf / T) + p = np.max([np.ones_like(p) * 0.00001, p], axis=0) + p = p / np.sum(p) + return np.random.choice(np.arange(len(target)), size=1, replace=False, p=p)[0] + def softmax(x): maxx = np.max(x) - return np.exp(x-maxx)/np.sum(np.exp(x-maxx)) + return np.exp(x - maxx) / np.sum(np.exp(x - maxx)) -def augment(sample, coord, ifflip = True, ifrotate=True, ifswap = True,filling_value=0): +def augment(sample, coord, ifflip=True, ifrotate=True, ifswap=True, filling_value=0): # angle1 = np.random.rand()*180 if ifrotate: validrot = False counter = 0 - angle1 = np.random.rand()*180 + angle1 = np.random.rand() * 180 size = np.array(sample.shape[2:4]).astype('float') - rotmat = np.array([[np.cos(angle1/180*np.pi),-np.sin(angle1/180*np.pi)],[np.sin(angle1/180*np.pi),np.cos(angle1/180*np.pi)]]) - sample = rotate(sample,angle1,axes=(2,3),reshape=False,cval=filling_value) - + rotmat = np.array([[np.cos(angle1 / 180 * np.pi), -np.sin(angle1 / 180 * np.pi)], + [np.sin(angle1 / 180 * np.pi), np.cos(angle1 / 180 * np.pi)]]) + sample = rotate(sample, angle1, axes=(2, 3), reshape=False, cval=filling_value) + if ifswap: - if sample.shape[1]==sample.shape[2] and sample.shape[1]==sample.shape[3]: + if sample.shape[1] == sample.shape[2] and sample.shape[1] == sample.shape[3]: axisorder = np.random.permutation(3) - sample = np.transpose(sample,np.concatenate([[0],axisorder+1])) - coord = np.transpose(coord,np.concatenate([[0],axisorder+1])) - + sample = np.transpose(sample, np.concatenate([[0], axisorder + 1])) + coord = np.transpose(coord, np.concatenate([[0], axisorder + 1])) + if ifflip: - flipid = np.array([np.random.randint(2),np.random.randint(2),np.random.randint(2)])*2-1 - sample = np.ascontiguousarray(sample[:,::flipid[0],::flipid[1],::flipid[2]]) - coord = np.ascontiguousarray(coord[:,::flipid[0],::flipid[1],::flipid[2]]) - return sample, coord + flipid = np.array([np.random.randint(2), np.random.randint(2), np.random.randint(2)]) * 2 - 1 + sample = np.ascontiguousarray(sample[:, ::flipid[0], ::flipid[1], ::flipid[2]]) + coord = np.ascontiguousarray(coord[:, ::flipid[0], ::flipid[1], ::flipid[2]]) + return sample, coord diff --git a/data_detector.py b/data_detector.py index 35f0874..cb910b4 100644 --- a/data_detector.py +++ b/data_detector.py @@ -1,115 +1,118 @@ -import numpy as np -import torch -from torch.utils.data import Dataset -import os -import time import collections +import os import random -from layers import iou -from scipy.ndimage import zoom +import time import warnings + +import numpy as np +import torch +from scipy.ndimage import zoom from scipy.ndimage.interpolation import rotate -from scipy.ndimage.morphology import binary_dilation,generate_binary_structure +from scipy.ndimage.morphology import binary_dilation, generate_binary_structure +from torch.utils.data import Dataset + class DataBowl3Detector(Dataset): - def __init__(self, split, config, phase = 'train',split_comber=None): - assert(phase == 'train' or phase == 'val' or phase == 'test') + def __init__(self, split, config, phase='train', split_comber=None): + assert (phase == 'train' or phase == 'val' or phase == 'test') self.phase = phase - self.max_stride = config['max_stride'] - self.stride = config['stride'] - sizelim = config['sizelim']/config['reso'] - sizelim2 = config['sizelim2']/config['reso'] - sizelim3 = config['sizelim3']/config['reso'] + self.max_stride = config['max_stride'] + self.stride = config['stride'] + sizelim = config['sizelim'] / config['reso'] + sizelim2 = config['sizelim2'] / config['reso'] + sizelim3 = config['sizelim3'] / config['reso'] self.blacklist = config['blacklist'] self.isScale = config['aug_scale'] self.r_rand = config['r_rand_crop'] self.augtype = config['augtype'] data_dir = config['datadir'] - self.pad_value = config['pad_value'] - + self.pad_value = config['pad_value'] + self.split_comber = split_comber idcs = split - if phase!='test': + if phase != 'test': idcs = [f for f in idcs if f not in self.blacklist] self.channel = config['chanel'] - if self.channel==2: + if self.channel == 2: self.filenames = [os.path.join(data_dir, '%s_merge.npy' % idx) for idx in idcs] - elif self.channel ==1: - if 'cleanimg' in config and config['cleanimg']: + elif self.channel == 1: + if 'cleanimg' in config and config['cleanimg']: self.filenames = [os.path.join(data_dir, '%s_clean.npy' % idx) for idx in idcs] else: self.filenames = [os.path.join(data_dir, '%s_img.npy' % idx) for idx in idcs] - self.kagglenames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0])>20] - self.lunanames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0])<20] - + self.kagglenames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0]) > 20] + self.lunanames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0]) < 20] + labels = [] - + for idx in idcs: - if config['luna_raw'] ==True: + if config['luna_raw'] == True: try: l = np.load(os.path.join(data_dir, '%s_label_raw.npy' % idx)) except: - l = np.load(os.path.join(data_dir, '%s_label.npy' %idx)) + l = np.load(os.path.join(data_dir, '%s_label.npy' % idx)) else: - l = np.load(os.path.join(data_dir, '%s_label.npy' %idx)) + l = np.load(os.path.join(data_dir, '%s_label.npy' % idx)) labels.append(l) self.sample_bboxes = labels - if self.phase!='test': + if self.phase != 'test': self.bboxes = [] for i, l in enumerate(labels): - if len(l) > 0 : + if len(l) > 0: for t in l: - if t[3]>sizelim: - self.bboxes.append([np.concatenate([[i],t])]) - if t[3]>sizelim2: - self.bboxes+=[[np.concatenate([[i],t])]]*2 - if t[3]>sizelim3: - self.bboxes+=[[np.concatenate([[i],t])]]*4 - self.bboxes = np.concatenate(self.bboxes,axis = 0) + if t[3] > sizelim: + self.bboxes.append([np.concatenate([[i], t])]) + if t[3] > sizelim2: + self.bboxes += [[np.concatenate([[i], t])]] * 2 + if t[3] > sizelim3: + self.bboxes += [[np.concatenate([[i], t])]] * 4 + self.bboxes = np.concatenate(self.bboxes, axis=0) self.crop = Crop(config) self.label_mapping = LabelMapping(config, self.phase) - def __getitem__(self, idx,split=None): + def __getitem__(self, idx, split=None): t = time.time() - np.random.seed(int(str(t%1)[2:7]))#seed according to time + np.random.seed(int(str(t % 1)[2:7])) # seed according to time - isRandomImg = False - if self.phase !='test': - if idx>=len(self.bboxes): + isRandomImg = False + if self.phase != 'test': + if idx >= len(self.bboxes): isRandom = True - idx = idx%len(self.bboxes) + idx = idx % len(self.bboxes) isRandomImg = np.random.randint(2) else: isRandom = False else: isRandom = False - + if self.phase != 'test': if not isRandomImg: bbox = self.bboxes[idx] - filename = self.filenames[int(bbox[0])] + filename = self.filenames[int(bbox[0])] imgs = np.load(filename)[0:self.channel] bboxes = self.sample_bboxes[int(bbox[0])] - isScale = self.augtype['scale'] and (self.phase=='train') - sample, target, bboxes, coord = self.crop(imgs, bbox[1:], bboxes,isScale,isRandom) - if self.phase=='train' and not isRandom: - sample, target, bboxes, coord = augment(sample, target, bboxes, coord, - ifflip = self.augtype['flip'], ifrotate=self.augtype['rotate'], ifswap = self.augtype['swap']) + isScale = self.augtype['scale'] and (self.phase == 'train') + sample, target, bboxes, coord = self.crop(imgs, bbox[1:], bboxes, isScale, isRandom) + if self.phase == 'train' and not isRandom: + sample, target, bboxes, coord = augment(sample, target, bboxes, coord, + ifflip=self.augtype['flip'], + ifrotate=self.augtype['rotate'], + ifswap=self.augtype['swap']) else: randimid = np.random.randint(len(self.kagglenames)) - filename = self.kagglenames[randimid] + filename = self.kagglenames[randimid] imgs = np.load(filename)[0:self.channel] bboxes = self.sample_bboxes[randimid] - isScale = self.augtype['scale'] and (self.phase=='train') - sample, target, bboxes, coord = self.crop(imgs, [], bboxes,isScale=False,isRand=True) + isScale = self.augtype['scale'] and (self.phase == 'train') + sample, target, bboxes, coord = self.crop(imgs, [], bboxes, isScale=False, isRand=True) label = self.label_mapping(sample.shape[1:], target, bboxes) sample = sample.astype(np.float32) - #if filename in self.kagglenames: - # label[label==-1]=0 - sample = (sample.astype(np.float32)-128)/128 + # if filename in self.kagglenames: + # label[label==-1]=0 + sample = (sample.astype(np.float32) - 128) / 128 return torch.from_numpy(sample), torch.from_numpy(label), coord else: imgs = np.load(self.filenames[idx]) @@ -118,148 +121,153 @@ def __getitem__(self, idx,split=None): pz = int(np.ceil(float(nz) / self.stride)) * self.stride ph = int(np.ceil(float(nh) / self.stride)) * self.stride pw = int(np.ceil(float(nw) / self.stride)) * self.stride - imgs = np.pad(imgs, [[0,0],[0, pz - nz], [0, ph - nh], [0, pw - nw]], 'constant',constant_values = self.pad_value) - xx,yy,zz = np.meshgrid(np.linspace(-0.5,0.5,imgs.shape[1]/self.stride), - np.linspace(-0.5,0.5,imgs.shape[2]/self.stride), - np.linspace(-0.5,0.5,imgs.shape[3]/self.stride),indexing ='ij') - coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32') + imgs = np.pad(imgs, [[0, 0], [0, pz - nz], [0, ph - nh], [0, pw - nw]], 'constant', + constant_values=self.pad_value) + xx, yy, zz = np.meshgrid(np.linspace(-0.5, 0.5, imgs.shape[1] / self.stride), + np.linspace(-0.5, 0.5, imgs.shape[2] / self.stride), + np.linspace(-0.5, 0.5, imgs.shape[3] / self.stride), indexing='ij') + coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32') imgs, nzhw = self.split_comber.split(imgs) coord2, nzhw2 = self.split_comber.split(coord, - side_len = self.split_comber.side_len/self.stride, - max_stride = self.split_comber.max_stride/self.stride, - margin = self.split_comber.margin/self.stride) - assert np.all(nzhw==nzhw2) - imgs = (imgs.astype(np.float32)-128)/128 - return torch.from_numpy(imgs.astype(np.float32)), bboxes, torch.from_numpy(coord2.astype(np.float32)), np.array(nzhw) + side_len=self.split_comber.side_len / self.stride, + max_stride=self.split_comber.max_stride / self.stride, + margin=self.split_comber.margin / self.stride) + assert np.all(nzhw == nzhw2) + imgs = (imgs.astype(np.float32) - 128) / 128 + return torch.from_numpy(imgs.astype(np.float32)), bboxes, torch.from_numpy( + coord2.astype(np.float32)), np.array(nzhw) def __len__(self): if self.phase == 'train': - return len(self.bboxes)/(1-self.r_rand) - elif self.phase =='val': - return len(self.bboxes) + return len(self.bboxes) / (1 - self.r_rand) + elif self.phase == 'val': + return len(self.bboxes) else: return len(self.filenames) - - -def augment(sample, target, bboxes, coord, ifflip = True, ifrotate=True, ifswap = True): + + +def augment(sample, target, bboxes, coord, ifflip=True, ifrotate=True, ifswap=True): # angle1 = np.random.rand()*180 if ifrotate: validrot = False counter = 0 while not validrot: newtarget = np.copy(target) - angle1 = (np.random.rand()-0.5)*20 + angle1 = (np.random.rand() - 0.5) * 20 size = np.array(sample.shape[2:4]).astype('float') - rotmat = np.array([[np.cos(angle1/180*np.pi),-np.sin(angle1/180*np.pi)],[np.sin(angle1/180*np.pi),np.cos(angle1/180*np.pi)]]) - newtarget[1:3] = np.dot(rotmat,target[1:3]-size/2)+size/2 - if np.all(newtarget[:3]>target[3]) and np.all(newtarget[:3]< np.array(sample.shape[1:4])-newtarget[3]): + rotmat = np.array([[np.cos(angle1 / 180 * np.pi), -np.sin(angle1 / 180 * np.pi)], + [np.sin(angle1 / 180 * np.pi), np.cos(angle1 / 180 * np.pi)]]) + newtarget[1:3] = np.dot(rotmat, target[1:3] - size / 2) + size / 2 + if np.all(newtarget[:3] > target[3]) and np.all(newtarget[:3] < np.array(sample.shape[1:4]) - newtarget[3]): validrot = True target = newtarget - sample = rotate(sample,angle1,axes=(2,3),reshape=False) - coord = rotate(coord,angle1,axes=(2,3),reshape=False) + sample = rotate(sample, angle1, axes=(2, 3), reshape=False) + coord = rotate(coord, angle1, axes=(2, 3), reshape=False) for box in bboxes: - box[1:3] = np.dot(rotmat,box[1:3]-size/2)+size/2 + box[1:3] = np.dot(rotmat, box[1:3] - size / 2) + size / 2 else: counter += 1 - if counter ==3: + if counter == 3: break if ifswap: - if sample.shape[1]==sample.shape[2] and sample.shape[1]==sample.shape[3]: + if sample.shape[1] == sample.shape[2] and sample.shape[1] == sample.shape[3]: axisorder = np.random.permutation(3) - sample = np.transpose(sample,np.concatenate([[0],axisorder+1])) - coord = np.transpose(coord,np.concatenate([[0],axisorder+1])) + sample = np.transpose(sample, np.concatenate([[0], axisorder + 1])) + coord = np.transpose(coord, np.concatenate([[0], axisorder + 1])) target[:3] = target[:3][axisorder] - bboxes[:,:3] = bboxes[:,:3][:,axisorder] - + bboxes[:, :3] = bboxes[:, :3][:, axisorder] + if ifflip: -# flipid = np.array([np.random.randint(2),np.random.randint(2),np.random.randint(2)])*2-1 - flipid = np.array([1,np.random.randint(2),np.random.randint(2)])*2-1 - sample = np.ascontiguousarray(sample[:,::flipid[0],::flipid[1],::flipid[2]]) - coord = np.ascontiguousarray(coord[:,::flipid[0],::flipid[1],::flipid[2]]) + # flipid = np.array([np.random.randint(2),np.random.randint(2),np.random.randint(2)])*2-1 + flipid = np.array([1, np.random.randint(2), np.random.randint(2)]) * 2 - 1 + sample = np.ascontiguousarray(sample[:, ::flipid[0], ::flipid[1], ::flipid[2]]) + coord = np.ascontiguousarray(coord[:, ::flipid[0], ::flipid[1], ::flipid[2]]) for ax in range(3): - if flipid[ax]==-1: - target[ax] = np.array(sample.shape[ax+1])-target[ax] - bboxes[:,ax]= np.array(sample.shape[ax+1])-bboxes[:,ax] - return sample, target, bboxes, coord + if flipid[ax] == -1: + target[ax] = np.array(sample.shape[ax + 1]) - target[ax] + bboxes[:, ax] = np.array(sample.shape[ax + 1]) - bboxes[:, ax] + return sample, target, bboxes, coord + class Crop(object): def __init__(self, config): self.crop_size = config['crop_size'] self.bound_size = config['bound_size'] self.stride = config['stride'] - self.pad_value = config['pad_value'] + self.pad_value = config['pad_value'] - def __call__(self, imgs, target, bboxes,isScale=False,isRand=False): + def __call__(self, imgs, target, bboxes, isScale=False, isRand=False): if isScale: - radiusLim = [8.,100.] - scaleLim = [0.75,1.25] - scaleRange = [np.min([np.max([(radiusLim[0]/target[3]),scaleLim[0]]),1]) - ,np.max([np.min([(radiusLim[1]/target[3]),scaleLim[1]]),1])] - scale = np.random.rand()*(scaleRange[1]-scaleRange[0])+scaleRange[0] - crop_size = (np.array(self.crop_size).astype('float')/scale).astype('int') + radiusLim = [8., 100.] + scaleLim = [0.75, 1.25] + scaleRange = [np.min([np.max([(radiusLim[0] / target[3]), scaleLim[0]]), 1]) + , np.max([np.min([(radiusLim[1] / target[3]), scaleLim[1]]), 1])] + scale = np.random.rand() * (scaleRange[1] - scaleRange[0]) + scaleRange[0] + crop_size = (np.array(self.crop_size).astype('float') / scale).astype('int') else: - crop_size=self.crop_size + crop_size = self.crop_size bound_size = self.bound_size target = np.copy(target) bboxes = np.copy(bboxes) - + start = [] for i in range(3): if not isRand: r = target[3] / 2 - s = np.floor(target[i] - r)+ 1 - bound_size - e = np.ceil (target[i] + r)+ 1 + bound_size - crop_size[i] + s = np.floor(target[i] - r) + 1 - bound_size + e = np.ceil(target[i] + r) + 1 + bound_size - crop_size[i] else: - s = np.max([imgs.shape[i+1]-crop_size[i]/2,imgs.shape[i+1]/2+bound_size]) - e = np.min([crop_size[i]/2, imgs.shape[i+1]/2-bound_size]) - target = np.array([np.nan,np.nan,np.nan,np.nan]) - if s>e: - start.append(np.random.randint(e,s))#! + s = np.max([imgs.shape[i + 1] - crop_size[i] / 2, imgs.shape[i + 1] / 2 + bound_size]) + e = np.min([crop_size[i] / 2, imgs.shape[i + 1] / 2 - bound_size]) + target = np.array([np.nan, np.nan, np.nan, np.nan]) + if s > e: + start.append(np.random.randint(e, s)) # ! else: - start.append(int(target[i])-crop_size[i]/2+np.random.randint(-bound_size/2,bound_size/2)) - - - normstart = np.array(start).astype('float32')/np.array(imgs.shape[1:])-0.5 - normsize = np.array(crop_size).astype('float32')/np.array(imgs.shape[1:]) - xx,yy,zz = np.meshgrid(np.linspace(normstart[0],normstart[0]+normsize[0],self.crop_size[0]/self.stride), - np.linspace(normstart[1],normstart[1]+normsize[1],self.crop_size[1]/self.stride), - np.linspace(normstart[2],normstart[2]+normsize[2],self.crop_size[2]/self.stride),indexing ='ij') - coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32') + start.append(int(target[i]) - crop_size[i] / 2 + np.random.randint(-bound_size / 2, bound_size / 2)) + + normstart = np.array(start).astype('float32') / np.array(imgs.shape[1:]) - 0.5 + normsize = np.array(crop_size).astype('float32') / np.array(imgs.shape[1:]) + xx, yy, zz = np.meshgrid(np.linspace(normstart[0], normstart[0] + normsize[0], self.crop_size[0] / self.stride), + np.linspace(normstart[1], normstart[1] + normsize[1], self.crop_size[1] / self.stride), + np.linspace(normstart[2], normstart[2] + normsize[2], self.crop_size[2] / self.stride), + indexing='ij') + coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32') pad = [] - pad.append([0,0]) + pad.append([0, 0]) for i in range(3): - leftpad = max(0,-start[i]) - rightpad = max(0,start[i]+crop_size[i]-imgs.shape[i+1]) - pad.append([leftpad,rightpad]) + leftpad = max(0, -start[i]) + rightpad = max(0, start[i] + crop_size[i] - imgs.shape[i + 1]) + pad.append([leftpad, rightpad]) crop = imgs[:, - max(start[0],0):min(start[0] + crop_size[0],imgs.shape[1]), - max(start[1],0):min(start[1] + crop_size[1],imgs.shape[2]), - max(start[2],0):min(start[2] + crop_size[2],imgs.shape[3])] - crop = np.pad(crop,pad,'constant',constant_values =self.pad_value) + max(start[0], 0):min(start[0] + crop_size[0], imgs.shape[1]), + max(start[1], 0):min(start[1] + crop_size[1], imgs.shape[2]), + max(start[2], 0):min(start[2] + crop_size[2], imgs.shape[3])] + crop = np.pad(crop, pad, 'constant', constant_values=self.pad_value) for i in range(3): - target[i] = target[i] - start[i] + target[i] = target[i] - start[i] for i in range(len(bboxes)): for j in range(3): - bboxes[i][j] = bboxes[i][j] - start[j] - + bboxes[i][j] = bboxes[i][j] - start[j] + if isScale: with warnings.catch_warnings(): warnings.simplefilter("ignore") - crop = zoom(crop,[1,scale,scale,scale],order=1) - newpad = self.crop_size[0]-crop.shape[1:][0] - if newpad<0: - crop = crop[:,:-newpad,:-newpad,:-newpad] - elif newpad>0: - pad2 = [[0,0],[0,newpad],[0,newpad],[0,newpad]] - crop = np.pad(crop,pad2,'constant',constant_values =self.pad_value) + crop = zoom(crop, [1, scale, scale, scale], order=1) + newpad = self.crop_size[0] - crop.shape[1:][0] + if newpad < 0: + crop = crop[:, :-newpad, :-newpad, :-newpad] + elif newpad > 0: + pad2 = [[0, 0], [0, newpad], [0, newpad], [0, newpad]] + crop = np.pad(crop, pad2, 'constant', constant_values=self.pad_value) for i in range(4): - target[i] = target[i]*scale + target[i] = target[i] * scale for i in range(len(bboxes)): for j in range(4): - bboxes[i][j] = bboxes[i][j]*scale + bboxes[i][j] = bboxes[i][j] * scale return crop, target, bboxes, coord - + + class LabelMapping(object): def __init__(self, config, phase): self.stride = np.array(config['stride']) @@ -272,20 +280,19 @@ def __init__(self, config, phase): elif phase == 'val': self.th_pos = config['th_pos_val'] - def __call__(self, input_size, target, bboxes): stride = self.stride num_neg = self.num_neg th_neg = self.th_neg anchors = self.anchors th_pos = self.th_pos - struct = generate_binary_structure(3,1) - + struct = generate_binary_structure(3, 1) + output_size = [] for i in range(3): - assert(input_size[i] % stride == 0) + assert (input_size[i] % stride == 0) output_size.append(input_size[i] / stride) - + label = np.zeros(output_size + [len(anchors), 5], np.float32) offset = ((stride.astype('float')) - 1) / 2 oz = np.arange(offset, offset + stride * (output_size[0] - 1) + 1, stride) @@ -296,10 +303,10 @@ def __call__(self, input_size, target, bboxes): for i, anchor in enumerate(anchors): iz, ih, iw = select_samples(bbox, anchor, th_neg, oz, oh, ow) label[iz, ih, iw, i, 0] = 1 - label[:,:,:, i, 0] = binary_dilation(label[:,:,:, i, 0].astype('bool'),structure=struct,iterations=1).astype('float32') - - - label = label-1 + label[:, :, :, i, 0] = binary_dilation(label[:, :, :, i, 0].astype('bool'), structure=struct, + iterations=1).astype('float32') + + label = label - 1 if self.phase == 'train' and self.num_neg > 0: neg_z, neg_h, neg_w, neg_a = np.where(label[:, :, :, :, 0] == -1) @@ -321,7 +328,7 @@ def __call__(self, input_size, target, bboxes): ih = np.concatenate(ih, 0) iw = np.concatenate(iw, 0) ia = np.concatenate(ia, 0) - flag = True + flag = True if len(iz) == 0: pos = [] for i in range(3): @@ -337,7 +344,8 @@ def __call__(self, input_size, target, bboxes): dw = (target[2] - ow[pos[2]]) / anchors[pos[3]] dd = np.log(target[3] / anchors[pos[3]]) label[pos[0], pos[1], pos[2], pos[3], :] = [1, dz, dh, dw, dd] - return label + return label + def select_samples(bbox, anchor, th, oz, oh, ow): z, h, w, d = bbox @@ -350,12 +358,12 @@ def select_samples(bbox, anchor, th, oz, oh, ow): e = z + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap) mz = np.logical_and(oz >= s, oz <= e) iz = np.where(mz)[0] - + s = h - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap) e = h + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap) mh = np.logical_and(oh >= s, oh <= e) ih = np.where(mh)[0] - + s = w - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap) e = w + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap) mw = np.logical_and(ow >= s, ow <= e) @@ -363,7 +371,7 @@ def select_samples(bbox, anchor, th, oz, oh, ow): if len(iz) == 0 or len(ih) == 0 or len(iw) == 0: return np.zeros((0,), np.int64), np.zeros((0,), np.int64), np.zeros((0,), np.int64) - + lz, lh, lw = len(iz), len(ih), len(iw) iz = iz.reshape((-1, 1, 1)) ih = ih.reshape((1, -1, 1)) @@ -374,36 +382,37 @@ def select_samples(bbox, anchor, th, oz, oh, ow): centers = np.concatenate([ oz[iz].reshape((-1, 1)), oh[ih].reshape((-1, 1)), - ow[iw].reshape((-1, 1))], axis = 1) - + ow[iw].reshape((-1, 1))], axis=1) + r0 = anchor / 2 s0 = centers - r0 e0 = centers + r0 - + r1 = d / 2 s1 = bbox[:3] - r1 s1 = s1.reshape((1, -1)) e1 = bbox[:3] + r1 e1 = e1.reshape((1, -1)) - + overlap = np.maximum(0, np.minimum(e0, e1) - np.maximum(s0, s1)) - + intersection = overlap[:, 0] * overlap[:, 1] * overlap[:, 2] union = anchor * anchor * anchor + d * d * d - intersection iou = intersection / union mask = iou >= th - #if th > 0.4: - # if np.sum(mask) == 0: - # print(['iou not large', iou.max()]) - # else: - # print(['iou large', iou[mask]]) + # if th > 0.4: + # if np.sum(mask) == 0: + # print(['iou not large', iou.max()]) + # else: + # print(['iou large', iou[mask]]) iz = iz[mask] ih = ih[mask] iw = iw[mask] return iz, ih, iw + def collate(batch): if torch.is_tensor(batch[0]): return [b.unsqueeze(0) for b in batch] @@ -414,4 +423,3 @@ def collate(batch): elif isinstance(batch[0], collections.Iterable): transposed = zip(*batch) return [collate(samples) for samples in transposed] - diff --git a/layers.py b/layers.py index 939b7be..aadb473 100644 --- a/layers.py +++ b/layers.py @@ -2,20 +2,20 @@ import torch from torch import nn -import math + class PostRes2d(nn.Module): - def __init__(self, n_in, n_out, stride = 1): + def __init__(self, n_in, n_out, stride=1): super(PostRes2d, self).__init__() - self.conv1 = nn.Conv2d(n_in, n_out, kernel_size = 3, stride = stride, padding = 1) + self.conv1 = nn.Conv2d(n_in, n_out, kernel_size=3, stride=stride, padding=1) self.bn1 = nn.BatchNorm2d(n_out) - self.relu = nn.ReLU(inplace = True) - self.conv2 = nn.Conv2d(n_out, n_out, kernel_size = 3, padding = 1) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(n_out, n_out, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(n_out) if stride != 1 or n_out != n_in: self.shortcut = nn.Sequential( - nn.Conv2d(n_in, n_out, kernel_size = 1, stride = stride), + nn.Conv2d(n_in, n_out, kernel_size=1, stride=stride), nn.BatchNorm2d(n_out)) else: self.shortcut = None @@ -29,23 +29,24 @@ def forward(self, x): out = self.relu(out) out = self.conv2(out) out = self.bn2(out) - + out += residual out = self.relu(out) return out - + + class PostRes(nn.Module): - def __init__(self, n_in, n_out, stride = 1): + def __init__(self, n_in, n_out, stride=1): super(PostRes, self).__init__() - self.conv1 = nn.Conv3d(n_in, n_out, kernel_size = 3, stride = stride, padding = 1) + self.conv1 = nn.Conv3d(n_in, n_out, kernel_size=3, stride=stride, padding=1) self.bn1 = nn.BatchNorm3d(n_out) - self.relu = nn.ReLU(inplace = True) - self.conv2 = nn.Conv3d(n_out, n_out, kernel_size = 3, padding = 1) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv3d(n_out, n_out, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm3d(n_out) if stride != 1 or n_out != n_in: self.shortcut = nn.Sequential( - nn.Conv3d(n_in, n_out, kernel_size = 1, stride = stride), + nn.Conv3d(n_in, n_out, kernel_size=1, stride=stride), nn.BatchNorm3d(n_out)) else: self.shortcut = None @@ -59,72 +60,73 @@ def forward(self, x): out = self.relu(out) out = self.conv2(out) out = self.bn2(out) - + out += residual out = self.relu(out) return out + class Rec3(nn.Module): - def __init__(self, n0, n1, n2, n3, p = 0.0, integrate = True): + def __init__(self, n0, n1, n2, n3, p=0.0, integrate=True): super(Rec3, self).__init__() - + self.block01 = nn.Sequential( - nn.Conv3d(n0, n1, kernel_size = 3, stride = 2, padding = 1), + nn.Conv3d(n0, n1, kernel_size=3, stride=2, padding=1), nn.BatchNorm3d(n1), - nn.ReLU(inplace = True), - nn.Conv3d(n1, n1, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n1, n1, kernel_size=3, padding=1), nn.BatchNorm3d(n1)) self.block11 = nn.Sequential( - nn.Conv3d(n1, n1, kernel_size = 3, padding = 1), + nn.Conv3d(n1, n1, kernel_size=3, padding=1), nn.BatchNorm3d(n1), - nn.ReLU(inplace = True), - nn.Conv3d(n1, n1, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n1, n1, kernel_size=3, padding=1), nn.BatchNorm3d(n1)) - + self.block21 = nn.Sequential( - nn.ConvTranspose3d(n2, n1, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(n2, n1, kernel_size=2, stride=2), nn.BatchNorm3d(n1), - nn.ReLU(inplace = True), - nn.Conv3d(n1, n1, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n1, n1, kernel_size=3, padding=1), nn.BatchNorm3d(n1)) - + self.block12 = nn.Sequential( - nn.Conv3d(n1, n2, kernel_size = 3, stride = 2, padding = 1), + nn.Conv3d(n1, n2, kernel_size=3, stride=2, padding=1), nn.BatchNorm3d(n2), - nn.ReLU(inplace = True), - nn.Conv3d(n2, n2, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n2, n2, kernel_size=3, padding=1), nn.BatchNorm3d(n2)) - + self.block22 = nn.Sequential( - nn.Conv3d(n2, n2, kernel_size = 3, padding = 1), + nn.Conv3d(n2, n2, kernel_size=3, padding=1), nn.BatchNorm3d(n2), - nn.ReLU(inplace = True), - nn.Conv3d(n2, n2, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n2, n2, kernel_size=3, padding=1), nn.BatchNorm3d(n2)) - + self.block32 = nn.Sequential( - nn.ConvTranspose3d(n3, n2, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(n3, n2, kernel_size=2, stride=2), nn.BatchNorm3d(n2), - nn.ReLU(inplace = True), - nn.Conv3d(n2, n2, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n2, n2, kernel_size=3, padding=1), nn.BatchNorm3d(n2)) - + self.block23 = nn.Sequential( - nn.Conv3d(n2, n3, kernel_size = 3, stride = 2, padding = 1), + nn.Conv3d(n2, n3, kernel_size=3, stride=2, padding=1), nn.BatchNorm3d(n3), - nn.ReLU(inplace = True), - nn.Conv3d(n3, n3, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n3, n3, kernel_size=3, padding=1), nn.BatchNorm3d(n3)) self.block33 = nn.Sequential( - nn.Conv3d(n3, n3, kernel_size = 3, padding = 1), + nn.Conv3d(n3, n3, kernel_size=3, padding=1), nn.BatchNorm3d(n3), - nn.ReLU(inplace = True), - nn.Conv3d(n3, n3, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n3, n3, kernel_size=3, padding=1), nn.BatchNorm3d(n3)) - self.relu = nn.ReLU(inplace = True) + self.relu = nn.ReLU(inplace=True) self.p = p self.integrate = integrate @@ -146,25 +148,27 @@ def forward(self, x0, x1, x2, x3): return x0, self.relu(out1), self.relu(out2), self.relu(out3) + def hard_mining(neg_output, neg_labels, num_hard): _, idcs = torch.topk(neg_output, min(num_hard, len(neg_output))) neg_output = torch.index_select(neg_output, 0, idcs) neg_labels = torch.index_select(neg_labels, 0, idcs) return neg_output, neg_labels + class Loss(nn.Module): - def __init__(self, num_hard = 0): + def __init__(self, num_hard=0): super(Loss, self).__init__() self.sigmoid = nn.Sigmoid() self.classify_loss = nn.BCELoss() self.regress_loss = nn.SmoothL1Loss() self.num_hard = num_hard - def forward(self, output, labels, train = True): + def forward(self, output, labels, train=True): batch_size = labels.size(0) output = output.view(-1, 5) labels = labels.view(-1, 5) - + pos_idcs = labels[:, 0] > 0.5 pos_idcs = pos_idcs.unsqueeze(1).expand(pos_idcs.size(0), 5) pos_output = output[pos_idcs].view(-1, 5) @@ -173,15 +177,15 @@ def forward(self, output, labels, train = True): neg_idcs = labels[:, 0] < -0.5 neg_output = output[:, 0][neg_idcs] neg_labels = labels[:, 0][neg_idcs] - + if self.num_hard > 0 and train: neg_output, neg_labels = hard_mining(neg_output, neg_labels, self.num_hard * batch_size) neg_prob = self.sigmoid(neg_output) - #classify_loss = self.classify_loss( - # torch.cat((pos_prob, neg_prob), 0), - # torch.cat((pos_labels[:, 0], neg_labels + 1), 0)) - if len(pos_output)>0: + # classify_loss = self.classify_loss( + # torch.cat((pos_prob, neg_prob), 0), + # torch.cat((pos_labels[:, 0], neg_labels + 1), 0)) + if len(pos_output) > 0: pos_prob = self.sigmoid(pos_output[:, 0]) pz, ph, pw, pd = pos_output[:, 1], pos_output[:, 2], pos_output[:, 3], pos_output[:, 4] lz, lh, lw, ld = pos_labels[:, 1], pos_labels[:, 2], pos_labels[:, 3], pos_labels[:, 4] @@ -193,18 +197,18 @@ def forward(self, output, labels, train = True): self.regress_loss(pd, ld)] regress_losses_data = [l.data[0] for l in regress_losses] classify_loss = 0.5 * self.classify_loss( - pos_prob, pos_labels[:, 0]) + 0.5 * self.classify_loss( - neg_prob, neg_labels + 1) + pos_prob, pos_labels[:, 0]) + 0.5 * self.classify_loss( + neg_prob, neg_labels + 1) pos_correct = (pos_prob.data >= 0.5).sum() pos_total = len(pos_prob) else: - regress_losses = [0,0,0,0] - classify_loss = 0.5 * self.classify_loss( - neg_prob, neg_labels + 1) + regress_losses = [0, 0, 0, 0] + classify_loss = 0.5 * self.classify_loss( + neg_prob, neg_labels + 1) pos_correct = 0 pos_total = 0 - regress_losses_data = [0,0,0,0] + regress_losses_data = [0, 0, 0, 0] classify_loss_data = classify_loss.data[0] loss = classify_loss @@ -216,12 +220,13 @@ def forward(self, output, labels, train = True): return [loss, classify_loss_data] + regress_losses_data + [pos_correct, pos_total, neg_correct, neg_total] + class GetPBB(object): def __init__(self, config): self.stride = config['stride'] self.anchors = np.asarray(config['anchors']) - def __call__(self, output,thresh = -3, ismask=False): + def __call__(self, output, thresh=-3, ismask=False): stride = self.stride anchors = self.anchors output = np.copy(output) @@ -230,29 +235,31 @@ def __call__(self, output,thresh = -3, ismask=False): oz = np.arange(offset, offset + stride * (output_size[0] - 1) + 1, stride) oh = np.arange(offset, offset + stride * (output_size[1] - 1) + 1, stride) ow = np.arange(offset, offset + stride * (output_size[2] - 1) + 1, stride) - + output[:, :, :, :, 1] = oz.reshape((-1, 1, 1, 1)) + output[:, :, :, :, 1] * anchors.reshape((1, 1, 1, -1)) output[:, :, :, :, 2] = oh.reshape((1, -1, 1, 1)) + output[:, :, :, :, 2] * anchors.reshape((1, 1, 1, -1)) output[:, :, :, :, 3] = ow.reshape((1, 1, -1, 1)) + output[:, :, :, :, 3] * anchors.reshape((1, 1, 1, -1)) output[:, :, :, :, 4] = np.exp(output[:, :, :, :, 4]) * anchors.reshape((1, 1, 1, -1)) mask = output[..., 0] > thresh - xx,yy,zz,aa = np.where(mask) - - output = output[xx,yy,zz,aa] + xx, yy, zz, aa = np.where(mask) + + output = output[xx, yy, zz, aa] if ismask: - return output,[xx,yy,zz,aa] + return output, [xx, yy, zz, aa] else: return output - #output = output[output[:, 0] >= self.conf_th] - #bboxes = nms(output, self.nms_th) + # output = output[output[:, 0] >= self.conf_th] + # bboxes = nms(output, self.nms_th) + + def nms(output, nms_th): if len(output) == 0: return output output = output[np.argsort(-output[:, 0])] bboxes = [output[0]] - + for i in np.arange(1, len(output)): bbox = output[i] flag = 1 @@ -262,12 +269,12 @@ def nms(output, nms_th): break if flag == 1: bboxes.append(bbox) - + bboxes = np.asarray(bboxes, np.float32) return bboxes + def iou(box0, box1): - r0 = box0[3] / 2 s0 = box0[:3] - r0 e0 = box0[:3] + r0 @@ -284,8 +291,9 @@ def iou(box0, box1): union = box0[3] * box0[3] * box0[3] + box1[3] * box1[3] * box1[3] - intersection return intersection / union + def acc(pbb, lbb, conf_th, nms_th, detect_th): - pbb = pbb[pbb[:, 0] >= conf_th] + pbb = pbb[pbb[:, 0] >= conf_th] pbb = nms(pbb, nms_th) tp = [] @@ -297,63 +305,64 @@ def acc(pbb, lbb, conf_th, nms_th, detect_th): bestscore = 0 for i, l in enumerate(lbb): score = iou(p[1:5], l) - if score>bestscore: + if score > bestscore: bestscore = score besti = i if bestscore > detect_th: flag = 1 if l_flag[besti] == 0: l_flag[besti] = 1 - tp.append(np.concatenate([p,[bestscore]],0)) + tp.append(np.concatenate([p, [bestscore]], 0)) else: - fp.append(np.concatenate([p,[bestscore]],0)) + fp.append(np.concatenate([p, [bestscore]], 0)) if flag == 0: - fp.append(np.concatenate([p,[bestscore]],0)) - for i,l in enumerate(lbb): - if l_flag[i]==0: + fp.append(np.concatenate([p, [bestscore]], 0)) + for i, l in enumerate(lbb): + if l_flag[i] == 0: score = [] for p in pbb: - score.append(iou(p[1:5],l)) - if len(score)!=0: + score.append(iou(p[1:5], l)) + if len(score) != 0: bestscore = np.max(score) else: bestscore = 0 - if bestscore0: - fn = np.concatenate([fn,tp[fn_i,:5]]) + if len(fn_i) > 0: + fn = np.concatenate([fn, tp[fn_i, :5]]) else: fn = fn - if len(tp_in_topk)>0: + if len(tp_in_topk) > 0: tp = tp[tp_in_topk] else: tp = [] - if len(fp_in_topk)>0: + if len(fp_in_topk) > 0: fp = newallp[fp_in_topk] else: fp = [] - return tp, fp , fn + return tp, fp, fn diff --git a/main.py b/main.py index da479e6..27b15e8 100644 --- a/main.py +++ b/main.py @@ -1,22 +1,19 @@ -from preprocessing import full_prep -from config_submit import config as config_submit +from importlib import import_module +import pandas import torch -from torch.nn import DataParallel +from torch.autograd import Variable from torch.backends import cudnn +from torch.nn import DataParallel from torch.utils.data import DataLoader -from torch import optim -from torch.autograd import Variable -from layers import acc -from data_detector import DataBowl3Detector,collate +from config_submit import config as config_submit from data_classifier import DataBowl3Classifier - -from utils import * +from data_detector import DataBowl3Detector, collate +from preprocessing import full_prep from split_combine import SplitComb from test_detect import test_detect -from importlib import import_module -import pandas +from utils import * datapath = config_submit['datapath'] prep_result_path = config_submit['preprocess_result_path'] @@ -24,8 +21,8 @@ skip_detect = config_submit['skip_detect'] if not skip_prep: - testsplit = full_prep(datapath,prep_result_path, - n_worker = config_submit['n_worker_preprocessing'], + testsplit = full_prep(datapath, prep_result_path, + n_worker=config_submit['n_worker_preprocessing'], use_existing=config_submit['use_exsiting_preprocessing']) else: testsplit = os.listdir(datapath) @@ -43,22 +40,19 @@ bbox_result_path = './bbox_result' if not os.path.exists(bbox_result_path): os.mkdir(bbox_result_path) -#testsplit = [f.split('_clean')[0] for f in os.listdir(prep_result_path) if '_clean' in f] +# testsplit = [f.split('_clean')[0] for f in os.listdir(prep_result_path) if '_clean' in f] if not skip_detect: margin = 32 sidelen = 144 config1['datadir'] = prep_result_path - split_comber = SplitComb(sidelen,config1['max_stride'],config1['stride'],margin,pad_value= config1['pad_value']) + split_comber = SplitComb(sidelen, config1['max_stride'], config1['stride'], margin, pad_value=config1['pad_value']) - dataset = DataBowl3Detector(testsplit,config1,phase='test',split_comber=split_comber) - test_loader = DataLoader(dataset,batch_size = 1, - shuffle = False,num_workers = 32,pin_memory=False,collate_fn =collate) - - test_detect(test_loader, nod_net, get_pbb, bbox_result_path,config1,n_gpu=config_submit['n_gpu']) - - + dataset = DataBowl3Detector(testsplit, config1, phase='test', split_comber=split_comber) + test_loader = DataLoader(dataset, batch_size=1, + shuffle=False, num_workers=32, pin_memory=False, collate_fn=collate) + test_detect(test_loader, nod_net, get_pbb, bbox_result_path, config1, n_gpu=config_submit['n_gpu']) casemodel = import_module(config_submit['classifier_model'].split('.py')[0]) casenet = casemodel.CaseNet(topk=5) @@ -74,36 +68,34 @@ filename = config_submit['outputfile'] - -def test_casenet(model,testset): +def test_casenet(model, testset): data_loader = DataLoader( testset, - batch_size = 1, - shuffle = False, - num_workers = 32, + batch_size=1, + shuffle=False, + num_workers=32, pin_memory=True) - #model = model.cuda() + # model = model.cuda() model.eval() predlist = [] - - # weight = torch.from_numpy(np.ones_like(y).float().cuda() - for i,(x,coord) in enumerate(data_loader): + # weight = torch.from_numpy(np.ones_like(y).float().cuda() + for i, (x, coord) in enumerate(data_loader): coord = Variable(coord).cuda() x = Variable(x).cuda() - nodulePred,casePred,_ = model(x,coord) + nodulePred, casePred, _ = model(x, coord) predlist.append(casePred.data.cpu().numpy()) - #print([i,data_loader.dataset.split[i,1],casePred.data.cpu().numpy()]) + # print([i,data_loader.dataset.split[i,1],casePred.data.cpu().numpy()]) predlist = np.concatenate(predlist) - return predlist -config2['bboxpath'] = bbox_result_path -config2['datadir'] = prep_result_path + return predlist +config2['bboxpath'] = bbox_result_path +config2['datadir'] = prep_result_path -dataset = DataBowl3Classifier(testsplit, config2, phase = 'test') -predlist = test_casenet(casenet,dataset).T -anstable = np.concatenate([[testsplit],predlist],0).T +dataset = DataBowl3Classifier(testsplit, config2, phase='test') +predlist = test_casenet(casenet, dataset).T +anstable = np.concatenate([[testsplit], predlist], 0).T df = pandas.DataFrame(anstable) -df.columns={'id','cancer'} -df.to_csv(filename,index=False) +df.columns = {'id', 'cancer'} +df.to_csv(filename, index=False) diff --git a/net_classifier.py b/net_classifier.py index d467404..79afa21 100644 --- a/net_classifier.py +++ b/net_classifier.py @@ -1,15 +1,8 @@ +import numpy as np import torch from torch import nn + from layers import * -from torch.nn import DataParallel -from torch.backends import cudnn -from torch.utils.data import DataLoader -from torch import optim -from torch.autograd import Variable -from torch.utils.data import Dataset -from scipy.ndimage.interpolation import rotate -import numpy as np -import os config = {} config['topk'] = 5 @@ -22,9 +15,9 @@ config['padmask'] = False -config['crop_size'] = [96,96,96] -config['scaleLim'] = [0.85,1.15] -config['radiusLim'] = [6,100] +config['crop_size'] = [96, 96, 96] +config['scaleLim'] = [0.85, 1.15] +config['radiusLim'] = [6, 100] config['jitter_range'] = 0.15 config['isScale'] = True @@ -32,7 +25,7 @@ config['T'] = 1 config['topk'] = 5 config['stride'] = 4 -config['augtype'] = {'flip':True,'swap':False,'rotate':False,'scale':False} +config['augtype'] = {'flip': True, 'swap': False, 'rotate': False, 'scale': False} config['detect_th'] = 0.05 config['conf_th'] = -1 @@ -40,11 +33,12 @@ config['filling_value'] = 160 config['startepoch'] = 20 -config['lr_stage'] = np.array([50,100,140,160]) -config['lr'] = [0.01,0.001,0.0001,0.00001] +config['lr_stage'] = np.array([50, 100, 140, 160]) +config['lr'] = [0.01, 0.001, 0.0001, 0.00001] config['miss_ratio'] = 1 config['miss_thresh'] = 0.03 -config['anchors'] = [10,30,60] +config['anchors'] = [10, 30, 60] + class Net(nn.Module): def __init__(self): @@ -52,122 +46,123 @@ def __init__(self): # The first few layers consumes the most memory, so use simple convolution to save memory. # Call these layers preBlock, i.e., before the residual blocks of later layers. self.preBlock = nn.Sequential( - nn.Conv3d(1, 24, kernel_size = 3, padding = 1), + nn.Conv3d(1, 24, kernel_size=3, padding=1), nn.BatchNorm3d(24), - nn.ReLU(inplace = True), - nn.Conv3d(24, 24, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(24, 24, kernel_size=3, padding=1), nn.BatchNorm3d(24), - nn.ReLU(inplace = True)) - + nn.ReLU(inplace=True)) + # 3 poolings, each pooling downsamples the feature map by a factor 2. # 3 groups of blocks. The first block of each group has one pooling. - num_blocks_forw = [2,2,3,3] - num_blocks_back = [3,3] - self.featureNum_forw = [24,32,64,64,64] - self.featureNum_back = [128,64,64] + num_blocks_forw = [2, 2, 3, 3] + num_blocks_back = [3, 3] + self.featureNum_forw = [24, 32, 64, 64, 64] + self.featureNum_back = [128, 64, 64] for i in range(len(num_blocks_forw)): blocks = [] for j in range(num_blocks_forw[i]): if j == 0: - blocks.append(PostRes(self.featureNum_forw[i], self.featureNum_forw[i+1])) + blocks.append(PostRes(self.featureNum_forw[i], self.featureNum_forw[i + 1])) else: - blocks.append(PostRes(self.featureNum_forw[i+1], self.featureNum_forw[i+1])) + blocks.append(PostRes(self.featureNum_forw[i + 1], self.featureNum_forw[i + 1])) setattr(self, 'forw' + str(i + 1), nn.Sequential(*blocks)) - for i in range(len(num_blocks_back)): blocks = [] for j in range(num_blocks_back[i]): if j == 0: - if i==0: + if i == 0: addition = 3 else: addition = 0 - blocks.append(PostRes(self.featureNum_back[i+1]+self.featureNum_forw[i+2]+addition, self.featureNum_back[i])) + blocks.append(PostRes(self.featureNum_back[i + 1] + self.featureNum_forw[i + 2] + addition, + self.featureNum_back[i])) else: blocks.append(PostRes(self.featureNum_back[i], self.featureNum_back[i])) setattr(self, 'back' + str(i + 2), nn.Sequential(*blocks)) - self.maxpool1 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.maxpool2 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.maxpool3 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.maxpool4 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.unmaxpool1 = nn.MaxUnpool3d(kernel_size=2,stride=2) - self.unmaxpool2 = nn.MaxUnpool3d(kernel_size=2,stride=2) + self.maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.unmaxpool1 = nn.MaxUnpool3d(kernel_size=2, stride=2) + self.unmaxpool2 = nn.MaxUnpool3d(kernel_size=2, stride=2) self.path1 = nn.Sequential( - nn.ConvTranspose3d(64, 64, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), nn.BatchNorm3d(64), - nn.ReLU(inplace = True)) + nn.ReLU(inplace=True)) self.path2 = nn.Sequential( - nn.ConvTranspose3d(64, 64, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), nn.BatchNorm3d(64), - nn.ReLU(inplace = True)) - self.drop = nn.Dropout3d(p = 0.2, inplace = False) - self.output = nn.Sequential(nn.Conv3d(self.featureNum_back[0], 64, kernel_size = 1), + nn.ReLU(inplace=True)) + self.drop = nn.Dropout3d(p=0.2, inplace=False) + self.output = nn.Sequential(nn.Conv3d(self.featureNum_back[0], 64, kernel_size=1), nn.ReLU(), - #nn.Dropout3d(p = 0.3), - nn.Conv3d(64, 5 * len(config['anchors']), kernel_size = 1)) + # nn.Dropout3d(p = 0.3), + nn.Conv3d(64, 5 * len(config['anchors']), kernel_size=1)) def forward(self, x, coord): - out = self.preBlock(x)#16 - out_pool,indices0 = self.maxpool1(out) - out1 = self.forw1(out_pool)#32 - out1_pool,indices1 = self.maxpool2(out1) - out2 = self.forw2(out1_pool)#64 - #out2 = self.drop(out2) - out2_pool,indices2 = self.maxpool3(out2) - out3 = self.forw3(out2_pool)#96 - out3_pool,indices3 = self.maxpool4(out3) - out4 = self.forw4(out3_pool)#96 - #out4 = self.drop(out4) - + out = self.preBlock(x) # 16 + out_pool, indices0 = self.maxpool1(out) + out1 = self.forw1(out_pool) # 32 + out1_pool, indices1 = self.maxpool2(out1) + out2 = self.forw2(out1_pool) # 64 + # out2 = self.drop(out2) + out2_pool, indices2 = self.maxpool3(out2) + out3 = self.forw3(out2_pool) # 96 + out3_pool, indices3 = self.maxpool4(out3) + out4 = self.forw4(out3_pool) # 96 + # out4 = self.drop(out4) + rev3 = self.path1(out4) - comb3 = self.back3(torch.cat((rev3, out3), 1))#96+96 - #comb3 = self.drop(comb3) + comb3 = self.back3(torch.cat((rev3, out3), 1)) # 96+96 + # comb3 = self.drop(comb3) rev2 = self.path2(comb3) - - feat = self.back2(torch.cat((rev2, out2,coord), 1))#64+64 + + feat = self.back2(torch.cat((rev2, out2, coord), 1)) # 64+64 comb2 = self.drop(feat) out = self.output(comb2) size = out.size() out = out.view(out.size(0), out.size(1), -1) - #out = out.transpose(1, 4).transpose(1, 2).transpose(2, 3).contiguous() + # out = out.transpose(1, 4).transpose(1, 2).transpose(2, 3).contiguous() out = out.transpose(1, 2).contiguous().view(size[0], size[2], size[3], size[4], len(config['anchors']), 5) - #out = out.view(-1, 5) - return feat,out + # out = out.view(-1, 5) + return feat, out + - class CaseNet(nn.Module): - def __init__(self,topk): - super(CaseNet,self).__init__() - self.NoduleNet = Net() - self.fc1 = nn.Linear(128,64) - self.fc2 = nn.Linear(64,1) + def __init__(self, topk): + super(CaseNet, self).__init__() + self.NoduleNet = Net() + self.fc1 = nn.Linear(128, 64) + self.fc2 = nn.Linear(64, 1) self.pool = nn.MaxPool3d(kernel_size=2) self.dropout = nn.Dropout(0.5) self.baseline = nn.Parameter(torch.Tensor([-30.0]).float()) self.Relu = nn.ReLU() - def forward(self,xlist,coordlist): -# xlist: n x k x 1x 96 x 96 x 96 -# coordlist: n x k x 3 x 24 x 24 x 24 + + def forward(self, xlist, coordlist): + # xlist: n x k x 1x 96 x 96 x 96 + # coordlist: n x k x 3 x 24 x 24 x 24 xsize = xlist.size() corrdsize = coordlist.size() - xlist = xlist.view(-1,xsize[2],xsize[3],xsize[4],xsize[5]) - coordlist = coordlist.view(-1,corrdsize[2],corrdsize[3],corrdsize[4],corrdsize[5]) - - noduleFeat,nodulePred = self.NoduleNet(xlist,coordlist) - nodulePred = nodulePred.contiguous().view(corrdsize[0],corrdsize[1],-1) - - featshape = noduleFeat.size()#nk x 128 x 24 x 24 x24 - centerFeat = self.pool(noduleFeat[:,:,featshape[2]/2-1:featshape[2]/2+1, - featshape[3]/2-1:featshape[3]/2+1, - featshape[4]/2-1:featshape[4]/2+1]) - centerFeat = centerFeat[:,:,0,0,0] + xlist = xlist.view(-1, xsize[2], xsize[3], xsize[4], xsize[5]) + coordlist = coordlist.view(-1, corrdsize[2], corrdsize[3], corrdsize[4], corrdsize[5]) + + noduleFeat, nodulePred = self.NoduleNet(xlist, coordlist) + nodulePred = nodulePred.contiguous().view(corrdsize[0], corrdsize[1], -1) + + featshape = noduleFeat.size() # nk x 128 x 24 x 24 x24 + centerFeat = self.pool(noduleFeat[:, :, featshape[2] / 2 - 1:featshape[2] / 2 + 1, + featshape[3] / 2 - 1:featshape[3] / 2 + 1, + featshape[4] / 2 - 1:featshape[4] / 2 + 1]) + centerFeat = centerFeat[:, :, 0, 0, 0] out = self.dropout(centerFeat) out = self.Relu(self.fc1(out)) out = torch.sigmoid(self.fc2(out)) - out = out.view(xsize[0],xsize[1]) + out = out.view(xsize[0], xsize[1]) base_prob = torch.sigmoid(self.baseline) - casePred = 1-torch.prod(1-out,dim=1)*(1-base_prob.expand(out.size()[0])) - return nodulePred,casePred,out + casePred = 1 - torch.prod(1 - out, dim=1) * (1 - base_prob.expand(out.size()[0])) + return nodulePred, casePred, out diff --git a/net_detector.py b/net_detector.py index e993c7c..d648c22 100644 --- a/net_detector.py +++ b/net_detector.py @@ -1,9 +1,10 @@ import torch from torch import nn + from layers import * config = {} -config['anchors'] = [ 10.0, 30.0, 60.] +config['anchors'] = [10.0, 30.0, 60.] config['chanel'] = 1 config['crop_size'] = [128, 128, 128] config['stride'] = 4 @@ -17,7 +18,7 @@ config['num_hard'] = 2 config['bound_size'] = 12 config['reso'] = 1 -config['sizelim'] = 6. #mm +config['sizelim'] = 6. # mm config['sizelim2'] = 30 config['sizelim3'] = 40 config['aug_scale'] = True @@ -25,14 +26,15 @@ config['pad_value'] = 170 config['luna_raw'] = True config['cleanimg'] = True -config['augtype'] = {'flip':True,'swap':False,'scale':True,'rotate':False} -config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38','990fbe3f0a1b53878669967b9afd1441','adc3bbc63d40f8761c59be10f1e504c3'] +config['augtype'] = {'flip': True, 'swap': False, 'scale': True, 'rotate': False} +config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38', '990fbe3f0a1b53878669967b9afd1441', + 'adc3bbc63d40f8761c59be10f1e504c3'] +config['lr_stage'] = np.array([50, 100, 120]) +config['lr'] = [0.01, 0.001, 0.0001] -config['lr_stage'] = np.array([50,100,120]) -config['lr'] = [0.01,0.001,0.0001] -#config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38','d92998a73d4654a442e6d6ba15bbb827','990fbe3f0a1b53878669967b9afd1441','820245d8b211808bd18e78ff5be16fdb','adc3bbc63d40f8761c59be10f1e504c3', +# config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38','d92998a73d4654a442e6d6ba15bbb827','990fbe3f0a1b53878669967b9afd1441','820245d8b211808bd18e78ff5be16fdb','adc3bbc63d40f8761c59be10f1e504c3', # '417','077','188','876','057','087','130','468'] class Net(nn.Module): @@ -41,91 +43,92 @@ def __init__(self): # The first few layers consumes the most memory, so use simple convolution to save memory. # Call these layers preBlock, i.e., before the residual blocks of later layers. self.preBlock = nn.Sequential( - nn.Conv3d(1, 24, kernel_size = 3, padding = 1), + nn.Conv3d(1, 24, kernel_size=3, padding=1), nn.BatchNorm3d(24), - nn.ReLU(inplace = True), - nn.Conv3d(24, 24, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(24, 24, kernel_size=3, padding=1), nn.BatchNorm3d(24), - nn.ReLU(inplace = True)) - + nn.ReLU(inplace=True)) + # 3 poolings, each pooling downsamples the feature map by a factor 2. # 3 groups of blocks. The first block of each group has one pooling. - num_blocks_forw = [2,2,3,3] - num_blocks_back = [3,3] - self.featureNum_forw = [24,32,64,64,64] - self.featureNum_back = [128,64,64] + num_blocks_forw = [2, 2, 3, 3] + num_blocks_back = [3, 3] + self.featureNum_forw = [24, 32, 64, 64, 64] + self.featureNum_back = [128, 64, 64] for i in range(len(num_blocks_forw)): blocks = [] for j in range(num_blocks_forw[i]): if j == 0: - blocks.append(PostRes(self.featureNum_forw[i], self.featureNum_forw[i+1])) + blocks.append(PostRes(self.featureNum_forw[i], self.featureNum_forw[i + 1])) else: - blocks.append(PostRes(self.featureNum_forw[i+1], self.featureNum_forw[i+1])) + blocks.append(PostRes(self.featureNum_forw[i + 1], self.featureNum_forw[i + 1])) setattr(self, 'forw' + str(i + 1), nn.Sequential(*blocks)) - for i in range(len(num_blocks_back)): blocks = [] for j in range(num_blocks_back[i]): if j == 0: - if i==0: + if i == 0: addition = 3 else: addition = 0 - blocks.append(PostRes(self.featureNum_back[i+1]+self.featureNum_forw[i+2]+addition, self.featureNum_back[i])) + blocks.append(PostRes(self.featureNum_back[i + 1] + self.featureNum_forw[i + 2] + addition, + self.featureNum_back[i])) else: blocks.append(PostRes(self.featureNum_back[i], self.featureNum_back[i])) setattr(self, 'back' + str(i + 2), nn.Sequential(*blocks)) - self.maxpool1 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.maxpool2 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.maxpool3 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.maxpool4 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.unmaxpool1 = nn.MaxUnpool3d(kernel_size=2,stride=2) - self.unmaxpool2 = nn.MaxUnpool3d(kernel_size=2,stride=2) + self.maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.unmaxpool1 = nn.MaxUnpool3d(kernel_size=2, stride=2) + self.unmaxpool2 = nn.MaxUnpool3d(kernel_size=2, stride=2) self.path1 = nn.Sequential( - nn.ConvTranspose3d(64, 64, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), nn.BatchNorm3d(64), - nn.ReLU(inplace = True)) + nn.ReLU(inplace=True)) self.path2 = nn.Sequential( - nn.ConvTranspose3d(64, 64, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), nn.BatchNorm3d(64), - nn.ReLU(inplace = True)) - self.drop = nn.Dropout3d(p = 0.2, inplace = False) - self.output = nn.Sequential(nn.Conv3d(self.featureNum_back[0], 64, kernel_size = 1), + nn.ReLU(inplace=True)) + self.drop = nn.Dropout3d(p=0.2, inplace=False) + self.output = nn.Sequential(nn.Conv3d(self.featureNum_back[0], 64, kernel_size=1), nn.ReLU(), - #nn.Dropout3d(p = 0.3), - nn.Conv3d(64, 5 * len(config['anchors']), kernel_size = 1)) + # nn.Dropout3d(p = 0.3), + nn.Conv3d(64, 5 * len(config['anchors']), kernel_size=1)) def forward(self, x, coord): - out = self.preBlock(x)#16 - out_pool,indices0 = self.maxpool1(out) - out1 = self.forw1(out_pool)#32 - out1_pool,indices1 = self.maxpool2(out1) - out2 = self.forw2(out1_pool)#64 - #out2 = self.drop(out2) - out2_pool,indices2 = self.maxpool3(out2) - out3 = self.forw3(out2_pool)#96 - out3_pool,indices3 = self.maxpool4(out3) - out4 = self.forw4(out3_pool)#96 - #out4 = self.drop(out4) - + out = self.preBlock(x) # 16 + out_pool, indices0 = self.maxpool1(out) + out1 = self.forw1(out_pool) # 32 + out1_pool, indices1 = self.maxpool2(out1) + out2 = self.forw2(out1_pool) # 64 + # out2 = self.drop(out2) + out2_pool, indices2 = self.maxpool3(out2) + out3 = self.forw3(out2_pool) # 96 + out3_pool, indices3 = self.maxpool4(out3) + out4 = self.forw4(out3_pool) # 96 + # out4 = self.drop(out4) + rev3 = self.path1(out4) - comb3 = self.back3(torch.cat((rev3, out3), 1))#96+96 - #comb3 = self.drop(comb3) + comb3 = self.back3(torch.cat((rev3, out3), 1)) # 96+96 + # comb3 = self.drop(comb3) rev2 = self.path2(comb3) - - feat = self.back2(torch.cat((rev2, out2,coord), 1))#64+64 + + feat = self.back2(torch.cat((rev2, out2, coord), 1)) # 64+64 comb2 = self.drop(feat) out = self.output(comb2) size = out.size() out = out.view(out.size(0), out.size(1), -1) - #out = out.transpose(1, 4).transpose(1, 2).transpose(2, 3).contiguous() + # out = out.transpose(1, 4).transpose(1, 2).transpose(2, 3).contiguous() out = out.transpose(1, 2).contiguous().view(size[0], size[2], size[3], size[4], len(config['anchors']), 5) - #out = out.view(-1, 5) + # out = out.view(-1, 5) return out - + + def get_model(): net = Net() loss = Loss(config['num_hard']) diff --git a/preprocessing/__init__.py b/preprocessing/__init__.py index 9c31eaa..8b13789 100644 --- a/preprocessing/__init__.py +++ b/preprocessing/__init__.py @@ -1 +1 @@ -from full_prep import full_prep,savenpy + diff --git a/preprocessing/full_prep.py b/preprocessing/full_prep.py index d751600..31f6d04 100644 --- a/preprocessing/full_prep.py +++ b/preprocessing/full_prep.py @@ -1,127 +1,128 @@ import os +import warnings +from functools import partial +from multiprocessing import Pool + import numpy as np -from scipy.io import loadmat -import h5py from scipy.ndimage.interpolation import zoom -from skimage import measure -import warnings -from scipy.ndimage.morphology import binary_dilation,generate_binary_structure +from scipy.ndimage.morphology import binary_dilation, generate_binary_structure from skimage.morphology import convex_hull_image -from multiprocessing import Pool -from functools import partial from step1 import step1_python -import warnings + def process_mask(mask): convex_mask = np.copy(mask) for i_layer in range(convex_mask.shape[0]): - mask1 = np.ascontiguousarray(mask[i_layer]) - if np.sum(mask1)>0: + mask1 = np.ascontiguousarray(mask[i_layer]) + if np.sum(mask1) > 0: mask2 = convex_hull_image(mask1) - if np.sum(mask2)>2*np.sum(mask1): + if np.sum(mask2) > 2 * np.sum(mask1): mask2 = mask1 else: mask2 = mask1 convex_mask[i_layer] = mask2 - struct = generate_binary_structure(3,1) - dilatedMask = binary_dilation(convex_mask,structure=struct,iterations=10) + struct = generate_binary_structure(3, 1) + dilatedMask = binary_dilation(convex_mask, structure=struct, iterations=10) return dilatedMask + # def savenpy(id): id = 1 + def lumTrans(img): - lungwin = np.array([-1200.,600.]) - newimg = (img-lungwin[0])/(lungwin[1]-lungwin[0]) - newimg[newimg<0]=0 - newimg[newimg>1]=1 - newimg = (newimg*255).astype('uint8') + lungwin = np.array([-1200., 600.]) + newimg = (img - lungwin[0]) / (lungwin[1] - lungwin[0]) + newimg[newimg < 0] = 0 + newimg[newimg > 1] = 1 + newimg = (newimg * 255).astype('uint8') return newimg -def resample(imgs, spacing, new_spacing,order = 2): - if len(imgs.shape)==3: + +def resample(imgs, spacing, new_spacing, order=2): + if len(imgs.shape) == 3: new_shape = np.round(imgs.shape * spacing / new_spacing) true_spacing = spacing * imgs.shape / new_shape resize_factor = new_shape / imgs.shape with warnings.catch_warnings(): warnings.simplefilter("ignore") - imgs = zoom(imgs, resize_factor, mode = 'nearest',order=order) + imgs = zoom(imgs, resize_factor, mode='nearest', order=order) return imgs, true_spacing - elif len(imgs.shape)==4: + elif len(imgs.shape) == 4: n = imgs.shape[-1] newimg = [] for i in range(n): - slice = imgs[:,:,:,i] - newslice,true_spacing = resample(slice,spacing,new_spacing) + slice = imgs[:, :, :, i] + newslice, true_spacing = resample(slice, spacing, new_spacing) newimg.append(newslice) - newimg=np.transpose(np.array(newimg),[1,2,3,0]) - return newimg,true_spacing + newimg = np.transpose(np.array(newimg), [1, 2, 3, 0]) + return newimg, true_spacing else: raise ValueError('wrong shape') -def savenpy(id,filelist,prep_folder,data_path,use_existing=True): - resolution = np.array([1,1,1]) + +def savenpy(id, filelist, prep_folder, data_path, use_existing=True): + resolution = np.array([1, 1, 1]) name = filelist[id] if use_existing: - if os.path.exists(os.path.join(prep_folder,name+'_label.npy')) and os.path.exists(os.path.join(prep_folder,name+'_clean.npy')): - print(name+' had been done') + if os.path.exists(os.path.join(prep_folder, name + '_label.npy')) and os.path.exists( + os.path.join(prep_folder, name + '_clean.npy')): + print(name + ' had been done') return try: - im, m1, m2, spacing = step1_python(os.path.join(data_path,name)) - Mask = m1+m2 - - newshape = np.round(np.array(Mask.shape)*spacing/resolution) - xx,yy,zz= np.where(Mask) - box = np.array([[np.min(xx),np.max(xx)],[np.min(yy),np.max(yy)],[np.min(zz),np.max(zz)]]) - box = box*np.expand_dims(spacing,1)/np.expand_dims(resolution,1) + im, m1, m2, spacing = step1_python(os.path.join(data_path, name)) + Mask = m1 + m2 + + newshape = np.round(np.array(Mask.shape) * spacing / resolution) + xx, yy, zz = np.where(Mask) + box = np.array([[np.min(xx), np.max(xx)], [np.min(yy), np.max(yy)], [np.min(zz), np.max(zz)]]) + box = box * np.expand_dims(spacing, 1) / np.expand_dims(resolution, 1) box = np.floor(box).astype('int') margin = 5 - extendbox = np.vstack([np.max([[0,0,0],box[:,0]-margin],0),np.min([newshape,box[:,1]+2*margin],axis=0).T]).T + extendbox = np.vstack( + [np.max([[0, 0, 0], box[:, 0] - margin], 0), np.min([newshape, box[:, 1] + 2 * margin], axis=0).T]).T extendbox = extendbox.astype('int') - - convex_mask = m1 dm1 = process_mask(m1) dm2 = process_mask(m2) - dilatedMask = dm1+dm2 - Mask = m1+m2 + dilatedMask = dm1 + dm2 + Mask = m1 + m2 extramask = dilatedMask ^ Mask bone_thresh = 210 pad_value = 170 - im[np.isnan(im)]=-2000 + im[np.isnan(im)] = -2000 sliceim = lumTrans(im) - sliceim = sliceim*dilatedMask+pad_value*(1-dilatedMask).astype('uint8') - bones = sliceim*extramask>bone_thresh + sliceim = sliceim * dilatedMask + pad_value * (1 - dilatedMask).astype('uint8') + bones = sliceim * extramask > bone_thresh sliceim[bones] = pad_value - sliceim1,_ = resample(sliceim,spacing,resolution,order=1) - sliceim2 = sliceim1[extendbox[0,0]:extendbox[0,1], - extendbox[1,0]:extendbox[1,1], - extendbox[2,0]:extendbox[2,1]] - sliceim = sliceim2[np.newaxis,...] - np.save(os.path.join(prep_folder,name+'_clean'),sliceim) - np.save(os.path.join(prep_folder,name+'_label'),np.array([[0,0,0,0]])) + sliceim1, _ = resample(sliceim, spacing, resolution, order=1) + sliceim2 = sliceim1[extendbox[0, 0]:extendbox[0, 1], + extendbox[1, 0]:extendbox[1, 1], + extendbox[2, 0]:extendbox[2, 1]] + sliceim = sliceim2[np.newaxis, ...] + np.save(os.path.join(prep_folder, name + '_clean'), sliceim) + np.save(os.path.join(prep_folder, name + '_label'), np.array([[0, 0, 0, 0]])) except: - print('bug in '+name) + print('bug in ' + name) raise - print(name+' done') + print(name + ' done') + - -def full_prep(data_path,prep_folder,n_worker = None,use_existing=True): +def full_prep(data_path, prep_folder, n_worker=None, use_existing=True): warnings.filterwarnings("ignore") if not os.path.exists(prep_folder): os.mkdir(prep_folder) - print('starting preprocessing') pool = Pool(n_worker) filelist = [f for f in os.listdir(data_path)] - partial_savenpy = partial(savenpy,filelist=filelist,prep_folder=prep_folder, - data_path=data_path,use_existing=use_existing) + partial_savenpy = partial(savenpy, filelist=filelist, prep_folder=prep_folder, + data_path=data_path, use_existing=use_existing) N = len(filelist) - _=pool.map(partial_savenpy,range(N)) + _ = pool.map(partial_savenpy, range(N)) pool.close() pool.join() print('end preprocessing') diff --git a/preprocessing/step1.py b/preprocessing/step1.py index ae75cc5..1fd1f96 100644 --- a/preprocessing/step1.py +++ b/preprocessing/step1.py @@ -1,71 +1,74 @@ -import numpy as np # linear algebra -import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) -import dicom import os -import scipy.ndimage -import matplotlib.pyplot as plt - -from skimage import measure, morphology +import dicom +import matplotlib.pyplot as plt +import numpy as np # linear algebra +import scipy.ndimage +from skimage import measure def load_scan(path): slices = [dicom.read_file(path + '/' + s) for s in os.listdir(path)] - slices.sort(key = lambda x: float(x.ImagePositionPatient[2])) + slices.sort(key=lambda x: float(x.ImagePositionPatient[2])) if slices[0].ImagePositionPatient[2] == slices[1].ImagePositionPatient[2]: sec_num = 2; while slices[0].ImagePositionPatient[2] == slices[sec_num].ImagePositionPatient[2]: - sec_num = sec_num+1; + sec_num = sec_num + 1; slice_num = int(len(slices) / sec_num) - slices.sort(key = lambda x:float(x.InstanceNumber)) + slices.sort(key=lambda x: float(x.InstanceNumber)) slices = slices[0:slice_num] - slices.sort(key = lambda x:float(x.ImagePositionPatient[2])) + slices.sort(key=lambda x: float(x.ImagePositionPatient[2])) try: slice_thickness = np.abs(slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2]) except: slice_thickness = np.abs(slices[0].SliceLocation - slices[1].SliceLocation) - + for s in slices: s.SliceThickness = slice_thickness - + return slices + def get_pixels_hu(slices): image = np.stack([s.pixel_array for s in slices]) # Convert to int16 (from sometimes int16), # should be possible as values should always be low enough (<32k) image = image.astype(np.int16) - + # Convert to Hounsfield units (HU) - for slice_number in range(len(slices)): + for slice_number in range(len(slices)): intercept = slices[slice_number].RescaleIntercept slope = slices[slice_number].RescaleSlope - + if slope != 1: image[slice_number] = slope * image[slice_number].astype(np.float64) image[slice_number] = image[slice_number].astype(np.int16) - + image[slice_number] += np.int16(intercept) - - return np.array(image, dtype=np.int16), np.array([slices[0].SliceThickness] + slices[0].PixelSpacing, dtype=np.float32) + + return np.array(image, dtype=np.int16), np.array([slices[0].SliceThickness] + slices[0].PixelSpacing, + dtype=np.float32) + def binarize_per_slice(image, spacing, intensity_th=-600, sigma=1, area_th=30, eccen_th=0.99, bg_patch_size=10): bw = np.zeros(image.shape, dtype=bool) - + # prepare a mask, with all corner values set to nan image_size = image.shape[1] - grid_axis = np.linspace(-image_size/2+0.5, image_size/2-0.5, image_size) + grid_axis = np.linspace(-image_size / 2 + 0.5, image_size / 2 - 0.5, image_size) x, y = np.meshgrid(grid_axis, grid_axis) - d = (x**2+y**2)**0.5 - nan_mask = (d 0: @@ -87,22 +91,23 @@ def all_slice_analysis(bw, spacing, cut_num=0, vol_limit=[0.68, 8.2], area_th=6e # remove components access to corners mid = int(label.shape[2] / 2) bg_label = set([label[0, 0, 0], label[0, 0, -1], label[0, -1, 0], label[0, -1, -1], \ - label[-1-cut_num, 0, 0], label[-1-cut_num, 0, -1], label[-1-cut_num, -1, 0], label[-1-cut_num, -1, -1], \ - label[0, 0, mid], label[0, -1, mid], label[-1-cut_num, 0, mid], label[-1-cut_num, -1, mid]]) + label[-1 - cut_num, 0, 0], label[-1 - cut_num, 0, -1], label[-1 - cut_num, -1, 0], + label[-1 - cut_num, -1, -1], \ + label[0, 0, mid], label[0, -1, mid], label[-1 - cut_num, 0, mid], label[-1 - cut_num, -1, mid]]) for l in bg_label: label[label == l] = 0 - + # select components based on volume properties = measure.regionprops(label) for prop in properties: if prop.area * spacing.prod() < vol_limit[0] * 1e6 or prop.area * spacing.prod() > vol_limit[1] * 1e6: label[label == prop.label] = 0 - + # prepare a distance map for further analysis - x_axis = np.linspace(-label.shape[1]/2+0.5, label.shape[1]/2-0.5, label.shape[1]) * spacing[1] - y_axis = np.linspace(-label.shape[2]/2+0.5, label.shape[2]/2-0.5, label.shape[2]) * spacing[2] + x_axis = np.linspace(-label.shape[1] / 2 + 0.5, label.shape[1] / 2 - 0.5, label.shape[1]) * spacing[1] + y_axis = np.linspace(-label.shape[2] / 2 + 0.5, label.shape[2] / 2 - 0.5, label.shape[2]) * spacing[2] x, y = np.meshgrid(x_axis, y_axis) - d = (x**2+y**2)**0.5 + d = (x ** 2 + y ** 2) ** 0.5 vols = measure.regionprops(label) valid_label = set() # select components based on their area and distance to center axis on all slices @@ -113,12 +118,12 @@ def all_slice_analysis(bw, spacing, cut_num=0, vol_limit=[0.68, 8.2], area_th=6e for i in range(label.shape[0]): slice_area[i] = np.sum(single_vol[i]) * np.prod(spacing[1:3]) min_distance[i] = np.min(single_vol[i] * d + (1 - single_vol[i]) * np.max(d)) - + if np.average([min_distance[i] for i in range(label.shape[0]) if slice_area[i] > area_th]) < dist_th: valid_label.add(vol.label) - + bw = np.in1d(label, list(valid_label)).reshape(label.shape) - + # fill back the parts removed earlier if cut_num > 0: # bw1 is bw with removed slices, bw2 is a dilated version of bw, part of their intersection is returned as final mask @@ -132,14 +137,15 @@ def all_slice_analysis(bw, spacing, cut_num=0, vol_limit=[0.68, 8.2], area_th=6e l_list = list(set(np.unique(label)) - {0}) valid_l3 = set() for l in l_list: - indices = np.nonzero(label==l) + indices = np.nonzero(label == l) l3 = label3[indices[0][0], indices[1][0], indices[2][0]] if l3 > 0: valid_l3.add(l3) bw = np.in1d(label3, list(valid_l3)).reshape(label3.shape) - + return bw, len(valid_label) + def fill_hole(bw): # fill 3d holes label = measure.label(~bw) @@ -147,13 +153,11 @@ def fill_hole(bw): bg_label = set([label[0, 0, 0], label[0, 0, -1], label[0, -1, 0], label[0, -1, -1], \ label[-1, 0, 0], label[-1, 0, -1], label[-1, -1, 0], label[-1, -1, -1]]) bw = ~np.in1d(label, list(bg_label)).reshape(label.shape) - - return bw - + return bw -def two_lung_only(bw, spacing, max_iter=22, max_ratio=4.8): +def two_lung_only(bw, spacing, max_iter=22, max_ratio=4.8): def extract_main(bw, cover=0.95): for i in range(bw.shape[0]): current_slice = bw[i] @@ -163,22 +167,22 @@ def extract_main(bw, cover=0.95): area = [prop.area for prop in properties] count = 0 sum = 0 - while sum < np.sum(area)*cover: - sum = sum+area[count] - count = count+1 + while sum < np.sum(area) * cover: + sum = sum + area[count] + count = count + 1 filter = np.zeros(current_slice.shape, dtype=bool) for j in range(count): bb = properties[j].bbox filter[bb[0]:bb[2], bb[1]:bb[3]] = filter[bb[0]:bb[2], bb[1]:bb[3]] | properties[j].convex_image bw[i] = bw[i] & filter - + label = measure.label(bw) properties = measure.regionprops(label) properties.sort(key=lambda x: x.area, reverse=True) - bw = label==properties[0].label + bw = label == properties[0].label return bw - + def fill_2d_hole(bw): for i in range(bw.shape[0]): current_slice = bw[i] @@ -190,7 +194,7 @@ def fill_2d_hole(bw): bw[i] = current_slice return bw - + found_flag = False iter_count = 0 bw0 = np.copy(bw) @@ -198,33 +202,34 @@ def fill_2d_hole(bw): label = measure.label(bw, connectivity=2) properties = measure.regionprops(label) properties.sort(key=lambda x: x.area, reverse=True) - if len(properties) > 1 and properties[0].area/properties[1].area < max_ratio: + if len(properties) > 1 and properties[0].area / properties[1].area < max_ratio: found_flag = True bw1 = label == properties[0].label bw2 = label == properties[1].label else: bw = scipy.ndimage.binary_erosion(bw) iter_count = iter_count + 1 - + if found_flag: d1 = scipy.ndimage.morphology.distance_transform_edt(bw1 == False, sampling=spacing) d2 = scipy.ndimage.morphology.distance_transform_edt(bw2 == False, sampling=spacing) bw1 = bw0 & (d1 < d2) bw2 = bw0 & (d1 > d2) - + bw1 = extract_main(bw1) bw2 = extract_main(bw2) - + else: bw1 = bw0 bw2 = np.zeros(bw.shape).astype('bool') - + bw1 = fill_2d_hole(bw1) bw2 = fill_2d_hole(bw2) bw = bw1 | bw2 return bw1, bw2, bw + def step1_python(case_path): case = load_scan(case_path) case_pixels, spacing = get_pixels_hu(case) @@ -235,37 +240,38 @@ def step1_python(case_path): bw0 = np.copy(bw) while flag == 0 and cut_num < bw.shape[0]: bw = np.copy(bw0) - bw, flag = all_slice_analysis(bw, spacing, cut_num=cut_num, vol_limit=[0.68,7.5]) + bw, flag = all_slice_analysis(bw, spacing, cut_num=cut_num, vol_limit=[0.68, 7.5]) cut_num = cut_num + cut_step bw = fill_hole(bw) bw1, bw2, bw = two_lung_only(bw, spacing) return case_pixels, bw1, bw2, spacing - + + if __name__ == '__main__': INPUT_FOLDER = '/work/DataBowl3/stage1/stage1/' patients = os.listdir(INPUT_FOLDER) patients.sort() - case_pixels, m1, m2, spacing = step1_python(os.path.join(INPUT_FOLDER,patients[25])) + case_pixels, m1, m2, spacing = step1_python(os.path.join(INPUT_FOLDER, patients[25])) plt.imshow(m1[60]) plt.figure() plt.imshow(m2[60]) -# first_patient = load_scan(INPUT_FOLDER + patients[25]) +# first_patient = load_scan(INPUT_FOLDER + patients[25]) # first_patient_pixels, spacing = get_pixels_hu(first_patient) # plt.hist(first_patient_pixels.flatten(), bins=80, color='c') # plt.xlabel("Hounsfield Units (HU)") # plt.ylabel("Frequency") # plt.show() - + # # Show some slice in the middle # h = 80 # plt.imshow(first_patient_pixels[h], cmap=plt.cm.gray) # plt.show() - + # bw = binarize_per_slice(first_patient_pixels, spacing) # plt.imshow(bw[h], cmap=plt.cm.gray) # plt.show() - + # flag = 0 # cut_num = 0 # while flag == 0: @@ -273,11 +279,11 @@ def step1_python(case_path): # cut_num = cut_num + 1 # plt.imshow(bw[h], cmap=plt.cm.gray) # plt.show() - + # bw = fill_hole(bw) # plt.imshow(bw[h], cmap=plt.cm.gray) # plt.show() - + # bw1, bw2, bw = two_lung_only(bw, spacing) # plt.imshow(bw[h], cmap=plt.cm.gray) # plt.show() diff --git a/split_combine.py b/split_combine.py index 3083744..028d3f0 100644 --- a/split_combine.py +++ b/split_combine.py @@ -1,24 +1,25 @@ -import torch import numpy as np + + class SplitComb(): - def __init__(self,side_len,max_stride,stride,margin,pad_value): + def __init__(self, side_len, max_stride, stride, margin, pad_value): self.side_len = side_len self.max_stride = max_stride self.stride = stride self.margin = margin self.pad_value = pad_value - - def split(self, data, side_len = None, max_stride = None, margin = None): - if side_len==None: + + def split(self, data, side_len=None, max_stride=None, margin=None): + if side_len == None: side_len = self.side_len if max_stride == None: max_stride = self.max_stride if margin == None: margin = self.margin - - assert(side_len > margin) - assert(side_len % max_stride == 0) - assert(margin % max_stride == 0) + + assert (side_len > margin) + assert (side_len % max_stride == 0) + assert (margin % max_stride == 0) splits = [] _, z, h, w = data.shape @@ -26,14 +27,14 @@ def split(self, data, side_len = None, max_stride = None, margin = None): nz = int(np.ceil(float(z) / side_len)) nh = int(np.ceil(float(h) / side_len)) nw = int(np.ceil(float(w) / side_len)) - - nzhw = [nz,nh,nw] + + nzhw = [nz, nh, nw] self.nzhw = nzhw - - pad = [ [0, 0], - [margin, nz * side_len - z + margin], - [margin, nh * side_len - h + margin], - [margin, nw * side_len - w + margin]] + + pad = [[0, 0], + [margin, nz * side_len - z + margin], + [margin, nh * side_len - h + margin], + [margin, nw * side_len - w + margin]] data = np.pad(data, pad, 'edge') for iz in range(nz): @@ -50,11 +51,11 @@ def split(self, data, side_len = None, max_stride = None, margin = None): splits.append(split) splits = np.concatenate(splits, 0) - return splits,nzhw + return splits, nzhw + + def combine(self, output, nzhw=None, side_len=None, stride=None, margin=None): - def combine(self, output, nzhw = None, side_len=None, stride=None, margin=None): - - if side_len==None: + if side_len == None: side_len = self.side_len if stride == None: stride = self.stride @@ -65,9 +66,9 @@ def combine(self, output, nzhw = None, side_len=None, stride=None, margin=None): nh = self.nh nw = self.nw else: - nz,nh,nw = nzhw - assert(side_len % stride == 0) - assert(margin % stride == 0) + nz, nh, nw = nzhw + assert (side_len % stride == 0) + assert (margin % stride == 0) side_len /= stride margin /= stride @@ -97,4 +98,4 @@ def combine(self, output, nzhw = None, side_len=None, stride=None, margin=None): output[sz:ez, sh:eh, sw:ew] = split idx += 1 - return output + return output diff --git a/test_detect.py b/test_detect.py index 70163a0..4d9545a 100644 --- a/test_detect.py +++ b/test_detect.py @@ -1,23 +1,11 @@ -import argparse -import os import time -import numpy as np -from importlib import import_module -import shutil -from utils import * -import sys -from split_combine import SplitComb -import torch -from torch.nn import DataParallel -from torch.backends import cudnn -from torch.utils.data import DataLoader -from torch import optim from torch.autograd import Variable -from layers import acc +from utils import * + -def test_detect(data_loader, net, get_pbb, save_dir, config,n_gpu): +def test_detect(data_loader, net, get_pbb, save_dir, config, n_gpu): start_time = time.time() net.eval() split_comber = data_loader.dataset.split_comber @@ -36,41 +24,40 @@ def test_detect(data_loader, net, get_pbb, save_dir, config,n_gpu): isfeat = True n_per_run = n_gpu print(data.size()) - splitlist = range(0,len(data)+1,n_gpu) - if splitlist[-1]!=len(data): + splitlist = range(0, len(data) + 1, n_gpu) + if splitlist[-1] != len(data): splitlist.append(len(data)) outputlist = [] featurelist = [] - for i in range(len(splitlist)-1): - input = Variable(data[splitlist[i]:splitlist[i+1]], volatile = True).cuda() - inputcoord = Variable(coord[splitlist[i]:splitlist[i+1]], volatile = True).cuda() + for i in range(len(splitlist) - 1): + input = Variable(data[splitlist[i]:splitlist[i + 1]], volatile=True).cuda() + inputcoord = Variable(coord[splitlist[i]:splitlist[i + 1]], volatile=True).cuda() if isfeat: - output,feature = net(input,inputcoord) + output, feature = net(input, inputcoord) featurelist.append(feature.data.cpu().numpy()) else: - output = net(input,inputcoord) + output = net(input, inputcoord) outputlist.append(output.data.cpu().numpy()) - output = np.concatenate(outputlist,0) - output = split_comber.combine(output,nzhw=nzhw) + output = np.concatenate(outputlist, 0) + output = split_comber.combine(output, nzhw=nzhw) if isfeat: - feature = np.concatenate(featurelist,0).transpose([0,2,3,4,1])[:,:,:,:,:,np.newaxis] - feature = split_comber.combine(feature,sidelen)[...,0] + feature = np.concatenate(featurelist, 0).transpose([0, 2, 3, 4, 1])[:, :, :, :, :, np.newaxis] + feature = split_comber.combine(feature, sidelen)[..., 0] thresh = -3 - pbb,mask = get_pbb(output,thresh,ismask=True) + pbb, mask = get_pbb(output, thresh, ismask=True) if isfeat: - feature_selected = feature[mask[0],mask[1],mask[2]] - np.save(os.path.join(save_dir, shortname+'_feature.npy'), feature_selected) - #tp,fp,fn,_ = acc(pbb,lbb,0,0.1,0.1) - #print([len(tp),len(fp),len(fn)]) - print([i_name,shortname]) + feature_selected = feature[mask[0], mask[1], mask[2]] + np.save(os.path.join(save_dir, shortname + '_feature.npy'), feature_selected) + # tp,fp,fn,_ = acc(pbb,lbb,0,0.1,0.1) + # print([len(tp),len(fp),len(fn)]) + print([i_name, shortname]) e = time.time() - - np.save(os.path.join(save_dir, shortname+'_pbb.npy'), pbb) - np.save(os.path.join(save_dir, shortname+'_lbb.npy'), lbb) - end_time = time.time() + np.save(os.path.join(save_dir, shortname + '_pbb.npy'), pbb) + np.save(os.path.join(save_dir, shortname + '_lbb.npy'), lbb) + end_time = time.time() print('elapsed time is %3.2f seconds' % (end_time - start_time)) print diff --git a/training/classifier/adapt_ckpt.py b/training/classifier/adapt_ckpt.py index c2361e1..576c4e5 100644 --- a/training/classifier/adapt_ckpt.py +++ b/training/classifier/adapt_ckpt.py @@ -1,7 +1,8 @@ -import torch -import numpy as np import argparse from importlib import import_module + +import torch + parser = argparse.ArgumentParser(description='network surgery') parser.add_argument('--model1', '-m1', metavar='MODEL', default='base', help='model') @@ -24,6 +25,6 @@ args.lr_stage2 = config2['lr_stage'] args.lr_preset2 = config2['lr'] topk = config2['topk'] -case_net = casemodel.CaseNet(topk = topk,nodulenet=nod_net) +case_net = casemodel.CaseNet(topk=topk, nodulenet=nod_net) new_state_dict = case_net.state_dict() -torch.save({'state_dict': new_state_dict,'epoch':0},'results/start.ckpt') +torch.save({'state_dict': new_state_dict, 'epoch': 0}, 'results/start.ckpt') diff --git a/training/classifier/data_classifier.py b/training/classifier/data_classifier.py index d8335c9..d9c0dea 100644 --- a/training/classifier/data_classifier.py +++ b/training/classifier/data_classifier.py @@ -1,51 +1,51 @@ -import numpy as np -import torch -from torch.utils.data import Dataset import os import time -import collections -import random -from layers import iou -from scipy.ndimage import zoom import warnings -from scipy.ndimage.interpolation import rotate -from layers import nms,iou + +import numpy as np import pandas +import torch +from scipy.ndimage import zoom +from scipy.ndimage.interpolation import rotate +from torch.utils.data import Dataset + +from layers import nms, iou + class DataBowl3Classifier(Dataset): - def __init__(self, split, config, phase = 'train'): - assert(phase == 'train' or phase == 'val' or phase == 'test') - + def __init__(self, split, config, phase='train'): + assert (phase == 'train' or phase == 'val' or phase == 'test') + self.random_sample = config['random_sample'] self.T = config['T'] self.topk = config['topk'] self.crop_size = config['crop_size'] self.stride = config['stride'] - self.augtype = config['augtype'] - #self.labels = np.array(pandas.read_csv(config['labelfile'])) - + self.augtype = config['augtype'] + # self.labels = np.array(pandas.read_csv(config['labelfile'])) + datadir = config['datadir'] - bboxpath = config['bboxpath'] + bboxpath = config['bboxpath'] self.phase = phase self.candidate_box = [] self.pbb_label = [] - + idcs = split self.filenames = [os.path.join(datadir, '%s_clean.npy' % idx) for idx in idcs] labels = np.array(pandas.read_csv(config['labelfile'])) - if phase !='test': - self.yset = np.array([labels[labels[:,0]==f.split('-')[0].split('_')[0],1] for f in split]).astype('int') + if phase != 'test': + self.yset = np.array([labels[labels[:, 0] == f.split('-')[0].split('_')[0], 1] for f in split]).astype( + 'int') idcs = [f.split('-')[0] for f in idcs] - - + for idx in idcs: - pbb = np.load(os.path.join(bboxpath,idx+'_pbb.npy')) - pbb = pbb[pbb[:,0]>config['conf_th']] + pbb = np.load(os.path.join(bboxpath, idx + '_pbb.npy')) + pbb = pbb[pbb[:, 0] > config['conf_th']] pbb = nms(pbb, config['nms_th']) - - lbb = np.load(os.path.join(bboxpath,idx+'_lbb.npy')) + + lbb = np.load(os.path.join(bboxpath, idx + '_lbb.npy')) pbb_label = [] - + for p in pbb: isnod = False for l in lbb: @@ -54,165 +54,171 @@ def __init__(self, split, config, phase = 'train'): isnod = True break pbb_label.append(isnod) -# if idx.startswith() + # if idx.startswith() self.candidate_box.append(pbb) self.pbb_label.append(np.array(pbb_label)) - self.crop = simpleCrop(config,phase) - + self.crop = simpleCrop(config, phase) - def __getitem__(self, idx,split=None): + def __getitem__(self, idx, split=None): t = time.time() - np.random.seed(int(str(t%1)[2:7]))#seed according to time + np.random.seed(int(str(t % 1)[2:7])) # seed according to time pbb = self.candidate_box[idx] pbb_label = self.pbb_label[idx] - conf_list = pbb[:,0] + conf_list = pbb[:, 0] T = self.T topk = self.topk img = np.load(self.filenames[idx]) - if self.random_sample and self.phase=='train': - chosenid = sample(conf_list,topk,T=T) - #chosenid = conf_list.argsort()[::-1][:topk] + if self.random_sample and self.phase == 'train': + chosenid = sample(conf_list, topk, T=T) + # chosenid = conf_list.argsort()[::-1][:topk] else: chosenid = conf_list.argsort()[::-1][:topk] - croplist = np.zeros([topk,1,self.crop_size[0],self.crop_size[1],self.crop_size[2]]).astype('float32') - coordlist = np.zeros([topk,3,self.crop_size[0]/self.stride,self.crop_size[1]/self.stride,self.crop_size[2]/self.stride]).astype('float32') - padmask = np.concatenate([np.ones(len(chosenid)),np.zeros(self.topk-len(chosenid))]) + croplist = np.zeros([topk, 1, self.crop_size[0], self.crop_size[1], self.crop_size[2]]).astype('float32') + coordlist = np.zeros([topk, 3, self.crop_size[0] / self.stride, self.crop_size[1] / self.stride, + self.crop_size[2] / self.stride]).astype('float32') + padmask = np.concatenate([np.ones(len(chosenid)), np.zeros(self.topk - len(chosenid))]) isnodlist = np.zeros([topk]) - - for i,id in enumerate(chosenid): - target = pbb[id,1:] + for i, id in enumerate(chosenid): + target = pbb[id, 1:] isnod = pbb_label[id] - crop,coord = self.crop(img,target) - if self.phase=='train': - crop,coord = augment(crop,coord, - ifflip=self.augtype['flip'],ifrotate=self.augtype['rotate'], - ifswap = self.augtype['swap']) + crop, coord = self.crop(img, target) + if self.phase == 'train': + crop, coord = augment(crop, coord, + ifflip=self.augtype['flip'], ifrotate=self.augtype['rotate'], + ifswap=self.augtype['swap']) crop = crop.astype(np.float32) croplist[i] = crop coordlist[i] = coord isnodlist[i] = isnod - - if self.phase!='test': + + if self.phase != 'test': y = np.array([self.yset[idx]]) - return torch.from_numpy(croplist).float(), torch.from_numpy(coordlist).float(), torch.from_numpy(isnodlist).int(), torch.from_numpy(y) + return torch.from_numpy(croplist).float(), torch.from_numpy(coordlist).float(), torch.from_numpy( + isnodlist).int(), torch.from_numpy(y) else: - return torch.from_numpy(croplist).float(), torch.from_numpy(coordlist).float(), torch.from_numpy(isnodlist).int() + return torch.from_numpy(croplist).float(), torch.from_numpy(coordlist).float(), torch.from_numpy( + isnodlist).int() + def __len__(self): if self.phase != 'test': return len(self.candidate_box) else: return len(self.candidate_box) - - + class simpleCrop(): - def __init__(self,config,phase): + def __init__(self, config, phase): self.crop_size = config['crop_size'] self.scaleLim = config['scaleLim'] self.radiusLim = config['radiusLim'] self.jitter_range = config['jitter_range'] - self.isScale = config['augtype']['scale'] and phase=='train' + self.isScale = config['augtype']['scale'] and phase == 'train' self.stride = config['stride'] self.filling_value = config['filling_value'] self.phase = phase - - def __call__(self,imgs,target): + + def __call__(self, imgs, target): if self.isScale: radiusLim = self.radiusLim scaleLim = self.scaleLim - scaleRange = [np.min([np.max([(radiusLim[0]/target[3]),scaleLim[0]]),1]) - ,np.max([np.min([(radiusLim[1]/target[3]),scaleLim[1]]),1])] - scale = np.random.rand()*(scaleRange[1]-scaleRange[0])+scaleRange[0] - crop_size = (np.array(self.crop_size).astype('float')/scale).astype('int') + scaleRange = [np.min([np.max([(radiusLim[0] / target[3]), scaleLim[0]]), 1]) + , np.max([np.min([(radiusLim[1] / target[3]), scaleLim[1]]), 1])] + scale = np.random.rand() * (scaleRange[1] - scaleRange[0]) + scaleRange[0] + crop_size = (np.array(self.crop_size).astype('float') / scale).astype('int') else: crop_size = np.array(self.crop_size).astype('int') - if self.phase=='train': - jitter_range = target[3]*self.jitter_range - jitter = (np.random.rand(3)-0.5)*jitter_range + if self.phase == 'train': + jitter_range = target[3] * self.jitter_range + jitter = (np.random.rand(3) - 0.5) * jitter_range else: jitter = 0 - start = (target[:3]- crop_size/2 + jitter).astype('int') - pad = [[0,0]] + start = (target[:3] - crop_size / 2 + jitter).astype('int') + pad = [[0, 0]] for i in range(3): - if start[i]<0: + if start[i] < 0: leftpad = -start[i] start[i] = 0 else: leftpad = 0 - if start[i]+crop_size[i]>imgs.shape[i+1]: - rightpad = start[i]+crop_size[i]-imgs.shape[i+1] + if start[i] + crop_size[i] > imgs.shape[i + 1]: + rightpad = start[i] + crop_size[i] - imgs.shape[i + 1] else: rightpad = 0 - pad.append([leftpad,rightpad]) - imgs = np.pad(imgs,pad,'constant',constant_values =self.filling_value) - crop = imgs[:,start[0]:start[0]+crop_size[0],start[1]:start[1]+crop_size[1],start[2]:start[2]+crop_size[2]] - - normstart = np.array(start).astype('float32')/np.array(imgs.shape[1:])-0.5 - normsize = np.array(crop_size).astype('float32')/np.array(imgs.shape[1:]) - xx,yy,zz = np.meshgrid(np.linspace(normstart[0],normstart[0]+normsize[0],self.crop_size[0]/self.stride), - np.linspace(normstart[1],normstart[1]+normsize[1],self.crop_size[1]/self.stride), - np.linspace(normstart[2],normstart[2]+normsize[2],self.crop_size[2]/self.stride),indexing ='ij') - coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32') + pad.append([leftpad, rightpad]) + imgs = np.pad(imgs, pad, 'constant', constant_values=self.filling_value) + crop = imgs[:, start[0]:start[0] + crop_size[0], start[1]:start[1] + crop_size[1], + start[2]:start[2] + crop_size[2]] + + normstart = np.array(start).astype('float32') / np.array(imgs.shape[1:]) - 0.5 + normsize = np.array(crop_size).astype('float32') / np.array(imgs.shape[1:]) + xx, yy, zz = np.meshgrid(np.linspace(normstart[0], normstart[0] + normsize[0], self.crop_size[0] / self.stride), + np.linspace(normstart[1], normstart[1] + normsize[1], self.crop_size[1] / self.stride), + np.linspace(normstart[2], normstart[2] + normsize[2], self.crop_size[2] / self.stride), + indexing='ij') + coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32') if self.isScale: with warnings.catch_warnings(): warnings.simplefilter("ignore") - crop = zoom(crop,[1,scale,scale,scale],order=1) - newpad = self.crop_size[0]-crop.shape[1:][0] - if newpad<0: - crop = crop[:,:-newpad,:-newpad,:-newpad] - elif newpad>0: - pad2 = [[0,0],[0,newpad],[0,newpad],[0,newpad]] - crop = np.pad(crop,pad2,'constant',constant_values =self.filling_value) - - return crop,coord - -def sample(conf,N,T=1): - if len(conf)>N: + crop = zoom(crop, [1, scale, scale, scale], order=1) + newpad = self.crop_size[0] - crop.shape[1:][0] + if newpad < 0: + crop = crop[:, :-newpad, :-newpad, :-newpad] + elif newpad > 0: + pad2 = [[0, 0], [0, newpad], [0, newpad], [0, newpad]] + crop = np.pad(crop, pad2, 'constant', constant_values=self.filling_value) + + return crop, coord + + +def sample(conf, N, T=1): + if len(conf) > N: target = range(len(conf)) chosen_list = [] for i in range(N): - chosenidx = sampleone(target,conf,T) + chosenidx = sampleone(target, conf, T) chosen_list.append(target[chosenidx]) target.pop(chosenidx) conf = np.delete(conf, chosenidx) - return chosen_list else: return np.arange(len(conf)) -def sampleone(target,conf,T): - assert len(conf)>1 - p = softmax(conf/T) - p = np.max([np.ones_like(p)*0.00001,p],axis=0) - p = p/np.sum(p) - return np.random.choice(np.arange(len(target)),size=1,replace = False, p=p)[0] + +def sampleone(target, conf, T): + assert len(conf) > 1 + p = softmax(conf / T) + p = np.max([np.ones_like(p) * 0.00001, p], axis=0) + p = p / np.sum(p) + return np.random.choice(np.arange(len(target)), size=1, replace=False, p=p)[0] + def softmax(x): maxx = np.max(x) - return np.exp(x-maxx)/np.sum(np.exp(x-maxx)) + return np.exp(x - maxx) / np.sum(np.exp(x - maxx)) -def augment(sample, coord, ifflip = True, ifrotate=True, ifswap = True): +def augment(sample, coord, ifflip=True, ifrotate=True, ifswap=True): # angle1 = np.random.rand()*180 if ifrotate: validrot = False counter = 0 - angle1 = np.random.rand()*180 + angle1 = np.random.rand() * 180 size = np.array(sample.shape[2:4]).astype('float') - rotmat = np.array([[np.cos(angle1/180*np.pi),-np.sin(angle1/180*np.pi)],[np.sin(angle1/180*np.pi),np.cos(angle1/180*np.pi)]]) - sample = rotate(sample,angle1,axes=(2,3),reshape=False) + rotmat = np.array([[np.cos(angle1 / 180 * np.pi), -np.sin(angle1 / 180 * np.pi)], + [np.sin(angle1 / 180 * np.pi), np.cos(angle1 / 180 * np.pi)]]) + sample = rotate(sample, angle1, axes=(2, 3), reshape=False) if ifswap: - if sample.shape[1]==sample.shape[2] and sample.shape[1]==sample.shape[3]: + if sample.shape[1] == sample.shape[2] and sample.shape[1] == sample.shape[3]: axisorder = np.random.permutation(3) - sample = np.transpose(sample,np.concatenate([[0],axisorder+1])) - coord = np.transpose(coord,np.concatenate([[0],axisorder+1])) - + sample = np.transpose(sample, np.concatenate([[0], axisorder + 1])) + coord = np.transpose(coord, np.concatenate([[0], axisorder + 1])) + if ifflip: - flipid = np.array([np.random.randint(2),np.random.randint(2),np.random.randint(2)])*2-1 - sample = np.ascontiguousarray(sample[:,::flipid[0],::flipid[1],::flipid[2]]) - coord = np.ascontiguousarray(coord[:,::flipid[0],::flipid[1],::flipid[2]]) - return sample, coord + flipid = np.array([np.random.randint(2), np.random.randint(2), np.random.randint(2)]) * 2 - 1 + sample = np.ascontiguousarray(sample[:, ::flipid[0], ::flipid[1], ::flipid[2]]) + coord = np.ascontiguousarray(coord[:, ::flipid[0], ::flipid[1], ::flipid[2]]) + return sample, coord diff --git a/training/classifier/data_detector.py b/training/classifier/data_detector.py index 199c69b..893ef50 100644 --- a/training/classifier/data_detector.py +++ b/training/classifier/data_detector.py @@ -1,105 +1,108 @@ -import numpy as np -import torch -from torch.utils.data import Dataset -import os -import time import collections +import os import random -from layers import iou -from scipy.ndimage import zoom +import time import warnings + +import numpy as np +import torch +from scipy.ndimage import zoom from scipy.ndimage.interpolation import rotate -from scipy.ndimage.morphology import binary_dilation,generate_binary_structure +from scipy.ndimage.morphology import binary_dilation, generate_binary_structure +from torch.utils.data import Dataset + class DataBowl3Detector(Dataset): - def __init__(self, split, config, phase = 'train',split_comber=None): - assert(phase == 'train' or phase == 'val' or phase == 'test') + def __init__(self, split, config, phase='train', split_comber=None): + assert (phase == 'train' or phase == 'val' or phase == 'test') self.phase = phase - self.max_stride = config['max_stride'] - self.stride = config['stride'] - sizelim = config['sizelim']/config['reso'] - sizelim2 = config['sizelim2']/config['reso'] - sizelim3 = config['sizelim3']/config['reso'] + self.max_stride = config['max_stride'] + self.stride = config['stride'] + sizelim = config['sizelim'] / config['reso'] + sizelim2 = config['sizelim2'] / config['reso'] + sizelim3 = config['sizelim3'] / config['reso'] self.blacklist = config['blacklist'] self.isScale = config['aug_scale'] self.r_rand = config['r_rand_crop'] self.augtype = config['augtype'] data_dir = config['datadir'] self.pad_value = config['pad_value'] - + self.split_comber = split_comber idcs = split - if phase!='test': + if phase != 'test': idcs = [f for f in idcs if f not in self.blacklist] self.filenames = [os.path.join(data_dir, '%s_clean.npy' % idx) for idx in idcs] - self.kagglenames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0])>20] - self.lunanames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0])<20] - + self.kagglenames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0]) > 20] + self.lunanames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0]) < 20] + labels = [] - + for idx in idcs: - l = np.load(os.path.join(data_dir, '%s_label.npy' %idx)) - if np.all(l==0): - l=np.array([]) + l = np.load(os.path.join(data_dir, '%s_label.npy' % idx)) + if np.all(l == 0): + l = np.array([]) labels.append(l) self.sample_bboxes = labels - + self.bboxes = [] for i, l in enumerate(labels): - if len(l) > 0 : + if len(l) > 0: for t in l: - if t[3]>sizelim: - self.bboxes.append([np.concatenate([[i],t])]) - if t[3]>sizelim2: - self.bboxes+=[[np.concatenate([[i],t])]]*2 - if t[3]>sizelim3: - self.bboxes+=[[np.concatenate([[i],t])]]*4 - if len(self.bboxes)>0: - self.bboxes = np.concatenate(self.bboxes,axis = 0) + if t[3] > sizelim: + self.bboxes.append([np.concatenate([[i], t])]) + if t[3] > sizelim2: + self.bboxes += [[np.concatenate([[i], t])]] * 2 + if t[3] > sizelim3: + self.bboxes += [[np.concatenate([[i], t])]] * 4 + if len(self.bboxes) > 0: + self.bboxes = np.concatenate(self.bboxes, axis=0) else: self.bboxes = np.array(self.bboxes) self.crop = Crop(config) self.label_mapping = LabelMapping(config, self.phase) - def __getitem__(self, idx,split=None): + def __getitem__(self, idx, split=None): t = time.time() - np.random.seed(int(str(t%1)[2:7]))#seed according to time + np.random.seed(int(str(t % 1)[2:7])) # seed according to time - isRandomImg = False - if self.phase !='test': - if idx>=len(self.bboxes): + isRandomImg = False + if self.phase != 'test': + if idx >= len(self.bboxes): isRandom = True - idx = idx%len(self.bboxes) + idx = idx % len(self.bboxes) isRandomImg = np.random.randint(2) else: isRandom = False else: isRandom = False - + if self.phase != 'test': if not isRandomImg: bbox = self.bboxes[idx] filename = self.filenames[int(bbox[0])] imgs = np.load(filename) bboxes = self.sample_bboxes[int(bbox[0])] - isScale = self.augtype['scale'] and (self.phase=='train') - sample, target, bboxes, coord = self.crop(imgs, bbox[1:], bboxes,isScale,isRandom) - if self.phase=='train' and not isRandom: - sample, target, bboxes, coord = augment(sample, target, bboxes, coord, - ifflip = self.augtype['flip'], ifrotate=self.augtype['rotate'], ifswap = self.augtype['swap']) + isScale = self.augtype['scale'] and (self.phase == 'train') + sample, target, bboxes, coord = self.crop(imgs, bbox[1:], bboxes, isScale, isRandom) + if self.phase == 'train' and not isRandom: + sample, target, bboxes, coord = augment(sample, target, bboxes, coord, + ifflip=self.augtype['flip'], + ifrotate=self.augtype['rotate'], + ifswap=self.augtype['swap']) else: randimid = np.random.randint(len(self.kagglenames)) filename = self.kagglenames[randimid] imgs = np.load(filename) bboxes = self.sample_bboxes[randimid] - isScale = self.augtype['scale'] and (self.phase=='train') - sample, target, bboxes, coord = self.crop(imgs, [], bboxes,isScale=False,isRand=True) + isScale = self.augtype['scale'] and (self.phase == 'train') + sample, target, bboxes, coord = self.crop(imgs, [], bboxes, isScale=False, isRand=True) label = self.label_mapping(sample.shape[1:], target, bboxes) sample = sample.astype(np.float32) - #if filename in self.kagglenames: - # label[label==-1]=0 + # if filename in self.kagglenames: + # label[label==-1]=0 return torch.from_numpy(sample), torch.from_numpy(label), coord else: imgs = np.load(self.filenames[idx]) @@ -108,147 +111,151 @@ def __getitem__(self, idx,split=None): pz = int(np.ceil(float(nz) / self.stride)) * self.stride ph = int(np.ceil(float(nh) / self.stride)) * self.stride pw = int(np.ceil(float(nw) / self.stride)) * self.stride - imgs = np.pad(imgs, [[0,0],[0, pz - nz], [0, ph - nh], [0, pw - nw]], 'constant',constant_value = self.pad_value) - xx,yy,zz = np.meshgrid(np.linspace(-0.5,0.5,imgs.shape[1]/self.stride), - np.linspace(-0.5,0.5,imgs.shape[2]/self.stride), - np.linspace(-0.5,0.5,imgs.shape[3]/self.stride),indexing ='ij') - coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32') + imgs = np.pad(imgs, [[0, 0], [0, pz - nz], [0, ph - nh], [0, pw - nw]], 'constant', + constant_value=self.pad_value) + xx, yy, zz = np.meshgrid(np.linspace(-0.5, 0.5, imgs.shape[1] / self.stride), + np.linspace(-0.5, 0.5, imgs.shape[2] / self.stride), + np.linspace(-0.5, 0.5, imgs.shape[3] / self.stride), indexing='ij') + coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32') imgs, nzhw = self.split_comber.split(imgs) coord2, nzhw2 = self.split_comber.split(coord, - side_len = self.split_comber.side_len/self.stride, - max_stride = self.split_comber.max_stride/self.stride, - margin = self.split_comber.margin/self.stride) - assert np.all(nzhw==nzhw2) + side_len=self.split_comber.side_len / self.stride, + max_stride=self.split_comber.max_stride / self.stride, + margin=self.split_comber.margin / self.stride) + assert np.all(nzhw == nzhw2) return torch.from_numpy(imgs), bboxes, torch.from_numpy(coord2), np.array(nzhw) def __len__(self): if self.phase == 'train': - return len(self.bboxes)/(1-self.r_rand) - elif self.phase =='val': + return len(self.bboxes) / (1 - self.r_rand) + elif self.phase == 'val': return len(self.bboxes) else: return len(self.sample_bboxes) - - -def augment(sample, target, bboxes, coord, ifflip = True, ifrotate=True, ifswap = True): + + +def augment(sample, target, bboxes, coord, ifflip=True, ifrotate=True, ifswap=True): # angle1 = np.random.rand()*180 if ifrotate: validrot = False counter = 0 while not validrot: newtarget = np.copy(target) - angle1 = (np.random.rand()-0.5)*20 + angle1 = (np.random.rand() - 0.5) * 20 size = np.array(sample.shape[2:4]).astype('float') - rotmat = np.array([[np.cos(angle1/180*np.pi),-np.sin(angle1/180*np.pi)],[np.sin(angle1/180*np.pi),np.cos(angle1/180*np.pi)]]) - newtarget[1:3] = np.dot(rotmat,target[1:3]-size/2)+size/2 - if np.all(newtarget[:3]>target[3]) and np.all(newtarget[:3]< np.array(sample.shape[1:4])-newtarget[3]): + rotmat = np.array([[np.cos(angle1 / 180 * np.pi), -np.sin(angle1 / 180 * np.pi)], + [np.sin(angle1 / 180 * np.pi), np.cos(angle1 / 180 * np.pi)]]) + newtarget[1:3] = np.dot(rotmat, target[1:3] - size / 2) + size / 2 + if np.all(newtarget[:3] > target[3]) and np.all(newtarget[:3] < np.array(sample.shape[1:4]) - newtarget[3]): validrot = True target = newtarget - sample = rotate(sample,angle1,axes=(2,3),reshape=False) - coord = rotate(coord,angle1,axes=(2,3),reshape=False) + sample = rotate(sample, angle1, axes=(2, 3), reshape=False) + coord = rotate(coord, angle1, axes=(2, 3), reshape=False) for box in bboxes: - box[1:3] = np.dot(rotmat,box[1:3]-size/2)+size/2 + box[1:3] = np.dot(rotmat, box[1:3] - size / 2) + size / 2 else: counter += 1 - if counter ==3: + if counter == 3: break if ifswap: - if sample.shape[1]==sample.shape[2] and sample.shape[1]==sample.shape[3]: + if sample.shape[1] == sample.shape[2] and sample.shape[1] == sample.shape[3]: axisorder = np.random.permutation(3) - sample = np.transpose(sample,np.concatenate([[0],axisorder+1])) - coord = np.transpose(coord,np.concatenate([[0],axisorder+1])) + sample = np.transpose(sample, np.concatenate([[0], axisorder + 1])) + coord = np.transpose(coord, np.concatenate([[0], axisorder + 1])) target[:3] = target[:3][axisorder] - bboxes[:,:3] = bboxes[:,:3][:,axisorder] - + bboxes[:, :3] = bboxes[:, :3][:, axisorder] + if ifflip: -# flipid = np.array([np.random.randint(2),np.random.randint(2),np.random.randint(2)])*2-1 - flipid = np.array([1,np.random.randint(2),np.random.randint(2)])*2-1 - sample = np.ascontiguousarray(sample[:,::flipid[0],::flipid[1],::flipid[2]]) - coord = np.ascontiguousarray(coord[:,::flipid[0],::flipid[1],::flipid[2]]) + # flipid = np.array([np.random.randint(2),np.random.randint(2),np.random.randint(2)])*2-1 + flipid = np.array([1, np.random.randint(2), np.random.randint(2)]) * 2 - 1 + sample = np.ascontiguousarray(sample[:, ::flipid[0], ::flipid[1], ::flipid[2]]) + coord = np.ascontiguousarray(coord[:, ::flipid[0], ::flipid[1], ::flipid[2]]) for ax in range(3): - if flipid[ax]==-1: - target[ax] = np.array(sample.shape[ax+1])-target[ax] - bboxes[:,ax]= np.array(sample.shape[ax+1])-bboxes[:,ax] - return sample, target, bboxes, coord + if flipid[ax] == -1: + target[ax] = np.array(sample.shape[ax + 1]) - target[ax] + bboxes[:, ax] = np.array(sample.shape[ax + 1]) - bboxes[:, ax] + return sample, target, bboxes, coord + class Crop(object): def __init__(self, config): self.crop_size = config['crop_size'] self.bound_size = config['bound_size'] self.stride = config['stride'] - self.pad_value = config['pad_value'] + self.pad_value = config['pad_value'] - def __call__(self, imgs, target, bboxes,isScale=False,isRand=False): + def __call__(self, imgs, target, bboxes, isScale=False, isRand=False): if isScale: - radiusLim = [8.,100.] - scaleLim = [0.75,1.25] - scaleRange = [np.min([np.max([(radiusLim[0]/target[3]),scaleLim[0]]),1]) - ,np.max([np.min([(radiusLim[1]/target[3]),scaleLim[1]]),1])] - scale = np.random.rand()*(scaleRange[1]-scaleRange[0])+scaleRange[0] - crop_size = (np.array(self.crop_size).astype('float')/scale).astype('int') + radiusLim = [8., 100.] + scaleLim = [0.75, 1.25] + scaleRange = [np.min([np.max([(radiusLim[0] / target[3]), scaleLim[0]]), 1]) + , np.max([np.min([(radiusLim[1] / target[3]), scaleLim[1]]), 1])] + scale = np.random.rand() * (scaleRange[1] - scaleRange[0]) + scaleRange[0] + crop_size = (np.array(self.crop_size).astype('float') / scale).astype('int') else: - crop_size=self.crop_size + crop_size = self.crop_size bound_size = self.bound_size target = np.copy(target) bboxes = np.copy(bboxes) - + start = [] for i in range(3): if not isRand: r = target[3] / 2 - s = np.floor(target[i] - r)+ 1 - bound_size - e = np.ceil (target[i] + r)+ 1 + bound_size - crop_size[i] + s = np.floor(target[i] - r) + 1 - bound_size + e = np.ceil(target[i] + r) + 1 + bound_size - crop_size[i] else: - s = np.max([imgs.shape[i+1]-crop_size[i]/2,imgs.shape[i+1]/2+bound_size]) - e = np.min([crop_size[i]/2, imgs.shape[i+1]/2-bound_size]) - target = np.array([np.nan,np.nan,np.nan,np.nan]) - if s>e: - start.append(np.random.randint(e,s))#! + s = np.max([imgs.shape[i + 1] - crop_size[i] / 2, imgs.shape[i + 1] / 2 + bound_size]) + e = np.min([crop_size[i] / 2, imgs.shape[i + 1] / 2 - bound_size]) + target = np.array([np.nan, np.nan, np.nan, np.nan]) + if s > e: + start.append(np.random.randint(e, s)) # ! else: - start.append(int(target[i])-crop_size[i]/2+np.random.randint(-bound_size/2,bound_size/2)) - - - normstart = np.array(start).astype('float32')/np.array(imgs.shape[1:])-0.5 - normsize = np.array(crop_size).astype('float32')/np.array(imgs.shape[1:]) - xx,yy,zz = np.meshgrid(np.linspace(normstart[0],normstart[0]+normsize[0],self.crop_size[0]/self.stride), - np.linspace(normstart[1],normstart[1]+normsize[1],self.crop_size[1]/self.stride), - np.linspace(normstart[2],normstart[2]+normsize[2],self.crop_size[2]/self.stride),indexing ='ij') - coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32') + start.append(int(target[i]) - crop_size[i] / 2 + np.random.randint(-bound_size / 2, bound_size / 2)) + + normstart = np.array(start).astype('float32') / np.array(imgs.shape[1:]) - 0.5 + normsize = np.array(crop_size).astype('float32') / np.array(imgs.shape[1:]) + xx, yy, zz = np.meshgrid(np.linspace(normstart[0], normstart[0] + normsize[0], self.crop_size[0] / self.stride), + np.linspace(normstart[1], normstart[1] + normsize[1], self.crop_size[1] / self.stride), + np.linspace(normstart[2], normstart[2] + normsize[2], self.crop_size[2] / self.stride), + indexing='ij') + coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32') pad = [] - pad.append([0,0]) + pad.append([0, 0]) for i in range(3): - leftpad = max(0,-start[i]) - rightpad = max(0,start[i]+crop_size[i]-imgs.shape[i+1]) - pad.append([leftpad,rightpad]) + leftpad = max(0, -start[i]) + rightpad = max(0, start[i] + crop_size[i] - imgs.shape[i + 1]) + pad.append([leftpad, rightpad]) crop = imgs[:, - max(start[0],0):min(start[0] + crop_size[0],imgs.shape[1]), - max(start[1],0):min(start[1] + crop_size[1],imgs.shape[2]), - max(start[2],0):min(start[2] + crop_size[2],imgs.shape[3])] - crop = np.pad(crop,pad,'constant',constant_values =self.pad_value) + max(start[0], 0):min(start[0] + crop_size[0], imgs.shape[1]), + max(start[1], 0):min(start[1] + crop_size[1], imgs.shape[2]), + max(start[2], 0):min(start[2] + crop_size[2], imgs.shape[3])] + crop = np.pad(crop, pad, 'constant', constant_values=self.pad_value) for i in range(3): - target[i] = target[i] - start[i] + target[i] = target[i] - start[i] for i in range(len(bboxes)): for j in range(3): - bboxes[i][j] = bboxes[i][j] - start[j] - + bboxes[i][j] = bboxes[i][j] - start[j] + if isScale: with warnings.catch_warnings(): warnings.simplefilter("ignore") - crop = zoom(crop,[1,scale,scale,scale],order=1) - newpad = self.crop_size[0]-crop.shape[1:][0] - if newpad<0: - crop = crop[:,:-newpad,:-newpad,:-newpad] - elif newpad>0: - pad2 = [[0,0],[0,newpad],[0,newpad],[0,newpad]] - crop = np.pad(crop,pad2,'constant',constant_values =self.pad_value) + crop = zoom(crop, [1, scale, scale, scale], order=1) + newpad = self.crop_size[0] - crop.shape[1:][0] + if newpad < 0: + crop = crop[:, :-newpad, :-newpad, :-newpad] + elif newpad > 0: + pad2 = [[0, 0], [0, newpad], [0, newpad], [0, newpad]] + crop = np.pad(crop, pad2, 'constant', constant_values=self.pad_value) for i in range(4): - target[i] = target[i]*scale + target[i] = target[i] * scale for i in range(len(bboxes)): for j in range(4): - bboxes[i][j] = bboxes[i][j]*scale + bboxes[i][j] = bboxes[i][j] * scale return crop, target, bboxes, coord - + + class LabelMapping(object): def __init__(self, config, phase): self.stride = np.array(config['stride']) @@ -261,20 +268,19 @@ def __init__(self, config, phase): elif phase == 'val': self.th_pos = config['th_pos_val'] - def __call__(self, input_size, target, bboxes): stride = self.stride num_neg = self.num_neg th_neg = self.th_neg anchors = self.anchors th_pos = self.th_pos - struct = generate_binary_structure(3,1) - + struct = generate_binary_structure(3, 1) + output_size = [] for i in range(3): - assert(input_size[i] % stride == 0) + assert (input_size[i] % stride == 0) output_size.append(input_size[i] / stride) - + label = np.zeros(output_size + [len(anchors), 5], np.float32) offset = ((stride.astype('float')) - 1) / 2 oz = np.arange(offset, offset + stride * (output_size[0] - 1) + 1, stride) @@ -285,10 +291,10 @@ def __call__(self, input_size, target, bboxes): for i, anchor in enumerate(anchors): iz, ih, iw = select_samples(bbox, anchor, th_neg, oz, oh, ow) label[iz, ih, iw, i, 0] = 1 - label[:,:,:, i, 0] = binary_dilation(label[:,:,:, i, 0].astype('bool'),structure=struct,iterations=1).astype('float32') - - - label = label-1 + label[:, :, :, i, 0] = binary_dilation(label[:, :, :, i, 0].astype('bool'), structure=struct, + iterations=1).astype('float32') + + label = label - 1 if self.phase == 'train' and self.num_neg > 0: neg_z, neg_h, neg_w, neg_a = np.where(label[:, :, :, :, 0] == -1) @@ -310,7 +316,7 @@ def __call__(self, input_size, target, bboxes): ih = np.concatenate(ih, 0) iw = np.concatenate(iw, 0) ia = np.concatenate(ia, 0) - flag = True + flag = True if len(iz) == 0: pos = [] for i in range(3): @@ -326,7 +332,8 @@ def __call__(self, input_size, target, bboxes): dw = (target[2] - ow[pos[2]]) / anchors[pos[3]] dd = np.log(target[3] / anchors[pos[3]]) label[pos[0], pos[1], pos[2], pos[3], :] = [1, dz, dh, dw, dd] - return label + return label + def select_samples(bbox, anchor, th, oz, oh, ow): z, h, w, d = bbox @@ -339,12 +346,12 @@ def select_samples(bbox, anchor, th, oz, oh, ow): e = z + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap) mz = np.logical_and(oz >= s, oz <= e) iz = np.where(mz)[0] - + s = h - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap) e = h + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap) mh = np.logical_and(oh >= s, oh <= e) ih = np.where(mh)[0] - + s = w - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap) e = w + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap) mw = np.logical_and(ow >= s, ow <= e) @@ -352,7 +359,7 @@ def select_samples(bbox, anchor, th, oz, oh, ow): if len(iz) == 0 or len(ih) == 0 or len(iw) == 0: return np.zeros((0,), np.int64), np.zeros((0,), np.int64), np.zeros((0,), np.int64) - + lz, lh, lw = len(iz), len(ih), len(iw) iz = iz.reshape((-1, 1, 1)) ih = ih.reshape((1, -1, 1)) @@ -363,36 +370,37 @@ def select_samples(bbox, anchor, th, oz, oh, ow): centers = np.concatenate([ oz[iz].reshape((-1, 1)), oh[ih].reshape((-1, 1)), - ow[iw].reshape((-1, 1))], axis = 1) - + ow[iw].reshape((-1, 1))], axis=1) + r0 = anchor / 2 s0 = centers - r0 e0 = centers + r0 - + r1 = d / 2 s1 = bbox[:3] - r1 s1 = s1.reshape((1, -1)) e1 = bbox[:3] + r1 e1 = e1.reshape((1, -1)) - + overlap = np.maximum(0, np.minimum(e0, e1) - np.maximum(s0, s1)) - + intersection = overlap[:, 0] * overlap[:, 1] * overlap[:, 2] union = anchor * anchor * anchor + d * d * d - intersection iou = intersection / union mask = iou >= th - #if th > 0.4: - # if np.sum(mask) == 0: - # print(['iou not large', iou.max()]) - # else: - # print(['iou large', iou[mask]]) + # if th > 0.4: + # if np.sum(mask) == 0: + # print(['iou not large', iou.max()]) + # else: + # print(['iou large', iou[mask]]) iz = iz[mask] ih = ih[mask] iw = iw[mask] return iz, ih, iw + def collate(batch): if torch.is_tensor(batch[0]): return [b.unsqueeze(0) for b in batch] @@ -403,4 +411,3 @@ def collate(batch): elif isinstance(batch[0], collections.Iterable): transposed = zip(*batch) return [collate(samples) for samples in transposed] - diff --git a/training/classifier/layers.py b/training/classifier/layers.py index 939b7be..aadb473 100644 --- a/training/classifier/layers.py +++ b/training/classifier/layers.py @@ -2,20 +2,20 @@ import torch from torch import nn -import math + class PostRes2d(nn.Module): - def __init__(self, n_in, n_out, stride = 1): + def __init__(self, n_in, n_out, stride=1): super(PostRes2d, self).__init__() - self.conv1 = nn.Conv2d(n_in, n_out, kernel_size = 3, stride = stride, padding = 1) + self.conv1 = nn.Conv2d(n_in, n_out, kernel_size=3, stride=stride, padding=1) self.bn1 = nn.BatchNorm2d(n_out) - self.relu = nn.ReLU(inplace = True) - self.conv2 = nn.Conv2d(n_out, n_out, kernel_size = 3, padding = 1) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(n_out, n_out, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(n_out) if stride != 1 or n_out != n_in: self.shortcut = nn.Sequential( - nn.Conv2d(n_in, n_out, kernel_size = 1, stride = stride), + nn.Conv2d(n_in, n_out, kernel_size=1, stride=stride), nn.BatchNorm2d(n_out)) else: self.shortcut = None @@ -29,23 +29,24 @@ def forward(self, x): out = self.relu(out) out = self.conv2(out) out = self.bn2(out) - + out += residual out = self.relu(out) return out - + + class PostRes(nn.Module): - def __init__(self, n_in, n_out, stride = 1): + def __init__(self, n_in, n_out, stride=1): super(PostRes, self).__init__() - self.conv1 = nn.Conv3d(n_in, n_out, kernel_size = 3, stride = stride, padding = 1) + self.conv1 = nn.Conv3d(n_in, n_out, kernel_size=3, stride=stride, padding=1) self.bn1 = nn.BatchNorm3d(n_out) - self.relu = nn.ReLU(inplace = True) - self.conv2 = nn.Conv3d(n_out, n_out, kernel_size = 3, padding = 1) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv3d(n_out, n_out, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm3d(n_out) if stride != 1 or n_out != n_in: self.shortcut = nn.Sequential( - nn.Conv3d(n_in, n_out, kernel_size = 1, stride = stride), + nn.Conv3d(n_in, n_out, kernel_size=1, stride=stride), nn.BatchNorm3d(n_out)) else: self.shortcut = None @@ -59,72 +60,73 @@ def forward(self, x): out = self.relu(out) out = self.conv2(out) out = self.bn2(out) - + out += residual out = self.relu(out) return out + class Rec3(nn.Module): - def __init__(self, n0, n1, n2, n3, p = 0.0, integrate = True): + def __init__(self, n0, n1, n2, n3, p=0.0, integrate=True): super(Rec3, self).__init__() - + self.block01 = nn.Sequential( - nn.Conv3d(n0, n1, kernel_size = 3, stride = 2, padding = 1), + nn.Conv3d(n0, n1, kernel_size=3, stride=2, padding=1), nn.BatchNorm3d(n1), - nn.ReLU(inplace = True), - nn.Conv3d(n1, n1, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n1, n1, kernel_size=3, padding=1), nn.BatchNorm3d(n1)) self.block11 = nn.Sequential( - nn.Conv3d(n1, n1, kernel_size = 3, padding = 1), + nn.Conv3d(n1, n1, kernel_size=3, padding=1), nn.BatchNorm3d(n1), - nn.ReLU(inplace = True), - nn.Conv3d(n1, n1, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n1, n1, kernel_size=3, padding=1), nn.BatchNorm3d(n1)) - + self.block21 = nn.Sequential( - nn.ConvTranspose3d(n2, n1, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(n2, n1, kernel_size=2, stride=2), nn.BatchNorm3d(n1), - nn.ReLU(inplace = True), - nn.Conv3d(n1, n1, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n1, n1, kernel_size=3, padding=1), nn.BatchNorm3d(n1)) - + self.block12 = nn.Sequential( - nn.Conv3d(n1, n2, kernel_size = 3, stride = 2, padding = 1), + nn.Conv3d(n1, n2, kernel_size=3, stride=2, padding=1), nn.BatchNorm3d(n2), - nn.ReLU(inplace = True), - nn.Conv3d(n2, n2, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n2, n2, kernel_size=3, padding=1), nn.BatchNorm3d(n2)) - + self.block22 = nn.Sequential( - nn.Conv3d(n2, n2, kernel_size = 3, padding = 1), + nn.Conv3d(n2, n2, kernel_size=3, padding=1), nn.BatchNorm3d(n2), - nn.ReLU(inplace = True), - nn.Conv3d(n2, n2, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n2, n2, kernel_size=3, padding=1), nn.BatchNorm3d(n2)) - + self.block32 = nn.Sequential( - nn.ConvTranspose3d(n3, n2, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(n3, n2, kernel_size=2, stride=2), nn.BatchNorm3d(n2), - nn.ReLU(inplace = True), - nn.Conv3d(n2, n2, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n2, n2, kernel_size=3, padding=1), nn.BatchNorm3d(n2)) - + self.block23 = nn.Sequential( - nn.Conv3d(n2, n3, kernel_size = 3, stride = 2, padding = 1), + nn.Conv3d(n2, n3, kernel_size=3, stride=2, padding=1), nn.BatchNorm3d(n3), - nn.ReLU(inplace = True), - nn.Conv3d(n3, n3, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n3, n3, kernel_size=3, padding=1), nn.BatchNorm3d(n3)) self.block33 = nn.Sequential( - nn.Conv3d(n3, n3, kernel_size = 3, padding = 1), + nn.Conv3d(n3, n3, kernel_size=3, padding=1), nn.BatchNorm3d(n3), - nn.ReLU(inplace = True), - nn.Conv3d(n3, n3, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n3, n3, kernel_size=3, padding=1), nn.BatchNorm3d(n3)) - self.relu = nn.ReLU(inplace = True) + self.relu = nn.ReLU(inplace=True) self.p = p self.integrate = integrate @@ -146,25 +148,27 @@ def forward(self, x0, x1, x2, x3): return x0, self.relu(out1), self.relu(out2), self.relu(out3) + def hard_mining(neg_output, neg_labels, num_hard): _, idcs = torch.topk(neg_output, min(num_hard, len(neg_output))) neg_output = torch.index_select(neg_output, 0, idcs) neg_labels = torch.index_select(neg_labels, 0, idcs) return neg_output, neg_labels + class Loss(nn.Module): - def __init__(self, num_hard = 0): + def __init__(self, num_hard=0): super(Loss, self).__init__() self.sigmoid = nn.Sigmoid() self.classify_loss = nn.BCELoss() self.regress_loss = nn.SmoothL1Loss() self.num_hard = num_hard - def forward(self, output, labels, train = True): + def forward(self, output, labels, train=True): batch_size = labels.size(0) output = output.view(-1, 5) labels = labels.view(-1, 5) - + pos_idcs = labels[:, 0] > 0.5 pos_idcs = pos_idcs.unsqueeze(1).expand(pos_idcs.size(0), 5) pos_output = output[pos_idcs].view(-1, 5) @@ -173,15 +177,15 @@ def forward(self, output, labels, train = True): neg_idcs = labels[:, 0] < -0.5 neg_output = output[:, 0][neg_idcs] neg_labels = labels[:, 0][neg_idcs] - + if self.num_hard > 0 and train: neg_output, neg_labels = hard_mining(neg_output, neg_labels, self.num_hard * batch_size) neg_prob = self.sigmoid(neg_output) - #classify_loss = self.classify_loss( - # torch.cat((pos_prob, neg_prob), 0), - # torch.cat((pos_labels[:, 0], neg_labels + 1), 0)) - if len(pos_output)>0: + # classify_loss = self.classify_loss( + # torch.cat((pos_prob, neg_prob), 0), + # torch.cat((pos_labels[:, 0], neg_labels + 1), 0)) + if len(pos_output) > 0: pos_prob = self.sigmoid(pos_output[:, 0]) pz, ph, pw, pd = pos_output[:, 1], pos_output[:, 2], pos_output[:, 3], pos_output[:, 4] lz, lh, lw, ld = pos_labels[:, 1], pos_labels[:, 2], pos_labels[:, 3], pos_labels[:, 4] @@ -193,18 +197,18 @@ def forward(self, output, labels, train = True): self.regress_loss(pd, ld)] regress_losses_data = [l.data[0] for l in regress_losses] classify_loss = 0.5 * self.classify_loss( - pos_prob, pos_labels[:, 0]) + 0.5 * self.classify_loss( - neg_prob, neg_labels + 1) + pos_prob, pos_labels[:, 0]) + 0.5 * self.classify_loss( + neg_prob, neg_labels + 1) pos_correct = (pos_prob.data >= 0.5).sum() pos_total = len(pos_prob) else: - regress_losses = [0,0,0,0] - classify_loss = 0.5 * self.classify_loss( - neg_prob, neg_labels + 1) + regress_losses = [0, 0, 0, 0] + classify_loss = 0.5 * self.classify_loss( + neg_prob, neg_labels + 1) pos_correct = 0 pos_total = 0 - regress_losses_data = [0,0,0,0] + regress_losses_data = [0, 0, 0, 0] classify_loss_data = classify_loss.data[0] loss = classify_loss @@ -216,12 +220,13 @@ def forward(self, output, labels, train = True): return [loss, classify_loss_data] + regress_losses_data + [pos_correct, pos_total, neg_correct, neg_total] + class GetPBB(object): def __init__(self, config): self.stride = config['stride'] self.anchors = np.asarray(config['anchors']) - def __call__(self, output,thresh = -3, ismask=False): + def __call__(self, output, thresh=-3, ismask=False): stride = self.stride anchors = self.anchors output = np.copy(output) @@ -230,29 +235,31 @@ def __call__(self, output,thresh = -3, ismask=False): oz = np.arange(offset, offset + stride * (output_size[0] - 1) + 1, stride) oh = np.arange(offset, offset + stride * (output_size[1] - 1) + 1, stride) ow = np.arange(offset, offset + stride * (output_size[2] - 1) + 1, stride) - + output[:, :, :, :, 1] = oz.reshape((-1, 1, 1, 1)) + output[:, :, :, :, 1] * anchors.reshape((1, 1, 1, -1)) output[:, :, :, :, 2] = oh.reshape((1, -1, 1, 1)) + output[:, :, :, :, 2] * anchors.reshape((1, 1, 1, -1)) output[:, :, :, :, 3] = ow.reshape((1, 1, -1, 1)) + output[:, :, :, :, 3] * anchors.reshape((1, 1, 1, -1)) output[:, :, :, :, 4] = np.exp(output[:, :, :, :, 4]) * anchors.reshape((1, 1, 1, -1)) mask = output[..., 0] > thresh - xx,yy,zz,aa = np.where(mask) - - output = output[xx,yy,zz,aa] + xx, yy, zz, aa = np.where(mask) + + output = output[xx, yy, zz, aa] if ismask: - return output,[xx,yy,zz,aa] + return output, [xx, yy, zz, aa] else: return output - #output = output[output[:, 0] >= self.conf_th] - #bboxes = nms(output, self.nms_th) + # output = output[output[:, 0] >= self.conf_th] + # bboxes = nms(output, self.nms_th) + + def nms(output, nms_th): if len(output) == 0: return output output = output[np.argsort(-output[:, 0])] bboxes = [output[0]] - + for i in np.arange(1, len(output)): bbox = output[i] flag = 1 @@ -262,12 +269,12 @@ def nms(output, nms_th): break if flag == 1: bboxes.append(bbox) - + bboxes = np.asarray(bboxes, np.float32) return bboxes + def iou(box0, box1): - r0 = box0[3] / 2 s0 = box0[:3] - r0 e0 = box0[:3] + r0 @@ -284,8 +291,9 @@ def iou(box0, box1): union = box0[3] * box0[3] * box0[3] + box1[3] * box1[3] * box1[3] - intersection return intersection / union + def acc(pbb, lbb, conf_th, nms_th, detect_th): - pbb = pbb[pbb[:, 0] >= conf_th] + pbb = pbb[pbb[:, 0] >= conf_th] pbb = nms(pbb, nms_th) tp = [] @@ -297,63 +305,64 @@ def acc(pbb, lbb, conf_th, nms_th, detect_th): bestscore = 0 for i, l in enumerate(lbb): score = iou(p[1:5], l) - if score>bestscore: + if score > bestscore: bestscore = score besti = i if bestscore > detect_th: flag = 1 if l_flag[besti] == 0: l_flag[besti] = 1 - tp.append(np.concatenate([p,[bestscore]],0)) + tp.append(np.concatenate([p, [bestscore]], 0)) else: - fp.append(np.concatenate([p,[bestscore]],0)) + fp.append(np.concatenate([p, [bestscore]], 0)) if flag == 0: - fp.append(np.concatenate([p,[bestscore]],0)) - for i,l in enumerate(lbb): - if l_flag[i]==0: + fp.append(np.concatenate([p, [bestscore]], 0)) + for i, l in enumerate(lbb): + if l_flag[i] == 0: score = [] for p in pbb: - score.append(iou(p[1:5],l)) - if len(score)!=0: + score.append(iou(p[1:5], l)) + if len(score) != 0: bestscore = np.max(score) else: bestscore = 0 - if bestscore0: - fn = np.concatenate([fn,tp[fn_i,:5]]) + if len(fn_i) > 0: + fn = np.concatenate([fn, tp[fn_i, :5]]) else: fn = fn - if len(tp_in_topk)>0: + if len(tp_in_topk) > 0: tp = tp[tp_in_topk] else: tp = [] - if len(fp_in_topk)>0: + if len(fp_in_topk) > 0: fp = newallp[fp_in_topk] else: fp = [] - return tp, fp , fn + return tp, fp, fn diff --git a/training/classifier/main.py b/training/classifier/main.py index 5e58a32..50adf2e 100644 --- a/training/classifier/main.py +++ b/training/classifier/main.py @@ -1,25 +1,17 @@ import argparse -import os +import shutil import time -import numpy as np from importlib import import_module -import shutil -import sys -from split_combine import SplitComb import torch -from torch.nn import DataParallel from torch.backends import cudnn +from torch.nn import DataParallel from torch.utils.data import DataLoader -from torch import optim -from torch.autograd import Variable - -from layers import acc -from trainval_detector import * from trainval_classifier import * -from data_detector import DataBowl3Detector -from data_classifier import DataBowl3Classifier +from trainval_detector import * +from data_classifier import DataBowl3Classifier +from data_detector import DataBowl3Detector from utils import * parser = argparse.ArgumentParser(description='PyTorch DataBowl3 Detector') @@ -66,14 +58,13 @@ parser.add_argument('--freeze_batchnorm', default=0, type=int, metavar='TEST', help='freeze the batchnorm when training') + def main(): global args args = parser.parse_args() - - + torch.manual_seed(0) - - + ################################## nodmodel = import_module(args.model1) @@ -81,32 +72,30 @@ def main(): args.lr_stage = config1['lr_stage'] args.lr_preset = config1['lr'] - save_dir = args.save_dir - ################################## - + casemodel = import_module(args.model2) - + config2 = casemodel.config args.lr_stage2 = config2['lr_stage'] args.lr_preset2 = config2['lr'] topk = config2['topk'] - case_net = casemodel.CaseNet(topk = topk,nodulenet=nod_net) + case_net = casemodel.CaseNet(topk=topk, nodulenet=nod_net) args.miss_ratio = config2['miss_ratio'] args.miss_thresh = config2['miss_thresh'] if args.debug: args.save_dir = 'debug' - + ################################### - - - - - - + + + + + + ################################ start_epoch = args.start_epoch if args.resume: @@ -116,7 +105,7 @@ def main(): if not save_dir: save_dir = checkpoint['save_dir'] else: - save_dir = os.path.join('results',save_dir) + save_dir = os.path.join('results', save_dir) case_net.load_state_dict(checkpoint['state_dict']) else: if start_epoch == 0: @@ -125,7 +114,7 @@ def main(): exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime()) save_dir = os.path.join('results', args.model1 + '-' + exp_id) else: - save_dir = os.path.join('results',save_dir) + save_dir = os.path.join('results', save_dir) if args.epochs == None: end_epoch = args.lr_stage2[-1] else: @@ -133,15 +122,15 @@ def main(): ################################ if not os.path.exists(save_dir): os.makedirs(save_dir) - logfile = os.path.join(save_dir,'log') - if args.test1!=1 and args.test2!=1 : + logfile = os.path.join(save_dir, 'log') + if args.test1 != 1 and args.test2 != 1: sys.stdout = Logger(logfile) pyfiles = [f for f in os.listdir('./') if f.endswith('.py')] for f in pyfiles: - shutil.copy(f,os.path.join(save_dir,f)) + shutil.copy(f, os.path.join(save_dir, f)) ################################ torch.cuda.set_device(0) - #nod_net = nod_net.cuda() + # nod_net = nod_net.cuda() case_net = case_net.cuda() loss = loss.cuda() cudnn.benchmark = True @@ -153,32 +142,31 @@ def main(): if args.test1 == 1: testsplit = np.load('full.npy') - dataset = DataBowl3Classifier(testsplit, config2, phase = 'test') - predlist = test_casenet(case_net,dataset).T - anstable = np.concatenate([[testsplit],predlist],0).T + dataset = DataBowl3Classifier(testsplit, config2, phase='test') + predlist = test_casenet(case_net, dataset).T + anstable = np.concatenate([[testsplit], predlist], 0).T df = pandas.DataFrame(anstable) - df.columns={'id','cancer'} - df.to_csv('allstage1.csv',index=False) + df.columns = {'id', 'cancer'} + df.to_csv('allstage1.csv', index=False) return - if args.test2 ==1: - + if args.test2 == 1: testsplit = np.load('test.npy') - dataset = DataBowl3Classifier(testsplit, config2, phase = 'test') - predlist = test_casenet(case_net,dataset).T - anstable = np.concatenate([[testsplit],predlist],0).T + dataset = DataBowl3Classifier(testsplit, config2, phase='test') + predlist = test_casenet(case_net, dataset).T + anstable = np.concatenate([[testsplit], predlist], 0).T df = pandas.DataFrame(anstable) - df.columns={'id','cancer'} - df.to_csv('quick',index=False) + df.columns = {'id', 'cancer'} + df.to_csv('quick', index=False) return if args.test3 == 1: testsplit3 = np.load('stage2.npy') - dataset = DataBowl3Classifier(testsplit3,config2,phase = 'test') - predlist = test_casenet(case_net,dataset).T - anstable = np.concatenate([[testsplit3],predlist],0).T + dataset = DataBowl3Classifier(testsplit3, config2, phase='test') + predlist = test_casenet(case_net, dataset).T + anstable = np.concatenate([[testsplit3], predlist], 0).T df = pandas.DataFrame(anstable) - df.columns={'id','cancer'} - df.to_csv('stage2_ans.csv',index=False) + df.columns = {'id', 'cancer'} + df.to_csv('stage2_ans.csv', index=False) return print(save_dir) print(args.save_freq) @@ -186,51 +174,51 @@ def main(): valsplit = np.load('valsplit.npy') testsplit = np.load('test.npy') - dataset = DataBowl3Detector(trainsplit,config1,phase = 'train') - train_loader_nod = DataLoader(dataset,batch_size = args.batch_size, - shuffle = True,num_workers = args.workers,pin_memory=True) + dataset = DataBowl3Detector(trainsplit, config1, phase='train') + train_loader_nod = DataLoader(dataset, batch_size=args.batch_size, + shuffle=True, num_workers=args.workers, pin_memory=True) - dataset = DataBowl3Detector(valsplit,config1,phase = 'val') - val_loader_nod = DataLoader(dataset,batch_size = args.batch_size, - shuffle = False,num_workers = args.workers,pin_memory=True) + dataset = DataBowl3Detector(valsplit, config1, phase='val') + val_loader_nod = DataLoader(dataset, batch_size=args.batch_size, + shuffle=False, num_workers=args.workers, pin_memory=True) optimizer = torch.optim.SGD(nod_net.parameters(), - args.lr,momentum = 0.9,weight_decay = args.weight_decay) - + args.lr, momentum=0.9, weight_decay=args.weight_decay) + trainsplit = np.load('full.npy') - dataset = DataBowl3Classifier(trainsplit,config2,phase = 'train') - train_loader_case = DataLoader(dataset,batch_size = args.batch_size2, - shuffle = True,num_workers = args.workers,pin_memory=True) - - dataset = DataBowl3Classifier(valsplit,config2,phase = 'val') - val_loader_case = DataLoader(dataset,batch_size = max([args.batch_size2,1]), - shuffle = False,num_workers = args.workers,pin_memory=True) - - dataset = DataBowl3Classifier(trainsplit,config2,phase = 'val') - all_loader_case = DataLoader(dataset,batch_size = max([args.batch_size2,1]), - shuffle = False,num_workers = args.workers,pin_memory=True) + dataset = DataBowl3Classifier(trainsplit, config2, phase='train') + train_loader_case = DataLoader(dataset, batch_size=args.batch_size2, + shuffle=True, num_workers=args.workers, pin_memory=True) + + dataset = DataBowl3Classifier(valsplit, config2, phase='val') + val_loader_case = DataLoader(dataset, batch_size=max([args.batch_size2, 1]), + shuffle=False, num_workers=args.workers, pin_memory=True) + + dataset = DataBowl3Classifier(trainsplit, config2, phase='val') + all_loader_case = DataLoader(dataset, batch_size=max([args.batch_size2, 1]), + shuffle=False, num_workers=args.workers, pin_memory=True) optimizer2 = torch.optim.SGD(case_net.parameters(), - args.lr,momentum = 0.9,weight_decay = args.weight_decay) - + args.lr, momentum=0.9, weight_decay=args.weight_decay) + for epoch in range(start_epoch, end_epoch + 1): - if epoch ==start_epoch: + if epoch == start_epoch: lr = args.lr debug = args.debug args.lr = 0.0 args.debug = True - train_casenet(epoch,case_net,train_loader_case,optimizer2,args) + train_casenet(epoch, case_net, train_loader_case, optimizer2, args) args.lr = lr args.debug = debug - if epochconfig2['startepoch']: - train_casenet(epoch,case_net,train_loader_case,optimizer2,args) - val_casenet(epoch,case_net,val_loader_case,args) - val_casenet(epoch,case_net,all_loader_case,args) + if epoch > config2['startepoch']: + train_casenet(epoch, case_net, train_loader_case, optimizer2, args) + val_casenet(epoch, case_net, val_loader_case, args) + val_casenet(epoch, case_net, all_loader_case, args) - if epoch % args.save_freq == 0: + if epoch % args.save_freq == 0: state_dict = case_net.module.state_dict() for key in state_dict.keys(): state_dict[key] = state_dict[key].cpu() @@ -241,6 +229,7 @@ def main(): 'state_dict': state_dict, 'args': args}, os.path.join(save_dir, '%03d.ckpt' % epoch)) + + if __name__ == '__main__': main() - diff --git a/training/classifier/net_classifier_3.py b/training/classifier/net_classifier_3.py index ae30a4e..e075066 100644 --- a/training/classifier/net_classifier_3.py +++ b/training/classifier/net_classifier_3.py @@ -1,16 +1,9 @@ +import sys + +import numpy as np import torch from torch import nn -from layers import * -from torch.nn import DataParallel -from torch.backends import cudnn -from torch.utils.data import DataLoader -from torch import optim -from torch.autograd import Variable -from torch.utils.data import Dataset -from scipy.ndimage.interpolation import rotate -import numpy as np -import os -import sys + sys.path.append('../') from config_training import config as config_training @@ -25,9 +18,9 @@ config['padmask'] = False -config['crop_size'] = [96,96,96] -config['scaleLim'] = [0.85,1.15] -config['radiusLim'] = [6,100] +config['crop_size'] = [96, 96, 96] +config['scaleLim'] = [0.85, 1.15] +config['radiusLim'] = [6, 100] config['jitter_range'] = 0.15 config['isScale'] = True @@ -35,7 +28,7 @@ config['T'] = 1 config['topk'] = 5 config['stride'] = 4 -config['augtype'] = {'flip':True,'swap':False,'rotate':False,'scale':False} +config['augtype'] = {'flip': True, 'swap': False, 'rotate': False, 'scale': False} config['detect_th'] = 0.05 config['conf_th'] = -1 @@ -43,41 +36,43 @@ config['filling_value'] = 160 config['startepoch'] = 20 -config['lr_stage'] = np.array([50,100,140,160]) -config['lr'] = [0.01,0.001,0.0001,0.00001] +config['lr_stage'] = np.array([50, 100, 140, 160]) +config['lr'] = [0.01, 0.001, 0.0001, 0.00001] config['miss_ratio'] = 1 config['miss_thresh'] = 0.03 + class CaseNet(nn.Module): - def __init__(self,topk,nodulenet): - super(CaseNet,self).__init__() - self.NoduleNet = nodulenet - self.fc1 = nn.Linear(128,64) - self.fc2 = nn.Linear(64,1) + def __init__(self, topk, nodulenet): + super(CaseNet, self).__init__() + self.NoduleNet = nodulenet + self.fc1 = nn.Linear(128, 64) + self.fc2 = nn.Linear(64, 1) self.pool = nn.MaxPool3d(kernel_size=2) self.dropout = nn.Dropout(0.5) self.baseline = nn.Parameter(torch.Tensor([-30.0]).float()) self.Relu = nn.ReLU() - def forward(self,xlist,coordlist): -# xlist: n x k x 1x 96 x 96 x 96 -# coordlist: n x k x 3 x 24 x 24 x 24 + + def forward(self, xlist, coordlist): + # xlist: n x k x 1x 96 x 96 x 96 + # coordlist: n x k x 3 x 24 x 24 x 24 xsize = xlist.size() corrdsize = coordlist.size() - xlist = xlist.view(-1,xsize[2],xsize[3],xsize[4],xsize[5]) - coordlist = coordlist.view(-1,corrdsize[2],corrdsize[3],corrdsize[4],corrdsize[5]) - - noduleFeat,nodulePred = self.NoduleNet(xlist,coordlist) - nodulePred = nodulePred.contiguous().view(corrdsize[0],corrdsize[1],-1) - - featshape = noduleFeat.size()#nk x 128 x 24 x 24 x24 - centerFeat = self.pool(noduleFeat[:,:,featshape[2]/2-1:featshape[2]/2+1, - featshape[3]/2-1:featshape[3]/2+1, - featshape[4]/2-1:featshape[4]/2+1]) - centerFeat = centerFeat[:,:,0,0,0] + xlist = xlist.view(-1, xsize[2], xsize[3], xsize[4], xsize[5]) + coordlist = coordlist.view(-1, corrdsize[2], corrdsize[3], corrdsize[4], corrdsize[5]) + + noduleFeat, nodulePred = self.NoduleNet(xlist, coordlist) + nodulePred = nodulePred.contiguous().view(corrdsize[0], corrdsize[1], -1) + + featshape = noduleFeat.size() # nk x 128 x 24 x 24 x24 + centerFeat = self.pool(noduleFeat[:, :, featshape[2] / 2 - 1:featshape[2] / 2 + 1, + featshape[3] / 2 - 1:featshape[3] / 2 + 1, + featshape[4] / 2 - 1:featshape[4] / 2 + 1]) + centerFeat = centerFeat[:, :, 0, 0, 0] out = self.dropout(centerFeat) out = self.Relu(self.fc1(out)) out = torch.sigmoid(self.fc2(out)) - out = out.view(xsize[0],xsize[1]) + out = out.view(xsize[0], xsize[1]) base_prob = torch.sigmoid(self.baseline) - casePred = 1-torch.prod(1-out,dim=1)*(1-base_prob.expand(out.size()[0])) - return nodulePred,casePred,out + casePred = 1 - torch.prod(1 - out, dim=1) * (1 - base_prob.expand(out.size()[0])) + return nodulePred, casePred, out diff --git a/training/classifier/net_classifier_4.py b/training/classifier/net_classifier_4.py index f8f2393..fac87d9 100644 --- a/training/classifier/net_classifier_4.py +++ b/training/classifier/net_classifier_4.py @@ -1,16 +1,9 @@ +import sys + +import numpy as np import torch from torch import nn -from layers import * -from torch.nn import DataParallel -from torch.backends import cudnn -from torch.utils.data import DataLoader -from torch import optim -from torch.autograd import Variable -from torch.utils.data import Dataset -from scipy.ndimage.interpolation import rotate -import numpy as np -import os -import sys + sys.path.append('../') from config_training import config as config_training @@ -25,9 +18,9 @@ config['padmask'] = False -config['crop_size'] = [96,96,96] -config['scaleLim'] = [0.85,1.15] -config['radiusLim'] = [6,100] +config['crop_size'] = [96, 96, 96] +config['scaleLim'] = [0.85, 1.15] +config['radiusLim'] = [6, 100] config['jitter_range'] = 0.15 config['isScale'] = True @@ -35,7 +28,7 @@ config['T'] = 1 config['topk'] = 5 config['stride'] = 4 -config['augtype'] = {'flip':True,'swap':True,'rotate':True,'scale':True} +config['augtype'] = {'flip': True, 'swap': True, 'rotate': True, 'scale': True} config['detect_th'] = 0.05 config['conf_th'] = -1 @@ -43,41 +36,43 @@ config['filling_value'] = 160 config['startepoch'] = 20 -config['lr_stage'] = np.array([50,100,140,160,180]) -config['lr'] = [0.01,0.001,0.0001,0.00001,0.000001] +config['lr_stage'] = np.array([50, 100, 140, 160, 180]) +config['lr'] = [0.01, 0.001, 0.0001, 0.00001, 0.000001] config['miss_ratio'] = 1 config['miss_thresh'] = 0.03 + class CaseNet(nn.Module): - def __init__(self,topk,nodulenet): - super(CaseNet,self).__init__() - self.NoduleNet = nodulenet - self.fc1 = nn.Linear(128,64) - self.fc2 = nn.Linear(64,1) + def __init__(self, topk, nodulenet): + super(CaseNet, self).__init__() + self.NoduleNet = nodulenet + self.fc1 = nn.Linear(128, 64) + self.fc2 = nn.Linear(64, 1) self.pool = nn.MaxPool3d(kernel_size=2) self.dropout = nn.Dropout(0.5) self.baseline = nn.Parameter(torch.Tensor([-30.0]).float()) self.Relu = nn.ReLU() - def forward(self,xlist,coordlist): -# xlist: n x k x 1x 96 x 96 x 96 -# coordlist: n x k x 3 x 24 x 24 x 24 + + def forward(self, xlist, coordlist): + # xlist: n x k x 1x 96 x 96 x 96 + # coordlist: n x k x 3 x 24 x 24 x 24 xsize = xlist.size() corrdsize = coordlist.size() - xlist = xlist.view(-1,xsize[2],xsize[3],xsize[4],xsize[5]) - coordlist = coordlist.view(-1,corrdsize[2],corrdsize[3],corrdsize[4],corrdsize[5]) - - noduleFeat,nodulePred = self.NoduleNet(xlist,coordlist) - nodulePred = nodulePred.contiguous().view(corrdsize[0],corrdsize[1],-1) - - featshape = noduleFeat.size()#nk x 128 x 24 x 24 x24 - centerFeat = self.pool(noduleFeat[:,:,featshape[2]/2-1:featshape[2]/2+1, - featshape[3]/2-1:featshape[3]/2+1, - featshape[4]/2-1:featshape[4]/2+1]) - centerFeat = centerFeat[:,:,0,0,0] + xlist = xlist.view(-1, xsize[2], xsize[3], xsize[4], xsize[5]) + coordlist = coordlist.view(-1, corrdsize[2], corrdsize[3], corrdsize[4], corrdsize[5]) + + noduleFeat, nodulePred = self.NoduleNet(xlist, coordlist) + nodulePred = nodulePred.contiguous().view(corrdsize[0], corrdsize[1], -1) + + featshape = noduleFeat.size() # nk x 128 x 24 x 24 x24 + centerFeat = self.pool(noduleFeat[:, :, featshape[2] / 2 - 1:featshape[2] / 2 + 1, + featshape[3] / 2 - 1:featshape[3] / 2 + 1, + featshape[4] / 2 - 1:featshape[4] / 2 + 1]) + centerFeat = centerFeat[:, :, 0, 0, 0] out = self.dropout(centerFeat) out = self.Relu(self.fc1(out)) out = torch.sigmoid(self.fc2(out)) - out = out.view(xsize[0],xsize[1]) + out = out.view(xsize[0], xsize[1]) base_prob = torch.sigmoid(self.baseline) - casePred = 1-torch.prod(1-out,dim=1)*(1-base_prob.expand(out.size()[0])) - return nodulePred,casePred,out + casePred = 1 - torch.prod(1 - out, dim=1) * (1 - base_prob.expand(out.size()[0])) + return nodulePred, casePred, out diff --git a/training/classifier/net_detector_3.py b/training/classifier/net_detector_3.py index 2277ded..a215551 100644 --- a/training/classifier/net_detector_3.py +++ b/training/classifier/net_detector_3.py @@ -1,13 +1,15 @@ +import sys + import torch from torch import nn + from layers import * -import sys + sys.path.append('../') from config_training import config as config_training - config = {} -config['anchors'] = [ 10.0, 30.0, 60.] +config['anchors'] = [10.0, 30.0, 60.] config['chanel'] = 1 config['crop_size'] = [128, 128, 128] config['stride'] = 4 @@ -21,20 +23,21 @@ config['num_hard'] = 2 config['bound_size'] = 12 config['reso'] = 1 -config['sizelim'] = 6. #mm +config['sizelim'] = 6. # mm config['sizelim2'] = 30 config['sizelim3'] = 40 config['aug_scale'] = True config['r_rand_crop'] = 0.3 config['pad_value'] = 170 -config['augtype'] = {'flip':True,'swap':False,'scale':True,'rotate':False} -config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38','990fbe3f0a1b53878669967b9afd1441','adc3bbc63d40f8761c59be10f1e504c3'] +config['augtype'] = {'flip': True, 'swap': False, 'scale': True, 'rotate': False} +config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38', '990fbe3f0a1b53878669967b9afd1441', + 'adc3bbc63d40f8761c59be10f1e504c3'] +config['lr_stage'] = np.array([50, 100, 140, 160]) +config['lr'] = [0.01, 0.001, 0.0001, 0.00001] -config['lr_stage'] = np.array([50,100,140,160]) -config['lr'] = [0.01,0.001,0.0001,0.00001] -#config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38','d92998a73d4654a442e6d6ba15bbb827','990fbe3f0a1b53878669967b9afd1441','820245d8b211808bd18e78ff5be16fdb','adc3bbc63d40f8761c59be10f1e504c3', +# config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38','d92998a73d4654a442e6d6ba15bbb827','990fbe3f0a1b53878669967b9afd1441','820245d8b211808bd18e78ff5be16fdb','adc3bbc63d40f8761c59be10f1e504c3', # '417','077','188','876','057','087','130','468'] class Net(nn.Module): @@ -43,92 +46,93 @@ def __init__(self): # The first few layers consumes the most memory, so use simple convolution to save memory. # Call these layers preBlock, i.e., before the residual blocks of later layers. self.preBlock = nn.Sequential( - nn.Conv3d(1, 24, kernel_size = 3, padding = 1), + nn.Conv3d(1, 24, kernel_size=3, padding=1), nn.BatchNorm3d(24), - nn.ReLU(inplace = True), - nn.Conv3d(24, 24, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(24, 24, kernel_size=3, padding=1), nn.BatchNorm3d(24), - nn.ReLU(inplace = True)) - + nn.ReLU(inplace=True)) + # 3 poolings, each pooling downsamples the feature map by a factor 2. # 3 groups of blocks. The first block of each group has one pooling. - num_blocks_forw = [2,2,3,3] - num_blocks_back = [3,3] - self.featureNum_forw = [24,32,64,64,64] - self.featureNum_back = [128,64,64] + num_blocks_forw = [2, 2, 3, 3] + num_blocks_back = [3, 3] + self.featureNum_forw = [24, 32, 64, 64, 64] + self.featureNum_back = [128, 64, 64] for i in range(len(num_blocks_forw)): blocks = [] for j in range(num_blocks_forw[i]): if j == 0: - blocks.append(PostRes(self.featureNum_forw[i], self.featureNum_forw[i+1])) + blocks.append(PostRes(self.featureNum_forw[i], self.featureNum_forw[i + 1])) else: - blocks.append(PostRes(self.featureNum_forw[i+1], self.featureNum_forw[i+1])) + blocks.append(PostRes(self.featureNum_forw[i + 1], self.featureNum_forw[i + 1])) setattr(self, 'forw' + str(i + 1), nn.Sequential(*blocks)) - for i in range(len(num_blocks_back)): blocks = [] for j in range(num_blocks_back[i]): if j == 0: - if i==0: + if i == 0: addition = 3 else: addition = 0 - blocks.append(PostRes(self.featureNum_back[i+1]+self.featureNum_forw[i+2]+addition, self.featureNum_back[i])) + blocks.append(PostRes(self.featureNum_back[i + 1] + self.featureNum_forw[i + 2] + addition, + self.featureNum_back[i])) else: blocks.append(PostRes(self.featureNum_back[i], self.featureNum_back[i])) setattr(self, 'back' + str(i + 2), nn.Sequential(*blocks)) - self.maxpool1 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.maxpool2 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.maxpool3 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.maxpool4 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.unmaxpool1 = nn.MaxUnpool3d(kernel_size=2,stride=2) - self.unmaxpool2 = nn.MaxUnpool3d(kernel_size=2,stride=2) + self.maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.unmaxpool1 = nn.MaxUnpool3d(kernel_size=2, stride=2) + self.unmaxpool2 = nn.MaxUnpool3d(kernel_size=2, stride=2) self.path1 = nn.Sequential( - nn.ConvTranspose3d(64, 64, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), nn.BatchNorm3d(64), - nn.ReLU(inplace = True)) + nn.ReLU(inplace=True)) self.path2 = nn.Sequential( - nn.ConvTranspose3d(64, 64, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), nn.BatchNorm3d(64), - nn.ReLU(inplace = True)) - self.drop = nn.Dropout3d(p = 0.2, inplace = False) - self.output = nn.Sequential(nn.Conv3d(self.featureNum_back[0], 64, kernel_size = 1), + nn.ReLU(inplace=True)) + self.drop = nn.Dropout3d(p=0.2, inplace=False) + self.output = nn.Sequential(nn.Conv3d(self.featureNum_back[0], 64, kernel_size=1), nn.ReLU(), - #nn.Dropout3d(p = 0.3), - nn.Conv3d(64, 5 * len(config['anchors']), kernel_size = 1)) + # nn.Dropout3d(p = 0.3), + nn.Conv3d(64, 5 * len(config['anchors']), kernel_size=1)) def forward(self, x, coord): - #x = (x-128.)/128. - out = self.preBlock(x)#16 - out_pool,indices0 = self.maxpool1(out) - out1 = self.forw1(out_pool)#32 - out1_pool,indices1 = self.maxpool2(out1) - out2 = self.forw2(out1_pool)#64 - #out2 = self.drop(out2) - out2_pool,indices2 = self.maxpool3(out2) - out3 = self.forw3(out2_pool)#96 - out3_pool,indices3 = self.maxpool4(out3) - out4 = self.forw4(out3_pool)#96 - #out4 = self.drop(out4) - + # x = (x-128.)/128. + out = self.preBlock(x) # 16 + out_pool, indices0 = self.maxpool1(out) + out1 = self.forw1(out_pool) # 32 + out1_pool, indices1 = self.maxpool2(out1) + out2 = self.forw2(out1_pool) # 64 + # out2 = self.drop(out2) + out2_pool, indices2 = self.maxpool3(out2) + out3 = self.forw3(out2_pool) # 96 + out3_pool, indices3 = self.maxpool4(out3) + out4 = self.forw4(out3_pool) # 96 + # out4 = self.drop(out4) + rev3 = self.path1(out4) - comb3 = self.back3(torch.cat((rev3, out3), 1))#96+96 - #comb3 = self.drop(comb3) + comb3 = self.back3(torch.cat((rev3, out3), 1)) # 96+96 + # comb3 = self.drop(comb3) rev2 = self.path2(comb3) - - feat = self.back2(torch.cat((rev2, out2,coord), 1))#64+64 + + feat = self.back2(torch.cat((rev2, out2, coord), 1)) # 64+64 comb2 = self.drop(feat) out = self.output(comb2) size = out.size() out = out.view(out.size(0), out.size(1), -1) - #out = out.transpose(1, 4).transpose(1, 2).transpose(2, 3).contiguous() + # out = out.transpose(1, 4).transpose(1, 2).transpose(2, 3).contiguous() out = out.transpose(1, 2).contiguous().view(size[0], size[2], size[3], size[4], len(config['anchors']), 5) - #out = out.view(-1, 5) - return feat,out - + # out = out.view(-1, 5) + return feat, out + + def get_model(): net = Net() loss = Loss(config['num_hard']) diff --git a/training/classifier/split_combine.py b/training/classifier/split_combine.py index 2e399a9..9db34e6 100644 --- a/training/classifier/split_combine.py +++ b/training/classifier/split_combine.py @@ -1,23 +1,24 @@ -import torch import numpy as np + + class SplitComb(): - def __init__(self,side_len,max_stride,stride,margin): + def __init__(self, side_len, max_stride, stride, margin): self.side_len = side_len self.max_stride = max_stride self.stride = stride self.margin = margin - - def split(self, data, side_len = None, max_stride = None, margin = None): - if side_len==None: + + def split(self, data, side_len=None, max_stride=None, margin=None): + if side_len == None: side_len = self.side_len if max_stride == None: max_stride = self.max_stride if margin == None: margin = self.margin - - assert(side_len > margin) - assert(side_len % max_stride == 0) - assert(margin % max_stride == 0) + + assert (side_len > margin) + assert (side_len % max_stride == 0) + assert (margin % max_stride == 0) splits = [] _, z, h, w = data.shape @@ -25,14 +26,14 @@ def split(self, data, side_len = None, max_stride = None, margin = None): nz = int(np.ceil(float(z) / side_len)) nh = int(np.ceil(float(h) / side_len)) nw = int(np.ceil(float(w) / side_len)) - - nzhw = [nz,nh,nw] + + nzhw = [nz, nh, nw] self.nzhw = nzhw - - pad = [ [0, 0], - [margin, nz * side_len - z + margin], - [margin, nh * side_len - h + margin], - [margin, nw * side_len - w + margin]] + + pad = [[0, 0], + [margin, nz * side_len - z + margin], + [margin, nh * side_len - h + margin], + [margin, nw * side_len - w + margin]] data = np.pad(data, pad, 'constant') for iz in range(nz): @@ -49,24 +50,24 @@ def split(self, data, side_len = None, max_stride = None, margin = None): splits.append(split) splits = np.concatenate(splits, 0) - return splits,nzhw + return splits, nzhw + + def combine(self, output, nzhw=None, side_len=None, stride=None, margin=None): - def combine(self, output, nzhw = None, side_len=None, stride=None, margin=None): - - if side_len==None: + if side_len == None: side_len = self.side_len if stride == None: stride = self.stride if margin == None: margin = self.margin - if nzhw==None: + if nzhw == None: nz = self.nz nh = self.nh nw = self.nw else: - nz,nh,nw = nzhw - assert(side_len % stride == 0) - assert(margin % stride == 0) + nz, nh, nw = nzhw + assert (side_len % stride == 0) + assert (margin % stride == 0) side_len /= stride margin /= stride @@ -96,4 +97,4 @@ def combine(self, output, nzhw = None, side_len=None, stride=None, margin=None): output[sz:ez, sh:eh, sw:ew] = split idx += 1 - return output + return output diff --git a/training/classifier/trainval_classifier.py b/training/classifier/trainval_classifier.py index 88951eb..31da4f3 100644 --- a/training/classifier/trainval_classifier.py +++ b/training/classifier/trainval_classifier.py @@ -1,34 +1,32 @@ -import numpy as np -import os import time -import random -import warnings +import numpy as np import torch from torch import nn -from torch import optim from torch.autograd import Variable -from torch.nn.functional import cross_entropy,sigmoid,binary_cross_entropy +from torch.nn.functional import binary_cross_entropy from torch.utils.data import DataLoader -def get_lr(epoch,args): - assert epoch<=args.lr_stage2[-1] - if args.lr==None: - lrstage = np.sum(epoch>args.lr_stage2) + +def get_lr(epoch, args): + assert epoch <= args.lr_stage2[-1] + if args.lr == None: + lrstage = np.sum(epoch > args.lr_stage2) lr = args.lr_preset2[lrstage] else: lr = args.lr return lr -def train_casenet(epoch,model,data_loader,optimizer,args): + +def train_casenet(epoch, model, data_loader, optimizer, args): model.train() - if args.freeze_batchnorm: + if args.freeze_batchnorm: for m in model.modules(): if isinstance(m, nn.BatchNorm3d): m.eval() starttime = time.time() - lr = get_lr(epoch,args) + lr = get_lr(epoch, args) for param_group in optimizer.param_groups: param_group['lr'] = lr @@ -41,26 +39,26 @@ def train_casenet(epoch,model,data_loader,optimizer,args): tpn = 0 fpn = 0 fnn = 0 -# weight = torch.from_numpy(np.ones_like(y).float().cuda() - for i,(x,coord,isnod,y) in enumerate(data_loader): + # weight = torch.from_numpy(np.ones_like(y).float().cuda() + for i, (x, coord, isnod, y) in enumerate(data_loader): if args.debug: - if i >4: + if i > 4: break coord = Variable(coord).cuda() x = Variable(x).cuda() xsize = x.size() isnod = Variable(isnod).float().cuda() - ydata = y.numpy()[:,0] + ydata = y.numpy()[:, 0] y = Variable(y).float().cuda() -# weight = 3*torch.ones(y.size()).float().cuda() + # weight = 3*torch.ones(y.size()).float().cuda() optimizer.zero_grad() - nodulePred,casePred,casePred_each = model(x,coord) - loss2 = binary_cross_entropy(casePred,y[:,0]) - missMask = (casePred_each0.5 - tpn += np.sum(1==pred[ydata==1]) - fpn += np.sum(1==pred[ydata==0]) - fnn += np.sum(0==pred[ydata==1]) - acc = np.mean(ydata==pred) + pred = outdata > 0.5 + tpn += np.sum(1 == pred[ydata == 1]) + fpn += np.sum(1 == pred[ydata == 0]) + fnn += np.sum(0 == pred[ydata == 1]) + acc = np.mean(ydata == pred) accHist.append(acc) - + endtime = time.time() lenHist = np.array(lenHist) loss2Hist = np.array(loss2Hist) lossHist = np.array(lossHist) accHist = np.array(accHist) - - mean_loss2 = np.sum(loss2Hist*lenHist)/np.sum(lenHist) - mean_missloss = np.sum(missHist*lenHist)/np.sum(lenHist) - mean_acc = np.sum(accHist*lenHist)/np.sum(lenHist) + + mean_loss2 = np.sum(loss2Hist * lenHist) / np.sum(lenHist) + mean_missloss = np.sum(missHist * lenHist) / np.sum(lenHist) + mean_acc = np.sum(accHist * lenHist) / np.sum(lenHist) print('Train, epoch %d, loss2 %.4f, miss loss %.4f, acc %.4f, tpn %d, fpn %d, fnn %d, time %3.2f, lr % .5f ' - %(epoch,mean_loss2,mean_missloss,mean_acc,tpn,fpn, fnn, endtime-starttime,lr)) + % (epoch, mean_loss2, mean_missloss, mean_acc, tpn, fpn, fnn, endtime - starttime, lr)) -def val_casenet(epoch,model,data_loader,args): + +def val_casenet(epoch, model, data_loader, args): model.eval() starttime = time.time() loss1Hist = [] @@ -100,62 +99,60 @@ def val_casenet(epoch,model,data_loader,args): fpn = 0 fnn = 0 - for i,(x,coord,isnod,y) in enumerate(data_loader): - - coord = Variable(coord,volatile=True).cuda() - x = Variable(x,volatile=True).cuda() + for i, (x, coord, isnod, y) in enumerate(data_loader): + coord = Variable(coord, volatile=True).cuda() + x = Variable(x, volatile=True).cuda() xsize = x.size() - ydata = y.numpy()[:,0] + ydata = y.numpy()[:, 0] y = Variable(y).float().cuda() isnod = Variable(isnod).float().cuda() - nodulePred,casePred,casePred_each = model(x,coord) - - loss2 = binary_cross_entropy(casePred,y[:,0]) - missMask = (casePred_each0.5 - tpn += np.sum(1==pred[ydata==1]) - fpn += np.sum(1==pred[ydata==0]) - fnn += np.sum(0==pred[ydata==1]) - acc = np.mean(ydata==pred) + # print([i,data_loader.dataset.split[i,1],sigmoid(casePred).data.cpu().numpy()]) + pred = outdata > 0.5 + tpn += np.sum(1 == pred[ydata == 1]) + fpn += np.sum(1 == pred[ydata == 0]) + fnn += np.sum(0 == pred[ydata == 1]) + acc = np.mean(ydata == pred) accHist.append(acc) endtime = time.time() lenHist = np.array(lenHist) loss2Hist = np.array(loss2Hist) accHist = np.array(accHist) - mean_loss2 = np.sum(loss2Hist*lenHist)/np.sum(lenHist) - mean_missloss = np.sum(missHist*lenHist)/np.sum(lenHist) - mean_acc = np.sum(accHist*lenHist)/np.sum(lenHist) + mean_loss2 = np.sum(loss2Hist * lenHist) / np.sum(lenHist) + mean_missloss = np.sum(missHist * lenHist) / np.sum(lenHist) + mean_acc = np.sum(accHist * lenHist) / np.sum(lenHist) print('Valid, epoch %d, loss2 %.4f, miss loss %.4f, acc %.4f, tpn %d, fpn %d, fnn %d, time %3.2f' - %(epoch,mean_loss2,mean_missloss,mean_acc,tpn,fpn, fnn, endtime-starttime)) + % (epoch, mean_loss2, mean_missloss, mean_acc, tpn, fpn, fnn, endtime - starttime)) - -def test_casenet(model,testset): + +def test_casenet(model, testset): data_loader = DataLoader( testset, - batch_size = 4, - shuffle = False, - num_workers = 32, + batch_size=4, + shuffle=False, + num_workers=32, pin_memory=True) - #model = model.cuda() + # model = model.cuda() model.eval() predlist = [] - - # weight = torch.from_numpy(np.ones_like(y).float().cuda() - for i,(x,coord) in enumerate(data_loader): + # weight = torch.from_numpy(np.ones_like(y).float().cuda() + for i, (x, coord) in enumerate(data_loader): coord = Variable(coord).cuda() x = Variable(x).cuda() - nodulePred,casePred,_ = model(x,coord) + nodulePred, casePred, _ = model(x, coord) predlist.append(casePred.data.cpu().numpy()) - #print([i,data_loader.dataset.split[i,1],casePred.data.cpu().numpy()]) + # print([i,data_loader.dataset.split[i,1],casePred.data.cpu().numpy()]) predlist = np.concatenate(predlist) - return predlist + return predlist diff --git a/training/classifier/trainval_detector.py b/training/classifier/trainval_detector.py index 6e02b13..657ea6c 100644 --- a/training/classifier/trainval_detector.py +++ b/training/classifier/trainval_detector.py @@ -1,26 +1,21 @@ import os import time -import numpy as np -import torch +import numpy as np from torch import nn -from torch.nn import DataParallel -from torch.backends import cudnn -from torch.utils.data import DataLoader -from torch import optim from torch.autograd import Variable -from layers import acc -def get_lr(epoch,args): - assert epoch<=args.lr_stage[-1] - if args.lr==None: - lrstage = np.sum(epoch>args.lr_stage) +def get_lr(epoch, args): + assert epoch <= args.lr_stage[-1] + if args.lr == None: + lrstage = np.sum(epoch > args.lr_stage) lr = args.lr_preset[lrstage] else: lr = args.lr return lr + def train_nodulenet(data_loader, net, loss, epoch, optimizer, args): start_time = time.time() net.train() @@ -29,24 +24,24 @@ def train_nodulenet(data_loader, net, loss, epoch, optimizer, args): if isinstance(m, nn.BatchNorm3d): m.eval() - lr = get_lr(epoch,args) + lr = get_lr(epoch, args) for param_group in optimizer.param_groups: param_group['lr'] = lr metrics = [] for i, (data, target, coord) in enumerate(data_loader): if args.debug: - if i >4: + if i > 4: break - data = Variable(data.cuda(async = True)) - target = Variable(target.cuda(async = True)) - coord = Variable(coord.cuda(async = True)) + data = Variable(data.cuda(async=True)) + target = Variable(target.cuda(async=True)) + coord = Variable(coord.cuda(async=True)) - _,output = net(data, coord) + _, output = net(data, coord) loss_output = loss(output, target) optimizer.zero_grad() loss_output[0].backward() - #torch.nn.utils.clip_grad_norm(net.parameters(), 1) + # torch.nn.utils.clip_grad_norm(net.parameters(), 1) optimizer.step() loss_output[0] = loss_output[0].data[0] @@ -71,22 +66,23 @@ def train_nodulenet(data_loader, net, loss, epoch, optimizer, args): np.mean(metrics[:, 5]))) print + def validate_nodulenet(data_loader, net, loss): start_time = time.time() - + net.eval() metrics = [] for i, (data, target, coord) in enumerate(data_loader): - data = Variable(data.cuda(async = True), volatile = True) - target = Variable(target.cuda(async = True), volatile = True) - coord = Variable(coord.cuda(async = True), volatile = True) + data = Variable(data.cuda(async=True), volatile=True) + target = Variable(target.cuda(async=True), volatile=True) + coord = Variable(coord.cuda(async=True), volatile=True) - _,output = net(data, coord) - loss_output = loss(output, target, train = False) + _, output = net(data, coord) + loss_output = loss(output, target, train=False) loss_output[0] = loss_output[0].data[0] - metrics.append(loss_output) + metrics.append(loss_output) end_time = time.time() metrics = np.asarray(metrics, np.float32) @@ -106,9 +102,10 @@ def validate_nodulenet(data_loader, net, loss): print print + def test_nodulenet(data_loader, net, get_pbb, save_dir, config, n_per_run): start_time = time.time() - save_dir = os.path.join(save_dir,'bbox') + save_dir = os.path.join(save_dir, 'bbox') if not os.path.exists(save_dir): os.makedirs(save_dir) net.eval() @@ -127,31 +124,30 @@ def test_nodulenet(data_loader, net, get_pbb, save_dir, config, n_per_run): if config['output_feature']: isfeat = True print(data.size()) - splitlist = range(0,len(data)+1,n_per_run) - if splitlist[-1]!=len(data): + splitlist = range(0, len(data) + 1, n_per_run) + if splitlist[-1] != len(data): splitlist.append(len(data)) outputlist = [] featurelist = [] - for i in range(len(splitlist)-1): - input = Variable(data[splitlist[i]:splitlist[i+1]], volatile = True).cuda() - inputcoord = Variable(coord[splitlist[i]:splitlist[i+1]], volatile = True).cuda() - _,output = net(input,inputcoord) + for i in range(len(splitlist) - 1): + input = Variable(data[splitlist[i]:splitlist[i + 1]], volatile=True).cuda() + inputcoord = Variable(coord[splitlist[i]:splitlist[i + 1]], volatile=True).cuda() + _, output = net(input, inputcoord) outputlist.append(output.data.cpu().numpy()) - output = np.concatenate(outputlist,0) - output = split_comber.combine(output,nzhw=nzhw) + output = np.concatenate(outputlist, 0) + output = split_comber.combine(output, nzhw=nzhw) thresh = -3 - pbb,mask = get_pbb(output,thresh,ismask=True) - #tp,fp,fn,_ = acc(pbb,lbb,0,0.1,0.1) - #print([len(tp),len(fp),len(fn)]) - print([i_name,name]) + pbb, mask = get_pbb(output, thresh, ismask=True) + # tp,fp,fn,_ = acc(pbb,lbb,0,0.1,0.1) + # print([len(tp),len(fp),len(fn)]) + print([i_name, name]) e = time.time() - np.save(os.path.join(save_dir, name+'_pbb.npy'), pbb) - np.save(os.path.join(save_dir, name+'_lbb.npy'), lbb) + np.save(os.path.join(save_dir, name + '_pbb.npy'), pbb) + np.save(os.path.join(save_dir, name + '_lbb.npy'), lbb) np.save(os.path.join(save_dir, 'namelist.npy'), namelist) end_time = time.time() - print('elapsed time is %3.2f seconds' % (end_time - start_time)) print print diff --git a/training/classifier/utils.py b/training/classifier/utils.py index 8743c82..d0140c8 100644 --- a/training/classifier/utils.py +++ b/training/classifier/utils.py @@ -1,80 +1,87 @@ -import sys import os +import sys + import numpy as np import torch + + def getFreeId(): - import pynvml + import pynvml pynvml.nvmlInit() + def getFreeRatio(id): handle = pynvml.nvmlDeviceGetHandleByIndex(id) use = pynvml.nvmlDeviceGetUtilizationRates(handle) - ratio = 0.5*(float(use.gpu+float(use.memory))) + ratio = 0.5 * (float(use.gpu + float(use.memory))) return ratio deviceCount = pynvml.nvmlDeviceGetCount() available = [] for i in range(deviceCount): - if getFreeRatio(i)<70: + if getFreeRatio(i) < 70: available.append(i) gpus = '' for g in available: - gpus = gpus+str(g)+',' + gpus = gpus + str(g) + ',' gpus = gpus[:-1] return gpus + def setgpu(gpuinput): freeids = getFreeId() - if gpuinput=='all': + if gpuinput == 'all': gpus = freeids else: gpus = gpuinput if any([g not in freeids for g in gpus.split(',')]): - raise ValueError('gpu'+g+'is being used') - print('using gpu '+gpus) - os.environ['CUDA_VISIBLE_DEVICES']=gpus + raise ValueError('gpu' + g + 'is being used') + print('using gpu ' + gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = gpus return len(gpus.split(',')) + class Logger(object): - def __init__(self,logfile): + def __init__(self, logfile): self.terminal = sys.stdout self.log = open(logfile, "a") def write(self, message): self.terminal.write(message) - self.log.write(message) + self.log.write(message) def flush(self): - #this flush method is needed for python 3 compatibility. - #this handles the flush command by doing nothing. - #you might want to specify some extra behavior here. - pass + # this flush method is needed for python 3 compatibility. + # this handles the flush command by doing nothing. + # you might want to specify some extra behavior here. + pass -def split4(data, max_stride, margin): +def split4(data, max_stride, margin): splits = [] data = torch.Tensor.numpy(data) - _,c, z, h, w = data.shape - - w_width = np.ceil(float(w / 2 + margin)/max_stride).astype('int')*max_stride - h_width = np.ceil(float(h / 2 + margin)/max_stride).astype('int')*max_stride - pad = int(np.ceil(float(z)/max_stride)*max_stride)-z - leftpad = pad/2 - pad = [[0,0],[0,0],[leftpad,pad-leftpad],[0,0],[0,0]] - data = np.pad(data,pad,'constant',constant_values=-1) + _, c, z, h, w = data.shape + + w_width = np.ceil(float(w / 2 + margin) / max_stride).astype('int') * max_stride + h_width = np.ceil(float(h / 2 + margin) / max_stride).astype('int') * max_stride + pad = int(np.ceil(float(z) / max_stride) * max_stride) - z + leftpad = pad / 2 + pad = [[0, 0], [0, 0], [leftpad, pad - leftpad], [0, 0], [0, 0]] + data = np.pad(data, pad, 'constant', constant_values=-1) data = torch.from_numpy(data) splits.append(data[:, :, :, :h_width, :w_width]) splits.append(data[:, :, :, :h_width, -w_width:]) splits.append(data[:, :, :, -h_width:, :w_width]) splits.append(data[:, :, :, -h_width:, -w_width:]) - + return torch.cat(splits, 0) + def combine4(output, h, w): splits = [] for i in range(len(output)): splits.append(output[i]) - + output = np.zeros(( splits[0].shape[0], h, @@ -101,37 +108,36 @@ def combine4(output, h, w): return output -def split8(data, max_stride, margin): + +def split8(data, max_stride, margin): splits = [] if isinstance(data, np.ndarray): c, z, h, w = data.shape else: - _,c, z, h, w = data.size() - - z_width = np.ceil(float(z / 2 + margin)/max_stride).astype('int')*max_stride - w_width = np.ceil(float(w / 2 + margin)/max_stride).astype('int')*max_stride - h_width = np.ceil(float(h / 2 + margin)/max_stride).astype('int')*max_stride - for zz in [[0,z_width],[-z_width,None]]: - for hh in [[0,h_width],[-h_width,None]]: - for ww in [[0,w_width],[-w_width,None]]: + _, c, z, h, w = data.size() + + z_width = np.ceil(float(z / 2 + margin) / max_stride).astype('int') * max_stride + w_width = np.ceil(float(w / 2 + margin) / max_stride).astype('int') * max_stride + h_width = np.ceil(float(h / 2 + margin) / max_stride).astype('int') * max_stride + for zz in [[0, z_width], [-z_width, None]]: + for hh in [[0, h_width], [-h_width, None]]: + for ww in [[0, w_width], [-w_width, None]]: if isinstance(data, np.ndarray): splits.append(data[np.newaxis, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) else: splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) - if isinstance(data, np.ndarray): return np.concatenate(splits, 0) else: return torch.cat(splits, 0) - def combine8(output, z, h, w): splits = [] for i in range(len(output)): splits.append(output[i]) - + output = np.zeros(( z, h, @@ -139,41 +145,42 @@ def combine8(output, z, h, w): splits[0].shape[3], splits[0].shape[4]), np.float32) - z_width = z / 2 h_width = h / 2 w_width = w / 2 i = 0 - for zz in [[0,z_width],[z_width-z,None]]: - for hh in [[0,h_width],[h_width-h,None]]: - for ww in [[0,w_width],[w_width-w,None]]: - output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] - i = i+1 - + for zz in [[0, z_width], [z_width - z, None]]: + for hh in [[0, h_width], [h_width - h, None]]: + for ww in [[0, w_width], [w_width - w, None]]: + output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], + :, :] + i = i + 1 + return output -def split16(data, max_stride, margin): +def split16(data, max_stride, margin): splits = [] - _,c, z, h, w = data.size() - - z_width = np.ceil(float(z / 4 + margin)/max_stride).astype('int')*max_stride - z_pos = [z*3/8-z_width/2, - z*5/8-z_width/2] - h_width = np.ceil(float(h / 2 + margin)/max_stride).astype('int')*max_stride - w_width = np.ceil(float(w / 2 + margin)/max_stride).astype('int')*max_stride - for zz in [[0,z_width],[z_pos[0],z_pos[0]+z_width],[z_pos[1],z_pos[1]+z_width],[-z_width,None]]: - for hh in [[0,h_width],[-h_width,None]]: - for ww in [[0,w_width],[-w_width,None]]: + _, c, z, h, w = data.size() + + z_width = np.ceil(float(z / 4 + margin) / max_stride).astype('int') * max_stride + z_pos = [z * 3 / 8 - z_width / 2, + z * 5 / 8 - z_width / 2] + h_width = np.ceil(float(h / 2 + margin) / max_stride).astype('int') * max_stride + w_width = np.ceil(float(w / 2 + margin) / max_stride).astype('int') * max_stride + for zz in [[0, z_width], [z_pos[0], z_pos[0] + z_width], [z_pos[1], z_pos[1] + z_width], [-z_width, None]]: + for hh in [[0, h_width], [-h_width, None]]: + for ww in [[0, w_width], [-w_width, None]]: splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) - + return torch.cat(splits, 0) + def combine16(output, z, h, w): splits = [] for i in range(len(output)): splits.append(output[i]) - + output = np.zeros(( z, h, @@ -181,45 +188,47 @@ def combine16(output, z, h, w): splits[0].shape[3], splits[0].shape[4]), np.float32) - z_width = z / 4 h_width = h / 2 w_width = w / 2 - splitzstart = splits[0].shape[0]/2-z_width/2 - z_pos = [z*3/8-z_width/2, - z*5/8-z_width/2] + splitzstart = splits[0].shape[0] / 2 - z_width / 2 + z_pos = [z * 3 / 8 - z_width / 2, + z * 5 / 8 - z_width / 2] i = 0 - for zz,zz2 in zip([[0,z_width],[z_width,z_width*2],[z_width*2,z_width*3],[z_width*3-z,None]], - [[0,z_width],[splitzstart,z_width+splitzstart],[splitzstart,z_width+splitzstart],[z_width*3-z,None]]): - for hh in [[0,h_width],[h_width-h,None]]: - for ww in [[0,w_width],[w_width-w,None]]: - output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz2[0]:zz2[1], hh[0]:hh[1], ww[0]:ww[1], :, :] - i = i+1 - + for zz, zz2 in zip([[0, z_width], [z_width, z_width * 2], [z_width * 2, z_width * 3], [z_width * 3 - z, None]], + [[0, z_width], [splitzstart, z_width + splitzstart], [splitzstart, z_width + splitzstart], + [z_width * 3 - z, None]]): + for hh in [[0, h_width], [h_width - h, None]]: + for ww in [[0, w_width], [w_width - w, None]]: + output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz2[0]:zz2[1], hh[0]:hh[1], ww[0]:ww[1], + :, :] + i = i + 1 + return output -def split32(data, max_stride, margin): + +def split32(data, max_stride, margin): splits = [] - _,c, z, h, w = data.size() - - z_width = np.ceil(float(z / 2 + margin)/max_stride).astype('int')*max_stride - w_width = np.ceil(float(w / 4 + margin)/max_stride).astype('int')*max_stride - h_width = np.ceil(float(h / 4 + margin)/max_stride).astype('int')*max_stride - - w_pos = [w*3/8-w_width/2, - w*5/8-w_width/2] - h_pos = [h*3/8-h_width/2, - h*5/8-h_width/2] - - for zz in [[0,z_width],[-z_width,None]]: - for hh in [[0,h_width],[h_pos[0],h_pos[0]+h_width],[h_pos[1],h_pos[1]+h_width],[-h_width,None]]: - for ww in [[0,w_width],[w_pos[0],w_pos[0]+w_width],[w_pos[1],w_pos[1]+w_width],[-w_width,None]]: + _, c, z, h, w = data.size() + + z_width = np.ceil(float(z / 2 + margin) / max_stride).astype('int') * max_stride + w_width = np.ceil(float(w / 4 + margin) / max_stride).astype('int') * max_stride + h_width = np.ceil(float(h / 4 + margin) / max_stride).astype('int') * max_stride + + w_pos = [w * 3 / 8 - w_width / 2, + w * 5 / 8 - w_width / 2] + h_pos = [h * 3 / 8 - h_width / 2, + h * 5 / 8 - h_width / 2] + + for zz in [[0, z_width], [-z_width, None]]: + for hh in [[0, h_width], [h_pos[0], h_pos[0] + h_width], [h_pos[1], h_pos[1] + h_width], [-h_width, None]]: + for ww in [[0, w_width], [w_pos[0], w_pos[0] + w_width], [w_pos[1], w_pos[1] + w_width], [-w_width, None]]: splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) - + return torch.cat(splits, 0) + def combine32(splits, z, h, w): - output = np.zeros(( z, h, @@ -227,56 +236,58 @@ def combine32(splits, z, h, w): splits[0].shape[3], splits[0].shape[4]), np.float32) - z_width = int(np.ceil(float(z) / 2)) h_width = int(np.ceil(float(h) / 4)) w_width = int(np.ceil(float(w) / 4)) - splithstart = splits[0].shape[1]/2-h_width/2 - splitwstart = splits[0].shape[2]/2-w_width/2 - + splithstart = splits[0].shape[1] / 2 - h_width / 2 + splitwstart = splits[0].shape[2] / 2 - w_width / 2 + i = 0 - for zz in [[0,z_width],[z_width-z,None]]: - - for hh,hh2 in zip([[0,h_width],[h_width,h_width*2],[h_width*2,h_width*3],[h_width*3-h,None]], - [[0,h_width],[splithstart,h_width+splithstart],[splithstart,h_width+splithstart],[h_width*3-h,None]]): - - for ww,ww2 in zip([[0,w_width],[w_width,w_width*2],[w_width*2,w_width*3],[w_width*3-w,None]], - [[0,w_width],[splitwstart,w_width+splitwstart],[splitwstart,w_width+splitwstart],[w_width*3-w,None]]): - - output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz[0]:zz[1], hh2[0]:hh2[1], ww2[0]:ww2[1], :, :] - i = i+1 - - return output + for zz in [[0, z_width], [z_width - z, None]]: + for hh, hh2 in zip([[0, h_width], [h_width, h_width * 2], [h_width * 2, h_width * 3], [h_width * 3 - h, None]], + [[0, h_width], [splithstart, h_width + splithstart], [splithstart, h_width + splithstart], + [h_width * 3 - h, None]]): + for ww, ww2 in zip( + [[0, w_width], [w_width, w_width * 2], [w_width * 2, w_width * 3], [w_width * 3 - w, None]], + [[0, w_width], [splitwstart, w_width + splitwstart], [splitwstart, w_width + splitwstart], + [w_width * 3 - w, None]]): + output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz[0]:zz[1], hh2[0]:hh2[1], + ww2[0]:ww2[1], :, :] + i = i + 1 -def split64(data, max_stride, margin): + return output + + +def split64(data, max_stride, margin): splits = [] - _,c, z, h, w = data.size() - - z_width = np.ceil(float(z / 4 + margin)/max_stride).astype('int')*max_stride - w_width = np.ceil(float(w / 4 + margin)/max_stride).astype('int')*max_stride - h_width = np.ceil(float(h / 4 + margin)/max_stride).astype('int')*max_stride - - z_pos = [z*3/8-z_width/2, - z*5/8-z_width/2] - w_pos = [w*3/8-w_width/2, - w*5/8-w_width/2] - h_pos = [h*3/8-h_width/2, - h*5/8-h_width/2] - - for zz in [[0,z_width],[z_pos[0],z_pos[0]+z_width],[z_pos[1],z_pos[1]+z_width],[-z_width,None]]: - for hh in [[0,h_width],[h_pos[0],h_pos[0]+h_width],[h_pos[1],h_pos[1]+h_width],[-h_width,None]]: - for ww in [[0,w_width],[w_pos[0],w_pos[0]+w_width],[w_pos[1],w_pos[1]+w_width],[-w_width,None]]: + _, c, z, h, w = data.size() + + z_width = np.ceil(float(z / 4 + margin) / max_stride).astype('int') * max_stride + w_width = np.ceil(float(w / 4 + margin) / max_stride).astype('int') * max_stride + h_width = np.ceil(float(h / 4 + margin) / max_stride).astype('int') * max_stride + + z_pos = [z * 3 / 8 - z_width / 2, + z * 5 / 8 - z_width / 2] + w_pos = [w * 3 / 8 - w_width / 2, + w * 5 / 8 - w_width / 2] + h_pos = [h * 3 / 8 - h_width / 2, + h * 5 / 8 - h_width / 2] + + for zz in [[0, z_width], [z_pos[0], z_pos[0] + z_width], [z_pos[1], z_pos[1] + z_width], [-z_width, None]]: + for hh in [[0, h_width], [h_pos[0], h_pos[0] + h_width], [h_pos[1], h_pos[1] + h_width], [-h_width, None]]: + for ww in [[0, w_width], [w_pos[0], w_pos[0] + w_width], [w_pos[1], w_pos[1] + w_width], [-w_width, None]]: splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) - + return torch.cat(splits, 0) + def combine64(output, z, h, w): splits = [] for i in range(len(output)): splits.append(output[i]) - + output = np.zeros(( z, h, @@ -284,25 +295,28 @@ def combine64(output, z, h, w): splits[0].shape[3], splits[0].shape[4]), np.float32) - z_width = int(np.ceil(float(z) / 4)) h_width = int(np.ceil(float(h) / 4)) w_width = int(np.ceil(float(w) / 4)) - splitzstart = splits[0].shape[0]/2-z_width/2 - splithstart = splits[0].shape[1]/2-h_width/2 - splitwstart = splits[0].shape[2]/2-w_width/2 - + splitzstart = splits[0].shape[0] / 2 - z_width / 2 + splithstart = splits[0].shape[1] / 2 - h_width / 2 + splitwstart = splits[0].shape[2] / 2 - w_width / 2 + i = 0 - for zz,zz2 in zip([[0,z_width],[z_width,z_width*2],[z_width*2,z_width*3],[z_width*3-z,None]], - [[0,z_width],[splitzstart,z_width+splitzstart],[splitzstart,z_width+splitzstart],[z_width*3-z,None]]): - - for hh,hh2 in zip([[0,h_width],[h_width,h_width*2],[h_width*2,h_width*3],[h_width*3-h,None]], - [[0,h_width],[splithstart,h_width+splithstart],[splithstart,h_width+splithstart],[h_width*3-h,None]]): - - for ww,ww2 in zip([[0,w_width],[w_width,w_width*2],[w_width*2,w_width*3],[w_width*3-w,None]], - [[0,w_width],[splitwstart,w_width+splitwstart],[splitwstart,w_width+splitwstart],[w_width*3-w,None]]): - - output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz2[0]:zz2[1], hh2[0]:hh2[1], ww2[0]:ww2[1], :, :] - i = i+1 - + for zz, zz2 in zip([[0, z_width], [z_width, z_width * 2], [z_width * 2, z_width * 3], [z_width * 3 - z, None]], + [[0, z_width], [splitzstart, z_width + splitzstart], [splitzstart, z_width + splitzstart], + [z_width * 3 - z, None]]): + + for hh, hh2 in zip([[0, h_width], [h_width, h_width * 2], [h_width * 2, h_width * 3], [h_width * 3 - h, None]], + [[0, h_width], [splithstart, h_width + splithstart], [splithstart, h_width + splithstart], + [h_width * 3 - h, None]]): + + for ww, ww2 in zip( + [[0, w_width], [w_width, w_width * 2], [w_width * 2, w_width * 3], [w_width * 3 - w, None]], + [[0, w_width], [splitwstart, w_width + splitwstart], [splitwstart, w_width + splitwstart], + [w_width * 3 - w, None]]): + output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz2[0]:zz2[1], hh2[0]:hh2[1], + ww2[0]:ww2[1], :, :] + i = i + 1 + return output diff --git a/training/config_training.py b/training/config_training.py index 20cd7e2..f8b3014 100644 --- a/training/config_training.py +++ b/training/config_training.py @@ -1,17 +1,17 @@ -config = {'stage1_data_path':'/work/DataBowl3/stage1/stage1/', - 'luna_raw':'/work/DataBowl3/luna/raw/', - 'luna_segment':'/work/DataBowl3/luna/seg-lungs-LUNA16/', - - 'luna_data':'/work/DataBowl3/luna/allset', - 'preprocess_result_path':'/work/DataBowl3/stage1/preprocess/', - - 'luna_abbr':'./detector/labels/shorter.csv', - 'luna_label':'./detector/labels/lunaqualified.csv', - 'stage1_annos_path':['./detector/labels/label_job5.csv', - './detector/labels/label_job4_2.csv', - './detector/labels/label_job4_1.csv', - './detector/labels/label_job0.csv', - './detector/labels/label_qualified.csv'], - 'bbox_path':'../detector/results/res18/bbox/', - 'preprocessing_backend':'python' - } +config = {'stage1_data_path': '/work/DataBowl3/stage1/stage1/', + 'luna_raw': '/work/DataBowl3/luna/raw/', + 'luna_segment': '/work/DataBowl3/luna/seg-lungs-LUNA16/', + + 'luna_data': '/work/DataBowl3/luna/allset', + 'preprocess_result_path': '/work/DataBowl3/stage1/preprocess/', + + 'luna_abbr': './detector/labels/shorter.csv', + 'luna_label': './detector/labels/lunaqualified.csv', + 'stage1_annos_path': ['./detector/labels/label_job5.csv', + './detector/labels/label_job4_2.csv', + './detector/labels/label_job4_1.csv', + './detector/labels/label_job0.csv', + './detector/labels/label_qualified.csv'], + 'bbox_path': '../detector/results/res18/bbox/', + 'preprocessing_backend': 'python' + } diff --git a/training/detector/data.py b/training/detector/data.py index 9cdff4f..d01fa23 100644 --- a/training/detector/data.py +++ b/training/detector/data.py @@ -1,24 +1,25 @@ -import numpy as np -import torch -from torch.utils.data import Dataset -import os -import time import collections +import os import random -from layers import iou -from scipy.ndimage import zoom +import time import warnings + +import numpy as np +import torch +from scipy.ndimage import zoom from scipy.ndimage.interpolation import rotate +from torch.utils.data import Dataset + class DataBowl3Detector(Dataset): - def __init__(self, data_dir, split_path, config, phase = 'train',split_comber=None): - assert(phase == 'train' or phase == 'val' or phase == 'test') + def __init__(self, data_dir, split_path, config, phase='train', split_comber=None): + assert (phase == 'train' or phase == 'val' or phase == 'test') self.phase = phase - self.max_stride = config['max_stride'] - self.stride = config['stride'] - sizelim = config['sizelim']/config['reso'] - sizelim2 = config['sizelim2']/config['reso'] - sizelim3 = config['sizelim3']/config['reso'] + self.max_stride = config['max_stride'] + self.stride = config['stride'] + sizelim = config['sizelim'] / config['reso'] + sizelim2 = config['sizelim2'] / config['reso'] + sizelim3 = config['sizelim3'] / config['reso'] self.blacklist = config['blacklist'] self.isScale = config['aug_scale'] self.r_rand = config['r_rand_crop'] @@ -26,74 +27,76 @@ def __init__(self, data_dir, split_path, config, phase = 'train',split_comber=No self.pad_value = config['pad_value'] self.split_comber = split_comber idcs = np.load(split_path) - if phase!='test': + if phase != 'test': idcs = [f for f in idcs if (f not in self.blacklist)] self.filenames = [os.path.join(data_dir, '%s_clean.npy' % idx) for idx in idcs] - self.kagglenames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0])>20] - self.lunanames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0])<20] - + self.kagglenames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0]) > 20] + self.lunanames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0]) < 20] + labels = [] - + for idx in idcs: - l = np.load(os.path.join(data_dir, '%s_label.npy' %idx)) - if np.all(l==0): - l=np.array([]) + l = np.load(os.path.join(data_dir, '%s_label.npy' % idx)) + if np.all(l == 0): + l = np.array([]) labels.append(l) self.sample_bboxes = labels if self.phase != 'test': self.bboxes = [] for i, l in enumerate(labels): - if len(l) > 0 : + if len(l) > 0: for t in l: - if t[3]>sizelim: - self.bboxes.append([np.concatenate([[i],t])]) - if t[3]>sizelim2: - self.bboxes+=[[np.concatenate([[i],t])]]*2 - if t[3]>sizelim3: - self.bboxes+=[[np.concatenate([[i],t])]]*4 - self.bboxes = np.concatenate(self.bboxes,axis = 0) + if t[3] > sizelim: + self.bboxes.append([np.concatenate([[i], t])]) + if t[3] > sizelim2: + self.bboxes += [[np.concatenate([[i], t])]] * 2 + if t[3] > sizelim3: + self.bboxes += [[np.concatenate([[i], t])]] * 4 + self.bboxes = np.concatenate(self.bboxes, axis=0) self.crop = Crop(config) self.label_mapping = LabelMapping(config, self.phase) - def __getitem__(self, idx,split=None): + def __getitem__(self, idx, split=None): t = time.time() - np.random.seed(int(str(t%1)[2:7]))#seed according to time + np.random.seed(int(str(t % 1)[2:7])) # seed according to time - isRandomImg = False - if self.phase !='test': - if idx>=len(self.bboxes): + isRandomImg = False + if self.phase != 'test': + if idx >= len(self.bboxes): isRandom = True - idx = idx%len(self.bboxes) + idx = idx % len(self.bboxes) isRandomImg = np.random.randint(2) else: isRandom = False else: isRandom = False - + if self.phase != 'test': if not isRandomImg: bbox = self.bboxes[idx] filename = self.filenames[int(bbox[0])] imgs = np.load(filename) bboxes = self.sample_bboxes[int(bbox[0])] - isScale = self.augtype['scale'] and (self.phase=='train') - sample, target, bboxes, coord = self.crop(imgs, bbox[1:], bboxes,isScale,isRandom) - if self.phase=='train' and not isRandom: - sample, target, bboxes, coord = augment(sample, target, bboxes, coord, - ifflip = self.augtype['flip'], ifrotate=self.augtype['rotate'], ifswap = self.augtype['swap']) + isScale = self.augtype['scale'] and (self.phase == 'train') + sample, target, bboxes, coord = self.crop(imgs, bbox[1:], bboxes, isScale, isRandom) + if self.phase == 'train' and not isRandom: + sample, target, bboxes, coord = augment(sample, target, bboxes, coord, + ifflip=self.augtype['flip'], + ifrotate=self.augtype['rotate'], + ifswap=self.augtype['swap']) else: randimid = np.random.randint(len(self.kagglenames)) filename = self.kagglenames[randimid] imgs = np.load(filename) bboxes = self.sample_bboxes[randimid] - isScale = self.augtype['scale'] and (self.phase=='train') - sample, target, bboxes, coord = self.crop(imgs, [], bboxes,isScale=False,isRand=True) + isScale = self.augtype['scale'] and (self.phase == 'train') + sample, target, bboxes, coord = self.crop(imgs, [], bboxes, isScale=False, isRand=True) label = self.label_mapping(sample.shape[1:], target, bboxes) - sample = (sample.astype(np.float32)-128)/128 - #if filename in self.kagglenames and self.phase=='train': + sample = (sample.astype(np.float32) - 128) / 128 + # if filename in self.kagglenames and self.phase=='train': # label[label==-1]=0 return torch.from_numpy(sample), torch.from_numpy(label), coord else: @@ -103,70 +106,73 @@ def __getitem__(self, idx,split=None): pz = int(np.ceil(float(nz) / self.stride)) * self.stride ph = int(np.ceil(float(nh) / self.stride)) * self.stride pw = int(np.ceil(float(nw) / self.stride)) * self.stride - imgs = np.pad(imgs, [[0,0],[0, pz - nz], [0, ph - nh], [0, pw - nw]], 'constant',constant_values = self.pad_value) - - xx,yy,zz = np.meshgrid(np.linspace(-0.5,0.5,imgs.shape[1]/self.stride), - np.linspace(-0.5,0.5,imgs.shape[2]/self.stride), - np.linspace(-0.5,0.5,imgs.shape[3]/self.stride),indexing ='ij') - coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32') + imgs = np.pad(imgs, [[0, 0], [0, pz - nz], [0, ph - nh], [0, pw - nw]], 'constant', + constant_values=self.pad_value) + + xx, yy, zz = np.meshgrid(np.linspace(-0.5, 0.5, imgs.shape[1] / self.stride), + np.linspace(-0.5, 0.5, imgs.shape[2] / self.stride), + np.linspace(-0.5, 0.5, imgs.shape[3] / self.stride), indexing='ij') + coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32') imgs, nzhw = self.split_comber.split(imgs) coord2, nzhw2 = self.split_comber.split(coord, - side_len = self.split_comber.side_len/self.stride, - max_stride = self.split_comber.max_stride/self.stride, - margin = self.split_comber.margin/self.stride) - assert np.all(nzhw==nzhw2) - imgs = (imgs.astype(np.float32)-128)/128 + side_len=self.split_comber.side_len / self.stride, + max_stride=self.split_comber.max_stride / self.stride, + margin=self.split_comber.margin / self.stride) + assert np.all(nzhw == nzhw2) + imgs = (imgs.astype(np.float32) - 128) / 128 return torch.from_numpy(imgs), bboxes, torch.from_numpy(coord2), np.array(nzhw) def __len__(self): if self.phase == 'train': - return len(self.bboxes)/(1-self.r_rand) - elif self.phase =='val': + return len(self.bboxes) / (1 - self.r_rand) + elif self.phase == 'val': return len(self.bboxes) else: return len(self.sample_bboxes) - - -def augment(sample, target, bboxes, coord, ifflip = True, ifrotate=True, ifswap = True): + + +def augment(sample, target, bboxes, coord, ifflip=True, ifrotate=True, ifswap=True): # angle1 = np.random.rand()*180 if ifrotate: validrot = False counter = 0 while not validrot: newtarget = np.copy(target) - angle1 = np.random.rand()*180 + angle1 = np.random.rand() * 180 size = np.array(sample.shape[2:4]).astype('float') - rotmat = np.array([[np.cos(angle1/180*np.pi),-np.sin(angle1/180*np.pi)],[np.sin(angle1/180*np.pi),np.cos(angle1/180*np.pi)]]) - newtarget[1:3] = np.dot(rotmat,target[1:3]-size/2)+size/2 - if np.all(newtarget[:3]>target[3]) and np.all(newtarget[:3]< np.array(sample.shape[1:4])-newtarget[3]): + rotmat = np.array([[np.cos(angle1 / 180 * np.pi), -np.sin(angle1 / 180 * np.pi)], + [np.sin(angle1 / 180 * np.pi), np.cos(angle1 / 180 * np.pi)]]) + newtarget[1:3] = np.dot(rotmat, target[1:3] - size / 2) + size / 2 + if np.all(newtarget[:3] > target[3]) and np.all(newtarget[:3] < np.array(sample.shape[1:4]) - newtarget[3]): validrot = True target = newtarget - sample = rotate(sample,angle1,axes=(2,3),reshape=False) - coord = rotate(coord,angle1,axes=(2,3),reshape=False) + sample = rotate(sample, angle1, axes=(2, 3), reshape=False) + coord = rotate(coord, angle1, axes=(2, 3), reshape=False) for box in bboxes: - box[1:3] = np.dot(rotmat,box[1:3]-size/2)+size/2 + box[1:3] = np.dot(rotmat, box[1:3] - size / 2) + size / 2 else: counter += 1 - if counter ==3: + if counter == 3: break if ifswap: - if sample.shape[1]==sample.shape[2] and sample.shape[1]==sample.shape[3]: + if sample.shape[1] == sample.shape[2] and sample.shape[1] == sample.shape[3]: axisorder = np.random.permutation(3) - sample = np.transpose(sample,np.concatenate([[0],axisorder+1])) - coord = np.transpose(coord,np.concatenate([[0],axisorder+1])) + sample = np.transpose(sample, np.concatenate([[0], axisorder + 1])) + coord = np.transpose(coord, np.concatenate([[0], axisorder + 1])) target[:3] = target[:3][axisorder] - bboxes[:,:3] = bboxes[:,:3][:,axisorder] - + bboxes[:, :3] = bboxes[:, :3][:, axisorder] + if ifflip: -# flipid = np.array([np.random.randint(2),np.random.randint(2),np.random.randint(2)])*2-1 - flipid = np.array([1,np.random.randint(2),np.random.randint(2)])*2-1 - sample = np.ascontiguousarray(sample[:,::flipid[0],::flipid[1],::flipid[2]]) - coord = np.ascontiguousarray(coord[:,::flipid[0],::flipid[1],::flipid[2]]) + # flipid = np.array([np.random.randint(2),np.random.randint(2),np.random.randint(2)])*2-1 + flipid = np.array([1, np.random.randint(2), np.random.randint(2)]) * 2 - 1 + sample = np.ascontiguousarray(sample[:, ::flipid[0], ::flipid[1], ::flipid[2]]) + coord = np.ascontiguousarray(coord[:, ::flipid[0], ::flipid[1], ::flipid[2]]) for ax in range(3): - if flipid[ax]==-1: - target[ax] = np.array(sample.shape[ax+1])-target[ax] - bboxes[:,ax]= np.array(sample.shape[ax+1])-bboxes[:,ax] - return sample, target, bboxes, coord + if flipid[ax] == -1: + target[ax] = np.array(sample.shape[ax + 1]) - target[ax] + bboxes[:, ax] = np.array(sample.shape[ax + 1]) - bboxes[:, ax] + return sample, target, bboxes, coord + class Crop(object): def __init__(self, config): @@ -174,77 +180,79 @@ def __init__(self, config): self.bound_size = config['bound_size'] self.stride = config['stride'] self.pad_value = config['pad_value'] - def __call__(self, imgs, target, bboxes,isScale=False,isRand=False): + + def __call__(self, imgs, target, bboxes, isScale=False, isRand=False): if isScale: - radiusLim = [8.,120.] - scaleLim = [0.75,1.25] - scaleRange = [np.min([np.max([(radiusLim[0]/target[3]),scaleLim[0]]),1]) - ,np.max([np.min([(radiusLim[1]/target[3]),scaleLim[1]]),1])] - scale = np.random.rand()*(scaleRange[1]-scaleRange[0])+scaleRange[0] - crop_size = (np.array(self.crop_size).astype('float')/scale).astype('int') + radiusLim = [8., 120.] + scaleLim = [0.75, 1.25] + scaleRange = [np.min([np.max([(radiusLim[0] / target[3]), scaleLim[0]]), 1]) + , np.max([np.min([(radiusLim[1] / target[3]), scaleLim[1]]), 1])] + scale = np.random.rand() * (scaleRange[1] - scaleRange[0]) + scaleRange[0] + crop_size = (np.array(self.crop_size).astype('float') / scale).astype('int') else: - crop_size=self.crop_size + crop_size = self.crop_size bound_size = self.bound_size target = np.copy(target) bboxes = np.copy(bboxes) - + start = [] for i in range(3): if not isRand: r = target[3] / 2 - s = np.floor(target[i] - r)+ 1 - bound_size - e = np.ceil (target[i] + r)+ 1 + bound_size - crop_size[i] + s = np.floor(target[i] - r) + 1 - bound_size + e = np.ceil(target[i] + r) + 1 + bound_size - crop_size[i] else: - s = np.max([imgs.shape[i+1]-crop_size[i]/2,imgs.shape[i+1]/2+bound_size]) - e = np.min([crop_size[i]/2, imgs.shape[i+1]/2-bound_size]) - target = np.array([np.nan,np.nan,np.nan,np.nan]) - if s>e: - start.append(np.random.randint(e,s))#! + s = np.max([imgs.shape[i + 1] - crop_size[i] / 2, imgs.shape[i + 1] / 2 + bound_size]) + e = np.min([crop_size[i] / 2, imgs.shape[i + 1] / 2 - bound_size]) + target = np.array([np.nan, np.nan, np.nan, np.nan]) + if s > e: + start.append(np.random.randint(e, s)) # ! else: - start.append(int(target[i])-crop_size[i]/2+np.random.randint(-bound_size/2,bound_size/2)) - - - normstart = np.array(start).astype('float32')/np.array(imgs.shape[1:])-0.5 - normsize = np.array(crop_size).astype('float32')/np.array(imgs.shape[1:]) - xx,yy,zz = np.meshgrid(np.linspace(normstart[0],normstart[0]+normsize[0],self.crop_size[0]/self.stride), - np.linspace(normstart[1],normstart[1]+normsize[1],self.crop_size[1]/self.stride), - np.linspace(normstart[2],normstart[2]+normsize[2],self.crop_size[2]/self.stride),indexing ='ij') - coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32') + start.append(int(target[i]) - crop_size[i] / 2 + np.random.randint(-bound_size / 2, bound_size / 2)) + + normstart = np.array(start).astype('float32') / np.array(imgs.shape[1:]) - 0.5 + normsize = np.array(crop_size).astype('float32') / np.array(imgs.shape[1:]) + xx, yy, zz = np.meshgrid(np.linspace(normstart[0], normstart[0] + normsize[0], self.crop_size[0] / self.stride), + np.linspace(normstart[1], normstart[1] + normsize[1], self.crop_size[1] / self.stride), + np.linspace(normstart[2], normstart[2] + normsize[2], self.crop_size[2] / self.stride), + indexing='ij') + coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32') pad = [] - pad.append([0,0]) + pad.append([0, 0]) for i in range(3): - leftpad = max(0,-start[i]) - rightpad = max(0,start[i]+crop_size[i]-imgs.shape[i+1]) - pad.append([leftpad,rightpad]) + leftpad = max(0, -start[i]) + rightpad = max(0, start[i] + crop_size[i] - imgs.shape[i + 1]) + pad.append([leftpad, rightpad]) crop = imgs[:, - max(start[0],0):min(start[0] + crop_size[0],imgs.shape[1]), - max(start[1],0):min(start[1] + crop_size[1],imgs.shape[2]), - max(start[2],0):min(start[2] + crop_size[2],imgs.shape[3])] - crop = np.pad(crop,pad,'constant',constant_values =self.pad_value) + max(start[0], 0):min(start[0] + crop_size[0], imgs.shape[1]), + max(start[1], 0):min(start[1] + crop_size[1], imgs.shape[2]), + max(start[2], 0):min(start[2] + crop_size[2], imgs.shape[3])] + crop = np.pad(crop, pad, 'constant', constant_values=self.pad_value) for i in range(3): - target[i] = target[i] - start[i] + target[i] = target[i] - start[i] for i in range(len(bboxes)): for j in range(3): - bboxes[i][j] = bboxes[i][j] - start[j] - + bboxes[i][j] = bboxes[i][j] - start[j] + if isScale: with warnings.catch_warnings(): warnings.simplefilter("ignore") - crop = zoom(crop,[1,scale,scale,scale],order=1) - newpad = self.crop_size[0]-crop.shape[1:][0] - if newpad<0: - crop = crop[:,:-newpad,:-newpad,:-newpad] - elif newpad>0: - pad2 = [[0,0],[0,newpad],[0,newpad],[0,newpad]] - crop = np.pad(crop,pad2,'constant',constant_values =self.pad_value) + crop = zoom(crop, [1, scale, scale, scale], order=1) + newpad = self.crop_size[0] - crop.shape[1:][0] + if newpad < 0: + crop = crop[:, :-newpad, :-newpad, :-newpad] + elif newpad > 0: + pad2 = [[0, 0], [0, newpad], [0, newpad], [0, newpad]] + crop = np.pad(crop, pad2, 'constant', constant_values=self.pad_value) for i in range(4): - target[i] = target[i]*scale + target[i] = target[i] * scale for i in range(len(bboxes)): for j in range(4): - bboxes[i][j] = bboxes[i][j]*scale + bboxes[i][j] = bboxes[i][j] * scale return crop, target, bboxes, coord - + + class LabelMapping(object): def __init__(self, config, phase): self.stride = np.array(config['stride']) @@ -257,19 +265,18 @@ def __init__(self, config, phase): elif phase == 'val': self.th_pos = config['th_pos_val'] - def __call__(self, input_size, target, bboxes): stride = self.stride num_neg = self.num_neg th_neg = self.th_neg anchors = self.anchors th_pos = self.th_pos - + output_size = [] for i in range(3): - assert(input_size[i] % stride == 0) + assert (input_size[i] % stride == 0) output_size.append(input_size[i] / stride) - + label = -1 * np.ones(output_size + [len(anchors), 5], np.float32) offset = ((stride.astype('float')) - 1) / 2 oz = np.arange(offset, offset + stride * (output_size[0] - 1) + 1, stride) @@ -301,7 +308,7 @@ def __call__(self, input_size, target, bboxes): ih = np.concatenate(ih, 0) iw = np.concatenate(iw, 0) ia = np.concatenate(ia, 0) - flag = True + flag = True if len(iz) == 0: pos = [] for i in range(3): @@ -317,7 +324,8 @@ def __call__(self, input_size, target, bboxes): dw = (target[2] - ow[pos[2]]) / anchors[pos[3]] dd = np.log(target[3] / anchors[pos[3]]) label[pos[0], pos[1], pos[2], pos[3], :] = [1, dz, dh, dw, dd] - return label + return label + def select_samples(bbox, anchor, th, oz, oh, ow): z, h, w, d = bbox @@ -330,12 +338,12 @@ def select_samples(bbox, anchor, th, oz, oh, ow): e = z + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap) mz = np.logical_and(oz >= s, oz <= e) iz = np.where(mz)[0] - + s = h - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap) e = h + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap) mh = np.logical_and(oh >= s, oh <= e) ih = np.where(mh)[0] - + s = w - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap) e = w + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap) mw = np.logical_and(ow >= s, ow <= e) @@ -343,7 +351,7 @@ def select_samples(bbox, anchor, th, oz, oh, ow): if len(iz) == 0 or len(ih) == 0 or len(iw) == 0: return np.zeros((0,), np.int64), np.zeros((0,), np.int64), np.zeros((0,), np.int64) - + lz, lh, lw = len(iz), len(ih), len(iw) iz = iz.reshape((-1, 1, 1)) ih = ih.reshape((1, -1, 1)) @@ -354,36 +362,37 @@ def select_samples(bbox, anchor, th, oz, oh, ow): centers = np.concatenate([ oz[iz].reshape((-1, 1)), oh[ih].reshape((-1, 1)), - ow[iw].reshape((-1, 1))], axis = 1) - + ow[iw].reshape((-1, 1))], axis=1) + r0 = anchor / 2 s0 = centers - r0 e0 = centers + r0 - + r1 = d / 2 s1 = bbox[:3] - r1 s1 = s1.reshape((1, -1)) e1 = bbox[:3] + r1 e1 = e1.reshape((1, -1)) - + overlap = np.maximum(0, np.minimum(e0, e1) - np.maximum(s0, s1)) - + intersection = overlap[:, 0] * overlap[:, 1] * overlap[:, 2] union = anchor * anchor * anchor + d * d * d - intersection iou = intersection / union mask = iou >= th - #if th > 0.4: - # if np.sum(mask) == 0: - # print(['iou not large', iou.max()]) - # else: - # print(['iou large', iou[mask]]) + # if th > 0.4: + # if np.sum(mask) == 0: + # print(['iou not large', iou.max()]) + # else: + # print(['iou large', iou[mask]]) iz = iz[mask] ih = ih[mask] iw = iw[mask] return iz, ih, iw + def collate(batch): if torch.is_tensor(batch[0]): return [b.unsqueeze(0) for b in batch] @@ -394,4 +403,3 @@ def collate(batch): elif isinstance(batch[0], collections.Iterable): transposed = zip(*batch) return [collate(samples) for samples in transposed] - diff --git a/training/detector/detect.py b/training/detector/detect.py index 378d285..99f939e 100644 --- a/training/detector/detect.py +++ b/training/detector/detect.py @@ -1,7 +1,9 @@ -import numpy as np -from layers import nms, iou, acc -import time import multiprocessing as mp +import time + +import numpy as np + +from layers import acc save_dir = 'results/ma_offset40_res_n6_100-1/' pbb = np.load(save_dir + 'pbb.npy') @@ -10,13 +12,15 @@ conf_th = [-1, 0, 1] nms_th = [0.3, 0.5, 0.7] detect_th = [0.2, 0.3] -def mp_get_pr(conf_th, nms_th, detect_th, num_procs = 64): + + +def mp_get_pr(conf_th, nms_th, detect_th, num_procs=64): start_time = time.time() - + num_samples = len(pbb) split_size = int(np.ceil(float(num_samples) / num_procs)) num_procs = int(np.ceil(float(num_samples) / split_size)) - + manager = mp.Manager() tp = manager.list(range(num_procs)) fp = manager.list(range(num_procs)) @@ -24,23 +28,25 @@ def mp_get_pr(conf_th, nms_th, detect_th, num_procs = 64): procs = [] for pid in range(num_procs): proc = mp.Process( - target = get_pr, - args = ( + target=get_pr, + args=( pbb[pid * split_size:min((pid + 1) * split_size, num_samples)], lbb[pid * split_size:min((pid + 1) * split_size, num_samples)], conf_th, nms_th, detect_th, pid, tp, fp, p)) procs.append(proc) proc.start() - + for proc in procs: proc.join() tp = np.sum(tp) fp = np.sum(fp) p = np.sum(p) - + end_time = time.time() - print('conf_th %1.1f, nms_th %1.1f, detect_th %1.1f, tp %d, fp %d, p %d, recall %f, time %3.2f' % (conf_th, nms_th, detect_th, tp, fp, p, float(tp) / p, end_time - start_time)) + print('conf_th %1.1f, nms_th %1.1f, detect_th %1.1f, tp %d, fp %d, p %d, recall %f, time %3.2f' % ( + conf_th, nms_th, detect_th, tp, fp, p, float(tp) / p, end_time - start_time)) + def get_pr(pbb, lbb, conf_th, nms_th, detect_th, pid, tp_list, fp_list, p_list): tp, fp, p = 0, 0, 0 @@ -53,6 +59,7 @@ def get_pr(pbb, lbb, conf_th, nms_th, detect_th, pid, tp_list, fp_list, p_list): fp_list[pid] = fp p_list[pid] = p + if __name__ == '__main__': for ct in conf_th: for nt in nms_th: diff --git a/training/detector/layers.py b/training/detector/layers.py index 939b7be..aadb473 100644 --- a/training/detector/layers.py +++ b/training/detector/layers.py @@ -2,20 +2,20 @@ import torch from torch import nn -import math + class PostRes2d(nn.Module): - def __init__(self, n_in, n_out, stride = 1): + def __init__(self, n_in, n_out, stride=1): super(PostRes2d, self).__init__() - self.conv1 = nn.Conv2d(n_in, n_out, kernel_size = 3, stride = stride, padding = 1) + self.conv1 = nn.Conv2d(n_in, n_out, kernel_size=3, stride=stride, padding=1) self.bn1 = nn.BatchNorm2d(n_out) - self.relu = nn.ReLU(inplace = True) - self.conv2 = nn.Conv2d(n_out, n_out, kernel_size = 3, padding = 1) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(n_out, n_out, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(n_out) if stride != 1 or n_out != n_in: self.shortcut = nn.Sequential( - nn.Conv2d(n_in, n_out, kernel_size = 1, stride = stride), + nn.Conv2d(n_in, n_out, kernel_size=1, stride=stride), nn.BatchNorm2d(n_out)) else: self.shortcut = None @@ -29,23 +29,24 @@ def forward(self, x): out = self.relu(out) out = self.conv2(out) out = self.bn2(out) - + out += residual out = self.relu(out) return out - + + class PostRes(nn.Module): - def __init__(self, n_in, n_out, stride = 1): + def __init__(self, n_in, n_out, stride=1): super(PostRes, self).__init__() - self.conv1 = nn.Conv3d(n_in, n_out, kernel_size = 3, stride = stride, padding = 1) + self.conv1 = nn.Conv3d(n_in, n_out, kernel_size=3, stride=stride, padding=1) self.bn1 = nn.BatchNorm3d(n_out) - self.relu = nn.ReLU(inplace = True) - self.conv2 = nn.Conv3d(n_out, n_out, kernel_size = 3, padding = 1) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv3d(n_out, n_out, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm3d(n_out) if stride != 1 or n_out != n_in: self.shortcut = nn.Sequential( - nn.Conv3d(n_in, n_out, kernel_size = 1, stride = stride), + nn.Conv3d(n_in, n_out, kernel_size=1, stride=stride), nn.BatchNorm3d(n_out)) else: self.shortcut = None @@ -59,72 +60,73 @@ def forward(self, x): out = self.relu(out) out = self.conv2(out) out = self.bn2(out) - + out += residual out = self.relu(out) return out + class Rec3(nn.Module): - def __init__(self, n0, n1, n2, n3, p = 0.0, integrate = True): + def __init__(self, n0, n1, n2, n3, p=0.0, integrate=True): super(Rec3, self).__init__() - + self.block01 = nn.Sequential( - nn.Conv3d(n0, n1, kernel_size = 3, stride = 2, padding = 1), + nn.Conv3d(n0, n1, kernel_size=3, stride=2, padding=1), nn.BatchNorm3d(n1), - nn.ReLU(inplace = True), - nn.Conv3d(n1, n1, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n1, n1, kernel_size=3, padding=1), nn.BatchNorm3d(n1)) self.block11 = nn.Sequential( - nn.Conv3d(n1, n1, kernel_size = 3, padding = 1), + nn.Conv3d(n1, n1, kernel_size=3, padding=1), nn.BatchNorm3d(n1), - nn.ReLU(inplace = True), - nn.Conv3d(n1, n1, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n1, n1, kernel_size=3, padding=1), nn.BatchNorm3d(n1)) - + self.block21 = nn.Sequential( - nn.ConvTranspose3d(n2, n1, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(n2, n1, kernel_size=2, stride=2), nn.BatchNorm3d(n1), - nn.ReLU(inplace = True), - nn.Conv3d(n1, n1, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n1, n1, kernel_size=3, padding=1), nn.BatchNorm3d(n1)) - + self.block12 = nn.Sequential( - nn.Conv3d(n1, n2, kernel_size = 3, stride = 2, padding = 1), + nn.Conv3d(n1, n2, kernel_size=3, stride=2, padding=1), nn.BatchNorm3d(n2), - nn.ReLU(inplace = True), - nn.Conv3d(n2, n2, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n2, n2, kernel_size=3, padding=1), nn.BatchNorm3d(n2)) - + self.block22 = nn.Sequential( - nn.Conv3d(n2, n2, kernel_size = 3, padding = 1), + nn.Conv3d(n2, n2, kernel_size=3, padding=1), nn.BatchNorm3d(n2), - nn.ReLU(inplace = True), - nn.Conv3d(n2, n2, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n2, n2, kernel_size=3, padding=1), nn.BatchNorm3d(n2)) - + self.block32 = nn.Sequential( - nn.ConvTranspose3d(n3, n2, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(n3, n2, kernel_size=2, stride=2), nn.BatchNorm3d(n2), - nn.ReLU(inplace = True), - nn.Conv3d(n2, n2, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n2, n2, kernel_size=3, padding=1), nn.BatchNorm3d(n2)) - + self.block23 = nn.Sequential( - nn.Conv3d(n2, n3, kernel_size = 3, stride = 2, padding = 1), + nn.Conv3d(n2, n3, kernel_size=3, stride=2, padding=1), nn.BatchNorm3d(n3), - nn.ReLU(inplace = True), - nn.Conv3d(n3, n3, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n3, n3, kernel_size=3, padding=1), nn.BatchNorm3d(n3)) self.block33 = nn.Sequential( - nn.Conv3d(n3, n3, kernel_size = 3, padding = 1), + nn.Conv3d(n3, n3, kernel_size=3, padding=1), nn.BatchNorm3d(n3), - nn.ReLU(inplace = True), - nn.Conv3d(n3, n3, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(n3, n3, kernel_size=3, padding=1), nn.BatchNorm3d(n3)) - self.relu = nn.ReLU(inplace = True) + self.relu = nn.ReLU(inplace=True) self.p = p self.integrate = integrate @@ -146,25 +148,27 @@ def forward(self, x0, x1, x2, x3): return x0, self.relu(out1), self.relu(out2), self.relu(out3) + def hard_mining(neg_output, neg_labels, num_hard): _, idcs = torch.topk(neg_output, min(num_hard, len(neg_output))) neg_output = torch.index_select(neg_output, 0, idcs) neg_labels = torch.index_select(neg_labels, 0, idcs) return neg_output, neg_labels + class Loss(nn.Module): - def __init__(self, num_hard = 0): + def __init__(self, num_hard=0): super(Loss, self).__init__() self.sigmoid = nn.Sigmoid() self.classify_loss = nn.BCELoss() self.regress_loss = nn.SmoothL1Loss() self.num_hard = num_hard - def forward(self, output, labels, train = True): + def forward(self, output, labels, train=True): batch_size = labels.size(0) output = output.view(-1, 5) labels = labels.view(-1, 5) - + pos_idcs = labels[:, 0] > 0.5 pos_idcs = pos_idcs.unsqueeze(1).expand(pos_idcs.size(0), 5) pos_output = output[pos_idcs].view(-1, 5) @@ -173,15 +177,15 @@ def forward(self, output, labels, train = True): neg_idcs = labels[:, 0] < -0.5 neg_output = output[:, 0][neg_idcs] neg_labels = labels[:, 0][neg_idcs] - + if self.num_hard > 0 and train: neg_output, neg_labels = hard_mining(neg_output, neg_labels, self.num_hard * batch_size) neg_prob = self.sigmoid(neg_output) - #classify_loss = self.classify_loss( - # torch.cat((pos_prob, neg_prob), 0), - # torch.cat((pos_labels[:, 0], neg_labels + 1), 0)) - if len(pos_output)>0: + # classify_loss = self.classify_loss( + # torch.cat((pos_prob, neg_prob), 0), + # torch.cat((pos_labels[:, 0], neg_labels + 1), 0)) + if len(pos_output) > 0: pos_prob = self.sigmoid(pos_output[:, 0]) pz, ph, pw, pd = pos_output[:, 1], pos_output[:, 2], pos_output[:, 3], pos_output[:, 4] lz, lh, lw, ld = pos_labels[:, 1], pos_labels[:, 2], pos_labels[:, 3], pos_labels[:, 4] @@ -193,18 +197,18 @@ def forward(self, output, labels, train = True): self.regress_loss(pd, ld)] regress_losses_data = [l.data[0] for l in regress_losses] classify_loss = 0.5 * self.classify_loss( - pos_prob, pos_labels[:, 0]) + 0.5 * self.classify_loss( - neg_prob, neg_labels + 1) + pos_prob, pos_labels[:, 0]) + 0.5 * self.classify_loss( + neg_prob, neg_labels + 1) pos_correct = (pos_prob.data >= 0.5).sum() pos_total = len(pos_prob) else: - regress_losses = [0,0,0,0] - classify_loss = 0.5 * self.classify_loss( - neg_prob, neg_labels + 1) + regress_losses = [0, 0, 0, 0] + classify_loss = 0.5 * self.classify_loss( + neg_prob, neg_labels + 1) pos_correct = 0 pos_total = 0 - regress_losses_data = [0,0,0,0] + regress_losses_data = [0, 0, 0, 0] classify_loss_data = classify_loss.data[0] loss = classify_loss @@ -216,12 +220,13 @@ def forward(self, output, labels, train = True): return [loss, classify_loss_data] + regress_losses_data + [pos_correct, pos_total, neg_correct, neg_total] + class GetPBB(object): def __init__(self, config): self.stride = config['stride'] self.anchors = np.asarray(config['anchors']) - def __call__(self, output,thresh = -3, ismask=False): + def __call__(self, output, thresh=-3, ismask=False): stride = self.stride anchors = self.anchors output = np.copy(output) @@ -230,29 +235,31 @@ def __call__(self, output,thresh = -3, ismask=False): oz = np.arange(offset, offset + stride * (output_size[0] - 1) + 1, stride) oh = np.arange(offset, offset + stride * (output_size[1] - 1) + 1, stride) ow = np.arange(offset, offset + stride * (output_size[2] - 1) + 1, stride) - + output[:, :, :, :, 1] = oz.reshape((-1, 1, 1, 1)) + output[:, :, :, :, 1] * anchors.reshape((1, 1, 1, -1)) output[:, :, :, :, 2] = oh.reshape((1, -1, 1, 1)) + output[:, :, :, :, 2] * anchors.reshape((1, 1, 1, -1)) output[:, :, :, :, 3] = ow.reshape((1, 1, -1, 1)) + output[:, :, :, :, 3] * anchors.reshape((1, 1, 1, -1)) output[:, :, :, :, 4] = np.exp(output[:, :, :, :, 4]) * anchors.reshape((1, 1, 1, -1)) mask = output[..., 0] > thresh - xx,yy,zz,aa = np.where(mask) - - output = output[xx,yy,zz,aa] + xx, yy, zz, aa = np.where(mask) + + output = output[xx, yy, zz, aa] if ismask: - return output,[xx,yy,zz,aa] + return output, [xx, yy, zz, aa] else: return output - #output = output[output[:, 0] >= self.conf_th] - #bboxes = nms(output, self.nms_th) + # output = output[output[:, 0] >= self.conf_th] + # bboxes = nms(output, self.nms_th) + + def nms(output, nms_th): if len(output) == 0: return output output = output[np.argsort(-output[:, 0])] bboxes = [output[0]] - + for i in np.arange(1, len(output)): bbox = output[i] flag = 1 @@ -262,12 +269,12 @@ def nms(output, nms_th): break if flag == 1: bboxes.append(bbox) - + bboxes = np.asarray(bboxes, np.float32) return bboxes + def iou(box0, box1): - r0 = box0[3] / 2 s0 = box0[:3] - r0 e0 = box0[:3] + r0 @@ -284,8 +291,9 @@ def iou(box0, box1): union = box0[3] * box0[3] * box0[3] + box1[3] * box1[3] * box1[3] - intersection return intersection / union + def acc(pbb, lbb, conf_th, nms_th, detect_th): - pbb = pbb[pbb[:, 0] >= conf_th] + pbb = pbb[pbb[:, 0] >= conf_th] pbb = nms(pbb, nms_th) tp = [] @@ -297,63 +305,64 @@ def acc(pbb, lbb, conf_th, nms_th, detect_th): bestscore = 0 for i, l in enumerate(lbb): score = iou(p[1:5], l) - if score>bestscore: + if score > bestscore: bestscore = score besti = i if bestscore > detect_th: flag = 1 if l_flag[besti] == 0: l_flag[besti] = 1 - tp.append(np.concatenate([p,[bestscore]],0)) + tp.append(np.concatenate([p, [bestscore]], 0)) else: - fp.append(np.concatenate([p,[bestscore]],0)) + fp.append(np.concatenate([p, [bestscore]], 0)) if flag == 0: - fp.append(np.concatenate([p,[bestscore]],0)) - for i,l in enumerate(lbb): - if l_flag[i]==0: + fp.append(np.concatenate([p, [bestscore]], 0)) + for i, l in enumerate(lbb): + if l_flag[i] == 0: score = [] for p in pbb: - score.append(iou(p[1:5],l)) - if len(score)!=0: + score.append(iou(p[1:5], l)) + if len(score) != 0: bestscore = np.max(score) else: bestscore = 0 - if bestscore0: - fn = np.concatenate([fn,tp[fn_i,:5]]) + if len(fn_i) > 0: + fn = np.concatenate([fn, tp[fn_i, :5]]) else: fn = fn - if len(tp_in_topk)>0: + if len(tp_in_topk) > 0: tp = tp[tp_in_topk] else: tp = [] - if len(fp_in_topk)>0: + if len(fp_in_topk) > 0: fp = newallp[fp_in_topk] else: fp = [] - return tp, fp , fn + return tp, fp, fn diff --git a/training/detector/main.py b/training/detector/main.py index 81a31c1..3a1e825 100644 --- a/training/detector/main.py +++ b/training/detector/main.py @@ -1,12 +1,13 @@ import argparse -import os +import shutil +import sys import time -import numpy as np -import data from importlib import import_module -import shutil + +import data + from utils import * -import sys + sys.path.append('../') from split_combine import SplitComb @@ -14,12 +15,9 @@ from torch.nn import DataParallel from torch.backends import cudnn from torch.utils.data import DataLoader -from torch import optim from torch.autograd import Variable from config_training import config as config_training -from layers import acc - parser = argparse.ArgumentParser(description='PyTorch DataBowl3 Detector') parser.add_argument('--model', '-m', metavar='MODEL', default='base', help='model') @@ -52,11 +50,11 @@ parser.add_argument('--n_test', default=8, type=int, metavar='N', help='number of gpu for test') + def main(): global args args = parser.parse_args() - - + torch.manual_seed(0) torch.cuda.set_device(0) @@ -64,7 +62,7 @@ def main(): config, net, loss, get_pbb = model.get_model() start_epoch = args.start_epoch save_dir = args.save_dir - + if args.resume: checkpoint = torch.load(args.resume) if start_epoch == 0: @@ -72,7 +70,7 @@ def main(): if not save_dir: save_dir = checkpoint['save_dir'] else: - save_dir = os.path.join('results',save_dir) + save_dir = os.path.join('results', save_dir) net.load_state_dict(checkpoint['state_dict']) else: if start_epoch == 0: @@ -81,16 +79,16 @@ def main(): exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime()) save_dir = os.path.join('results', args.model + '-' + exp_id) else: - save_dir = os.path.join('results',save_dir) - + save_dir = os.path.join('results', save_dir) + if not os.path.exists(save_dir): os.makedirs(save_dir) - logfile = os.path.join(save_dir,'log') - if args.test!=1: + logfile = os.path.join(save_dir, 'log') + if args.test != 1: sys.stdout = Logger(logfile) pyfiles = [f for f in os.listdir('./') if f.endswith('.py')] for f in pyfiles: - shutil.copy(f,os.path.join(save_dir,f)) + shutil.copy(f, os.path.join(save_dir, f)) n_gpu = setgpu(args.gpu) args.n_gpu = n_gpu net = net.cuda() @@ -98,12 +96,12 @@ def main(): cudnn.benchmark = True net = DataParallel(net) datadir = config_training['preprocess_result_path'] - + if args.test == 1: margin = 32 sidelen = 144 - split_comber = SplitComb(sidelen,config['max_stride'],config['stride'],margin,config['pad_value']) + split_comber = SplitComb(sidelen, config['max_stride'], config['stride'], margin, config['pad_value']) dataset = data.DataBowl3Detector( datadir, 'full.npy', @@ -112,47 +110,47 @@ def main(): split_comber=split_comber) test_loader = DataLoader( dataset, - batch_size = 1, - shuffle = False, - num_workers = args.workers, - collate_fn = data.collate, + batch_size=1, + shuffle=False, + num_workers=args.workers, + collate_fn=data.collate, pin_memory=False) - - test(test_loader, net, get_pbb, save_dir,config) + + test(test_loader, net, get_pbb, save_dir, config) return - #net = DataParallel(net) - + # net = DataParallel(net) + dataset = data.DataBowl3Detector( datadir, 'kaggleluna_full.npy', config, - phase = 'train') + phase='train') train_loader = DataLoader( dataset, - batch_size = args.batch_size, - shuffle = True, - num_workers = args.workers, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, pin_memory=True) dataset = data.DataBowl3Detector( datadir, 'valsplit.npy', config, - phase = 'val') + phase='val') val_loader = DataLoader( dataset, - batch_size = args.batch_size, - shuffle = False, - num_workers = args.workers, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, pin_memory=True) optimizer = torch.optim.SGD( net.parameters(), args.lr, - momentum = 0.9, - weight_decay = args.weight_decay) - + momentum=0.9, + weight_decay=args.weight_decay) + def get_lr(epoch): if epoch <= args.epochs * 0.5: lr = args.lr @@ -161,15 +159,15 @@ def get_lr(epoch): else: lr = 0.01 * args.lr return lr - for epoch in range(start_epoch, args.epochs + 1): train(train_loader, net, loss, epoch, optimizer, get_lr, args.save_freq, save_dir) validate(val_loader, net, loss) + def train(data_loader, net, loss, epoch, optimizer, get_lr, save_freq, save_dir): start_time = time.time() - + net.train() lr = get_lr(epoch) for param_group in optimizer.param_groups: @@ -177,9 +175,9 @@ def train(data_loader, net, loss, epoch, optimizer, get_lr, save_freq, save_dir) metrics = [] for i, (data, target, coord) in enumerate(data_loader): - data = Variable(data.cuda(async = True)) - target = Variable(target.cuda(async = True)) - coord = Variable(coord.cuda(async = True)) + data = Variable(data.cuda(async=True)) + target = Variable(target.cuda(async=True)) + coord = Variable(coord.cuda(async=True)) output = net(data, coord) loss_output = loss(output, target) @@ -190,11 +188,11 @@ def train(data_loader, net, loss, epoch, optimizer, get_lr, save_freq, save_dir) loss_output[0] = loss_output[0].data[0] metrics.append(loss_output) - if epoch % args.save_freq == 0: + if epoch % args.save_freq == 0: state_dict = net.module.state_dict() for key in state_dict.keys(): state_dict[key] = state_dict[key].cpu() - + torch.save({ 'epoch': epoch, 'save_dir': save_dir, @@ -221,22 +219,23 @@ def train(data_loader, net, loss, epoch, optimizer, get_lr, save_freq, save_dir) np.mean(metrics[:, 5]))) print + def validate(data_loader, net, loss): start_time = time.time() - + net.eval() metrics = [] for i, (data, target, coord) in enumerate(data_loader): - data = Variable(data.cuda(async = True), volatile = True) - target = Variable(target.cuda(async = True), volatile = True) - coord = Variable(coord.cuda(async = True), volatile = True) + data = Variable(data.cuda(async=True), volatile=True) + target = Variable(target.cuda(async=True), volatile=True) + coord = Variable(coord.cuda(async=True), volatile=True) output = net(data, coord) - loss_output = loss(output, target, train = False) + loss_output = loss(output, target, train=False) loss_output[0] = loss_output[0].data[0] - metrics.append(loss_output) + metrics.append(loss_output) end_time = time.time() metrics = np.asarray(metrics, np.float32) @@ -256,9 +255,10 @@ def validate(data_loader, net, loss): print print + def test(data_loader, net, get_pbb, save_dir, config): start_time = time.time() - save_dir = os.path.join(save_dir,'bbox') + save_dir = os.path.join(save_dir, 'bbox') if not os.path.exists(save_dir): os.makedirs(save_dir) print(save_dir) @@ -279,71 +279,72 @@ def test(data_loader, net, get_pbb, save_dir, config): isfeat = True n_per_run = args.n_test print(data.size()) - splitlist = range(0,len(data)+1,n_per_run) - if splitlist[-1]!=len(data): + splitlist = range(0, len(data) + 1, n_per_run) + if splitlist[-1] != len(data): splitlist.append(len(data)) outputlist = [] featurelist = [] - for i in range(len(splitlist)-1): - input = Variable(data[splitlist[i]:splitlist[i+1]], volatile = True).cuda() - inputcoord = Variable(coord[splitlist[i]:splitlist[i+1]], volatile = True).cuda() + for i in range(len(splitlist) - 1): + input = Variable(data[splitlist[i]:splitlist[i + 1]], volatile=True).cuda() + inputcoord = Variable(coord[splitlist[i]:splitlist[i + 1]], volatile=True).cuda() if isfeat: - output,feature = net(input,inputcoord) + output, feature = net(input, inputcoord) featurelist.append(feature.data.cpu().numpy()) else: - output = net(input,inputcoord) + output = net(input, inputcoord) outputlist.append(output.data.cpu().numpy()) - output = np.concatenate(outputlist,0) - output = split_comber.combine(output,nzhw=nzhw) + output = np.concatenate(outputlist, 0) + output = split_comber.combine(output, nzhw=nzhw) if isfeat: - feature = np.concatenate(featurelist,0).transpose([0,2,3,4,1])[:,:,:,:,:,np.newaxis] - feature = split_comber.combine(feature,sidelen)[...,0] + feature = np.concatenate(featurelist, 0).transpose([0, 2, 3, 4, 1])[:, :, :, :, :, np.newaxis] + feature = split_comber.combine(feature, sidelen)[..., 0] thresh = -3 - pbb,mask = get_pbb(output,thresh,ismask=True) + pbb, mask = get_pbb(output, thresh, ismask=True) if isfeat: - feature_selected = feature[mask[0],mask[1],mask[2]] - np.save(os.path.join(save_dir, name+'_feature.npy'), feature_selected) - #tp,fp,fn,_ = acc(pbb,lbb,0,0.1,0.1) - #print([len(tp),len(fp),len(fn)]) - print([i_name,name]) + feature_selected = feature[mask[0], mask[1], mask[2]] + np.save(os.path.join(save_dir, name + '_feature.npy'), feature_selected) + # tp,fp,fn,_ = acc(pbb,lbb,0,0.1,0.1) + # print([len(tp),len(fp),len(fn)]) + print([i_name, name]) e = time.time() - np.save(os.path.join(save_dir, name+'_pbb.npy'), pbb) - np.save(os.path.join(save_dir, name+'_lbb.npy'), lbb) + np.save(os.path.join(save_dir, name + '_pbb.npy'), pbb) + np.save(os.path.join(save_dir, name + '_lbb.npy'), lbb) np.save(os.path.join(save_dir, 'namelist.npy'), namelist) end_time = time.time() - print('elapsed time is %3.2f seconds' % (end_time - start_time)) print print -def singletest(data,net,config,splitfun,combinefun,n_per_run,margin = 64,isfeat=False): + +def singletest(data, net, config, splitfun, combinefun, n_per_run, margin=64, isfeat=False): z, h, w = data.size(2), data.size(3), data.size(4) print(data.size()) - data = splitfun(data,config['max_stride'],margin) - data = Variable(data.cuda(async = True), volatile = True,requires_grad=False) - splitlist = range(0,args.split+1,n_per_run) + data = splitfun(data, config['max_stride'], margin) + data = Variable(data.cuda(async=True), volatile=True, requires_grad=False) + splitlist = range(0, args.split + 1, n_per_run) outputlist = [] featurelist = [] - for i in range(len(splitlist)-1): + for i in range(len(splitlist) - 1): if isfeat: - output,feature = net(data[splitlist[i]:splitlist[i+1]]) + output, feature = net(data[splitlist[i]:splitlist[i + 1]]) featurelist.append(feature) else: - output = net(data[splitlist[i]:splitlist[i+1]]) + output = net(data[splitlist[i]:splitlist[i + 1]]) output = output.data.cpu().numpy() outputlist.append(output) - - output = np.concatenate(outputlist,0) + + output = np.concatenate(outputlist, 0) output = combinefun(output, z / config['stride'], h / config['stride'], w / config['stride']) if isfeat: - feature = np.concatenate(featurelist,0).transpose([0,2,3,4,1]) + feature = np.concatenate(featurelist, 0).transpose([0, 2, 3, 4, 1]) feature = combinefun(feature, z / config['stride'], h / config['stride'], w / config['stride']) - return output,feature + return output, feature else: return output + + if __name__ == '__main__': main() - diff --git a/training/detector/res18.py b/training/detector/res18.py index 5e133f5..0438618 100644 --- a/training/detector/res18.py +++ b/training/detector/res18.py @@ -1,9 +1,10 @@ import torch from torch import nn + from layers import * config = {} -config['anchors'] = [ 10.0, 30.0, 60.] +config['anchors'] = [10.0, 30.0, 60.] config['chanel'] = 1 config['crop_size'] = [128, 128, 128] config['stride'] = 4 @@ -15,16 +16,18 @@ config['num_hard'] = 2 config['bound_size'] = 12 config['reso'] = 1 -config['sizelim'] = 6. #mm +config['sizelim'] = 6. # mm config['sizelim2'] = 30 config['sizelim3'] = 40 config['aug_scale'] = True config['r_rand_crop'] = 0.3 config['pad_value'] = 170 -config['augtype'] = {'flip':True,'swap':False,'scale':True,'rotate':False} -config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38','990fbe3f0a1b53878669967b9afd1441','adc3bbc63d40f8761c59be10f1e504c3'] +config['augtype'] = {'flip': True, 'swap': False, 'scale': True, 'rotate': False} +config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38', '990fbe3f0a1b53878669967b9afd1441', + 'adc3bbc63d40f8761c59be10f1e504c3'] + -#config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38','d92998a73d4654a442e6d6ba15bbb827','990fbe3f0a1b53878669967b9afd1441','820245d8b211808bd18e78ff5be16fdb','adc3bbc63d40f8761c59be10f1e504c3', +# config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38','d92998a73d4654a442e6d6ba15bbb827','990fbe3f0a1b53878669967b9afd1441','820245d8b211808bd18e78ff5be16fdb','adc3bbc63d40f8761c59be10f1e504c3', # '417','077','188','876','057','087','130','468'] class Net(nn.Module): @@ -33,92 +36,92 @@ def __init__(self): # The first few layers consumes the most memory, so use simple convolution to save memory. # Call these layers preBlock, i.e., before the residual blocks of later layers. self.preBlock = nn.Sequential( - nn.Conv3d(1, 24, kernel_size = 3, padding = 1), + nn.Conv3d(1, 24, kernel_size=3, padding=1), nn.BatchNorm3d(24), - nn.ReLU(inplace = True), - nn.Conv3d(24, 24, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(24, 24, kernel_size=3, padding=1), nn.BatchNorm3d(24), - nn.ReLU(inplace = True)) - + nn.ReLU(inplace=True)) + # 3 poolings, each pooling downsamples the feature map by a factor 2. # 3 groups of blocks. The first block of each group has one pooling. - num_blocks_forw = [2,2,3,3] - num_blocks_back = [3,3] - self.featureNum_forw = [24,32,64,64,64] - self.featureNum_back = [128,64,64] + num_blocks_forw = [2, 2, 3, 3] + num_blocks_back = [3, 3] + self.featureNum_forw = [24, 32, 64, 64, 64] + self.featureNum_back = [128, 64, 64] for i in range(len(num_blocks_forw)): blocks = [] for j in range(num_blocks_forw[i]): if j == 0: - blocks.append(PostRes(self.featureNum_forw[i], self.featureNum_forw[i+1])) + blocks.append(PostRes(self.featureNum_forw[i], self.featureNum_forw[i + 1])) else: - blocks.append(PostRes(self.featureNum_forw[i+1], self.featureNum_forw[i+1])) + blocks.append(PostRes(self.featureNum_forw[i + 1], self.featureNum_forw[i + 1])) setattr(self, 'forw' + str(i + 1), nn.Sequential(*blocks)) - for i in range(len(num_blocks_back)): blocks = [] for j in range(num_blocks_back[i]): if j == 0: - if i==0: + if i == 0: addition = 3 else: addition = 0 - blocks.append(PostRes(self.featureNum_back[i+1]+self.featureNum_forw[i+2]+addition, self.featureNum_back[i])) + blocks.append(PostRes(self.featureNum_back[i + 1] + self.featureNum_forw[i + 2] + addition, + self.featureNum_back[i])) else: blocks.append(PostRes(self.featureNum_back[i], self.featureNum_back[i])) setattr(self, 'back' + str(i + 2), nn.Sequential(*blocks)) - self.maxpool1 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.maxpool2 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.maxpool3 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.maxpool4 = nn.MaxPool3d(kernel_size=2,stride=2,return_indices =True) - self.unmaxpool1 = nn.MaxUnpool3d(kernel_size=2,stride=2) - self.unmaxpool2 = nn.MaxUnpool3d(kernel_size=2,stride=2) + self.maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.unmaxpool1 = nn.MaxUnpool3d(kernel_size=2, stride=2) + self.unmaxpool2 = nn.MaxUnpool3d(kernel_size=2, stride=2) self.path1 = nn.Sequential( - nn.ConvTranspose3d(64, 64, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), nn.BatchNorm3d(64), - nn.ReLU(inplace = True)) + nn.ReLU(inplace=True)) self.path2 = nn.Sequential( - nn.ConvTranspose3d(64, 64, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), nn.BatchNorm3d(64), - nn.ReLU(inplace = True)) - self.drop = nn.Dropout3d(p = 0.5, inplace = False) - self.output = nn.Sequential(nn.Conv3d(self.featureNum_back[0], 64, kernel_size = 1), + nn.ReLU(inplace=True)) + self.drop = nn.Dropout3d(p=0.5, inplace=False) + self.output = nn.Sequential(nn.Conv3d(self.featureNum_back[0], 64, kernel_size=1), nn.ReLU(), - #nn.Dropout3d(p = 0.3), - nn.Conv3d(64, 5 * len(config['anchors']), kernel_size = 1)) + # nn.Dropout3d(p = 0.3), + nn.Conv3d(64, 5 * len(config['anchors']), kernel_size=1)) def forward(self, x, coord): - out = self.preBlock(x)#16 - out_pool,indices0 = self.maxpool1(out) - out1 = self.forw1(out_pool)#32 - out1_pool,indices1 = self.maxpool2(out1) - out2 = self.forw2(out1_pool)#64 - #out2 = self.drop(out2) - out2_pool,indices2 = self.maxpool3(out2) - out3 = self.forw3(out2_pool)#96 - out3_pool,indices3 = self.maxpool4(out3) - out4 = self.forw4(out3_pool)#96 - #out4 = self.drop(out4) - + out = self.preBlock(x) # 16 + out_pool, indices0 = self.maxpool1(out) + out1 = self.forw1(out_pool) # 32 + out1_pool, indices1 = self.maxpool2(out1) + out2 = self.forw2(out1_pool) # 64 + # out2 = self.drop(out2) + out2_pool, indices2 = self.maxpool3(out2) + out3 = self.forw3(out2_pool) # 96 + out3_pool, indices3 = self.maxpool4(out3) + out4 = self.forw4(out3_pool) # 96 + # out4 = self.drop(out4) + rev3 = self.path1(out4) - comb3 = self.back3(torch.cat((rev3, out3), 1))#96+96 - #comb3 = self.drop(comb3) + comb3 = self.back3(torch.cat((rev3, out3), 1)) # 96+96 + # comb3 = self.drop(comb3) rev2 = self.path2(comb3) - - comb2 = self.back2(torch.cat((rev2, out2,coord), 1))#64+64 + + comb2 = self.back2(torch.cat((rev2, out2, coord), 1)) # 64+64 comb2 = self.drop(comb2) out = self.output(comb2) size = out.size() out = out.view(out.size(0), out.size(1), -1) - #out = out.transpose(1, 4).transpose(1, 2).transpose(2, 3).contiguous() + # out = out.transpose(1, 4).transpose(1, 2).transpose(2, 3).contiguous() out = out.transpose(1, 2).contiguous().view(size[0], size[2], size[3], size[4], len(config['anchors']), 5) - #out = out.view(-1, 5) + # out = out.view(-1, 5) return out - + def get_model(): net = Net() loss = Loss(config['num_hard']) diff --git a/training/detector/res_pool.py b/training/detector/res_pool.py index 42001d6..7ce593b 100644 --- a/training/detector/res_pool.py +++ b/training/detector/res_pool.py @@ -1,22 +1,25 @@ import torch from torch import nn + from layers import * config = {} -config['anchors'] = [ 10.0, 25.0, 40.0] +config['anchors'] = [10.0, 25.0, 40.0] config['chanel'] = 2 config['crop_size'] = [64, 128, 128] -config['stride'] = [2,4,4] +config['stride'] = [2, 4, 4] config['max_stride'] = 16 config['num_neg'] = 10 config['th_neg'] = 0.2 config['th_pos'] = 0.5 config['num_hard'] = 1 config['bound_size'] = 12 -config['reso'] = [1.5,0.75,0.75] -config['sizelim'] = 6. #mm -config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38','d92998a73d4654a442e6d6ba15bbb827','990fbe3f0a1b53878669967b9afd1441','820245d8b211808bd18e78ff5be16fdb', - '417','077','188','876','057','087','130','468'] +config['reso'] = [1.5, 0.75, 0.75] +config['sizelim'] = 6. # mm +config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38', 'd92998a73d4654a442e6d6ba15bbb827', + '990fbe3f0a1b53878669967b9afd1441', '820245d8b211808bd18e78ff5be16fdb', + '417', '077', '188', '876', '057', '087', '130', '468'] + class Net(nn.Module): def __init__(self): @@ -24,24 +27,24 @@ def __init__(self): # The first few layers consumes the most memory, so use simple convolution to save memory. # Call these layers preBlock, i.e., before the residual blocks of later layers. self.preBlock = nn.Sequential( - nn.Conv3d(2, 16, kernel_size = 3, padding = 1), + nn.Conv3d(2, 16, kernel_size=3, padding=1), nn.BatchNorm3d(16), - nn.ReLU(inplace = True), - nn.Conv3d(16, 16, kernel_size = 3, padding = 1), + nn.ReLU(inplace=True), + nn.Conv3d(16, 16, kernel_size=3, padding=1), nn.BatchNorm3d(16), - nn.ReLU(inplace = True)) - + nn.ReLU(inplace=True)) + # 3 poolings, each pooling downsamples the feature map by a factor 2. # 3 groups of blocks. The first block of each group has one pooling. - num_blocks = [6,6,6,6] - n_in = [16, 32, 64,96] - n_out = [32, 64, 96,96] + num_blocks = [6, 6, 6, 6] + n_in = [16, 32, 64, 96] + n_out = [32, 64, 96, 96] for i in range(len(num_blocks)): blocks = [] for j in range(num_blocks[i]): if j == 0: - if i ==0: - blocks.append(nn.MaxPool3d(kernel_size=[1,2,2])) + if i == 0: + blocks.append(nn.MaxPool3d(kernel_size=[1, 2, 2])) blocks.append(PostRes(n_in[i], n_out[i])) else: blocks.append(nn.MaxPool3d(kernel_size=2)) @@ -49,35 +52,35 @@ def __init__(self): else: blocks.append(PostRes(n_out[i], n_out[i])) setattr(self, 'group' + str(i + 1), nn.Sequential(*blocks)) - + self.path1 = nn.Sequential( - nn.Conv3d(64, 32, kernel_size = 3, padding = 1), + nn.Conv3d(64, 32, kernel_size=3, padding=1), nn.BatchNorm3d(32), - nn.ReLU(inplace = True)) + nn.ReLU(inplace=True)) self.path2 = nn.Sequential( - nn.ConvTranspose3d(96, 32, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(96, 32, kernel_size=2, stride=2), nn.BatchNorm3d(32), - nn.ReLU(inplace = True)) + nn.ReLU(inplace=True)) self.path3 = nn.Sequential( - nn.ConvTranspose3d(96, 32, kernel_size = 2, stride = 2), + nn.ConvTranspose3d(96, 32, kernel_size=2, stride=2), nn.BatchNorm3d(32), - nn.ReLU(inplace = True), - nn.ConvTranspose3d(32, 32, kernel_size = 2, stride = 2), + nn.ReLU(inplace=True), + nn.ConvTranspose3d(32, 32, kernel_size=2, stride=2), nn.BatchNorm3d(32), - nn.ReLU(inplace = True)) + nn.ReLU(inplace=True)) self.combine = nn.Sequential( - nn.Conv3d(96, 128, kernel_size = 1), + nn.Conv3d(96, 128, kernel_size=1), nn.BatchNorm3d(128), - nn.ReLU(inplace = True)) + nn.ReLU(inplace=True)) - self.drop = nn.Dropout3d(p = 0.5, inplace = False) - self.output = nn.Conv3d(128, 5 * len(config['anchors']), kernel_size = 1) + self.drop = nn.Dropout3d(p=0.5, inplace=False) + self.output = nn.Conv3d(128, 5 * len(config['anchors']), kernel_size=1) def forward(self, x): - x = x.view(x.size(0), 2,x.size(2), x.size(3), x.size(4)) + x = x.view(x.size(0), 2, x.size(2), x.size(3), x.size(4)) out = self.preBlock(x) out1 = self.group1(out) @@ -95,11 +98,12 @@ def forward(self, x): out = self.output(out) size = out.size() out = out.view(out.size(0), out.size(1), -1) - #out = out.transpose(1, 4).transpose(1, 2).transpose(2, 3).contiguous() + # out = out.transpose(1, 4).transpose(1, 2).transpose(2, 3).contiguous() out = out.transpose(1, 2).contiguous().view(size[0], size[2], size[3], size[4], len(config['anchors']), 5) - #out = out.view(-1, 5) + # out = out.view(-1, 5) return out + def get_model(): net = Net() loss = Loss(config['num_hard']) diff --git a/training/detector/split_combine.py b/training/detector/split_combine.py index d4b5b49..2e8e3aa 100644 --- a/training/detector/split_combine.py +++ b/training/detector/split_combine.py @@ -1,24 +1,25 @@ -import torch import numpy as np + + class SplitComb(): - def __init__(self,side_len,max_stride,stride,margin,pad_value): + def __init__(self, side_len, max_stride, stride, margin, pad_value): self.side_len = side_len self.max_stride = max_stride self.stride = stride self.margin = margin self.pad_value = pad_value - - def split(self, data, side_len = None, max_stride = None, margin = None): - if side_len==None: + + def split(self, data, side_len=None, max_stride=None, margin=None): + if side_len == None: side_len = self.side_len if max_stride == None: max_stride = self.max_stride if margin == None: margin = self.margin - - assert(side_len > margin) - assert(side_len % max_stride == 0) - assert(margin % max_stride == 0) + + assert (side_len > margin) + assert (side_len % max_stride == 0) + assert (margin % max_stride == 0) splits = [] _, z, h, w = data.shape @@ -26,14 +27,14 @@ def split(self, data, side_len = None, max_stride = None, margin = None): nz = int(np.ceil(float(z) / side_len)) nh = int(np.ceil(float(h) / side_len)) nw = int(np.ceil(float(w) / side_len)) - - nzhw = [nz,nh,nw] + + nzhw = [nz, nh, nw] self.nzhw = nzhw - - pad = [ [0, 0], - [margin, nz * side_len - z + margin], - [margin, nh * side_len - h + margin], - [margin, nw * side_len - w + margin]] + + pad = [[0, 0], + [margin, nz * side_len - z + margin], + [margin, nh * side_len - h + margin], + [margin, nw * side_len - w + margin]] data = np.pad(data, pad, 'edge') for iz in range(nz): @@ -50,24 +51,24 @@ def split(self, data, side_len = None, max_stride = None, margin = None): splits.append(split) splits = np.concatenate(splits, 0) - return splits,nzhw + return splits, nzhw + + def combine(self, output, nzhw=None, side_len=None, stride=None, margin=None): - def combine(self, output, nzhw = None, side_len=None, stride=None, margin=None): - - if side_len==None: + if side_len == None: side_len = self.side_len if stride == None: stride = self.stride if margin == None: margin = self.margin - if nzhw==None: + if nzhw == None: nz = self.nz nh = self.nh nw = self.nw else: - nz,nh,nw = nzhw - assert(side_len % stride == 0) - assert(margin % stride == 0) + nz, nh, nw = nzhw + assert (side_len % stride == 0) + assert (margin % stride == 0) side_len /= stride margin /= stride @@ -97,4 +98,4 @@ def combine(self, output, nzhw = None, side_len=None, stride=None, margin=None): output[sz:ez, sh:eh, sw:ew] = split idx += 1 - return output + return output diff --git a/training/detector/utils.py b/training/detector/utils.py index 8743c82..d0140c8 100644 --- a/training/detector/utils.py +++ b/training/detector/utils.py @@ -1,80 +1,87 @@ -import sys import os +import sys + import numpy as np import torch + + def getFreeId(): - import pynvml + import pynvml pynvml.nvmlInit() + def getFreeRatio(id): handle = pynvml.nvmlDeviceGetHandleByIndex(id) use = pynvml.nvmlDeviceGetUtilizationRates(handle) - ratio = 0.5*(float(use.gpu+float(use.memory))) + ratio = 0.5 * (float(use.gpu + float(use.memory))) return ratio deviceCount = pynvml.nvmlDeviceGetCount() available = [] for i in range(deviceCount): - if getFreeRatio(i)<70: + if getFreeRatio(i) < 70: available.append(i) gpus = '' for g in available: - gpus = gpus+str(g)+',' + gpus = gpus + str(g) + ',' gpus = gpus[:-1] return gpus + def setgpu(gpuinput): freeids = getFreeId() - if gpuinput=='all': + if gpuinput == 'all': gpus = freeids else: gpus = gpuinput if any([g not in freeids for g in gpus.split(',')]): - raise ValueError('gpu'+g+'is being used') - print('using gpu '+gpus) - os.environ['CUDA_VISIBLE_DEVICES']=gpus + raise ValueError('gpu' + g + 'is being used') + print('using gpu ' + gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = gpus return len(gpus.split(',')) + class Logger(object): - def __init__(self,logfile): + def __init__(self, logfile): self.terminal = sys.stdout self.log = open(logfile, "a") def write(self, message): self.terminal.write(message) - self.log.write(message) + self.log.write(message) def flush(self): - #this flush method is needed for python 3 compatibility. - #this handles the flush command by doing nothing. - #you might want to specify some extra behavior here. - pass + # this flush method is needed for python 3 compatibility. + # this handles the flush command by doing nothing. + # you might want to specify some extra behavior here. + pass -def split4(data, max_stride, margin): +def split4(data, max_stride, margin): splits = [] data = torch.Tensor.numpy(data) - _,c, z, h, w = data.shape - - w_width = np.ceil(float(w / 2 + margin)/max_stride).astype('int')*max_stride - h_width = np.ceil(float(h / 2 + margin)/max_stride).astype('int')*max_stride - pad = int(np.ceil(float(z)/max_stride)*max_stride)-z - leftpad = pad/2 - pad = [[0,0],[0,0],[leftpad,pad-leftpad],[0,0],[0,0]] - data = np.pad(data,pad,'constant',constant_values=-1) + _, c, z, h, w = data.shape + + w_width = np.ceil(float(w / 2 + margin) / max_stride).astype('int') * max_stride + h_width = np.ceil(float(h / 2 + margin) / max_stride).astype('int') * max_stride + pad = int(np.ceil(float(z) / max_stride) * max_stride) - z + leftpad = pad / 2 + pad = [[0, 0], [0, 0], [leftpad, pad - leftpad], [0, 0], [0, 0]] + data = np.pad(data, pad, 'constant', constant_values=-1) data = torch.from_numpy(data) splits.append(data[:, :, :, :h_width, :w_width]) splits.append(data[:, :, :, :h_width, -w_width:]) splits.append(data[:, :, :, -h_width:, :w_width]) splits.append(data[:, :, :, -h_width:, -w_width:]) - + return torch.cat(splits, 0) + def combine4(output, h, w): splits = [] for i in range(len(output)): splits.append(output[i]) - + output = np.zeros(( splits[0].shape[0], h, @@ -101,37 +108,36 @@ def combine4(output, h, w): return output -def split8(data, max_stride, margin): + +def split8(data, max_stride, margin): splits = [] if isinstance(data, np.ndarray): c, z, h, w = data.shape else: - _,c, z, h, w = data.size() - - z_width = np.ceil(float(z / 2 + margin)/max_stride).astype('int')*max_stride - w_width = np.ceil(float(w / 2 + margin)/max_stride).astype('int')*max_stride - h_width = np.ceil(float(h / 2 + margin)/max_stride).astype('int')*max_stride - for zz in [[0,z_width],[-z_width,None]]: - for hh in [[0,h_width],[-h_width,None]]: - for ww in [[0,w_width],[-w_width,None]]: + _, c, z, h, w = data.size() + + z_width = np.ceil(float(z / 2 + margin) / max_stride).astype('int') * max_stride + w_width = np.ceil(float(w / 2 + margin) / max_stride).astype('int') * max_stride + h_width = np.ceil(float(h / 2 + margin) / max_stride).astype('int') * max_stride + for zz in [[0, z_width], [-z_width, None]]: + for hh in [[0, h_width], [-h_width, None]]: + for ww in [[0, w_width], [-w_width, None]]: if isinstance(data, np.ndarray): splits.append(data[np.newaxis, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) else: splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) - if isinstance(data, np.ndarray): return np.concatenate(splits, 0) else: return torch.cat(splits, 0) - def combine8(output, z, h, w): splits = [] for i in range(len(output)): splits.append(output[i]) - + output = np.zeros(( z, h, @@ -139,41 +145,42 @@ def combine8(output, z, h, w): splits[0].shape[3], splits[0].shape[4]), np.float32) - z_width = z / 2 h_width = h / 2 w_width = w / 2 i = 0 - for zz in [[0,z_width],[z_width-z,None]]: - for hh in [[0,h_width],[h_width-h,None]]: - for ww in [[0,w_width],[w_width-w,None]]: - output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] - i = i+1 - + for zz in [[0, z_width], [z_width - z, None]]: + for hh in [[0, h_width], [h_width - h, None]]: + for ww in [[0, w_width], [w_width - w, None]]: + output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], + :, :] + i = i + 1 + return output -def split16(data, max_stride, margin): +def split16(data, max_stride, margin): splits = [] - _,c, z, h, w = data.size() - - z_width = np.ceil(float(z / 4 + margin)/max_stride).astype('int')*max_stride - z_pos = [z*3/8-z_width/2, - z*5/8-z_width/2] - h_width = np.ceil(float(h / 2 + margin)/max_stride).astype('int')*max_stride - w_width = np.ceil(float(w / 2 + margin)/max_stride).astype('int')*max_stride - for zz in [[0,z_width],[z_pos[0],z_pos[0]+z_width],[z_pos[1],z_pos[1]+z_width],[-z_width,None]]: - for hh in [[0,h_width],[-h_width,None]]: - for ww in [[0,w_width],[-w_width,None]]: + _, c, z, h, w = data.size() + + z_width = np.ceil(float(z / 4 + margin) / max_stride).astype('int') * max_stride + z_pos = [z * 3 / 8 - z_width / 2, + z * 5 / 8 - z_width / 2] + h_width = np.ceil(float(h / 2 + margin) / max_stride).astype('int') * max_stride + w_width = np.ceil(float(w / 2 + margin) / max_stride).astype('int') * max_stride + for zz in [[0, z_width], [z_pos[0], z_pos[0] + z_width], [z_pos[1], z_pos[1] + z_width], [-z_width, None]]: + for hh in [[0, h_width], [-h_width, None]]: + for ww in [[0, w_width], [-w_width, None]]: splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) - + return torch.cat(splits, 0) + def combine16(output, z, h, w): splits = [] for i in range(len(output)): splits.append(output[i]) - + output = np.zeros(( z, h, @@ -181,45 +188,47 @@ def combine16(output, z, h, w): splits[0].shape[3], splits[0].shape[4]), np.float32) - z_width = z / 4 h_width = h / 2 w_width = w / 2 - splitzstart = splits[0].shape[0]/2-z_width/2 - z_pos = [z*3/8-z_width/2, - z*5/8-z_width/2] + splitzstart = splits[0].shape[0] / 2 - z_width / 2 + z_pos = [z * 3 / 8 - z_width / 2, + z * 5 / 8 - z_width / 2] i = 0 - for zz,zz2 in zip([[0,z_width],[z_width,z_width*2],[z_width*2,z_width*3],[z_width*3-z,None]], - [[0,z_width],[splitzstart,z_width+splitzstart],[splitzstart,z_width+splitzstart],[z_width*3-z,None]]): - for hh in [[0,h_width],[h_width-h,None]]: - for ww in [[0,w_width],[w_width-w,None]]: - output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz2[0]:zz2[1], hh[0]:hh[1], ww[0]:ww[1], :, :] - i = i+1 - + for zz, zz2 in zip([[0, z_width], [z_width, z_width * 2], [z_width * 2, z_width * 3], [z_width * 3 - z, None]], + [[0, z_width], [splitzstart, z_width + splitzstart], [splitzstart, z_width + splitzstart], + [z_width * 3 - z, None]]): + for hh in [[0, h_width], [h_width - h, None]]: + for ww in [[0, w_width], [w_width - w, None]]: + output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz2[0]:zz2[1], hh[0]:hh[1], ww[0]:ww[1], + :, :] + i = i + 1 + return output -def split32(data, max_stride, margin): + +def split32(data, max_stride, margin): splits = [] - _,c, z, h, w = data.size() - - z_width = np.ceil(float(z / 2 + margin)/max_stride).astype('int')*max_stride - w_width = np.ceil(float(w / 4 + margin)/max_stride).astype('int')*max_stride - h_width = np.ceil(float(h / 4 + margin)/max_stride).astype('int')*max_stride - - w_pos = [w*3/8-w_width/2, - w*5/8-w_width/2] - h_pos = [h*3/8-h_width/2, - h*5/8-h_width/2] - - for zz in [[0,z_width],[-z_width,None]]: - for hh in [[0,h_width],[h_pos[0],h_pos[0]+h_width],[h_pos[1],h_pos[1]+h_width],[-h_width,None]]: - for ww in [[0,w_width],[w_pos[0],w_pos[0]+w_width],[w_pos[1],w_pos[1]+w_width],[-w_width,None]]: + _, c, z, h, w = data.size() + + z_width = np.ceil(float(z / 2 + margin) / max_stride).astype('int') * max_stride + w_width = np.ceil(float(w / 4 + margin) / max_stride).astype('int') * max_stride + h_width = np.ceil(float(h / 4 + margin) / max_stride).astype('int') * max_stride + + w_pos = [w * 3 / 8 - w_width / 2, + w * 5 / 8 - w_width / 2] + h_pos = [h * 3 / 8 - h_width / 2, + h * 5 / 8 - h_width / 2] + + for zz in [[0, z_width], [-z_width, None]]: + for hh in [[0, h_width], [h_pos[0], h_pos[0] + h_width], [h_pos[1], h_pos[1] + h_width], [-h_width, None]]: + for ww in [[0, w_width], [w_pos[0], w_pos[0] + w_width], [w_pos[1], w_pos[1] + w_width], [-w_width, None]]: splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) - + return torch.cat(splits, 0) + def combine32(splits, z, h, w): - output = np.zeros(( z, h, @@ -227,56 +236,58 @@ def combine32(splits, z, h, w): splits[0].shape[3], splits[0].shape[4]), np.float32) - z_width = int(np.ceil(float(z) / 2)) h_width = int(np.ceil(float(h) / 4)) w_width = int(np.ceil(float(w) / 4)) - splithstart = splits[0].shape[1]/2-h_width/2 - splitwstart = splits[0].shape[2]/2-w_width/2 - + splithstart = splits[0].shape[1] / 2 - h_width / 2 + splitwstart = splits[0].shape[2] / 2 - w_width / 2 + i = 0 - for zz in [[0,z_width],[z_width-z,None]]: - - for hh,hh2 in zip([[0,h_width],[h_width,h_width*2],[h_width*2,h_width*3],[h_width*3-h,None]], - [[0,h_width],[splithstart,h_width+splithstart],[splithstart,h_width+splithstart],[h_width*3-h,None]]): - - for ww,ww2 in zip([[0,w_width],[w_width,w_width*2],[w_width*2,w_width*3],[w_width*3-w,None]], - [[0,w_width],[splitwstart,w_width+splitwstart],[splitwstart,w_width+splitwstart],[w_width*3-w,None]]): - - output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz[0]:zz[1], hh2[0]:hh2[1], ww2[0]:ww2[1], :, :] - i = i+1 - - return output + for zz in [[0, z_width], [z_width - z, None]]: + for hh, hh2 in zip([[0, h_width], [h_width, h_width * 2], [h_width * 2, h_width * 3], [h_width * 3 - h, None]], + [[0, h_width], [splithstart, h_width + splithstart], [splithstart, h_width + splithstart], + [h_width * 3 - h, None]]): + for ww, ww2 in zip( + [[0, w_width], [w_width, w_width * 2], [w_width * 2, w_width * 3], [w_width * 3 - w, None]], + [[0, w_width], [splitwstart, w_width + splitwstart], [splitwstart, w_width + splitwstart], + [w_width * 3 - w, None]]): + output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz[0]:zz[1], hh2[0]:hh2[1], + ww2[0]:ww2[1], :, :] + i = i + 1 -def split64(data, max_stride, margin): + return output + + +def split64(data, max_stride, margin): splits = [] - _,c, z, h, w = data.size() - - z_width = np.ceil(float(z / 4 + margin)/max_stride).astype('int')*max_stride - w_width = np.ceil(float(w / 4 + margin)/max_stride).astype('int')*max_stride - h_width = np.ceil(float(h / 4 + margin)/max_stride).astype('int')*max_stride - - z_pos = [z*3/8-z_width/2, - z*5/8-z_width/2] - w_pos = [w*3/8-w_width/2, - w*5/8-w_width/2] - h_pos = [h*3/8-h_width/2, - h*5/8-h_width/2] - - for zz in [[0,z_width],[z_pos[0],z_pos[0]+z_width],[z_pos[1],z_pos[1]+z_width],[-z_width,None]]: - for hh in [[0,h_width],[h_pos[0],h_pos[0]+h_width],[h_pos[1],h_pos[1]+h_width],[-h_width,None]]: - for ww in [[0,w_width],[w_pos[0],w_pos[0]+w_width],[w_pos[1],w_pos[1]+w_width],[-w_width,None]]: + _, c, z, h, w = data.size() + + z_width = np.ceil(float(z / 4 + margin) / max_stride).astype('int') * max_stride + w_width = np.ceil(float(w / 4 + margin) / max_stride).astype('int') * max_stride + h_width = np.ceil(float(h / 4 + margin) / max_stride).astype('int') * max_stride + + z_pos = [z * 3 / 8 - z_width / 2, + z * 5 / 8 - z_width / 2] + w_pos = [w * 3 / 8 - w_width / 2, + w * 5 / 8 - w_width / 2] + h_pos = [h * 3 / 8 - h_width / 2, + h * 5 / 8 - h_width / 2] + + for zz in [[0, z_width], [z_pos[0], z_pos[0] + z_width], [z_pos[1], z_pos[1] + z_width], [-z_width, None]]: + for hh in [[0, h_width], [h_pos[0], h_pos[0] + h_width], [h_pos[1], h_pos[1] + h_width], [-h_width, None]]: + for ww in [[0, w_width], [w_pos[0], w_pos[0] + w_width], [w_pos[1], w_pos[1] + w_width], [-w_width, None]]: splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) - + return torch.cat(splits, 0) + def combine64(output, z, h, w): splits = [] for i in range(len(output)): splits.append(output[i]) - + output = np.zeros(( z, h, @@ -284,25 +295,28 @@ def combine64(output, z, h, w): splits[0].shape[3], splits[0].shape[4]), np.float32) - z_width = int(np.ceil(float(z) / 4)) h_width = int(np.ceil(float(h) / 4)) w_width = int(np.ceil(float(w) / 4)) - splitzstart = splits[0].shape[0]/2-z_width/2 - splithstart = splits[0].shape[1]/2-h_width/2 - splitwstart = splits[0].shape[2]/2-w_width/2 - + splitzstart = splits[0].shape[0] / 2 - z_width / 2 + splithstart = splits[0].shape[1] / 2 - h_width / 2 + splitwstart = splits[0].shape[2] / 2 - w_width / 2 + i = 0 - for zz,zz2 in zip([[0,z_width],[z_width,z_width*2],[z_width*2,z_width*3],[z_width*3-z,None]], - [[0,z_width],[splitzstart,z_width+splitzstart],[splitzstart,z_width+splitzstart],[z_width*3-z,None]]): - - for hh,hh2 in zip([[0,h_width],[h_width,h_width*2],[h_width*2,h_width*3],[h_width*3-h,None]], - [[0,h_width],[splithstart,h_width+splithstart],[splithstart,h_width+splithstart],[h_width*3-h,None]]): - - for ww,ww2 in zip([[0,w_width],[w_width,w_width*2],[w_width*2,w_width*3],[w_width*3-w,None]], - [[0,w_width],[splitwstart,w_width+splitwstart],[splitwstart,w_width+splitwstart],[w_width*3-w,None]]): - - output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz2[0]:zz2[1], hh2[0]:hh2[1], ww2[0]:ww2[1], :, :] - i = i+1 - + for zz, zz2 in zip([[0, z_width], [z_width, z_width * 2], [z_width * 2, z_width * 3], [z_width * 3 - z, None]], + [[0, z_width], [splitzstart, z_width + splitzstart], [splitzstart, z_width + splitzstart], + [z_width * 3 - z, None]]): + + for hh, hh2 in zip([[0, h_width], [h_width, h_width * 2], [h_width * 2, h_width * 3], [h_width * 3 - h, None]], + [[0, h_width], [splithstart, h_width + splithstart], [splithstart, h_width + splithstart], + [h_width * 3 - h, None]]): + + for ww, ww2 in zip( + [[0, w_width], [w_width, w_width * 2], [w_width * 2, w_width * 3], [w_width * 3 - w, None]], + [[0, w_width], [splitwstart, w_width + splitwstart], [splitwstart, w_width + splitwstart], + [w_width * 3 - w, None]]): + output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz2[0]:zz2[1], hh2[0]:hh2[1], + ww2[0]:ww2[1], :, :] + i = i + 1 + return output diff --git a/training/prepare.py b/training/prepare.py index 90f1227..3dc4645 100644 --- a/training/prepare.py +++ b/training/prepare.py @@ -1,265 +1,265 @@ import os import shutil -import numpy as np -from config_training import config - +import sys +from functools import partial +from multiprocessing import Pool -from scipy.io import loadmat +import SimpleITK as sitk import numpy as np -import h5py import pandas -import scipy +from config_training import config from scipy.ndimage.interpolation import zoom -from skimage import measure -import SimpleITK as sitk -from scipy.ndimage.morphology import binary_dilation,generate_binary_structure +from scipy.ndimage.morphology import binary_dilation, generate_binary_structure from skimage.morphology import convex_hull_image -import pandas -from multiprocessing import Pool -from functools import partial -import sys + sys.path.append('../preprocessing') from step1 import step1_python import warnings -def resample(imgs, spacing, new_spacing,order=2): - if len(imgs.shape)==3: + +def resample(imgs, spacing, new_spacing, order=2): + if len(imgs.shape) == 3: new_shape = np.round(imgs.shape * spacing / new_spacing) true_spacing = spacing * imgs.shape / new_shape resize_factor = new_shape / imgs.shape - imgs = zoom(imgs, resize_factor, mode = 'nearest',order=order) + imgs = zoom(imgs, resize_factor, mode='nearest', order=order) return imgs, true_spacing - elif len(imgs.shape)==4: + elif len(imgs.shape) == 4: n = imgs.shape[-1] newimg = [] for i in range(n): - slice = imgs[:,:,:,i] - newslice,true_spacing = resample(slice,spacing,new_spacing) + slice = imgs[:, :, :, i] + newslice, true_spacing = resample(slice, spacing, new_spacing) newimg.append(newslice) - newimg=np.transpose(np.array(newimg),[1,2,3,0]) - return newimg,true_spacing + newimg = np.transpose(np.array(newimg), [1, 2, 3, 0]) + return newimg, true_spacing else: raise ValueError('wrong shape') + + def worldToVoxelCoord(worldCoord, origin, spacing): - stretchedVoxelCoord = np.absolute(worldCoord - origin) voxelCoord = stretchedVoxelCoord / spacing return voxelCoord + def load_itk_image(filename): with open(filename) as f: contents = f.readlines() line = [k for k in contents if k.startswith('TransformMatrix')][0] transformM = np.array(line.split(' = ')[1].split(' ')).astype('float') transformM = np.round(transformM) - if np.any( transformM!=np.array([1,0,0, 0, 1, 0, 0, 0, 1])): + if np.any(transformM != np.array([1, 0, 0, 0, 1, 0, 0, 0, 1])): isflip = True else: isflip = False itkimage = sitk.ReadImage(filename) numpyImage = sitk.GetArrayFromImage(itkimage) - + numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) - - return numpyImage, numpyOrigin, numpySpacing,isflip + + return numpyImage, numpyOrigin, numpySpacing, isflip + def process_mask(mask): convex_mask = np.copy(mask) for i_layer in range(convex_mask.shape[0]): - mask1 = np.ascontiguousarray(mask[i_layer]) - if np.sum(mask1)>0: + mask1 = np.ascontiguousarray(mask[i_layer]) + if np.sum(mask1) > 0: mask2 = convex_hull_image(mask1) - if np.sum(mask2)>1.5*np.sum(mask1): + if np.sum(mask2) > 1.5 * np.sum(mask1): mask2 = mask1 else: mask2 = mask1 convex_mask[i_layer] = mask2 - struct = generate_binary_structure(3,1) - dilatedMask = binary_dilation(convex_mask,structure=struct,iterations=10) + struct = generate_binary_structure(3, 1) + dilatedMask = binary_dilation(convex_mask, structure=struct, iterations=10) return dilatedMask def lumTrans(img): - lungwin = np.array([-1200.,600.]) - newimg = (img-lungwin[0])/(lungwin[1]-lungwin[0]) - newimg[newimg<0]=0 - newimg[newimg>1]=1 - newimg = (newimg*255).astype('uint8') + lungwin = np.array([-1200., 600.]) + newimg = (img - lungwin[0]) / (lungwin[1] - lungwin[0]) + newimg[newimg < 0] = 0 + newimg[newimg > 1] = 1 + newimg = (newimg * 255).astype('uint8') return newimg -def savenpy(id,annos,filelist,data_path,prep_folder): - resolution = np.array([1,1,1]) +def savenpy(id, annos, filelist, data_path, prep_folder): + resolution = np.array([1, 1, 1]) name = filelist[id] - label = annos[annos[:,0]==name] - label = label[:,[3,1,2,4]].astype('float') - - im, m1, m2, spacing = step1_python(os.path.join(data_path,name)) - Mask = m1+m2 - - newshape = np.round(np.array(Mask.shape)*spacing/resolution) - xx,yy,zz= np.where(Mask) - box = np.array([[np.min(xx),np.max(xx)],[np.min(yy),np.max(yy)],[np.min(zz),np.max(zz)]]) - box = box*np.expand_dims(spacing,1)/np.expand_dims(resolution,1) + label = annos[annos[:, 0] == name] + label = label[:, [3, 1, 2, 4]].astype('float') + + im, m1, m2, spacing = step1_python(os.path.join(data_path, name)) + Mask = m1 + m2 + + newshape = np.round(np.array(Mask.shape) * spacing / resolution) + xx, yy, zz = np.where(Mask) + box = np.array([[np.min(xx), np.max(xx)], [np.min(yy), np.max(yy)], [np.min(zz), np.max(zz)]]) + box = box * np.expand_dims(spacing, 1) / np.expand_dims(resolution, 1) box = np.floor(box).astype('int') margin = 5 - extendbox = np.vstack([np.max([[0,0,0],box[:,0]-margin],0),np.min([newshape,box[:,1]+2*margin],axis=0).T]).T + extendbox = np.vstack( + [np.max([[0, 0, 0], box[:, 0] - margin], 0), np.min([newshape, box[:, 1] + 2 * margin], axis=0).T]).T extendbox = extendbox.astype('int') - - convex_mask = m1 dm1 = process_mask(m1) dm2 = process_mask(m2) - dilatedMask = dm1+dm2 - Mask = m1+m2 + dilatedMask = dm1 + dm2 + Mask = m1 + m2 extramask = dilatedMask - Mask bone_thresh = 210 pad_value = 170 - im[np.isnan(im)]=-2000 + im[np.isnan(im)] = -2000 sliceim = lumTrans(im) - sliceim = sliceim*dilatedMask+pad_value*(1-dilatedMask).astype('uint8') - bones = sliceim*extramask>bone_thresh + sliceim = sliceim * dilatedMask + pad_value * (1 - dilatedMask).astype('uint8') + bones = sliceim * extramask > bone_thresh sliceim[bones] = pad_value - sliceim1,_ = resample(sliceim,spacing,resolution,order=1) - sliceim2 = sliceim1[extendbox[0,0]:extendbox[0,1], - extendbox[1,0]:extendbox[1,1], - extendbox[2,0]:extendbox[2,1]] - sliceim = sliceim2[np.newaxis,...] - np.save(os.path.join(prep_folder,name+'_clean.npy'),sliceim) - - - if len(label)==0: - label2 = np.array([[0,0,0,0]]) - elif len(label[0])==0: - label2 = np.array([[0,0,0,0]]) - elif label[0][0]==0: - label2 = np.array([[0,0,0,0]]) + sliceim1, _ = resample(sliceim, spacing, resolution, order=1) + sliceim2 = sliceim1[extendbox[0, 0]:extendbox[0, 1], + extendbox[1, 0]:extendbox[1, 1], + extendbox[2, 0]:extendbox[2, 1]] + sliceim = sliceim2[np.newaxis, ...] + np.save(os.path.join(prep_folder, name + '_clean.npy'), sliceim) + + if len(label) == 0: + label2 = np.array([[0, 0, 0, 0]]) + elif len(label[0]) == 0: + label2 = np.array([[0, 0, 0, 0]]) + elif label[0][0] == 0: + label2 = np.array([[0, 0, 0, 0]]) else: haslabel = 1 label2 = np.copy(label).T - label2[:3] = label2[:3][[0,2,1]] - label2[:3] = label2[:3]*np.expand_dims(spacing,1)/np.expand_dims(resolution,1) - label2[3] = label2[3]*spacing[1]/resolution[1] - label2[:3] = label2[:3]-np.expand_dims(extendbox[:,0],1) + label2[:3] = label2[:3][[0, 2, 1]] + label2[:3] = label2[:3] * np.expand_dims(spacing, 1) / np.expand_dims(resolution, 1) + label2[3] = label2[3] * spacing[1] / resolution[1] + label2[:3] = label2[:3] - np.expand_dims(extendbox[:, 0], 1) label2 = label2[:4].T - np.save(os.path.join(prep_folder,name+'_label.npy'),label2) + np.save(os.path.join(prep_folder, name + '_label.npy'), label2) print(name) -def full_prep(step1=True,step2 = True): + +def full_prep(step1=True, step2=True): warnings.filterwarnings("ignore") - #preprocess_result_path = './prep_result' + # preprocess_result_path = './prep_result' prep_folder = config['preprocess_result_path'] data_path = config['stage1_data_path'] finished_flag = '.flag_prepkaggle' - + if not os.path.exists(finished_flag): alllabelfiles = config['stage1_annos_path'] tmp = [] for f in alllabelfiles: content = np.array(pandas.read_csv(f)) - content = content[content[:,0]!=np.nan] - tmp.append(content[:,:5]) - alllabel = np.concatenate(tmp,0) + content = content[content[:, 0] != np.nan] + tmp.append(content[:, :5]) + alllabel = np.concatenate(tmp, 0) filelist = os.listdir(config['stage1_data_path']) if not os.path.exists(prep_folder): os.mkdir(prep_folder) - #eng.addpath('preprocessing/',nargout=0) + # eng.addpath('preprocessing/',nargout=0) print('starting preprocessing') pool = Pool() filelist = [f for f in os.listdir(data_path)] - partial_savenpy = partial(savenpy,annos= alllabel,filelist=filelist,data_path=data_path,prep_folder=prep_folder ) + partial_savenpy = partial(savenpy, annos=alllabel, filelist=filelist, data_path=data_path, + prep_folder=prep_folder) N = len(filelist) - #savenpy(1) - _=pool.map(partial_savenpy,range(N)) + # savenpy(1) + _ = pool.map(partial_savenpy, range(N)) pool.close() pool.join() print('end preprocessing') - f= open(finished_flag,"w+") + f = open(finished_flag, "w+") + -def savenpy_luna(id,annos,filelist,luna_segment,luna_data,savepath): +def savenpy_luna(id, annos, filelist, luna_segment, luna_data, savepath): islabel = True isClean = True - resolution = np.array([1,1,1]) -# resolution = np.array([2,2,2]) + resolution = np.array([1, 1, 1]) + # resolution = np.array([2,2,2]) name = filelist[id] - - Mask,origin,spacing,isflip = load_itk_image(os.path.join(luna_segment,name+'.mhd')) + + Mask, origin, spacing, isflip = load_itk_image(os.path.join(luna_segment, name + '.mhd')) if isflip: - Mask = Mask[:,::-1,::-1] - newshape = np.round(np.array(Mask.shape)*spacing/resolution).astype('int') - m1 = Mask==3 - m2 = Mask==4 - Mask = m1+m2 - - xx,yy,zz= np.where(Mask) - box = np.array([[np.min(xx),np.max(xx)],[np.min(yy),np.max(yy)],[np.min(zz),np.max(zz)]]) - box = box*np.expand_dims(spacing,1)/np.expand_dims(resolution,1) + Mask = Mask[:, ::-1, ::-1] + newshape = np.round(np.array(Mask.shape) * spacing / resolution).astype('int') + m1 = Mask == 3 + m2 = Mask == 4 + Mask = m1 + m2 + + xx, yy, zz = np.where(Mask) + box = np.array([[np.min(xx), np.max(xx)], [np.min(yy), np.max(yy)], [np.min(zz), np.max(zz)]]) + box = box * np.expand_dims(spacing, 1) / np.expand_dims(resolution, 1) box = np.floor(box).astype('int') margin = 5 - extendbox = np.vstack([np.max([[0,0,0],box[:,0]-margin],0),np.min([newshape,box[:,1]+2*margin],axis=0).T]).T + extendbox = np.vstack( + [np.max([[0, 0, 0], box[:, 0] - margin], 0), np.min([newshape, box[:, 1] + 2 * margin], axis=0).T]).T - this_annos = np.copy(annos[annos[:,0]==int(name)]) + this_annos = np.copy(annos[annos[:, 0] == int(name)]) if isClean: convex_mask = m1 dm1 = process_mask(m1) dm2 = process_mask(m2) - dilatedMask = dm1+dm2 - Mask = m1+m2 + dilatedMask = dm1 + dm2 + Mask = m1 + m2 extramask = dilatedMask ^ Mask bone_thresh = 210 pad_value = 170 - sliceim,origin,spacing,isflip = load_itk_image(os.path.join(luna_data,name+'.mhd')) + sliceim, origin, spacing, isflip = load_itk_image(os.path.join(luna_data, name + '.mhd')) if isflip: - sliceim = sliceim[:,::-1,::-1] + sliceim = sliceim[:, ::-1, ::-1] print('flip!') sliceim = lumTrans(sliceim) - sliceim = sliceim*dilatedMask+pad_value*(1-dilatedMask).astype('uint8') - bones = (sliceim*extramask)>bone_thresh + sliceim = sliceim * dilatedMask + pad_value * (1 - dilatedMask).astype('uint8') + bones = (sliceim * extramask) > bone_thresh sliceim[bones] = pad_value - - sliceim1,_ = resample(sliceim,spacing,resolution,order=1) - sliceim2 = sliceim1[extendbox[0,0]:extendbox[0,1], - extendbox[1,0]:extendbox[1,1], - extendbox[2,0]:extendbox[2,1]] - sliceim = sliceim2[np.newaxis,...] - np.save(os.path.join(savepath,name+'_clean.npy'),sliceim) + sliceim1, _ = resample(sliceim, spacing, resolution, order=1) + sliceim2 = sliceim1[extendbox[0, 0]:extendbox[0, 1], + extendbox[1, 0]:extendbox[1, 1], + extendbox[2, 0]:extendbox[2, 1]] + sliceim = sliceim2[np.newaxis, ...] + np.save(os.path.join(savepath, name + '_clean.npy'), sliceim) if islabel: - this_annos = np.copy(annos[annos[:,0]==int(name)]) + this_annos = np.copy(annos[annos[:, 0] == int(name)]) label = [] - if len(this_annos)>0: - + if len(this_annos) > 0: + for c in this_annos: - pos = worldToVoxelCoord(c[1:4][::-1],origin=origin,spacing=spacing) + pos = worldToVoxelCoord(c[1:4][::-1], origin=origin, spacing=spacing) if isflip: - pos[1:] = Mask.shape[1:3]-pos[1:] - label.append(np.concatenate([pos,[c[4]/spacing[1]]])) - + pos[1:] = Mask.shape[1:3] - pos[1:] + label.append(np.concatenate([pos, [c[4] / spacing[1]]])) + label = np.array(label) - if len(label)==0: - label2 = np.array([[0,0,0,0]]) + if len(label) == 0: + label2 = np.array([[0, 0, 0, 0]]) else: label2 = np.copy(label).T - label2[:3] = label2[:3]*np.expand_dims(spacing,1)/np.expand_dims(resolution,1) - label2[3] = label2[3]*spacing[1]/resolution[1] - label2[:3] = label2[:3]-np.expand_dims(extendbox[:,0],1) + label2[:3] = label2[:3] * np.expand_dims(spacing, 1) / np.expand_dims(resolution, 1) + label2[3] = label2[3] * spacing[1] / resolution[1] + label2[:3] = label2[:3] - np.expand_dims(extendbox[:, 0], 1) label2 = label2[:4].T - np.save(os.path.join(savepath,name+'_label.npy'),label2) - + np.save(os.path.join(savepath, name + '_label.npy'), label2) + print(name) + def preprocess_luna(): luna_segment = config['luna_segment'] savepath = config['preprocess_result_path'] @@ -268,25 +268,25 @@ def preprocess_luna(): finished_flag = '.flag_preprocessluna' print('starting preprocessing luna') if not os.path.exists(finished_flag): - filelist = [f.split('.mhd')[0] for f in os.listdir(luna_data) if f.endswith('.mhd') ] + filelist = [f.split('.mhd')[0] for f in os.listdir(luna_data) if f.endswith('.mhd')] annos = np.array(pandas.read_csv(luna_label)) if not os.path.exists(savepath): os.mkdir(savepath) - pool = Pool() - partial_savenpy_luna = partial(savenpy_luna,annos=annos,filelist=filelist, - luna_segment=luna_segment,luna_data=luna_data,savepath=savepath) + partial_savenpy_luna = partial(savenpy_luna, annos=annos, filelist=filelist, + luna_segment=luna_segment, luna_data=luna_data, savepath=savepath) N = len(filelist) - #savenpy(1) - _=pool.map(partial_savenpy_luna,range(N)) + # savenpy(1) + _ = pool.map(partial_savenpy_luna, range(N)) pool.close() pool.join() print('end preprocessing luna') - f= open(finished_flag,"w+") - + f = open(finished_flag, "w+") + + def prepare_luna(): print('start changing luna name') luna_raw = config['luna_raw'] @@ -294,52 +294,52 @@ def prepare_luna(): luna_data = config['luna_data'] luna_segment = config['luna_segment'] finished_flag = '.flag_prepareluna' - + if not os.path.exists(finished_flag): - subsetdirs = [os.path.join(luna_raw,f) for f in os.listdir(luna_raw) if f.startswith('subset') and os.path.isdir(os.path.join(luna_raw,f))] + subsetdirs = [os.path.join(luna_raw, f) for f in os.listdir(luna_raw) if + f.startswith('subset') and os.path.isdir(os.path.join(luna_raw, f))] if not os.path.exists(luna_data): os.mkdir(luna_data) -# allnames = [] -# for d in subsetdirs: -# files = os.listdir(d) -# names = [f[:-4] for f in files if f.endswith('mhd')] -# allnames = allnames + names -# allnames = np.array(allnames) -# allnames = np.sort(allnames) - -# ids = np.arange(len(allnames)).astype('str') -# ids = np.array(['0'*(3-len(n))+n for n in ids]) -# pds = pandas.DataFrame(np.array([ids,allnames]).T) -# namelist = list(allnames) - - abbrevs = np.array(pandas.read_csv(config['luna_abbr'],header=None)) - namelist = list(abbrevs[:,1]) - ids = abbrevs[:,0] - + # allnames = [] + # for d in subsetdirs: + # files = os.listdir(d) + # names = [f[:-4] for f in files if f.endswith('mhd')] + # allnames = allnames + names + # allnames = np.array(allnames) + # allnames = np.sort(allnames) + + # ids = np.arange(len(allnames)).astype('str') + # ids = np.array(['0'*(3-len(n))+n for n in ids]) + # pds = pandas.DataFrame(np.array([ids,allnames]).T) + # namelist = list(allnames) + + abbrevs = np.array(pandas.read_csv(config['luna_abbr'], header=None)) + namelist = list(abbrevs[:, 1]) + ids = abbrevs[:, 0] + for d in subsetdirs: files = os.listdir(d) files.sort() for f in files: name = f[:-4] id = ids[namelist.index(name)] - filename = '0'*(3-len(str(id)))+str(id) - shutil.move(os.path.join(d,f),os.path.join(luna_data,filename+f[-4:])) - print(os.path.join(luna_data,str(id)+f[-4:])) + filename = '0' * (3 - len(str(id))) + str(id) + shutil.move(os.path.join(d, f), os.path.join(luna_data, filename + f[-4:])) + print(os.path.join(luna_data, str(id) + f[-4:])) files = [f for f in os.listdir(luna_data) if f.endswith('mhd')] for file in files: - with open(os.path.join(luna_data,file),'r') as f: + with open(os.path.join(luna_data, file), 'r') as f: content = f.readlines() id = file.split('.mhd')[0] - filename = '0'*(3-len(str(id)))+str(id) - content[-1]='ElementDataFile = '+filename+'.raw\n' + filename = '0' * (3 - len(str(id))) + str(id) + content[-1] = 'ElementDataFile = ' + filename + '.raw\n' print(content[-1]) - with open(os.path.join(luna_data,file),'w') as f: + with open(os.path.join(luna_data, file), 'w') as f: f.writelines(content) - seglist = os.listdir(luna_segment) for f in seglist: if f.endswith('.mhd'): @@ -351,27 +351,26 @@ def prepare_luna(): lastfix = f[-5:] if name in namelist: id = ids[namelist.index(name)] - filename = '0'*(3-len(str(id)))+str(id) - - shutil.move(os.path.join(luna_segment,f),os.path.join(luna_segment,filename+lastfix)) - print(os.path.join(luna_segment,filename+lastfix)) + filename = '0' * (3 - len(str(id))) + str(id) + shutil.move(os.path.join(luna_segment, f), os.path.join(luna_segment, filename + lastfix)) + print(os.path.join(luna_segment, filename + lastfix)) files = [f for f in os.listdir(luna_segment) if f.endswith('mhd')] for file in files: - with open(os.path.join(luna_segment,file),'r') as f: + with open(os.path.join(luna_segment, file), 'r') as f: content = f.readlines() - id = file.split('.mhd')[0] - filename = '0'*(3-len(str(id)))+str(id) - content[-1]='ElementDataFile = '+filename+'.zraw\n' + id = file.split('.mhd')[0] + filename = '0' * (3 - len(str(id))) + str(id) + content[-1] = 'ElementDataFile = ' + filename + '.zraw\n' print(content[-1]) - with open(os.path.join(luna_segment,file),'w') as f: + with open(os.path.join(luna_segment, file), 'w') as f: f.writelines(content) print('end changing luna name') - f= open(finished_flag,"w+") - -if __name__=='__main__': - full_prep(step1=True,step2=True) + f = open(finished_flag, "w+") + + +if __name__ == '__main__': + full_prep(step1=True, step2=True) prepare_luna() preprocess_luna() - diff --git a/utils.py b/utils.py index 8743c82..d0140c8 100644 --- a/utils.py +++ b/utils.py @@ -1,80 +1,87 @@ -import sys import os +import sys + import numpy as np import torch + + def getFreeId(): - import pynvml + import pynvml pynvml.nvmlInit() + def getFreeRatio(id): handle = pynvml.nvmlDeviceGetHandleByIndex(id) use = pynvml.nvmlDeviceGetUtilizationRates(handle) - ratio = 0.5*(float(use.gpu+float(use.memory))) + ratio = 0.5 * (float(use.gpu + float(use.memory))) return ratio deviceCount = pynvml.nvmlDeviceGetCount() available = [] for i in range(deviceCount): - if getFreeRatio(i)<70: + if getFreeRatio(i) < 70: available.append(i) gpus = '' for g in available: - gpus = gpus+str(g)+',' + gpus = gpus + str(g) + ',' gpus = gpus[:-1] return gpus + def setgpu(gpuinput): freeids = getFreeId() - if gpuinput=='all': + if gpuinput == 'all': gpus = freeids else: gpus = gpuinput if any([g not in freeids for g in gpus.split(',')]): - raise ValueError('gpu'+g+'is being used') - print('using gpu '+gpus) - os.environ['CUDA_VISIBLE_DEVICES']=gpus + raise ValueError('gpu' + g + 'is being used') + print('using gpu ' + gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = gpus return len(gpus.split(',')) + class Logger(object): - def __init__(self,logfile): + def __init__(self, logfile): self.terminal = sys.stdout self.log = open(logfile, "a") def write(self, message): self.terminal.write(message) - self.log.write(message) + self.log.write(message) def flush(self): - #this flush method is needed for python 3 compatibility. - #this handles the flush command by doing nothing. - #you might want to specify some extra behavior here. - pass + # this flush method is needed for python 3 compatibility. + # this handles the flush command by doing nothing. + # you might want to specify some extra behavior here. + pass -def split4(data, max_stride, margin): +def split4(data, max_stride, margin): splits = [] data = torch.Tensor.numpy(data) - _,c, z, h, w = data.shape - - w_width = np.ceil(float(w / 2 + margin)/max_stride).astype('int')*max_stride - h_width = np.ceil(float(h / 2 + margin)/max_stride).astype('int')*max_stride - pad = int(np.ceil(float(z)/max_stride)*max_stride)-z - leftpad = pad/2 - pad = [[0,0],[0,0],[leftpad,pad-leftpad],[0,0],[0,0]] - data = np.pad(data,pad,'constant',constant_values=-1) + _, c, z, h, w = data.shape + + w_width = np.ceil(float(w / 2 + margin) / max_stride).astype('int') * max_stride + h_width = np.ceil(float(h / 2 + margin) / max_stride).astype('int') * max_stride + pad = int(np.ceil(float(z) / max_stride) * max_stride) - z + leftpad = pad / 2 + pad = [[0, 0], [0, 0], [leftpad, pad - leftpad], [0, 0], [0, 0]] + data = np.pad(data, pad, 'constant', constant_values=-1) data = torch.from_numpy(data) splits.append(data[:, :, :, :h_width, :w_width]) splits.append(data[:, :, :, :h_width, -w_width:]) splits.append(data[:, :, :, -h_width:, :w_width]) splits.append(data[:, :, :, -h_width:, -w_width:]) - + return torch.cat(splits, 0) + def combine4(output, h, w): splits = [] for i in range(len(output)): splits.append(output[i]) - + output = np.zeros(( splits[0].shape[0], h, @@ -101,37 +108,36 @@ def combine4(output, h, w): return output -def split8(data, max_stride, margin): + +def split8(data, max_stride, margin): splits = [] if isinstance(data, np.ndarray): c, z, h, w = data.shape else: - _,c, z, h, w = data.size() - - z_width = np.ceil(float(z / 2 + margin)/max_stride).astype('int')*max_stride - w_width = np.ceil(float(w / 2 + margin)/max_stride).astype('int')*max_stride - h_width = np.ceil(float(h / 2 + margin)/max_stride).astype('int')*max_stride - for zz in [[0,z_width],[-z_width,None]]: - for hh in [[0,h_width],[-h_width,None]]: - for ww in [[0,w_width],[-w_width,None]]: + _, c, z, h, w = data.size() + + z_width = np.ceil(float(z / 2 + margin) / max_stride).astype('int') * max_stride + w_width = np.ceil(float(w / 2 + margin) / max_stride).astype('int') * max_stride + h_width = np.ceil(float(h / 2 + margin) / max_stride).astype('int') * max_stride + for zz in [[0, z_width], [-z_width, None]]: + for hh in [[0, h_width], [-h_width, None]]: + for ww in [[0, w_width], [-w_width, None]]: if isinstance(data, np.ndarray): splits.append(data[np.newaxis, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) else: splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) - if isinstance(data, np.ndarray): return np.concatenate(splits, 0) else: return torch.cat(splits, 0) - def combine8(output, z, h, w): splits = [] for i in range(len(output)): splits.append(output[i]) - + output = np.zeros(( z, h, @@ -139,41 +145,42 @@ def combine8(output, z, h, w): splits[0].shape[3], splits[0].shape[4]), np.float32) - z_width = z / 2 h_width = h / 2 w_width = w / 2 i = 0 - for zz in [[0,z_width],[z_width-z,None]]: - for hh in [[0,h_width],[h_width-h,None]]: - for ww in [[0,w_width],[w_width-w,None]]: - output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] - i = i+1 - + for zz in [[0, z_width], [z_width - z, None]]: + for hh in [[0, h_width], [h_width - h, None]]: + for ww in [[0, w_width], [w_width - w, None]]: + output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], + :, :] + i = i + 1 + return output -def split16(data, max_stride, margin): +def split16(data, max_stride, margin): splits = [] - _,c, z, h, w = data.size() - - z_width = np.ceil(float(z / 4 + margin)/max_stride).astype('int')*max_stride - z_pos = [z*3/8-z_width/2, - z*5/8-z_width/2] - h_width = np.ceil(float(h / 2 + margin)/max_stride).astype('int')*max_stride - w_width = np.ceil(float(w / 2 + margin)/max_stride).astype('int')*max_stride - for zz in [[0,z_width],[z_pos[0],z_pos[0]+z_width],[z_pos[1],z_pos[1]+z_width],[-z_width,None]]: - for hh in [[0,h_width],[-h_width,None]]: - for ww in [[0,w_width],[-w_width,None]]: + _, c, z, h, w = data.size() + + z_width = np.ceil(float(z / 4 + margin) / max_stride).astype('int') * max_stride + z_pos = [z * 3 / 8 - z_width / 2, + z * 5 / 8 - z_width / 2] + h_width = np.ceil(float(h / 2 + margin) / max_stride).astype('int') * max_stride + w_width = np.ceil(float(w / 2 + margin) / max_stride).astype('int') * max_stride + for zz in [[0, z_width], [z_pos[0], z_pos[0] + z_width], [z_pos[1], z_pos[1] + z_width], [-z_width, None]]: + for hh in [[0, h_width], [-h_width, None]]: + for ww in [[0, w_width], [-w_width, None]]: splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) - + return torch.cat(splits, 0) + def combine16(output, z, h, w): splits = [] for i in range(len(output)): splits.append(output[i]) - + output = np.zeros(( z, h, @@ -181,45 +188,47 @@ def combine16(output, z, h, w): splits[0].shape[3], splits[0].shape[4]), np.float32) - z_width = z / 4 h_width = h / 2 w_width = w / 2 - splitzstart = splits[0].shape[0]/2-z_width/2 - z_pos = [z*3/8-z_width/2, - z*5/8-z_width/2] + splitzstart = splits[0].shape[0] / 2 - z_width / 2 + z_pos = [z * 3 / 8 - z_width / 2, + z * 5 / 8 - z_width / 2] i = 0 - for zz,zz2 in zip([[0,z_width],[z_width,z_width*2],[z_width*2,z_width*3],[z_width*3-z,None]], - [[0,z_width],[splitzstart,z_width+splitzstart],[splitzstart,z_width+splitzstart],[z_width*3-z,None]]): - for hh in [[0,h_width],[h_width-h,None]]: - for ww in [[0,w_width],[w_width-w,None]]: - output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz2[0]:zz2[1], hh[0]:hh[1], ww[0]:ww[1], :, :] - i = i+1 - + for zz, zz2 in zip([[0, z_width], [z_width, z_width * 2], [z_width * 2, z_width * 3], [z_width * 3 - z, None]], + [[0, z_width], [splitzstart, z_width + splitzstart], [splitzstart, z_width + splitzstart], + [z_width * 3 - z, None]]): + for hh in [[0, h_width], [h_width - h, None]]: + for ww in [[0, w_width], [w_width - w, None]]: + output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz2[0]:zz2[1], hh[0]:hh[1], ww[0]:ww[1], + :, :] + i = i + 1 + return output -def split32(data, max_stride, margin): + +def split32(data, max_stride, margin): splits = [] - _,c, z, h, w = data.size() - - z_width = np.ceil(float(z / 2 + margin)/max_stride).astype('int')*max_stride - w_width = np.ceil(float(w / 4 + margin)/max_stride).astype('int')*max_stride - h_width = np.ceil(float(h / 4 + margin)/max_stride).astype('int')*max_stride - - w_pos = [w*3/8-w_width/2, - w*5/8-w_width/2] - h_pos = [h*3/8-h_width/2, - h*5/8-h_width/2] - - for zz in [[0,z_width],[-z_width,None]]: - for hh in [[0,h_width],[h_pos[0],h_pos[0]+h_width],[h_pos[1],h_pos[1]+h_width],[-h_width,None]]: - for ww in [[0,w_width],[w_pos[0],w_pos[0]+w_width],[w_pos[1],w_pos[1]+w_width],[-w_width,None]]: + _, c, z, h, w = data.size() + + z_width = np.ceil(float(z / 2 + margin) / max_stride).astype('int') * max_stride + w_width = np.ceil(float(w / 4 + margin) / max_stride).astype('int') * max_stride + h_width = np.ceil(float(h / 4 + margin) / max_stride).astype('int') * max_stride + + w_pos = [w * 3 / 8 - w_width / 2, + w * 5 / 8 - w_width / 2] + h_pos = [h * 3 / 8 - h_width / 2, + h * 5 / 8 - h_width / 2] + + for zz in [[0, z_width], [-z_width, None]]: + for hh in [[0, h_width], [h_pos[0], h_pos[0] + h_width], [h_pos[1], h_pos[1] + h_width], [-h_width, None]]: + for ww in [[0, w_width], [w_pos[0], w_pos[0] + w_width], [w_pos[1], w_pos[1] + w_width], [-w_width, None]]: splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) - + return torch.cat(splits, 0) + def combine32(splits, z, h, w): - output = np.zeros(( z, h, @@ -227,56 +236,58 @@ def combine32(splits, z, h, w): splits[0].shape[3], splits[0].shape[4]), np.float32) - z_width = int(np.ceil(float(z) / 2)) h_width = int(np.ceil(float(h) / 4)) w_width = int(np.ceil(float(w) / 4)) - splithstart = splits[0].shape[1]/2-h_width/2 - splitwstart = splits[0].shape[2]/2-w_width/2 - + splithstart = splits[0].shape[1] / 2 - h_width / 2 + splitwstart = splits[0].shape[2] / 2 - w_width / 2 + i = 0 - for zz in [[0,z_width],[z_width-z,None]]: - - for hh,hh2 in zip([[0,h_width],[h_width,h_width*2],[h_width*2,h_width*3],[h_width*3-h,None]], - [[0,h_width],[splithstart,h_width+splithstart],[splithstart,h_width+splithstart],[h_width*3-h,None]]): - - for ww,ww2 in zip([[0,w_width],[w_width,w_width*2],[w_width*2,w_width*3],[w_width*3-w,None]], - [[0,w_width],[splitwstart,w_width+splitwstart],[splitwstart,w_width+splitwstart],[w_width*3-w,None]]): - - output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz[0]:zz[1], hh2[0]:hh2[1], ww2[0]:ww2[1], :, :] - i = i+1 - - return output + for zz in [[0, z_width], [z_width - z, None]]: + for hh, hh2 in zip([[0, h_width], [h_width, h_width * 2], [h_width * 2, h_width * 3], [h_width * 3 - h, None]], + [[0, h_width], [splithstart, h_width + splithstart], [splithstart, h_width + splithstart], + [h_width * 3 - h, None]]): + for ww, ww2 in zip( + [[0, w_width], [w_width, w_width * 2], [w_width * 2, w_width * 3], [w_width * 3 - w, None]], + [[0, w_width], [splitwstart, w_width + splitwstart], [splitwstart, w_width + splitwstart], + [w_width * 3 - w, None]]): + output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz[0]:zz[1], hh2[0]:hh2[1], + ww2[0]:ww2[1], :, :] + i = i + 1 -def split64(data, max_stride, margin): + return output + + +def split64(data, max_stride, margin): splits = [] - _,c, z, h, w = data.size() - - z_width = np.ceil(float(z / 4 + margin)/max_stride).astype('int')*max_stride - w_width = np.ceil(float(w / 4 + margin)/max_stride).astype('int')*max_stride - h_width = np.ceil(float(h / 4 + margin)/max_stride).astype('int')*max_stride - - z_pos = [z*3/8-z_width/2, - z*5/8-z_width/2] - w_pos = [w*3/8-w_width/2, - w*5/8-w_width/2] - h_pos = [h*3/8-h_width/2, - h*5/8-h_width/2] - - for zz in [[0,z_width],[z_pos[0],z_pos[0]+z_width],[z_pos[1],z_pos[1]+z_width],[-z_width,None]]: - for hh in [[0,h_width],[h_pos[0],h_pos[0]+h_width],[h_pos[1],h_pos[1]+h_width],[-h_width,None]]: - for ww in [[0,w_width],[w_pos[0],w_pos[0]+w_width],[w_pos[1],w_pos[1]+w_width],[-w_width,None]]: + _, c, z, h, w = data.size() + + z_width = np.ceil(float(z / 4 + margin) / max_stride).astype('int') * max_stride + w_width = np.ceil(float(w / 4 + margin) / max_stride).astype('int') * max_stride + h_width = np.ceil(float(h / 4 + margin) / max_stride).astype('int') * max_stride + + z_pos = [z * 3 / 8 - z_width / 2, + z * 5 / 8 - z_width / 2] + w_pos = [w * 3 / 8 - w_width / 2, + w * 5 / 8 - w_width / 2] + h_pos = [h * 3 / 8 - h_width / 2, + h * 5 / 8 - h_width / 2] + + for zz in [[0, z_width], [z_pos[0], z_pos[0] + z_width], [z_pos[1], z_pos[1] + z_width], [-z_width, None]]: + for hh in [[0, h_width], [h_pos[0], h_pos[0] + h_width], [h_pos[1], h_pos[1] + h_width], [-h_width, None]]: + for ww in [[0, w_width], [w_pos[0], w_pos[0] + w_width], [w_pos[1], w_pos[1] + w_width], [-w_width, None]]: splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1]]) - + return torch.cat(splits, 0) + def combine64(output, z, h, w): splits = [] for i in range(len(output)): splits.append(output[i]) - + output = np.zeros(( z, h, @@ -284,25 +295,28 @@ def combine64(output, z, h, w): splits[0].shape[3], splits[0].shape[4]), np.float32) - z_width = int(np.ceil(float(z) / 4)) h_width = int(np.ceil(float(h) / 4)) w_width = int(np.ceil(float(w) / 4)) - splitzstart = splits[0].shape[0]/2-z_width/2 - splithstart = splits[0].shape[1]/2-h_width/2 - splitwstart = splits[0].shape[2]/2-w_width/2 - + splitzstart = splits[0].shape[0] / 2 - z_width / 2 + splithstart = splits[0].shape[1] / 2 - h_width / 2 + splitwstart = splits[0].shape[2] / 2 - w_width / 2 + i = 0 - for zz,zz2 in zip([[0,z_width],[z_width,z_width*2],[z_width*2,z_width*3],[z_width*3-z,None]], - [[0,z_width],[splitzstart,z_width+splitzstart],[splitzstart,z_width+splitzstart],[z_width*3-z,None]]): - - for hh,hh2 in zip([[0,h_width],[h_width,h_width*2],[h_width*2,h_width*3],[h_width*3-h,None]], - [[0,h_width],[splithstart,h_width+splithstart],[splithstart,h_width+splithstart],[h_width*3-h,None]]): - - for ww,ww2 in zip([[0,w_width],[w_width,w_width*2],[w_width*2,w_width*3],[w_width*3-w,None]], - [[0,w_width],[splitwstart,w_width+splitwstart],[splitwstart,w_width+splitwstart],[w_width*3-w,None]]): - - output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz2[0]:zz2[1], hh2[0]:hh2[1], ww2[0]:ww2[1], :, :] - i = i+1 - + for zz, zz2 in zip([[0, z_width], [z_width, z_width * 2], [z_width * 2, z_width * 3], [z_width * 3 - z, None]], + [[0, z_width], [splitzstart, z_width + splitzstart], [splitzstart, z_width + splitzstart], + [z_width * 3 - z, None]]): + + for hh, hh2 in zip([[0, h_width], [h_width, h_width * 2], [h_width * 2, h_width * 3], [h_width * 3 - h, None]], + [[0, h_width], [splithstart, h_width + splithstart], [splithstart, h_width + splithstart], + [h_width * 3 - h, None]]): + + for ww, ww2 in zip( + [[0, w_width], [w_width, w_width * 2], [w_width * 2, w_width * 3], [w_width * 3 - w, None]], + [[0, w_width], [splitwstart, w_width + splitwstart], [splitwstart, w_width + splitwstart], + [w_width * 3 - w, None]]): + output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[i][zz2[0]:zz2[1], hh2[0]:hh2[1], + ww2[0]:ww2[1], :, :] + i = i + 1 + return output