Skip to content

[ML] Refactor the RDataLoader: new cluster-aligned reading pattern, improved shuffling strategy and new Python API#21888

Open
siliataider wants to merge 8 commits intoroot-project:masterfrom
siliataider:rdataloader
Open

[ML] Refactor the RDataLoader: new cluster-aligned reading pattern, improved shuffling strategy and new Python API#21888
siliataider wants to merge 8 commits intoroot-project:masterfrom
siliataider:rdataloader

Conversation

@siliataider
Copy link
Copy Markdown
Contributor

@siliataider siliataider commented Apr 10, 2026

This Pull request:

This PR refactors the ML DataLoader infrastructure to implement a cluster-aware reading strategy with improved shuffling semantics and a new Python interface.

Data flow & cluster-aligned reading strategy

flowchart TB
    RDF["RDataFrame(s)"] --> SCAN["ScanClusters()<br/>extract cluster boundaries"]
    SCAN --> SPLIT["SplitDataset()<br/>Each cluster split into → <br/> (1-val_split) train + val_split validation"]
    
    SPLIT --> ACTIVATE["ActivateEpoch()"]
    ACTIVATE --> SHUFFLE{"shuffle?"}
    SHUFFLE -->|Yes| SHUFFLE_ORDER["Shuffle cluster order"] --> LOAD
    SHUFFLE -->|No| LOAD
    
    subgraph LOAD["Loading Loop (Background Thread)"]
        ACC["Accumulate clusters<br/>until buffer ≈ full"]
        READ["Load into staging buffer<br/>(filters evaluated here)"]
        SHUF["Shuffle rows in buffer"]
        BATCH["Create batches → queue"]
        
        ACC --> READ --> SHUF --> BATCH
    end
    
    LOAD --> CONSUMER["Consumer: GetTrainBatch() / GetValidationBatch()"]
    
    subgraph FILTER["Lazy Filter Discovery"]
        E1["First epoch: discover actual counts"]
        E2["Subsequent: use discovered counts"]
        E1 -.-> E2
    end
    
    READ -.-> FILTER
Loading

New Python Interface

train, val = ROOT.Experimental.ML.RDataLoader(
    df,
    batch_size=1024,
    batches_in_memory=10,    # Controls buffer size (replaces chunk_size/block_size)
    target="label",
    validation_split=0.2,
)

# 3 possible output formats (like before)
for x, y in train.AsNumpy():
    # train model

for x, y in train.AsTorch():
    # train model

for x, y in train.AsTensorFlow():
    # train model

The new parameter

The batches_in_memory parameter controls how much data is held in memory at once, affecting both memory usage and shuffle quality.

Buffer Capacity = batch_size × batches_in_memory

How it works:

  • The loading thread accumulates clusters until the buffer reaches approximately batches_in_memory × batch_size rows
  • Rows within this buffer are shuffled together before batching
  • Higher batches_in_memory = better cross-cluster mixing, but more memory usage

Note

This implementation supports passing multiple RDataFrames, fixes #21782
The feature of multiple filtered RDataFrames is yet to be added

TODO

  • Docs and release notes in a secondary PR

@siliataider siliataider changed the title Implement the new Rdataloader (new interface + reading patterns + shuffling strategy) Implement the new RDataLoader (new interface + reading patterns + shuffling strategy) Apr 10, 2026
@siliataider siliataider self-assigned this Apr 10, 2026
@siliataider siliataider added in:Python Interface in:ML Everything under ROOT/ML labels Apr 10, 2026
@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 10, 2026

Test Results

    21 files      21 suites   2d 23h 40m 0s ⏱️
 3 833 tests  3 831 ✅  1 💤 1 ❌
73 083 runs  73 064 ✅ 18 💤 1 ❌

For more details on these failures, see this check.

Results for commit 57e9f5e.

♻️ This comment has been updated with latest results.

@siliataider siliataider force-pushed the rdataloader branch 2 times, most recently from ee4e977 to fb3df72 Compare April 13, 2026 10:51
@siliataider siliataider changed the title Implement the new RDataLoader (new interface + reading patterns + shuffling strategy) [ML] Refactor the RDataLoader: new cluster-aligned reading pattern, improved shuffling strategy and new Python API Apr 13, 2026
@siliataider siliataider marked this pull request as ready for review April 13, 2026 11:30
@siliataider siliataider force-pushed the rdataloader branch 2 times, most recently from eecb1b6 to c7699f2 Compare April 14, 2026 10:43
Copy link
Copy Markdown
Member

@vepadulano vepadulano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this extensive work! I'm not finished yet with the full review but I leave a couple of first fly-by thoughts

ROOT.EnableThreadSafety()

self.generator = ROOT.Experimental.Internal.ML.RBatchGenerator(template)(
self.data_loader = ROOT.Experimental.Internal.ML.RDataLoader(template)(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit weird to see both ROOT.Experimental.ML.RDataLoader and ROOT.Experimental.Internal.RDataLoader with the same name. While I'm fine with changing name to the RBatchGenerator class since the new functionality goes a bit beyond the initial implementation, I wonder if we should give it a different name?

set_seed (int):
For reproducibility: Set the seed for the random number generator used
to split the dataset into training and validation and shuffling of the chunks
to split the dataset into training and validation and shuffling of the clusters
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clusters is a technical term, a bit of an implementation detail for ROOT data formats. While I agree it's more correct, I'm thinking for a general reader might be hard to parse. Maybe consider rephrasing to add some context? Or leaving chunks but also explain there is an I/O optimization going on? Not sure



def CreateTFDatasets(
def RDataLoader(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nitpick, but I wonder if at this point it wouldn't be better to be consistent with the Python naming casing and have RDataLoader be an actual Python class. So far we have been abusing a bit the nomenclature. I can imagine we could choose to go in a different number of ways:

  • If we keep this as a function, it shouldn't have the R prefix and it should probably start with a verb.
  • If we convert this to a class, it could have a class method Create and then a user code could be
ds_train, ds_validation = ROOT.ML.RDataLoader.Create(...)
  • It could be a class created with the same parameters as the current function but have a __call__ method and then user code could look like
# Example without validation_split
dl = ROOT.ML.RDataLoader(...)
# by default calling __call__ has no validation split, returns one value
for x, y in dl().AsTorch()

# Example with validation_split
dl = ROOT.ML.RDataLoader()
ds_train, ds_validation = dl(validation_split=0.2)
# Continue as now

time. Larger values improve shuffle quality across cluster
boundaries at the cost of higher memory usage. Acts as a soft
cap: the buffer may temporarily exceed this by up to one
cluster's worth of rows. Defaults to 10.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, not necessarily for this PR. For consistency with other parts of ROOT, we should probably refer to entries of the dataset rather than rows

*/
struct RClusterRange {
std::size_t rdfIdx; // which rdf this cluster belongs to
ULong64_t start; // first raw entry (incl)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given we are already TTree-agnostic, we can use standard integer types here

Suggested change
ULong64_t start; // first raw entry (incl)
std::uint64_t start; // first raw entry (incl)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in:ML Everything under ROOT/ML in:Python Interface

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ML] ROOT.Experimental.ML.CreatePyTorchGenerators only uses the first dataframe

2 participants