[ML] Refactor the RDataLoader: new cluster-aligned reading pattern, improved shuffling strategy and new Python API#21888
Conversation
Test Results 21 files 21 suites 2d 23h 40m 0s ⏱️ For more details on these failures, see this check. Results for commit 57e9f5e. ♻️ This comment has been updated with latest results. |
ee4e977 to
fb3df72
Compare
fb3df72 to
725a96f
Compare
eecb1b6 to
c7699f2
Compare
c7699f2 to
57e9f5e
Compare
vepadulano
left a comment
There was a problem hiding this comment.
Thanks for this extensive work! I'm not finished yet with the full review but I leave a couple of first fly-by thoughts
| ROOT.EnableThreadSafety() | ||
|
|
||
| self.generator = ROOT.Experimental.Internal.ML.RBatchGenerator(template)( | ||
| self.data_loader = ROOT.Experimental.Internal.ML.RDataLoader(template)( |
There was a problem hiding this comment.
It's a bit weird to see both ROOT.Experimental.ML.RDataLoader and ROOT.Experimental.Internal.RDataLoader with the same name. While I'm fine with changing name to the RBatchGenerator class since the new functionality goes a bit beyond the initial implementation, I wonder if we should give it a different name?
| set_seed (int): | ||
| For reproducibility: Set the seed for the random number generator used | ||
| to split the dataset into training and validation and shuffling of the chunks | ||
| to split the dataset into training and validation and shuffling of the clusters |
There was a problem hiding this comment.
clusters is a technical term, a bit of an implementation detail for ROOT data formats. While I agree it's more correct, I'm thinking for a general reader might be hard to parse. Maybe consider rephrasing to add some context? Or leaving chunks but also explain there is an I/O optimization going on? Not sure
|
|
||
|
|
||
| def CreateTFDatasets( | ||
| def RDataLoader( |
There was a problem hiding this comment.
This is a nitpick, but I wonder if at this point it wouldn't be better to be consistent with the Python naming casing and have RDataLoader be an actual Python class. So far we have been abusing a bit the nomenclature. I can imagine we could choose to go in a different number of ways:
- If we keep this as a function, it shouldn't have the
Rprefix and it should probably start with a verb. - If we convert this to a class, it could have a class method
Createand then a user code could be
ds_train, ds_validation = ROOT.ML.RDataLoader.Create(...)- It could be a class created with the same parameters as the current function but have a
__call__method and then user code could look like
# Example without validation_split
dl = ROOT.ML.RDataLoader(...)
# by default calling __call__ has no validation split, returns one value
for x, y in dl().AsTorch()
# Example with validation_split
dl = ROOT.ML.RDataLoader()
ds_train, ds_validation = dl(validation_split=0.2)
# Continue as now| time. Larger values improve shuffle quality across cluster | ||
| boundaries at the cost of higher memory usage. Acts as a soft | ||
| cap: the buffer may temporarily exceed this by up to one | ||
| cluster's worth of rows. Defaults to 10. |
There was a problem hiding this comment.
Note, not necessarily for this PR. For consistency with other parts of ROOT, we should probably refer to entries of the dataset rather than rows
| */ | ||
| struct RClusterRange { | ||
| std::size_t rdfIdx; // which rdf this cluster belongs to | ||
| ULong64_t start; // first raw entry (incl) |
There was a problem hiding this comment.
Given we are already TTree-agnostic, we can use standard integer types here
| ULong64_t start; // first raw entry (incl) | |
| std::uint64_t start; // first raw entry (incl) |
This Pull request:
This PR refactors the ML DataLoader infrastructure to implement a cluster-aware reading strategy with improved shuffling semantics and a new Python interface.
Data flow & cluster-aligned reading strategy
flowchart TB RDF["RDataFrame(s)"] --> SCAN["ScanClusters()<br/>extract cluster boundaries"] SCAN --> SPLIT["SplitDataset()<br/>Each cluster split into → <br/> (1-val_split) train + val_split validation"] SPLIT --> ACTIVATE["ActivateEpoch()"] ACTIVATE --> SHUFFLE{"shuffle?"} SHUFFLE -->|Yes| SHUFFLE_ORDER["Shuffle cluster order"] --> LOAD SHUFFLE -->|No| LOAD subgraph LOAD["Loading Loop (Background Thread)"] ACC["Accumulate clusters<br/>until buffer ≈ full"] READ["Load into staging buffer<br/>(filters evaluated here)"] SHUF["Shuffle rows in buffer"] BATCH["Create batches → queue"] ACC --> READ --> SHUF --> BATCH end LOAD --> CONSUMER["Consumer: GetTrainBatch() / GetValidationBatch()"] subgraph FILTER["Lazy Filter Discovery"] E1["First epoch: discover actual counts"] E2["Subsequent: use discovered counts"] E1 -.-> E2 end READ -.-> FILTERNew Python Interface
The new parameter
The
batches_in_memoryparameter controls how much data is held in memory at once, affecting both memory usage and shuffle quality.How it works:
batches_in_memory × batch_sizerowsbatches_in_memory= better cross-cluster mixing, but more memory usageNote
This implementation supports passing multiple RDataFrames, fixes #21782
The feature of multiple filtered RDataFrames is yet to be added
TODO