Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 141 additions & 3 deletions swift/dataloader/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,63 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch
import torch.distributed as dist
from tqdm import tqdm

from swift.utils import to_device


class _TensorMeta:
"""Sentinel replacing a tensor in the schema, carrying metadata for buffer allocation."""
__slots__ = ('idx', 'shape', 'dtype')

def __init__(self, idx, shape, dtype):
self.idx = idx
self.shape = shape
self.dtype = dtype


def _flatten_for_scatter(obj, tensors):
"""Recursively separate tensors from a nested structure.

Tensors are appended to `tensors` and replaced by _TensorMeta sentinels.
The returned schema is lightweight and can be pickled efficiently.
"""
if torch.is_tensor(obj):
idx = len(tensors)
tensors.append(obj)
return _TensorMeta(idx, tuple(obj.shape), obj.dtype)
elif isinstance(obj, dict):
return {k: _flatten_for_scatter(v, tensors) for k, v in obj.items()}
elif isinstance(obj, (tuple, list)):
return type(obj)(_flatten_for_scatter(v, tensors) for v in obj)
Comment on lines +29 to +32

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Preserve the original dictionary subclass (such as Hugging Face's BatchEncoding) and support namedtuple during flattening. Returning a plain dict for BatchEncoding will break downstream code that accesses keys as attributes (e.g., batch.input_ids). Additionally, instantiating a namedtuple with a generator directly (i.e., type(obj)(generator)) raises a TypeError because its __new__ expects separate positional arguments.

Suggested change
elif isinstance(obj, dict):
return {k: _flatten_for_scatter(v, tensors) for k, v in obj.items()}
elif isinstance(obj, (tuple, list)):
return type(obj)(_flatten_for_scatter(v, tensors) for v in obj)
elif isinstance(obj, dict):
return type(obj)({k: _flatten_for_scatter(v, tensors) for k, v in obj.items()})
elif isinstance(obj, (tuple, list)):
if hasattr(obj, '_fields'): # namedtuple
return type(obj)(*(_flatten_for_scatter(v, tensors) for v in obj))
return type(obj)(_flatten_for_scatter(v, tensors) for v in obj)

Comment thread
Jintao-Huang marked this conversation as resolved.
else:
return obj
Comment on lines +19 to +34

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If the input data contains namedtuple instances (which are common in PyTorch/HuggingFace ecosystems), calling type(obj)(...) with a generator will fail because namedtuple constructors expect individual field arguments rather than an iterable. Additionally, we should preserve the original dictionary subclass (like BatchEncoding or OrderedDict) by using type(obj)(...) instead of returning a plain dict.

Suggested change
def _flatten_for_scatter(obj, tensors):
"""Recursively separate tensors from a nested structure.
Tensors are appended to `tensors` and replaced by _TensorMeta sentinels.
The returned schema is lightweight and can be pickled efficiently.
"""
if torch.is_tensor(obj):
idx = len(tensors)
tensors.append(obj)
return _TensorMeta(idx, tuple(obj.shape), obj.dtype)
elif isinstance(obj, dict):
return {k: _flatten_for_scatter(v, tensors) for k, v in obj.items()}
elif isinstance(obj, (tuple, list)):
return type(obj)(_flatten_for_scatter(v, tensors) for v in obj)
else:
return obj
def _flatten_for_scatter(obj, tensors):
"""Recursively separate tensors from a nested structure.
Tensors are appended to tensors and replaced by _TensorMeta sentinels.
The returned schema is lightweight and can be pickled efficiently.
"""
if torch.is_tensor(obj):
idx = len(tensors)
tensors.append(obj)
return _TensorMeta(idx, tuple(obj.shape), obj.dtype)
elif isinstance(obj, dict):
return type(obj)({k: _flatten_for_scatter(v, tensors) for k, v in obj.items()})
elif isinstance(obj, (tuple, list)):
if hasattr(obj, '_fields'):
return type(obj)(*(_flatten_for_scatter(v, tensors) for v in obj))
return type(obj)(_flatten_for_scatter(v, tensors) for v in obj)
else:
return obj



def _unflatten_from_scatter(schema, tensors):
"""Reconstruct the original nested structure from schema and flat tensors list."""
if isinstance(schema, _TensorMeta):
return tensors[schema.idx]
elif isinstance(schema, dict):
return {k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()}
elif isinstance(schema, (tuple, list)):
return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema)
Comment on lines +41 to +44

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Preserve the original dictionary subclass and support namedtuple during unflattening to match the types of the original batch containers.

