Skip to content

Commit ee4e977

Browse files
committed
[Python][ML] Adapt tutorials to the new RDataLoader interface
1 parent ed0aca3 commit ee4e977

4 files changed

Lines changed: 14 additions & 28 deletions

File tree

tutorials/machine_learning/ml_dataloader_NumPy.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,16 @@
1414
file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
1515

1616
batch_size = 128
17-
chunk_size = 5000
18-
block_size = 400
1917

2018
rdataframe = ROOT.RDataFrame(tree_name, file_name)
2119

2220
target = "Type"
2321

2422
num_of_epochs = 2
2523

26-
gen_train, gen_validation = ROOT.Experimental.ML.CreateNumPyGenerators(
24+
gen_train, gen_validation = ROOT.Experimental.ML.RDataLoader(
2725
rdataframe,
2826
batch_size,
29-
chunk_size,
30-
block_size,
3127
target=target,
3228
validation_split=0.3,
3329
shuffle=True,
@@ -36,9 +32,9 @@
3632

3733
for i in range(num_of_epochs):
3834
# Loop through training set
39-
for i, (x_train, y_train) in enumerate(gen_train):
35+
for i, (x_train, y_train) in enumerate(gen_train.AsNumpy()):
4036
print(f"Training batch {i + 1} => x: {x_train.shape}, y: {y_train.shape}")
4137

4238
# Loop through Validation set
43-
for i, (x_validation, y_validation) in enumerate(gen_validation):
39+
for i, (x_validation, y_validation) in enumerate(gen_validation.AsNumpy()):
4440
print(f"Validation batch {i + 1} => x: {x_validation.shape}, y: {y_validation.shape}")

tutorials/machine_learning/ml_dataloader_PyTorch.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,16 @@
1515
file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
1616

1717
batch_size = 128
18-
chunk_size = 5000
19-
block_size = 300
2018

2119
rdataframe = ROOT.RDataFrame(tree_name, file_name)
2220

2321
target = "Type"
2422

2523
# Returns two generators that return training and validation batches
2624
# as PyTorch tensors.
27-
gen_train, gen_validation = ROOT.Experimental.ML.CreatePyTorchGenerators(
25+
gen_train, gen_validation = ROOT.Experimental.ML.RDataLoader(
2826
rdataframe,
2927
batch_size,
30-
chunk_size,
31-
block_size,
3228
target=target,
3329
validation_split=0.3,
3430
shuffle=True,
@@ -64,7 +60,7 @@ def calc_accuracy(targets, pred):
6460
print("Epoch ", i)
6561
model.train()
6662
# Loop through the training set and train model
67-
for i, (x_train, y_train) in enumerate(gen_train):
63+
for i, (x_train, y_train) in enumerate(gen_train.AsTorch()):
6864
# Make prediction and calculate loss
6965
pred = model(x_train)
7066
loss = loss_fn(pred, y_train)
@@ -85,7 +81,7 @@ def calc_accuracy(targets, pred):
8581

8682
model.eval()
8783
# Evaluate the model on the validation set
88-
for i, (x_val, y_val) in enumerate(gen_validation):
84+
for i, (x_val, y_val) in enumerate(gen_validation.AsTorch()):
8985
# Make prediction and calculate accuracy
9086
pred = model(x_val)
9187
accuracy = calc_accuracy(y_val, pred)

tutorials/machine_learning/ml_dataloader_TensorFlow.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,16 @@
1818
file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
1919

2020
batch_size = 128
21-
chunk_size = 5000
22-
block_size = 300
21+
approx_batches_in_memory = 50
2322

2423
rdataframe = ROOT.RDataFrame(tree_name, file_name)
2524
target = ["Type"]
2625

2726
# Returns two TF.Dataset for training and validation batches.
28-
ds_train, ds_valid = ROOT.Experimental.ML.CreateTFDatasets(
27+
ds_train, ds_valid = ROOT.Experimental.ML.RDataLoader(
2928
rdataframe,
3029
batch_size,
31-
chunk_size,
32-
block_size,
30+
approx_batches_in_memory,
3331
target=target,
3432
validation_split=0.3,
3533
shuffle=True,
@@ -39,8 +37,8 @@
3937
num_of_epochs = 2
4038

4139
# Datasets have to be repeated as many times as there are epochs
42-
ds_train_repeated = ds_train.repeat(num_of_epochs)
43-
ds_valid_repeated = ds_valid.repeat(num_of_epochs)
40+
ds_train_repeated = ds_train.AsTensorFlow().repeat(num_of_epochs)
41+
ds_valid_repeated = ds_valid.AsTensorFlow().repeat(num_of_epochs)
4442

4543
# Number of batches per epoch must be given for model.fit
4644
train_batches_per_epoch = ds_train.number_of_batches

tutorials/machine_learning/ml_dataloader_filters_vectors.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
tree_name = "test_tree"
1717
file_name = ROOT.gROOT.GetTutorialDir().Data() + "/machine_learning/ml_dataloader_filters_vectors_hvector.root"
1818

19-
chunk_size = 50 # Defines the size of the chunks
2019
batch_size = 5 # Defines the size of the returned batches
21-
block_size = 10 # Defines the size of the blocks that builds up a chunk
2220

2321
rdataframe = ROOT.RDataFrame(tree_name, file_name)
2422

@@ -29,20 +27,18 @@
2927

3028
max_vec_sizes = {"f4": 3, "f5": 2, "f6": 1}
3129

32-
ds_train, ds_validation = ROOT.Experimental.ML.CreateNumPyGenerators(
30+
ds_train, ds_validation = ROOT.Experimental.ML.RDataLoader(
3331
filteredrdf,
3432
batch_size,
35-
chunk_size,
36-
block_size,
3733
validation_split=0.3,
3834
max_vec_sizes=max_vec_sizes,
3935
shuffle=False,
4036
)
4137

4238
print(f"Columns: {ds_train.columns}")
4339

44-
for i, b in enumerate(ds_train):
40+
for i, b in enumerate(ds_train.AsNumpy()):
4541
print(f"Training batch {i} => {b.shape}")
4642

47-
for i, b in enumerate(ds_validation):
43+
for i, b in enumerate(ds_validation.AsNumpy()):
4844
print(f"Validation batch {i} => {b.shape}")

0 commit comments

Comments
 (0)