Support Parallel Data Loading Shufflable Iterable Datasets/DataStreams#100
Open
alex-jw-brooks wants to merge 6 commits intocaikit:mainfrom
Open
Support Parallel Data Loading Shufflable Iterable Datasets/DataStreams#100alex-jw-brooks wants to merge 6 commits intocaikit:mainfrom
alex-jw-brooks wants to merge 6 commits intocaikit:mainfrom
Conversation
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
73e7af8 to
1392991
Compare
Collaborator
Author
|
Minimal example showing shuffling (run from a venv on Linux): import torch
from caikit.core.data_model import DataStream
from caikit_nlp.toolkit.data_stream_wrapper import SimpleIterableStreamWrapper
SAMPLE_DATA = [{"label": "a"}, {"label": "b"}, {"label": "c"}, {"label": "d"}]
SAMPLE_STREAM = DataStream.from_iterable(SAMPLE_DATA)
wrapper = SimpleIterableStreamWrapper(stream=SAMPLE_STREAM, shuffle=True)
torch_loader = torch.utils.data.DataLoader(
wrapper,
num_workers=2,
persistent_workers=True, # Needed, otherwise every iteration shuffles the same way!
)
for epoch in range(3):
for idx, x in enumerate(torch_loader):
print(x)
print("Finished iteration: {}".format(epoch))Sample output: In some preliminary benchmarking I did, this is unfortunately slower than running with no worker processes, at least for the way we handle tokenizer mapping onto train streams in prompt tuning (on the order of 2-3x slower). While a bit of a bummer, this is a generic utility for datastreams, and may be beneficial for re-entrant streams that have heft iteration costs, like loading from files etc |
Collaborator
Author
|
There are some other potential optimizations that can be made around this, but they do break the genericism a bit; it might be better to consider getting this in first, and having the optimizations as a follow up. The two main ones I can think of are:
import caikit
s = caikit.core.data_model.DataStream.from_iterable([1])
def map_func(example):
print(f"Called the map func on example: {example}") # printed 10 times since every iteration calls this again upon reentry
return example + 1
mapped_s = s.map(map_func)
for _ in range(10):
for x in mapped_s:
pass
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds support for multiple workers when processing an iterable dataset in such a way that:
The caveat to this is that for the shuffling to work correctly, we need to use
persistent_workers=Truewhen creating our data loader.This is accomplished by defining a
shuffle_seed, which is essentially a random seed that gets incremented every time we cycle through our data. This is used as the random seed when creating the shuffled stream generator; the workers must be persistent, otherwise theshuffle_seedwill get reset with every iteration, but this approach lets us shuffle consistently across workers without them communicating.Then, to divide the data, we create an iterator yielding every nth item of the preprocessed stream (which would be shuffled by now) given n worker, with an offset based on the worker ID.
Also adds docstrings to the stream wrapper & caches the stream length, since
len()is an expensive operation for the data stream.Closes: #74