-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimulation.py
More file actions
474 lines (397 loc) · 21.2 KB
/
simulation.py
File metadata and controls
474 lines (397 loc) · 21.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
import os
import sys
import datetime
import numpy as np
import h5py
# from mpi4py import MPI
import xraylib as xlib
import xraylib_np as xlib_np
import torch as tc
tc.set_default_dtype(tc.float32) # Set the default tensor dtype
import time
from util import (
prepare_fl_lines,
intersecting_length_fl_detectorlet,
ATOMIC_NUMBERS,
rotate_3d,
rotate_3d_inplane,
)
from misc import print_flush_root, create_summary
from forward_model import PPM
import warnings
from nodeology.node import as_node
import json
from tqdm import tqdm
import tifffile
warnings.filterwarnings("ignore")
# Check environment variable
USE_WORKFLOW = os.getenv('USE_SIMULATION_WORKFLOW', 'true').lower() == 'true'
def conditional_simulation_decorator(func):
if USE_WORKFLOW:
return as_node(sink=["sim_XRF_file", "sim_XRT_file"])(func)
return func
@conditional_simulation_decorator
def simulate_XRF_maps(params):
params_dict = json.loads(params)
ground_truth_file = params_dict["ground_truth_file"]
probe_energy = np.array(params_dict["probe_energy"])
incident_probe_intensity = params_dict["incident_probe_intensity"]
model_probe_attenuation = params_dict["model_probe_attenuation"]
model_self_absorption = params_dict["model_self_absorption"]
elements = params_dict["elements"]
sample_size_cm = params_dict["sample_size_cm"]
det_size_cm = params_dict["det_size_cm"]
det_from_sample_cm = params_dict["det_from_sample_cm"]
det_ds_spacing_cm = params_dict["det_ds_spacing_cm"]
# batch_size = params_dict['batch_size']
suffix = params_dict.get(
"suffix", ""
) # Get suffix from params, default to empty string
debug = params_dict.get(
"debug", False
) # Get debug flag from params, default to False
overwrite_P = False
gpu_id = 3
if tc.cuda.is_available() and gpu_id >= 0:
dev = tc.device("cuda:{}".format(gpu_id))
else:
dev = "cpu"
if debug:
print(f"Device: {dev}")
####----------------------------------------------------------------------------------####
#### load true 3D objects ####
X = np.load(ground_truth_file)
if debug:
print(f"Test objects size: {X.shape}")
X = tc.from_numpy(X).float().to(dev) #cpu
sample_size_n = X.shape[1]
sample_height_n = X.shape[3]
batch_size = sample_height_n
if debug:
print(f"batch_size: {batch_size}")
dia_len_n = int(
1.2 * (sample_height_n**2 + sample_size_n**2 + sample_size_n**2) ** 0.5
) # number of voxels along the diagonal direction (0,0,0) -> (nx, ny, nz)
n_voxel_batch = batch_size * sample_size_n #number of voxels in each batch
n_voxel = sample_height_n * sample_size_n**2 # total number of voxels.
# sample_size_n seems to be the size along x and y axis in Figure 5.1.
# sample_height_n seems to be the size along z axis in Figure 5.1.
####----------------------------------------------------------------------------------####
#### parallelization ####
n_ranks = 1
rank = 0
minibatch_ls_0 = tc.arange(n_ranks).to(dev) #dev
n_batch = (sample_height_n * sample_size_n) // (n_ranks * batch_size)
if debug:
print(f'Number of batches: {n_batch}')
####----------------------------------------------------------------------------------####
#### get physical constants for the input elements and energy ####
if debug:
print(f'Elements: {elements}')
n_element = len(elements)
if debug:
print(f'n_element: {n_element}')
# Replace mendeleev with our dictionary
atomic_numbers = {name: ATOMIC_NUMBERS[name] for name in elements}
if debug:
print(f'Atomic_numbers: {atomic_numbers}')
element_lines_roi = np.array([[element, 'K'] for element in elements])
if debug:
print(f'Element_lines_roi: {element_lines_roi}')
# Figure out the number of FL lines for each element
n_line_group_each_element = np.array([np.sum(element_lines_roi[:, 0] == name) for name in elements])
if debug:
print(f'n_line_group_each_element: {n_line_group_each_element}')
# Create a lookup table of the fluorescence lines of interests
fl_all_lines_dic = prepare_fl_lines(
element_lines_roi,
n_line_group_each_element,
probe_energy,
sample_size_n,
sample_size_cm,
) # cpu
if debug:
print(f'fl_all_lines_dic done')
detected_fl_unit_concentration = tc.as_tensor(fl_all_lines_dic["detected_fl_unit_concentration"]).float().to(dev)
# Get the mass attenuation cross section for each XRF line (3rd row in Table 5.3.1) as a list ####
mass_attenuation_cross_section_FL = tc.as_tensor(
xlib_np.CS_Total(
np.array(list(atomic_numbers.values())),
fl_all_lines_dic["fl_energy"]
)
).float().to(dev) # dev
n_line_group_each_element = tc.IntTensor(fl_all_lines_dic["n_line_group_each_element"]).to(dev)
if debug:
print(n_line_group_each_element)
n_lines = fl_all_lines_dic["n_lines"] #scalar
if debug:
print(f'Total number of energy lines (n_lines)={n_lines}')
# Create the elements list using element_lines_roi
channel_name_roi_ls = np.array([
element_line_roi[0] if element_line_roi[1] == "K"
else f"{element_line_roi[0]}_{element_line_roi[1]}"
for element_line_roi in element_lines_roi
]).astype("S5")
#print(f'channel_name_roi_ls:')
#print(channel_name_roi_ls) # this format is for xrf maps??
scaler_names = np.array(["place_holder", "us_dc", "ds_ic", "abs_ic"]).astype("S12")
# Calculate the mass attenuation cross section of probe (2nd row in Table 5.3.1) as a list ####
probe_attCS_ls = tc.as_tensor(xlib_np.CS_Total(np.array(list(atomic_numbers.values())), np.array([probe_energy])).flatten()).to(dev)
#### det_solid_angle_ratio is used only for simulated dataset (use_std_calibation: False, manual_det_area: False, manual_det_coord: False)
#### in which the incident probe intensity is not calibrated with the axo_std file.
#### The simulated collected XRF photon number is estimated by multiplying the generated
#### fluorescence photon number by "det_solid_angle_ratio" to account for the limited solid angle and the detecting efficiency of the detector
# #### Calculate the detecting solid angle covered by the area of the spherical cap covered by the detector ####
# #### OPTION A: estimate the solid angle by the curved surface
# # The distance from the sample to the boundary of the detector
# r = (det_from_sample_cm**2 + (det_dia_cm/2)**2)**0.5
# # The height of the cap
# h = r - det_from_sample_cm
# # The area of the cap area
# fl_sig_collecting_cap_area = np.pi*((det_dia_cm/2)**2 + h**2)
# # The ratio of the detecting solid angle / full soilid angle
# det_solid_angle_ratio = fl_sig_collecting_cap_area / (4*np.pi*r**2)
#### OPTION B: estimate the solid angle by the flat surface
det_solid_angle_ratio = (np.pi * (det_size_cm/2)**2) / (4*np.pi * det_from_sample_cm**2)
#print(f'det_solid_angle_ratio={det_solid_angle_ratio}')
#### signal_attenuation_factor is used to account for other factors that cause the attenuation of the XRF
#### except for the limited solid angle and self-absorption
signal_attenuation_factor = 1.0
#print(f'signal_attenuation_factor={signal_attenuation_factor}')
####----------------------------------------------------------------------------------####
#### get P array ####
output_dir = os.path.join(os.path.dirname(ground_truth_file),
f'det_size{det_size_cm}_spacing_{det_ds_spacing_cm}_dist{det_from_sample_cm}_sample_size{sample_size_cm}_nxy{sample_size_n}_nz{sample_height_n}')
os.makedirs(output_dir, exist_ok=True)
#Check if the P array exists, if it doesn't exist, call the function to calculate the P array and store it as a .h5 file.
if not os.path.isfile(f'{output_dir}/P_array.h5') or overwrite_P:
#if debug:
print(f'Calculating the intersecting length array P. This will take quite some time...')
intersecting_length_fl_detectorlet(n_ranks, rank,
det_size_cm, det_from_sample_cm, det_ds_spacing_cm,
sample_size_n, sample_size_cm, sample_height_n,
output_dir, 'P_array') #has to use CPU for this step
if debug:
print(f'Completed. P is saved at {output_dir}/P_array.h5')
else:
if debug:
print(f'Loading an existing intersecting length array P from {output_dir}/P_array.h5')
P_handle = h5py.File(f'{output_dir}/P_array.h5', 'r')
####----------------------------------------------------------------------------------####
#### I/O ####
#stdout_options = {'root':0, 'output_folder': base_path, 'save_stdout': True, 'print_terminal': False}
stdout_options = {'root':0, 'output_folder': output_dir, 'save_stdout': False, 'print_terminal': True}
timestr = str(datetime.datetime.today())
if debug:
print_flush_root(0, val=f"time: {timestr}", output_file='', **stdout_options)
# Add rotatioError with self-absorption:n angles parameter with default of all zeros (no rotation)
rotation_angles = params_dict.get('rotation_angles', [0.0, 0.0, 0.0])
# Format rotation angles for filename
# Check if rotation_angles is already a tensor
#if not isinstance(rotation_angles, tc.Tensor):
rotation_str = '_rot' + '_'.join([f"{angle:.1f}".replace('.', 'p') for angle in rotation_angles])
#else:
# rotation_str = '_rot' + '_'.join([f"{angle:.2f}".replace('.', 'p') for angle in rotation_angles.cpu().numpy()])
# Determine suffix based on model options
suffix = params_dict.get('suffix', '')
suffix += rotation_str
if model_probe_attenuation:
suffix += '_pa'
if model_self_absorption:
suffix += '_sa'
# Add rotation info to suffix
# Convert rotation_angles to tensor if not already
if not isinstance(rotation_angles, tc.Tensor):
rotation_angles = tc.tensor(rotation_angles, device=dev)
# Construct output file names with the updated suffix
sim_XRF_file = f'{output_dir}/sim_xrf_E{probe_energy}{suffix}.h5'
sim_XRT_file = f'{output_dir}/sim_xrt_E{probe_energy}{suffix}.h5'
params_file_name = f'sim_params_E{probe_energy}_{suffix}.txt'
# initialize h5 files for saving simulated signals
with h5py.File(sim_XRF_file, 'w') as d:
grp = d.create_group("exchange")
data = grp.create_dataset("data", shape=(n_lines, sample_height_n, sample_size_n), dtype="f4")
element_names = grp.create_dataset("elements", data = channel_name_roi_ls)
with h5py.File(sim_XRT_file, 'w') as d:
grp = d.create_group("exchange")
data = grp.create_dataset("data", shape=(4, sample_height_n, sample_size_n), dtype="f4")
####----------------------------------------------------------------------------------####
#### simulation ####
start_time_total = datetime.datetime.now() # Start time for the simulation
# rotate the 3D objects
if not tc.all(rotation_angles == 0):
print(f"Rotating object with rotation angles (degrees): {rotation_angles.cpu().numpy()}")
# Apply rotations
X_rot = X.clone()
# Apply X-axis rotation (rotate in YZ plane)
if rotation_angles[0] != 0:
X_rot = rotate_3d_inplane(X_rot, rotation_angles[0], dev, use_degrees=True, axis='x')
# Apply Y-axis rotation if needed
if rotation_angles[1] != 0:
X_rot = rotate_3d_inplane(X_rot, rotation_angles[1], dev, use_degrees=True, axis='y')
# Apply Z-axis rotation if needed
if rotation_angles[2] != 0:
X_rot = rotate_3d_inplane(X_rot, rotation_angles[2], dev, use_degrees=True, axis='z')
X = X_rot.clone()
print(f'X shape: {X.shape}')
if debug:
for i in range(n_element):
#tifffile.imwrite(f'{output_dir}/X_{elements[i]}.tiff', X[i].cpu().numpy())
tifffile.imwrite(f'{output_dir}/X_{elements[i]}_rot' + '_'.join([f"{angle:.1f}".replace('.', 'p') for angle in rotation_angles])+ '.tiff', X[i].cpu().numpy())
## Calculate lac using the current X. lac (linear attenuation coefficient) has the dimension of [n_element, n_lines, n_voxel_minibatch, n_voxel]
if model_self_absorption == True:
# Check if we need rotation (any non-zero angle)
# if tc.all(rotation_angles == 0):
# # Reshape X directly without rotation
# X_ap_rot = X.view(n_element, -1)
# else:
# #X_rotated = rotate_3d(X, rotation_angles, dev)
# X_ap_rot = X_rot.view(n_element, -1)
#lac = X_ap_rot.view(n_element, 1, 1, n_voxel) * mass_attenuation_cross_section_FL.view(n_element, n_lines, 1, 1) #dev #Eq. 5.9
# X is already rotated at this point
lac = X.view(n_element, 1, 1, n_voxel) * mass_attenuation_cross_section_FL.view(n_element, n_lines, 1, 1) #dev #Eq. 5.9
lac = lac.expand(-1, -1, n_voxel_batch, -1).float() #dev
else:
lac = 0.
# Use tqdm for progress bar if not in debug mode
batch_iterator = range(n_batch)
if not debug:
batch_iterator = tqdm(batch_iterator, desc="Processing Batches")
for m in batch_iterator:
start_time = datetime.datetime.now() # Start time for the iteration
minibatch_ls = n_ranks * m + minibatch_ls_0 #dev, e.g. [5,6,7,8]
p = minibatch_ls[rank]
#print(f'mini batch start={p * dia_len_n * batch_size * sample_size_n}')
#print(f'mini batch end={(p+1) * dia_len_n * batch_size * sample_size_n}')
if model_self_absorption == True:
# Add debug information first for variables we already have
if debug:
print(f"\nDebug info for batch {m+1}:")
print(f"X shape: {X.shape}")
print(f"minibatch_ls: {minibatch_ls}")
print(f"p: {p}")
# Verify indices are within bounds
max_index = dia_len_n * batch_size * sample_size_n
start_idx = p * max_index
end_idx = (p+1) * max_index
if debug:
print(f"Accessing P_array indices: [{start_idx}:{end_idx}]")
print(f"P_array shape: {P_handle['P_array'].shape}")
if start_idx >= P_handle['P_array'].shape[2] or end_idx > P_handle['P_array'].shape[2]:
if debug:
print(f"WARNING: Index out of bounds!")
raise ValueError(f"Batch {m+1}: Index out of bounds in P_array access")
P_minibatch = tc.from_numpy(P_handle['P_array'][:,:, start_idx:end_idx]).to(dev)
n_det = P_minibatch.shape[0]
# Now we can print P_minibatch info
if debug:
print(f"P_minibatch shape: {P_minibatch.shape}")
print(f"n_det: {n_det}")
else:
P_minibatch = 0
n_det = 0
#print(f'P_minibatch={P_minibatch}')
try:
# Debug the input parameters to the model
if debug:
print(f"\nModel input debug for batch {m+1}:")
if not tc.all(rotation_angles == 0):
print(f"Using 3D rotation with angles: {rotation_angles.cpu().numpy()}")
print(f"lac shape: {lac.shape if isinstance(lac, tc.Tensor) else 'scalar 0'}")
print(f"X shape: {X.shape}")
print(f"n_element: {n_element}")
print(f"n_lines: {n_lines}")
print(f"mass_attenuation_cross_section_FL shape: {mass_attenuation_cross_section_FL.shape}")
print(f"detected_fl_unit_concentration shape: {detected_fl_unit_concentration.shape}")
print(f"n_line_group_each_element: {n_line_group_each_element}")
print(f"sample_height_n: {sample_height_n}")
print(f"batch_size: {batch_size}")
print(f"sample_size_n: {sample_size_n}")
# Verify tensor device consistency
devices = set()
for tensor in [X, mass_attenuation_cross_section_FL, detected_fl_unit_concentration]:
if isinstance(tensor, tc.Tensor):
devices.add(tensor.device)
if debug:
print(f"Tensor devices: {devices}")
if len(devices) > 1:
raise ValueError(f"Inconsistent tensor devices found: {devices}")
# Try with self-absorption first
try:
if model_self_absorption:
if debug:
print(f"Attempting with self-absorption for batch {m+1}")
model = PPM(dev, model_self_absorption, lac, X, p, n_element, n_lines, mass_attenuation_cross_section_FL,
detected_fl_unit_concentration, n_line_group_each_element,
sample_height_n, batch_size, sample_size_n, sample_size_cm,
probe_energy, incident_probe_intensity, model_probe_attenuation, probe_attCS_ls,
0, signal_attenuation_factor,
n_det, P_minibatch, det_size_cm, det_from_sample_cm, det_solid_angle_ratio)
# Try to catch CUDA errors early
tc.cuda.synchronize()
y1_hat, y2_hat = model()
# Synchronize again to catch any errors in model execution
tc.cuda.synchronize()
else:
# No self-absorption case
model = PPM(dev, model_self_absorption, lac, X, p, n_element, n_lines, mass_attenuation_cross_section_FL,
detected_fl_unit_concentration, n_line_group_each_element,
sample_height_n, batch_size, sample_size_n, sample_size_cm,
probe_energy, incident_probe_intensity, model_probe_attenuation, probe_attCS_ls,
0, signal_attenuation_factor,
n_det, P_minibatch, det_size_cm, det_from_sample_cm, det_solid_angle_ratio)
y1_hat, y2_hat = model()
except Exception as e:
print(f"Error with self-absorption: {str(e)}")
# Clear CUDA cache
tc.cuda.empty_cache()
except Exception as e:
if debug:
print(f"\nError in batch {m+1}: {str(e)}")
print("Skipping this batch and continuing...")
# Clear CUDA cache in case of error
try:
tc.cuda.empty_cache()
except:
pass
continue
xrf_data = np.clip(y1_hat.detach().cpu().numpy(), 0, np.inf)
xrt_data = np.exp(- y2_hat.detach().cpu().numpy())
#### Use mpi to write the generated dataset to the hdf5 file
with h5py.File(sim_XRF_file, 'r+') as d:
d["exchange/data"][:, batch_size * p // sample_size_n: batch_size * (p + 1) // sample_size_n, :] = \
np.reshape(xrf_data, (n_lines, batch_size // sample_size_n, -1))
#print(d["exchange/data"].shape)
## shape of d["exchange/data"] = (n_lines, sample_height_n, sample_size_n)
with h5py.File(sim_XRT_file, 'r+') as d:
d["exchange/data"][3, batch_size * p // sample_size_n: batch_size * (p + 1) // sample_size_n, :] = \
np.reshape(xrt_data, (batch_size // sample_size_n, -1))
## shape of d["exchange/data"] = (4, sample_height_n, sample_size_n)
####
iteration_time = datetime.datetime.now() - start_time # Calculate time taken for the iteration
if debug:
print(f"Batch {m + 1}/{n_batch} time cost: {iteration_time}")
total_time = datetime.datetime.now() - start_time_total # Calculate time taken for the whole simulation
if debug:
print(f"Total forward simulation time cost: {total_time}")
# with h5py.File(sim_XRF_file, 'r+') as d:
# d["exchange/data"][2, 0] = d["exchange/data"][1, 0] * d["exchange/data"][3, 0]
del lac
tc.cuda.empty_cache()
## It's important to close the hdf5 file hadle in the end of the reconstruction.
P_handle.close()
# Save xrf and xrt data as TIFF images
XRF_data_handle = h5py.File(sim_XRF_file, 'r')
xrf_data = XRF_data_handle['exchange/data'][:]
XRF_data_handle.close()
for i in range(n_element):
tifffile.imwrite(f'{output_dir}/sim_xrf_E{probe_energy}{suffix}_{elements[i]}.tif', xrf_data[i])
XRT_data_handle = h5py.File(sim_XRT_file, 'r')
xrt_data = XRT_data_handle['exchange/data'][:]
XRT_data_handle.close()
tifffile.imwrite(f'{output_dir}/sim_xrt_E{probe_energy}{suffix}.tif', xrt_data[-1])
if debug:
print("simulation done")
print(f"sim_XRF_file: {sim_XRF_file}")
print(f"sim_XRT_file: {sim_XRT_file}")
return sim_XRF_file, sim_XRT_file