Suggested change
elif isinstance(schema, dict):
return {k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()}
elif isinstance(schema, (tuple, list)):
return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema)
elif isinstance(schema, dict):
return type(schema)({k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()})
elif isinstance(schema, (tuple, list)):
if hasattr(schema, '_fields'): # namedtuple
return type(schema)(*(_unflatten_from_scatter(v, tensors) for v in schema))
return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema)

Comment thread
Jintao-Huang marked this conversation as resolved.
else:
return schema
Comment on lines +37 to +46

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similarly to _flatten_for_scatter, we should handle namedtuple instances and preserve dictionary subclasses when reconstructing the original nested structure.

Suggested change
def _unflatten_from_scatter(schema, tensors):
"""Reconstruct the original nested structure from schema and flat tensors list."""
if isinstance(schema, _TensorMeta):
return tensors[schema.idx]
elif isinstance(schema, dict):
return {k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()}
elif isinstance(schema, (tuple, list)):
return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema)
else:
return schema
def _unflatten_from_scatter(schema, tensors):
"""Reconstruct the original nested structure from schema and flat tensors list."""
if isinstance(schema, _TensorMeta):
return tensors[schema.idx]
elif isinstance(schema, dict):
return type(schema)({k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()})
elif isinstance(schema, (tuple, list)):
if hasattr(schema, '_fields'):
return type(schema)(*(_unflatten_from_scatter(v, tensors) for v in schema))
return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema)
else:
return schema



def _collect_tensor_metas(schema, metas):
"""Collect _TensorMeta from schema in DFS order (same order as flatten)."""
if isinstance(schema, _TensorMeta):
metas.append(schema)
elif isinstance(schema, dict):
for v in schema.values():
_collect_tensor_metas(v, metas)
elif isinstance(schema, (tuple, list)):
for v in schema:
_collect_tensor_metas(v, metas)


class DataLoaderDispatcher:

def __init__(self, base_dataloader, device=None, skip_batches: int = 0):
Expand All @@ -24,13 +77,98 @@ def world_size(self):
def group(self):
return dist.group.WORLD if dist.is_initialized() else 1

@property
def _scatter_device(self):
"""Determine the correct device for dist.scatter based on backend."""
backend = dist.get_backend(self.group)
if backend == 'nccl':
return torch.device('cuda', torch.cuda.current_device())
elif backend == 'hccl':
return torch.device('npu', torch.npu.current_device())
return None # keep tensors on their original device
Comment thread
Jintao-Huang marked this conversation as resolved.

def _scatter_object_list(self, inputs):
"""Scatter data from rank 0 to all ranks.

Optimization: separates tensors from non-tensor structure (schema) so that
schemas are scattered via pickle (lightweight) and tensors are transferred
via P2P isend/irecv (efficient NCCL/Gloo tensor transfer, zero padding waste).
Naturally handles variable-size tensors across ranks.
"""
if not dist.is_initialized():
return inputs[0]
outputs = [None]

global_src_rank = dist.get_global_rank(self.group, 0)
dist.scatter_object_list(outputs, inputs, global_src_rank, group=self.group)
return outputs[0]
scatter_device = self._scatter_device

if self.rank == 0:
# Flatten each rank's data: separate tensors from schema
schemas = []
per_rank_tensors = []
for item in inputs:
if item is None:
schemas.append(None)
per_rank_tensors.append([])
else:
tensors = []
schema = _flatten_for_scatter(item, tensors)
schemas.append(schema)
per_rank_tensors.append(tensors)

# Scatter lightweight schemas (no tensor payload, fast pickle)
schema_out = [None]
dist.scatter_object_list(schema_out, schemas, global_src_rank, group=self.group)
my_schema = schema_out[0]

# Send tensors to other ranks via async P2P
handles = []
send_bufs = [] # keep tensors alive until sends complete
for r in range(1, self.world_size):
dst_rank = dist.get_global_rank(self.group, r)
for t in per_rank_tensors[r]:
tensor = t.contiguous()
if scatter_device is not None:
tensor = tensor.to(scatter_device)
send_bufs.append(tensor)
handles.append(dist.isend(tensor, dst=dst_rank, group=self.group))
Comment on lines +126 to +133

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In PyTorch distributed, when a process group is specified in point-to-point communication APIs like dist.isend and dist.irecv, the destination/source rank must be the group-relative rank (local rank within the group), not the global rank. Using the global rank dst_rank with group=self.group will cause a ValueError (out of bounds) or incorrect/corrupted communication if self.group is a sub-group (e.g., in pipeline or tensor parallel setups).

Suggested change
for r in range(1, self.world_size):
dst_rank = dist.get_global_rank(self.group, r)
for t in per_rank_tensors[r]:
tensor = t.contiguous()
if scatter_device is not None:
tensor = tensor.to(scatter_device)
send_bufs.append(tensor)
handles.append(dist.isend(tensor, dst=dst_rank, group=self.group))
for r in range(1, self.world_size):
for t in per_rank_tensors[r]:
tensor = t.contiguous()
if scatter_device is not None:
tensor = tensor.to(scatter_device)
send_bufs.append(tensor)
handles.append(dist.isend(tensor, dst=r, group=self.group))


