-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsegment.py
More file actions
executable file
·38 lines (35 loc) · 2.08 KB
/
segment.py
File metadata and controls
executable file
·38 lines (35 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#!/usr/bin/env python
import torch
import torch.nn as nn
from function import predict_volumes
from model import UNet2d
import os, sys
import argparse
if __name__=='__main__':
NoneType=type(None)
# Argument
parser=argparse.ArgumentParser(description='Segmentation', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
optional=parser._action_groups.pop()
required=parser.add_argument_group('required arguments')
# Required Option
required.add_argument('-in', '--input_t1w', type=str, required=True, help='Input T1w Image for Skull Stripping')
required.add_argument('-model', '--predict_model', required=True, type=str, help='Predict Model')
# Optional Option
optional.add_argument('-out', '--out_dir', type=str, help='Output Dir')
optional.add_argument('-suffix', '--mask_suffix', type=str, default="pre_mask", help='Suffix of Mask')
optional.add_argument('-class', '--num_class', type=int, default=7, help='Number of Tissue Class for Model Input')
optional.add_argument('-slice', '--input_slice', type=int, default=3, help='Number of Slice for Model Input')
optional.add_argument('-conv', '--conv_block', type=int, default=5, help='Number of UNet Block')
optional.add_argument('-kernel', '--kernel_root', type=int, default=16, help='Number of the Root of Kernel')
optional.add_argument('-rescale', '--rescale_dim', type=int, default=256, help='Number of the Root of Kernel')
parser._action_groups.append(optional)
if len(sys.argv)==1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
train_model=UNet2d(dim_in=args.input_slice, num_class=args.num_class, num_conv_block=args.conv_block, kernel_root=args.kernel_root)
checkpoint=torch.load(args.predict_model, map_location={'cuda:0':'cpu'})
train_model.load_state_dict(checkpoint['state_dict'])
model=nn.Sequential(train_model, nn.Softmax2d())
predict_volumes(model, cimg_in=args.input_t1w, bmsk_in=None, rescale_dim=args.rescale_dim, num_class=args.num_class, save_dice=False,
save_nii=True, nii_outdir=args.out_dir, suffix=args.mask_suffix)