-
Notifications
You must be signed in to change notification settings - Fork 1.5k
optimize streaming #9425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
optimize streaming #9425
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,10 +1,63 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) ModelScope Contributors. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.distributed as dist | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tqdm import tqdm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from swift.utils import to_device | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class _TensorMeta: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Sentinel replacing a tensor in the schema, carrying metadata for buffer allocation.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| __slots__ = ('idx', 'shape', 'dtype') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, idx, shape, dtype): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.idx = idx | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.shape = shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.dtype = dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _flatten_for_scatter(obj, tensors): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Recursively separate tensors from a nested structure. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Tensors are appended to `tensors` and replaced by _TensorMeta sentinels. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The returned schema is lightweight and can be pickled efficiently. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if torch.is_tensor(obj): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| idx = len(tensors) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensors.append(obj) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return _TensorMeta(idx, tuple(obj.shape), obj.dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(obj, dict): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return {k: _flatten_for_scatter(v, tensors) for k, v in obj.items()} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(obj, (tuple, list)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return type(obj)(_flatten_for_scatter(v, tensors) for v in obj) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Jintao-Huang marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return obj | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+19
to
+34
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the input data contains
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _unflatten_from_scatter(schema, tensors): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Reconstruct the original nested structure from schema and flat tensors list.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(schema, _TensorMeta): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return tensors[schema.idx] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(schema, dict): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return {k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(schema, (tuple, list)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+41
to
+44
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Preserve the original dictionary subclass and support
Suggested change
Jintao-Huang marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return schema | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+37
to
+46
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly to
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _collect_tensor_metas(schema, metas): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Collect _TensorMeta from schema in DFS order (same order as flatten).""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(schema, _TensorMeta): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metas.append(schema) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(schema, dict): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for v in schema.values(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _collect_tensor_metas(v, metas) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(schema, (tuple, list)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for v in schema: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _collect_tensor_metas(v, metas) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class DataLoaderDispatcher: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, base_dataloader, device=None, skip_batches: int = 0): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -24,13 +77,98 @@ def world_size(self): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def group(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return dist.group.WORLD if dist.is_initialized() else 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _scatter_device(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Determine the correct device for dist.scatter based on backend.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| backend = dist.get_backend(self.group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if backend == 'nccl': | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.device('cuda', torch.cuda.current_device()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif backend == 'hccl': | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.device('npu', torch.npu.current_device()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None # keep tensors on their original device | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Jintao-Huang marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _scatter_object_list(self, inputs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Scatter data from rank 0 to all ranks. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optimization: separates tensors from non-tensor structure (schema) so that | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| schemas are scattered via pickle (lightweight) and tensors are transferred | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| via P2P isend/irecv (efficient NCCL/Gloo tensor transfer, zero padding waste). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Naturally handles variable-size tensors across ranks. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not dist.is_initialized(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return inputs[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs = [None] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| global_src_rank = dist.get_global_rank(self.group, 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dist.scatter_object_list(outputs, inputs, global_src_rank, group=self.group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return outputs[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scatter_device = self._scatter_device | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.rank == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Flatten each rank's data: separate tensors from schema | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| schemas = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| per_rank_tensors = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for item in inputs: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if item is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| schemas.append(None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| per_rank_tensors.append([]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensors = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| schema = _flatten_for_scatter(item, tensors) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| schemas.append(schema) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| per_rank_tensors.append(tensors) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Scatter lightweight schemas (no tensor payload, fast pickle) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| schema_out = [None] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dist.scatter_object_list(schema_out, schemas, global_src_rank, group=self.group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| my_schema = schema_out[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Send tensors to other ranks via async P2P | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| handles = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| send_bufs = [] # keep tensors alive until sends complete | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for r in range(1, self.world_size): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dst_rank = dist.get_global_rank(self.group, r) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for t in per_rank_tensors[r]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensor = t.contiguous() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if scatter_device is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensor = tensor.to(scatter_device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| send_bufs.append(tensor) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| handles.append(dist.isend(tensor, dst=dst_rank, group=self.group)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+126
to
+133
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In PyTorch distributed, when a process group is specified in point-to-point communication APIs like
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Rank 0 keeps its own tensors (move to device if needed) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| my_tensors = per_rank_tensors[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if scatter_device is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| my_tensors = [t.contiguous().to(scatter_device) for t in my_tensors] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Wait for all sends to complete | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for h in handles: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| h.wait() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Jintao-Huang marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| del send_bufs # safe to release after all sends finished | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Receive schema (lightweight) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| schema_out = [None] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dist.scatter_object_list(schema_out, None, global_src_rank, group=self.group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| my_schema = schema_out[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if my_schema is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Receive tensors via async P2P (shape/dtype from _TensorMeta in schema) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metas = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _collect_tensor_metas(my_schema, metas) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metas.sort(key=lambda m: m.idx) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device = scatter_device if scatter_device is not None else 'cpu' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| my_tensors = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| handles = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for meta in metas: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| handles.append(dist.irecv(recv_buf, src=global_src_rank, group=self.group)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| my_tensors.append(recv_buf) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+160
to
+163
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly, for
Suggested change
Comment on lines
101
to
+163
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In PyTorch Distributed, when a custom Using global ranks with a non-WORLD group will cause runtime errors or incorrect communication. We should use scatter_device = self._scatter_device
if self.rank == 0:
# Flatten each rank's data: separate tensors from schema
schemas = []
per_rank_tensors = []
for item in inputs:
if item is None:
schemas.append(None)
per_rank_tensors.append([])
else:
tensors = []
schema = _flatten_for_scatter(item, tensors)
schemas.append(schema)
per_rank_tensors.append(tensors)
# Scatter lightweight schemas (no tensor payload, fast pickle)
schema_out = [None]
dist.scatter_object_list(schema_out, schemas, src=0, group=self.group)
my_schema = schema_out[0]
# Send tensors to other ranks via async P2P
handles = []
send_bufs = [] # keep tensors alive until sends complete
for r in range(1, self.world_size):
for t in per_rank_tensors[r]:
tensor = t.contiguous()
if scatter_device is not None:
tensor = tensor.to(scatter_device)
send_bufs.append(tensor)
handles.append(dist.isend(tensor, dst=r, group=self.group))
# Rank 0 keeps its own tensors (move to device if needed)
my_tensors = per_rank_tensors[0]
if scatter_device is not None:
my_tensors = [t.contiguous().to(scatter_device) for t in my_tensors]
# Wait for all sends to complete
for h in handles:
h.wait()
del send_bufs # safe to release after all sends finished
else:
# Receive schema (lightweight)
schema_out = [None]
dist.scatter_object_list(schema_out, None, src=0, group=self.group)
my_schema = schema_out[0]
if my_schema is None:
return None
# Receive tensors via async P2P (shape/dtype from _TensorMeta in schema)
metas = []
_collect_tensor_metas(my_schema, metas)
metas.sort(key=lambda m: m.idx)
device = scatter_device if scatter_device is not None else 'cpu'
my_tensors = []
handles = []
for meta in metas:
recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device)
handles.append(dist.irecv(recv_buf, src=0, group=self.group))
my_tensors.append(recv_buf) |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Wait for all receives to complete | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for h in handles: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| h.wait() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if my_schema is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return _unflatten_from_scatter(my_schema, my_tensors) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _skip_batches(self, base_iter): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.rank == 0 and self.skip_batches > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preserve the original dictionary subclass (such as Hugging Face's
BatchEncoding) and supportnamedtupleduring flattening. Returning a plaindictforBatchEncodingwill break downstream code that accesses keys as attributes (e.g.,batch.input_ids). Additionally, instantiating anamedtuplewith a generator directly (i.e.,type(obj)(generator)) raises aTypeErrorbecause its__new__expects separate positional arguments.