-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest.py
More file actions
123 lines (113 loc) · 5.57 KB
/
Copy pathtest.py
File metadata and controls
123 lines (113 loc) · 5.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from utils import *
import yaml
import os
import logging
from models.HybridSN import *
import csv
from PIL import Image
# ------------------------------------------Logging Function-----------------------------------------
if not(os.path.exists("logs")):
os.mkdir("logs")
if os.path.isfile("logs/test.log"):
os.remove("logs/test.log")
logging.basicConfig(filename="logs/test.log", format='%(asctime)s %(levelname)-8s %(message)s',
level = logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
# define input and target directories
with open('./configs/test.yml', 'r') as stream:
configs = yaml.safe_load(stream)
# ------------------------------------------Read Configuration----------------------------------------
# 模型文件夹
parameter_dir = configs["parameter_dir"]
# 输出文件夹,保存(patch, label)对
input_dir = configs["input_dir"]
# Y文件夹
label_dir = configs["label_dir"]
# 输出文件夹,保存(patch, label)对
prediction_dir = configs["prediction_dir"]
# 选择需要测试的模型,并生成预测结果
test_models = configs["test_models"]
# PCA, must match train config
pca_components = configs["pca_components"]
patch_size = configs["patch_size"]
batch_size = configs["batch_size"]
SVM_patch_size = configs["SVM_patch_size"]
test_num = configs["test_num"]
if not(os.path.exists(prediction_dir)):
os.mkdir(prediction_dir)
# -------------------------------------------Read Data and Label-----------------------------------------
# !!! 考虑去噪和SG平滑
# Prepare input data
data_files = generate_file_list(input_dir, 'hdr')
label_files = generate_file_list(label_dir, "png")
print(len(data_files))
print(len(label_files))
assert len(data_files) == len(label_files)
N = len(data_files)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using gpu number ", torch.cuda.device_count())
models = {}
# load models
for model_name in configs["test_models"]:
if(configs['test_models'][model_name]['test']):
model = load_model(parameter_dir, configs['test_models'][model_name]['param'], model_name, logger, device)
models[model_name] = model
# build csv
csv_fields = ['File','Model','TN', 'FP', 'FN', 'TP', 'Accuracy', 'Specificity', 'Sensitivity']
csv_rows = []
# Start testing
for n in range(0, N):
if n >= test_num:
break
logger.info("Now processing image %s", data_files[n])
logger.info("Now processing image %s", label_files[n])
print("Now processing image %s", data_files[n])
print("Now processing image %s", label_files[n])
# read each data and true label, pad input data and make prediction based on selected model
data = read_process_hdr_image(data_files[n] ,pca_components)
label = read_process_tif_img(label_files[n])
label_img = read_tif_img(label_files[n])
padded_data = padWithZeros(data, (patch_size-1)//2)
logger.info("padded data has shape %s", str(padded_data.shape))
SVM_padded_data = padWithZeros(data, (SVM_patch_size-1)//2)
logger.info("svm padded data has shape %s", str(SVM_padded_data.shape))
# For each model
for model_name in configs["test_models"]:
# perform prediction only if test is turned on
if(configs['test_models'][model_name]['test']):
logger.info("Testing %s", model_name)
# Load model
model = models[model_name]
# Make prediction
if (model_name == 'RBF_SVM'):
prediction_img_red = padded_img_predict(SVM_padded_data, label, SVM_patch_size, model,
configs['test_models'][model_name]['net'],
device, logger, batch_size,prob=0)
else:
prediction_img_red = padded_img_predict(padded_data, label, patch_size, model, configs['test_models'][model_name]['net'],
device, logger, batch_size)
# Obtain test statistics
# Change 0/1
input_file_name = os.path.basename(label_files[n])
input_file_without_ext = os.path.splitext(input_file_name)[0]
prediction_model_name = prediction_dir + "/" + model_name
if not(os.path.exists(prediction_model_name)):
os.mkdir(prediction_model_name)
os.mkdir(prediction_model_name+"/output")
np.save(prediction_model_name + "/output/" + input_file_without_ext, prediction_img_red)
# graph_ill_cell_label = np.zeros((prediction_img_red.shape[0],prediction_img_red.shape[1],3), dtype=np.uint8)
# graph_ill_cell_label[:,:,0] = np.where(prediction_img_red==1,255,0)
# graph_ill_cell_label[:,:,1] = np.where(prediction_img_red==2,255,0)
# graph_ill_cell_label[:,:,1] = np.where(prediction_img_red==3,255,0)
# prediction_img_red[prediction_img_red==1] = 255
# prediction_img = np.zeros((prediction_img_red.shape[0],prediction_img_red.shape[1],3), dtype=np.uint8)
# prediction_img[:,:,0] = prediction_img_red
# prediction_img[:,:,1] = prediction_img_red
# Save to directory
# prediction_model_name = prediction_dir + "/" + model_name
# if not(os.path.exists(prediction_model_name)):
# os.mkdir(prediction_model_name)
# os.mkdir(prediction_model_name+"/output")
# output_img = Image.fromarray(graph_ill_cell_label)
# input_file_name = os.path.basename(label_files[n])
# output_img.save(prediction_model_name + "/output/" + input_file_name)