# Rank 0 keeps its own tensors (move to device if needed)
my_tensors = per_rank_tensors[0]
if scatter_device is not None:
my_tensors = [t.contiguous().to(scatter_device) for t in my_tensors]

# Wait for all sends to complete
for h in handles:
h.wait()
Comment thread
Jintao-Huang marked this conversation as resolved.
del send_bufs # safe to release after all sends finished
else:
# Receive schema (lightweight)
schema_out = [None]
dist.scatter_object_list(schema_out, None, global_src_rank, group=self.group)
my_schema = schema_out[0]

if my_schema is None:
return None

# Receive tensors via async P2P (shape/dtype from _TensorMeta in schema)
metas = []
_collect_tensor_metas(my_schema, metas)
metas.sort(key=lambda m: m.idx)
device = scatter_device if scatter_device is not None else 'cpu'
my_tensors = []
handles = []
for meta in metas:
recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device)
handles.append(dist.irecv(recv_buf, src=global_src_rank, group=self.group))
my_tensors.append(recv_buf)
Comment on lines +160 to +163

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similarly, for dist.irecv, when group=self.group is specified, the source rank must be the group-relative rank (which is 0 for the sender rank 0), not the global rank global_src_rank.

Suggested change
for meta in metas:
recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device)
handles.append(dist.irecv(recv_buf, src=global_src_rank, group=self.group))
my_tensors.append(recv_buf)
for meta in metas:
recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device)
handles.append(dist.irecv(recv_buf, src=0, group=self.group))
my_tensors.append(recv_buf)

Comment on lines 101 to +163

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In PyTorch Distributed, when a custom group is specified in collective or P2P operations (such as dist.scatter_object_list, dist.isend, and dist.irecv), the rank arguments (like src and dst) must be relative to the group (i.e., group-relative ranks), not global ranks.

Using global ranks with a non-WORLD group will cause runtime errors or incorrect communication. We should use 0 as the source rank and r as the destination rank directly, which also simplifies the code by removing the need for dist.get_global_rank calls.

        scatter_device = self._scatter_device

        if self.rank == 0:
            # Flatten each rank's data: separate tensors from schema
            schemas = []
            per_rank_tensors = []
            for item in inputs:
                if item is None:
                    schemas.append(None)
                    per_rank_tensors.append([])
                else:
                    tensors = []
                    schema = _flatten_for_scatter(item, tensors)
                    schemas.append(schema)
                    per_rank_tensors.append(tensors)

            # Scatter lightweight schemas (no tensor payload, fast pickle)
            schema_out = [None]
            dist.scatter_object_list(schema_out, schemas, src=0, group=self.group)
            my_schema = schema_out[0]

            # Send tensors to other ranks via async P2P
            handles = []
            send_bufs = []  # keep tensors alive until sends complete
            for r in range(1, self.world_size):
                for t in per_rank_tensors[r]:
                    tensor = t.contiguous()
                    if scatter_device is not None:
                        tensor = tensor.to(scatter_device)
                    send_bufs.append(tensor)
                    handles.append(dist.isend(tensor, dst=r, group=self.group))

            # Rank 0 keeps its own tensors (move to device if needed)
            my_tensors = per_rank_tensors[0]
            if scatter_device is not None:
                my_tensors = [t.contiguous().to(scatter_device) for t in my_tensors]

            # Wait for all sends to complete
            for h in handles:
                h.wait()
            del send_bufs  # safe to release after all sends finished
        else:
            # Receive schema (lightweight)
            schema_out = [None]
            dist.scatter_object_list(schema_out, None, src=0, group=self.group)
            my_schema = schema_out[0]

            if my_schema is None:
                return None

            # Receive tensors via async P2P (shape/dtype from _TensorMeta in schema)
            metas = []
            _collect_tensor_metas(my_schema, metas)
            metas.sort(key=lambda m: m.idx)
            device = scatter_device if scatter_device is not None else 'cpu'
            my_tensors = []
            handles = []
            for meta in metas:
                recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device)
                handles.append(dist.irecv(recv_buf, src=0, group=self.group))
                my_tensors.append(recv_buf)


# Wait for all receives to complete
for h in handles:
h.wait()

if my_schema is None:
return None
return _unflatten_from_scatter(my_schema, my_tensors)

def _skip_batches(self, base_iter):
if self.rank == 0 and self.skip_batches > 0:
Expand Down
Loading