-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathOptimizedDataGenerator_v3.py
More file actions
758 lines (631 loc) · 30 KB
/
OptimizedDataGenerator_v3.py
File metadata and controls
758 lines (631 loc) · 30 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
import os
import gc
import math
import glob
import random
import logging
import datetime
import numpy as np
import pandas as pd
import json
from typing import Union, List, Tuple, Dict, Any
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from tqdm import tqdm
import tensorflow as tf
from qkeras import quantized_bits
from utils import *
# @tf.function
def QKeras_data_prep_quantizer(data, bits=4, int_bits=0, alpha=1):
"""
Applies QKeras quantization.
Args:
data (tf.Tensor): Input data (tf.Tensor).
bits (int): Number of bits for quantization.
int_bits (int): Number of integer bits.
alpha (float): (don't change)
Returns::
tf.Tensor: Quantized data (tf.Tensor).
"""
quantizer = quantized_bits(bits, int_bits, alpha=alpha)
return quantizer(data)
# Custom tensorflow bucketize function
def tf_bucketize(x, boundaries, side='right'):
"""
TensorFlow equivalent of tf.bucketize.
Args:
x: tf.Tensor of any shape
boundaries: 1D tf.Tensor or list of boundaries
side: 'left' or 'right' (like np.searchsorted)
Returns:
tf.Tensor of same shape as x with integer bucket indices
"""
boundaries = tf.constant(boundaries, dtype=x.dtype)
original_shape = tf.shape(x)
x_flat = tf.reshape(x, [-1])
bucket_indices = tf.searchsorted(boundaries, x_flat, side=side)
bucket_indices = tf.clip_by_value(bucket_indices, 0, tf.size(boundaries))
return tf.reshape(bucket_indices, original_shape)
class OptimizedDataGenerator(tf.keras.utils.Sequence):
def __init__(self,
dataset_base_dir: str = "./",
batch_size: int = 32,
optimize_batch_size: bool = False,
file_count = None,
labels_list: Union[List,str] = ['x-midplane','y-midplane','cotAlpha','cotBeta'],
to_standardize: bool = False,
log_compression: bool = False,
input_shape: Tuple = (13,21),
transpose = None,
files_from_end = False,
shuffle=False,
# Added in Optimized datagenerators
load_from_tfrecords_dir: str = None,
tfrecords_dir: str = None,
use_time_stamps = -1,
select_contained = False, #If true, selects only clusters with chargeOriginal_atEdge<50
noise = -1, #add gaussian noise (mu, sigma), set to -1 to turn off
seed: int = None,
quantize: bool = False,
digitize: bool = False, # for the manual 2bit mapping
digitize_levels: Union[List[float], np.ndarray] = None,
digitize_thresholds: Union[List[float], np.ndarray] = None,
max_workers: int = 1,
label_scale_pctl: float = 99,
norm_pos_pctl: float = 99.7,
norm_neg_pctl: float = 99.7,
tail_tol: float = 0.75,
labels_scale = None,
**kwargs,
):
super().__init__()
self.shuffle = shuffle
self.seed = seed if seed is not None else 13
if shuffle:
self.rng = np.random.default_rng(seed = self.seed)
if load_from_tfrecords_dir is None:
# CREATOR MODE
n_time, height, width = input_shape
if use_time_stamps == -1:
use_time_stamps = list(np.arange(0,20))
assert len(use_time_stamps) == n_time, f"Expected {n_time} time steps, got {len(use_time_stamps)}"
len_xy = height * width
col_indices = [
np.arange(t * len_xy, (t + 1) * len_xy).astype(str)
for t in use_time_stamps
]
self.recon_cols = np.concatenate(col_indices).tolist()
self.max_workers = max_workers
self.label_scale_pctl = label_scale_pctl
self.norm_pos_pctl = norm_pos_pctl
self.norm_neg_pctl = norm_neg_pctl
self.files = sorted(glob.glob(os.path.join(dataset_base_dir, "part.*.parquet"), recursive=False))
if file_count != None:
if not files_from_end:
self.files = self.files[:file_count]
else:
self.files = self.files[-file_count:]
self.file_offsets = [0]
self.dataset_mean = None
self.dataset_std = None
self.norm_factor_pos = None
self.norm_factor_neg = None
self.labels_scale = labels_scale
self.labels_list = labels_list
self.input_shape = input_shape
self.transpose = transpose
self.to_standardize = to_standardize
self.log_compression = log_compression
self.select_contained = select_contained
self.process_file_parallel()
if optimize_batch_size:
original_bs = batch_size
new_bs, residual = self.get_best_batch_size(self.file_offsets, original_bs)
if new_bs != original_bs:
print(f"Batch size optimized from {original_bs} to {new_bs} "
f"to minimize final batch (residual: {residual} rows).")
self.batch_size = new_bs
else:
self.batch_size = batch_size
self.tail_tol = tail_tol
self.batch_metadata = self.build_batch_metadata(
batch_size=self.batch_size,
file_offsets=self.file_offsets,
tail_tol=self.tail_tol
)
self.current_file_index = None
self.current_dataframes = None
if tfrecords_dir is None:
raise ValueError(f"tfrecords_dir is None")
safe_remove_directory(tfrecords_dir)
self.tfrecords_dir = tfrecords_dir
os.makedirs(self.tfrecords_dir, exist_ok=True)
self.save_batches_sequentially()
del self.current_dataframes
metadata_file_path = os.path.join(self.tfrecords_dir, "metadata.json")
self.save_metadata(metadata_file_path)
load_from_tfrecords_dir = self.tfrecords_dir
# LOADER MODE
self.file_offsets = [None]
if not os.path.isdir(load_from_tfrecords_dir):
raise ValueError(f"Directory {load_from_tfrecords_dir} does not exist.")
self.tfrecords_dir = load_from_tfrecords_dir
metadata_file_path = os.path.join(self.tfrecords_dir, "metadata.json")
self.load_metadata(metadata_file_path)
self.tfrecord_filenames = np.sort(np.array(tf.io.gfile.glob(os.path.join(self.tfrecords_dir, "*.tfrecord"))))
self.quantize = quantize
self.noise = noise
# manual 2-bit digitization (FOR LOADING TFRecords ONLY!)
self.digitize = digitize
self.digitize_levels = digitize_levels
self.digitize_thresholds = digitize_thresholds
self.digitize_levels = np.array(digitize_levels if digitize_levels is not None
else [0.0, 1.0, 2.0, 3.0], dtype=np.float32)
self.digitize_thresholds = np.array(digitize_thresholds if digitize_thresholds is not None
else [400, 1000, 2000], dtype=np.float32)
# Ensure that if digitize is True, we must load from TFRecords
assert not (self.digitize and load_from_tfrecords_dir is None), \
"digitize=True requires load_from_tfrecords_dir to be specified"
# boundaries length = levels-1
assert len(self.digitize_thresholds) == len(self.digitize_levels)-1, \
"Number of boundaries must be one less than number of levels"
self.epoch_count = 0
self.on_epoch_end()
def save_metadata(self, metadata_file_path:str):
"""
Saves the metadata of the dataset to a JSON file.
Args:
metadata_file_path (str): Path to save the metadata file.
"""
metadata = {
# Key configurations
"batch_size": self.batch_size,
"input_shape": self.input_shape,
"recon_cols": self.recon_cols,
"labels_list": self.labels_list,
"to_standardize": self.to_standardize,
"log_compression": self.log_compression,
"transpose": self.transpose,
"shuffle": self.shuffle,
"select_contained": self.select_contained,
"seed": self.seed,
"label_scale_pctl": self.label_scale_pctl,
"norm_pos_pctl": self.norm_pos_pctl,
"norm_neg_pctl": self.norm_neg_pctl,
"tail_tol": self.tail_tol,
# Calculated statistics
"dataset_mean": self.dataset_mean.tolist() if self.dataset_mean is not None else None,
"dataset_std": self.dataset_std.tolist() if self.dataset_std is not None else None,
"dataset_min": np.float64(self.dataset_min) if self.dataset_min is not None else None,
"dataset_max": np.float64(self.dataset_max) if self.dataset_max is not None else None,
"norm_factor_pos": self.norm_factor_pos,
"norm_factor_neg": self.norm_factor_neg,
"labels_scale": self.labels_scale.tolist() if self.labels_scale is not None else None,
# Full batch plan
"batch_metadata": self.batch_metadata
}
with open(metadata_file_path, "w") as f:
json.dump(metadata, f, indent=4)
print(f"Metadata saved successfully ast {metadata_file_path}")
def load_metadata(self, metadata_file_path:str):
"""
Loads the metadata of the dataset from a JSON file.
Args:
metadata_file_path (str): Path to the metadata file.
"""
if not os.path.exists(metadata_file_path):
raise FileNotFoundError(f"Metadata file {metadata_file_path} does not exist.\n"
"Cannot initialiize genrator in load mode.")
print(f"Loading metadata from {metadata_file_path}")
with open(metadata_file_path, "r") as f:
metadata = json.load(f)
# Key configurations
self.batch_size = metadata['batch_size']
self.input_shape = tuple(metadata['input_shape'])
self.recon_cols = metadata['recon_cols']
self.labels_list = metadata['labels_list']
self.to_standardize = metadata['to_standardize']
self.log_compression = metadata['log_compression']
self.select_contained = metadata['select_contained']
self.label_scale_pctl = metadata['label_scale_pctl']
self.norm_pos_pctl = metadata['norm_pos_pctl']
self.norm_neg_pctl = metadata['norm_neg_pctl']
self.tail_tol = metadata['tail_tol']
# Calculated statistics
self.dataset_mean = np.array(metadata['dataset_mean'])
self.dataset_std = np.array(metadata['dataset_std'])
self.dataset_min = np.float64(metadata['dataset_min'])
self.dataset_max = np.float64(metadata['dataset_max'])
self.norm_factor_pos = metadata['norm_factor_pos']
self.norm_factor_neg = metadata['norm_factor_neg']
self.labels_scale = np.array(metadata['labels_scale'])
# Full batch plan
self.batch_metadata = metadata['batch_metadata']
# Optional parameters
self.shuffle = metadata.get('shuffle', False)
self.seed = metadata.get('seed', 13)
self.transpose = metadata.get('transpose', None)
if self.shuffle:
self.rng = np.random.default_rng(seed=self.seed)
def process_file_parallel(self):
file_infos = [(afile,
self.recon_cols, self.labels_list, self.select_contained,
self.log_compression, self.label_scale_pctl, self.norm_pos_pctl, self.norm_neg_pctl, self.labels_scale)
for afile in self.files
]
results = []
with ProcessPoolExecutor(self.max_workers) as executor:
futures = [executor.submit(self._process_file_single, file_info) for file_info in file_infos]
for future in tqdm(as_completed(futures), total=len(file_infos), desc="Processing Files..."):
results.append(future.result())
manual_labels_scale = False
if self.labels_scale is not None:
manual_labels_scale = True
for amean, avariance, amin, amax, num_rows, labels_scale, pos_scale, neg_scale in results:
self.file_offsets.append(self.file_offsets[-1] + num_rows)
if self.dataset_mean is None:
self.dataset_max = amax
self.dataset_min = amin
self.dataset_mean = amean
self.dataset_std = avariance
else:
self.dataset_max = max(self.dataset_max, amax)
self.dataset_min = min(self.dataset_min, amin)
self.dataset_mean += amean
self.dataset_std += avariance
if self.labels_scale is None:
self.labels_scale = labels_scale
elif manual_labels_scale == False:
self.labels_scale = np.maximum(self.labels_scale, labels_scale)
self.norm_factor_pos = (pos_scale if self.norm_factor_pos is None
else max(self.norm_factor_pos, pos_scale))
self.norm_factor_neg = (neg_scale if self.norm_factor_neg is None
else max(self.norm_factor_neg, neg_scale))
self.dataset_mean = self.dataset_mean / len(self.files)
self.dataset_std = np.sqrt(self.dataset_std / len(self.files))
self.file_offsets = np.array(self.file_offsets)
@staticmethod
def _process_file_single(file_info):
afile, recon_cols, labels_list, select_contained, log_compression, label_scale_pctl, norm_pos_pctl, norm_neg_pctl, custom_labels_scale = file_info
if select_contained:
df = (pd.read_parquet(afile,
columns=recon_cols + labels_list +['chargeOriginal_atEdge'])
.reset_index(drop=True))
df = df.loc[df['chargeOriginal_atEdge'] < 50]
else:
df = (pd.read_parquet(afile,
columns=recon_cols + labels_list)
.reset_index(drop=True))
# df = pd.read_parquet(afile, columns=recon_cols + labels_list).reset_index(drop=True)
x = df[recon_cols].values
manual_labels_scale = False
if custom_labels_scale is not None:
manual_labels_scale = True
nonzeros = abs(x) > 0
if log_compression:
x[nonzeros] = np.sign(x[nonzeros]) * np.log1p(abs(x[nonzeros])) / math.log(2)
amean, avariance = np.mean(x[nonzeros], keepdims=True), np.var(x[nonzeros], keepdims=True) + 1e-10
centered = np.zeros_like(x)
centered[nonzeros] = (x[nonzeros] - amean) / np.sqrt(avariance)
amin, amax = np.min(centered), np.max(centered)
pos_vals = np.abs(centered[centered > 0])
neg_vals = np.abs(centered[centered < 0])
pos_scale = (np.percentile(pos_vals, norm_pos_pctl)
if pos_vals.size else 1.0)
neg_scale = (np.percentile(neg_vals, norm_neg_pctl)
if neg_vals.size else 1.0)
len_adf = len(df)
labels_values = df[labels_list].values
if manual_labels_scale == False:
labels_scale = np.percentile(np.abs(labels_values), label_scale_pctl, axis=0)
else:
labels_scale = custom_labels_scale
del df
gc.collect()
return amean, avariance, amin, amax, len_adf, labels_scale, pos_scale, neg_scale
def standardize(self, x):
"""
Applies the normalization configuration in-place to a batch of inputs.
`x` is changed in-place since the function is mainly used internally
to standardize images and feed them to your network.
Args:
x: Batch of inputs to be normalized.
Returns:
The inputs, normalized.
"""
out = (x - self.dataset_mean)/self.dataset_std
out[out > 0] = out[out > 0]/self.norm_factor_pos
out[out < 0] = out[out < 0]/self.norm_factor_neg
out = np.clip(out, self.dataset_min, self.dataset_max)
return out
def save_batches_sequentially(self):
num_batches = self.__len__()
errors_found = []
for i in tqdm(range(num_batches), desc="Saving batches as TFRecords"):
result = self.save_single_batch(i)
if "Error" in result:
print(result)
errors_found.append(result)
if errors_found:
logging.warning(f"Encountered {len(errors_found)} errors during sequential saving of TFRecords.")
else:
logging.info("All batches saved successfully in sequential mode.")
def save_single_batch(self, batch_index):
"""
Serializes and saves a single batch to a TFRecord file.
Args:
batch_index (int): Index of the batch to save.
Returns:
str: Path to the saved TFRecord file or an error message.
"""
try:
filename = f"batch_{batch_index}.tfrecord"
TFRfile_path = os.path.join(self.tfrecords_dir, filename)
X, y = self.prepare_batch_data(batch_index)
serialized_example = self.serialize_example(X, y)
with tf.io.TFRecordWriter(TFRfile_path) as writer:
writer.write(serialized_example)
return TFRfile_path
except Exception as e:
return f"Error saving batch {batch_index}: {e}"
@staticmethod
def get_best_batch_size(file_offsets, target_bs=5000):
"""
Find the best batch size that minimizes the residual when dividing the total number of rows.
Args:
file_offsets (np.ndarray): Array of file offsets.
target_bs (int): Target batch size.
tol (float): Tolerance for batch size deviation.
Returns:
int: Best batch size.
"""
last_offset = file_offsets[-1]
d_bs = int(0.5 * target_bs)
batch_sizes = np.arange(target_bs - d_bs, target_bs + d_bs + 1)
residuals = last_offset % batch_sizes
min_res = residuals.min()
# All bs giving the minimal residual
candidates = batch_sizes[residuals == min_res]
# Prefer the one closest to the target
idx = np.argmin(np.abs(candidates - target_bs))
return int(candidates[idx]), min_res
@staticmethod
def _build_batching_plan(file_offsets, batch_size, tol = 0.75):
"""
Pre-compute (row_start, row_end) for every batch.
If the last batch < 0.5xbatch_size, merge the last two
and split them evenly, so both new batches are within
0.5x...1.0xbatch_size.
"""
total = file_offsets[-1]
b = batch_size
plan = []
start = 0
while start < total:
end = min(start + b, total)
plan.append((start, end))
start = end
# Re-balance if the tail is too short
if len(plan) >= 2:
last_len = plan[-1][1] - plan[-1][0]
if last_len < tol * b:
sec_start = plan[-2][0]
comb_len = plan[-1][1] - sec_start
half = math.ceil(comb_len / 2)
plan[-2] = (sec_start, sec_start + half)
plan[-1] = (sec_start + half, sec_start + comb_len)
return plan
@classmethod
def build_batch_metadata(cls, batch_size: int, file_offsets: np.ndarray, tail_tol: float = 0.75) -> List[Dict[str, Any]]:
"""
Builds optimized batch metadata using a pre-computed batch plan.
This ensures that the final batch is not excessively small.
"""
batching_plan = cls._build_batching_plan(file_offsets, batch_size, tail_tol)
batch_metadata = []
# 2. Loop through the generated plan instead of a simple range
for batch_index, (start_evt, end_evt) in enumerate(batching_plan):
# Create a new dictionary for the current batch
current_batch_meta = {
"batch_idx": batch_index,
"target_batch_size": int(batch_size),
# The actual size is now simply the difference from the plan
"actual_batch_size": int(end_evt - start_evt),
"segments": []
}
# 3. Use the same logic as before to find the file segments for the given range
file_idx = np.searchsorted(file_offsets, start_evt, side="right") - 1
evt_cursor = start_evt
while evt_cursor < end_evt:
file_start = file_offsets[file_idx]
file_end = file_offsets[file_idx + 1]
rel_start = evt_cursor - file_start
rel_end = min(end_evt, file_end) - file_start
# Append segment info to the current batch's metadata
current_batch_meta["segments"].append({
"file_idx": int(file_idx),
"row_start": int(rel_start),
"row_end": int(rel_end - 1)
})
evt_cursor += (rel_end - rel_start)
file_idx += 1
batch_metadata.append(current_batch_meta)
return batch_metadata
def prepare_batch_data(self, batch_index):
batch_plan = self.batch_metadata[batch_index]
X_chunks = []
y_chunks = []
for segment in batch_plan["segments"]:
file_idx = segment["file_idx"]
rel_start = segment["row_start"]
rel_end = segment["row_end"] + 1 # inclusive end
if file_idx != self.current_file_index:
parquet_file = self.files[file_idx]
if self.select_contained:
all_columns_to_read = self.recon_cols + self.labels_list + ['chargeOriginal_atEdge']
df = (pd.read_parquet(parquet_file,
columns = all_columns_to_read)
.dropna(subset=self.recon_cols)
.reset_index(drop=True))
df = df.loc[df['chargeOriginal_atEdge'] < 50]
else:
all_columns_to_read = self.recon_cols + self.labels_list
df =(pd.read_parquet(parquet_file,
columns = all_columns_to_read)
.dropna(subset=self.recon_cols)
.reset_index(drop=True))
# df = (pd.read_parquet(parquet_file,
# columns=self.recon_cols + self.labels_list)
# .dropna(subset=self.recon_cols)
# .reset_index(drop=True))
if self.shuffle:
df = df.sample(frac=1, random_state=self.seed).reset_index(drop=True)
recon_df = df[self.recon_cols]
labels_df = df[self.labels_list]
recon_values = recon_df.values
nonzeros = abs(recon_values) > 0
if self.log_compression:
recon_values[nonzeros] = np.sign(recon_values[nonzeros]) * np.log1p(abs(recon_values[nonzeros])) / np.log(2)
if self.to_standardize:
recon_values[nonzeros] = self.standardize(recon_values[nonzeros])
recon_values = recon_values.reshape((-1, *self.input_shape))
if self.transpose is not None:
recon_values = recon_values.transpose(self.transpose)
self.current_dataframes = (
recon_values,
labels_df.values,
)
self.current_file_index = file_idx
del df
gc.collect()
recon_df, labels_df = self.current_dataframes
X_chunk = recon_df[rel_start:rel_end]
y_chunk = labels_df[rel_start:rel_end] / self.labels_scale
X_chunks.append(X_chunk)
y_chunks.append(y_chunk)
X = np.concatenate(X_chunks, axis=0)
y = np.concatenate(y_chunks, axis=0)
return X, y
def serialize_example(self, X, y):
"""
Serializes a single example (featuresand labels) to TFRecord format.
Args:
- X: Training data
- y: labelled data
Returns:
- string (serialized TFRecord example).
"""
# X and y are float32 (maybe we can reduce this)
X = tf.cast(X, tf.float32)
y = tf.cast(y, tf.float32)
feature = {
'X': self._bytes_feature(tf.io.serialize_tensor(X)),
'y': self._bytes_feature(tf.io.serialize_tensor(y)),
}
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
@staticmethod
def _bytes_feature(value):
"""
Converts a string/byte value into a Tf feature of bytes_list
Args:
- string/byte value
Returns:
- tf.train.Feature object as a bytes_list containing the input value.
"""
if isinstance(value, type(tf.constant(0))): # check if Tf tensor
value = value.numpy()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def map_to_levels(self, x: tf.Tensor) -> tf.Tensor:
"""
Maps input tensor x to discrete levels using specified charge thresholds and output levels.
"""
boundaries = tf.constant(self.digitize_thresholds, dtype=tf.float32)
levels = tf.constant(self.digitize_levels, dtype=tf.float32)
# Use the custom bucketize
bucket_indices = tf_bucketize(x, boundaries, side='right')
# Clip to valid level indices
bucket_indices = tf.clip_by_value(bucket_indices, 0, len(self.digitize_levels) - 1)
# Map indices to levels
x_quant = tf.gather(levels, bucket_indices)
return x_quant
def __getitem__(self, batch_index):
"""
Load the batch from a pre-saved TFRecord file instead of processing raw data.
Each file contains exactly one batch.
quantization is done here: Helpful for pretraining without the quantization and the later training with quantized data.
shuffling is also done here.
TODO: prefetching (un-done)
"""
tfrecord_path = self.tfrecord_filenames[batch_index]
raw_dataset = tf.data.TFRecordDataset(tfrecord_path)
parsed_dataset = raw_dataset.map(self._parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE)
# Get the first (and only) batch from the dataset
try:
X_batch, y_batch = next(iter(parsed_dataset))
except StopIteration:
raise ValueError(f"No data found in TFRecord file: {tfrecord_path}")
X_batch = tf.reshape(X_batch, [-1, *X_batch.shape[1:]])
y_batch = tf.reshape(y_batch, [-1, *y_batch.shape[1:]])
if self.noise != -1: # add noise first before quantization/digitization
mu, sigma = self.noise
noise_array = np.random.normal(loc=mu, scale=sigma, size=X_batch.shape)
X_batch = X_batch + noise_array
if self.quantize:
X_batch = QKeras_data_prep_quantizer(X_batch, bits=4, int_bits=0, alpha=1)
if self.digitize:
X_batch = self.map_to_levels(X_batch)
if self.shuffle:
indices = tf.range(start=0, limit=tf.shape(X_batch)[0], dtype=tf.int32)
shuffled_indices = tf.random.shuffle(indices, seed=self.seed)
X_batch = tf.gather(X_batch, shuffled_indices)
y_batch = tf.gather(y_batch, shuffled_indices)
del raw_dataset, parsed_dataset
return X_batch, y_batch
@staticmethod
def _parse_tfrecord_fn(example):
"""
Parses a single TFRecord example.
Returns:
- X: as a float32 tensor.
- y: as a float32 tensor.
"""
feature_description = {
'X': tf.io.FixedLenFeature([], tf.string),
'y': tf.io.FixedLenFeature([], tf.string),
}
example = tf.io.parse_single_example(example, feature_description)
X = tf.io.parse_tensor(example['X'], out_type=tf.float32)
y = tf.io.parse_tensor(example['y'], out_type=tf.float32)
return X, y
def __len__(self):
"""
Phase-aware length:
during initial TFRecord creation: math on file_offsets
after creation in same process: len(batch_metadata)
when loading existing TFRecords: len(tfrecord_filenames)
"""
# already have metadata? Fastest answer.
if self.batch_metadata:
return len(self.batch_metadata)
# still building batches, so compute from source rows.
if len(self.file_offsets) > 1: # have real offsets
total_rows = self.file_offsets[-1]
return math.ceil(total_rows / self.batch_size)
# running in "load" mode.
self.tfrecord_filenames = np.sort(
np.array(tf.io.gfile.glob(
os.path.join(self.tfrecords_dir, "*.tfrecord"))))
return len(self.tfrecord_filenames)
def on_epoch_end(self):
'''
This shuffles the file ordering so that it shuffles the ordering in which the TFRecord
are loaded during the training for each epochs.
'''
gc.collect()
self.epoch_count += 1
# Log quantization status once
if self.epoch_count == 1:
logging.warning(f"Quantization is {self.quantize} in data generator. This may affect model performance.")
if self.shuffle:
self.rng.shuffle(self.tfrecord_filenames)
self.seed += 1 # So that after each epoch the batch is shuffled with a different seed (deterministic)