diff --git a/docs/source_en/Components/Checkpoint Engine/HCCLCheckpointEngine.md b/docs/source_en/Components/Checkpoint Engine/HCCLCheckpointEngine.md index 585031ca..e2824265 100644 --- a/docs/source_en/Components/Checkpoint Engine/HCCLCheckpointEngine.md +++ b/docs/source_en/Components/Checkpoint Engine/HCCLCheckpointEngine.md @@ -25,4 +25,11 @@ HCCLCheckpointEngine is specifically designed for Ascend NPU environments: - Synchronizing model weights between NPUs - Large-scale NPU cluster deployment +## Environment Variables + +- `TWINKLE_CKPT_HCCL_META_TIMEOUT_S`: + Controls the timeout (in seconds) for the HCCL CheckpointEngine + metadata handshake channel (ZMQ REQ/REP). + Default is `300`. This value should be an integer greater than `0`. + > In Ascend NPU environments, HCCLCheckpointEngine provides performance comparable to NCCL. diff --git a/docs/source_en/Components/Sampler/vLLMSampler.md b/docs/source_en/Components/Sampler/vLLMSampler.md index f6eb1fa6..95e7d283 100644 --- a/docs/source_en/Components/Sampler/vLLMSampler.md +++ b/docs/source_en/Components/Sampler/vLLMSampler.md @@ -69,4 +69,11 @@ sampler = vLLMSampler( response = sampler.sample(trajectories, sampling_params=params) ``` +## Environment Variables + +- `TWINKLE_VLLM_IPC_TIMEOUT_S`: + Controls the timeout (in seconds) for the IPC channel (ZMQ REQ/REP) + between `vLLMSampler` and the vLLM worker extension. + Default is `300`. This value must be greater than `0`. + > In RLHF training, vLLMSampler is typically separated from the Actor model, using different hardware resources to avoid interference between inference and training. diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/HCCLCheckpointEngine.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/HCCLCheckpointEngine.md" index 0aaf6d9b..b58000c4 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/HCCLCheckpointEngine.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/HCCLCheckpointEngine.md" @@ -25,4 +25,10 @@ HCCLCheckpointEngine 专门用于昇腾 NPU 环境: - 需要在 NPU 间同步模型权重 - 大规模 NPU 集群部署 +## 环境变量 + +- `TWINKLE_CKPT_HCCL_META_TIMEOUT_S`: + 控制 HCCL CheckpointEngine 元数据握手通道(ZMQ REQ/REP)的超时时间(秒)。 + 默认值为 `300`。该值应设置为大于 `0` 的整数。 + > 在昇腾 NPU 环境中,HCCLCheckpointEngine 提供了与 NCCL 相当的性能。 diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md" "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md" index eced51f7..38b4e5be 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md" @@ -69,4 +69,10 @@ sampler = vLLMSampler( response = sampler.sample(trajectories, sampling_params=params) ``` +## 环境变量 + +- `TWINKLE_VLLM_IPC_TIMEOUT_S`: + 控制 `vLLMSampler` 与 vLLM worker extension 之间 IPC 通道(ZMQ REQ/REP)的超时时间(秒)。 + 默认值为 `300`。该值必须大于 `0`。 + > vLLMSampler 在 RLHF 训练中通常与 Actor 模型分离,使用不同的硬件资源,避免推理和训练相互干扰。 diff --git a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py index e6b9cdde..33db5764 100644 --- a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py +++ b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py @@ -2,14 +2,13 @@ # Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/hccl_checkpoint_engine.py """HCCL-based checkpoint engine for Ascend NPU. -This engine uses HCCL broadcast for efficient NPU-to-NPU weight transfer -across different processes/nodes. It supports: -- Double buffering for pipelined transfer -- ZMQ for metadata, HCCL for weight data -- Streaming weight transfer to avoid OOM +This implementation keeps HCCL for weight payload transfer and uses a +reliable ZMQ REQ/REP control channel for bucket metadata handshakes. """ -import asyncio +from __future__ import annotations + +import os import time import torch import zmq @@ -18,7 +17,8 @@ from twinkle import get_logger from twinkle.utils import find_free_port, find_node_ip, is_valid_ipv6_address, stateless_init_process_group -from .base import CheckpointEngine, TensorMeta +from twinkle.utils.zmq_utils import configure_zmq_socket +from .base import CheckpointEngine logger = get_logger() @@ -26,79 +26,15 @@ @dataclass class MasterMetadata: """Metadata from the master for process group initialization.""" + zmq_ip: str zmq_port: int dist_ip: str dist_port: int -class BroadcastOperation: - """Async broadcast operation with HCCL in separate thread. - - Args: - rank: The rank of the current process. - process_group: The HCCL process group. - bucket: The tensor buffer to broadcast. - metadata: The metadata of tensors in the bucket. - socket: The ZMQ socket for metadata communication. - topic: The ZMQ topic for pub/sub. - """ - - def __init__( - self, - rank: int, - process_group, - bucket: torch.Tensor, - metadata: dict[str, TensorMeta], - socket: zmq.Socket, - topic: str, - ) -> None: - self.rank = rank - self.pyhccl = process_group - self.bucket = bucket - self.metadata = metadata - self.socket = socket - self.topic = topic - - loop = asyncio.get_running_loop() - self._task = loop.run_in_executor(None, self._run) - - def _run(self): - """Execute the broadcast operation in a thread.""" - # Broadcast tensor metadata via ZMQ PUB/SUB - if self.rank == 0: - self.socket.send_string(self.topic, flags=zmq.SNDMORE) - self.socket.send_pyobj(self.metadata) - else: - self.socket.recv_string() - self.metadata = self.socket.recv_pyobj() - - # Broadcast tensor data via HCCL - self.pyhccl.broadcast(self.bucket, src=0) - - async def wait_for_complete(self) -> dict[str, TensorMeta]: - """Wait for the broadcast operation to complete. - - Returns: - The bucket metadata after broadcast. - """ - await self._task - return self.metadata - - class HCCLCheckpointEngine(CheckpointEngine): - """HCCL checkpoint engine for Ascend NPU. - - Same lifecycle and semantics as NCCLCheckpointEngine but uses HCCL - instead of NCCL and stateless_init_process_group instead of - ray.util.collective. - - Args: - bucket_size: Bucket size in bytes for weight transfer. - group_name: Name of the process group. - rebuild_group: Whether to rebuild the group each sync. - rollout_dtype: Target dtype for weights. - """ + """HCCL checkpoint engine for Ascend NPU.""" def __init__( self, @@ -114,70 +50,76 @@ def __init__( self.rollout_dtype = rollout_dtype self.pyhccl = None - # Get current NPU device + self.meta_timeout_s = int(os.environ.get('TWINKLE_CKPT_HCCL_META_TIMEOUT_S', '300')) + self.meta_timeout_ms = self.meta_timeout_s * 1000 + try: self.device = torch.npu.current_device() except Exception: self.device = 0 - # Set by Manager before prepare() via attribute assignment self.is_master = False - self.topic = 'bucket_metadata' - # Will be set during prepare / init_process_group - self.rank = None - self.world_size = None - self.send_buf = None - self.recv_buf = None - self.socket = None + self.rank: int | None = None + self.world_size: int | None = None + self.send_buf: torch.Tensor | None = None + self.recv_buf: torch.Tensor | None = None + self.socket: zmq.Socket | None = None + self._zmq_ctx: zmq.Context | None = None - # Track whether resources are ready for reuse self._prepared = False self._group_initialized = False + self.ip: str | None = None + self.zmq_port: int | None = None + self.dist_port: int | None = None + + def _new_socket(self, socket_type: int) -> zmq.Socket: + assert self._zmq_ctx is not None + socket = self._zmq_ctx.socket(socket_type) + configure_zmq_socket(socket, timeout_ms=self.meta_timeout_ms, linger=0) + return socket + + def _recv_pyobj(self, where: str) -> Any: + assert self.socket is not None + try: + return self.socket.recv_pyobj() + except zmq.error.Again as e: + raise RuntimeError(f'HCCL metadata timeout ({self.meta_timeout_s}s) waiting at {where}.') from e - # ── ZMQ helpers ────────────────────────────────────────────────────── + def _send_pyobj(self, payload: Any, where: str) -> None: + assert self.socket is not None + try: + self.socket.send_pyobj(payload) + except zmq.error.Again as e: + raise RuntimeError(f'HCCL metadata timeout ({self.meta_timeout_s}s) sending at {where}.') from e def _start_zmq_server(self): - """Start ZMQ PUB server for metadata broadcast (master only).""" self.ip = find_node_ip() self.zmq_port = find_free_port() self.dist_port = find_free_port() - context = zmq.Context() - self.socket = context.socket(zmq.PUB) + self._zmq_ctx = zmq.Context() + self.socket = self._new_socket(zmq.REP) if is_valid_ipv6_address(self.ip): address = f'tcp://[{self.ip}]:{self.zmq_port}' self.socket.setsockopt(zmq.IPV6, 1) else: address = f'tcp://{self.ip}:{self.zmq_port}' - self.socket.bind(address) - logger.debug(f'ZMQ PUB server started at {address}') + logger.debug(f'ZMQ REP server started at {address}') def _connect_zmq_client(self, metadata: MasterMetadata): - """Connect to the ZMQ PUB server as a subscriber (receiver only).""" - context = zmq.Context() - self.socket = context.socket(zmq.SUB) + self._zmq_ctx = zmq.Context() + self.socket = self._new_socket(zmq.REQ) if is_valid_ipv6_address(metadata.zmq_ip): address = f'tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}' self.socket.setsockopt(zmq.IPV6, 1) else: address = f'tcp://{metadata.zmq_ip}:{metadata.zmq_port}' - self.socket.connect(address) - self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic) - logger.debug(f'ZMQ SUB client connected to {address}') - - # ── Core lifecycle ─────────────────────────────────────────────────── + logger.debug(f'ZMQ REQ client connected to {address}') def prepare(self) -> MasterMetadata | None: - """Allocate double buffers and start ZMQ server (master only). - - Idempotent: skips if already prepared. - - Returns: - MasterMetadata with ZMQ/dist IP/port if master, else None. - """ if self._prepared: if self.is_master: return MasterMetadata( @@ -200,15 +142,11 @@ def prepare(self) -> MasterMetadata | None: dist_ip=self.ip, dist_port=self.dist_port, ) + self._prepared = True return None def finalize(self): - """Clean up resources after a sync. - - When ``rebuild_group=False``: keeps everything alive for reuse. - When ``rebuild_group=True``: full teardown. - """ if self.rebuild_group: if self.socket is not None: try: @@ -217,6 +155,13 @@ def finalize(self): logger.warning(f'Error closing ZMQ socket: {e}') self.socket = None + if self._zmq_ctx is not None: + try: + self._zmq_ctx.term() + except Exception as e: + logger.warning(f'Error terminating ZMQ context: {e}') + self._zmq_ctx = None + if self.rank is not None and self.rank >= 0 and self.pyhccl is not None: try: self.pyhccl.destroyComm(self.pyhccl.comm) @@ -238,10 +183,6 @@ def build_topology( rollout_world_size: int, metadata: list[dict], ) -> tuple[dict[str, list[Any]], dict[str, list[Any]]]: - """Build communication topology for HCCL broadcast. - - Same topology as NCCLCheckpointEngine. - """ master_metadata = None for m in metadata: if m is not None: @@ -261,24 +202,12 @@ def build_topology( return trainer_kwargs, rollout_kwargs def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata): - """Initialize the HCCL process group. - - Idempotent: if already initialized and ``rebuild_group`` is False, - this is a fast no-op. - - Args: - rank: The rank of this worker (-1 for non-participating trainers). - world_size: Total number of workers in the sync group. - master_metadata: Metadata from the master. - """ - # Non-participating trainer ranks if rank < 0: self.rank = rank self.world_size = world_size self._group_initialized = True return - # Fast path: already initialized if self._group_initialized and not self.rebuild_group: return @@ -297,143 +226,253 @@ def init_process_group(self, rank: int, world_size: int, master_metadata: Master assert self.rank == rank assert self.world_size == world_size - # Receivers connect to master's ZMQ PUB server if self.rank > 0 and self.socket is None: self._connect_zmq_client(master_metadata) - # Barrier using all_reduce signal = torch.tensor([1], dtype=torch.int8, device=torch.npu.current_device()) self.pyhccl.all_reduce(signal) self._group_initialized = True logger.info(f'init_process_group: rank={self.rank}, world_size={self.world_size}') - # ── Send / Receive ─────────────────────────────────────────────────── - - @torch.no_grad() - async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): - """Send model weights via HCCL broadcast.""" - assert self.rank is not None and self.rank <= 0 + def _serve_bucket_requests(self, bucket_id: int, metadata: dict[str, Any]) -> None: + assert self.rank == 0 + assert self.world_size is not None - if self.rank < 0: - for name, weight in weights: - pass + if self.world_size <= 1: return - send_buf, recv_buf = self.send_buf, self.recv_buf - broadcast_op = None + pending = set(range(1, self.world_size)) + while pending: + req = self._recv_pyobj(f'NEXT request for bucket={bucket_id}') - start_time = time.time() - bucket_meta: dict[str, TensorMeta] = {} - offset = 0 + if not isinstance(req, dict) or req.get('type') != 'NEXT': + self._send_pyobj({'ok': False, 'error': f'unexpected message: {req}'}, 'NEXT reply') + continue - for name, weight in weights: - if offset + weight.nbytes > self.bucket_size: - torch.npu.synchronize() - - if broadcast_op is not None: - await broadcast_op.wait_for_complete() - - broadcast_op = BroadcastOperation( - rank=self.rank, - process_group=self.pyhccl, - bucket=send_buf, - metadata={ - 'bucket_meta': bucket_meta, - 'is_last': False + req_rank = int(req.get('rank', -1)) + req_bucket_id = int(req.get('bucket_id', -1)) + + if req_rank not in pending: + self._send_pyobj( + { + 'ok': False, + 'error': f'unexpected/duplicate rank={req_rank}' + }, + 'NEXT reply', + ) + continue + if req_bucket_id != bucket_id: + self._send_pyobj( + { + 'ok': False, + 'error': f'bucket mismatch rank={req_rank} got={req_bucket_id} expected={bucket_id}', }, - socket=self.socket, - topic=self.topic, + 'NEXT reply', ) + continue - send_buf, recv_buf = recv_buf, send_buf - bucket_meta = {} - offset = 0 + self._send_pyobj({'ok': True, 'metadata': metadata}, 'NEXT reply') + pending.remove(req_rank) - assert offset + weight.nbytes <= self.bucket_size + def _request_bucket(self, bucket_id: int) -> dict[str, Any]: + assert self.rank is not None and self.rank > 0 - bucket_meta[name] = { - 'name': name, - 'shape': weight.shape, - 'dtype': weight.dtype, - 'offset': offset, - } - send_buf[offset:offset + weight.nbytes] = weight.view(-1).view(torch.uint8) - offset += weight.nbytes - - torch.npu.synchronize() - if broadcast_op is not None: - await broadcast_op.wait_for_complete() - - broadcast_op = BroadcastOperation( - rank=self.rank, - process_group=self.pyhccl, - bucket=send_buf, - metadata={ - 'bucket_meta': bucket_meta, - 'is_last': True + self._send_pyobj( + { + 'type': 'NEXT', + 'rank': self.rank, + 'bucket_id': bucket_id }, - socket=self.socket, - topic=self.topic, + f'NEXT send bucket={bucket_id}', ) - await broadcast_op.wait_for_complete() - - elapsed = time.time() - start_time - logger.info(f'send_weights done: rank={self.rank}, time={elapsed:.2f}s') + resp = self._recv_pyobj(f'NEXT recv bucket={bucket_id}') + + if not isinstance(resp, dict): + raise RuntimeError(f'Invalid metadata response for bucket {bucket_id}: {resp}') + if not resp.get('ok', False): + raise RuntimeError(f'Metadata request failed for bucket {bucket_id}: {resp.get("error", "unknown")}') + metadata = resp.get('metadata') + if not isinstance(metadata, dict): + raise RuntimeError(f'Invalid metadata payload for bucket {bucket_id}: {metadata}') + got_bucket_id = int(metadata.get('bucket_id', -1)) + if got_bucket_id != bucket_id: + raise RuntimeError(f'Metadata bucket mismatch: got {got_bucket_id}, expected {bucket_id}') + return metadata @torch.no_grad() - async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: - """Receive model weights via HCCL broadcast.""" - assert self.rank is not None and self.rank > 0 + async def send_weights(self, weights: Generator[tuple[str, torch.Tensor]]): + assert self.rank is not None and self.rank <= 0 + if self.rank < 0: + for _name, _weight in weights: + pass + return - send_buf, recv_buf = self.send_buf, self.recv_buf - total_bytes, total_params = 0, 0 + assert self.send_buf is not None + send_buf = self.send_buf start_time = time.time() - broadcast_op = BroadcastOperation( - rank=self.rank, - process_group=self.pyhccl, - bucket=recv_buf, - metadata=None, - socket=self.socket, - topic=self.topic, - ) - metadata = await broadcast_op.wait_for_complete() - total_bytes += self.bucket_size - total_params += len(metadata['bucket_meta']) - - send_buf, recv_buf = recv_buf, send_buf - - while not metadata['is_last']: - broadcast_op = BroadcastOperation( - rank=self.rank, - process_group=self.pyhccl, - bucket=recv_buf, - metadata=None, - socket=self.socket, - topic=self.topic, - ) + bucket_meta: list[dict[str, Any]] = [] + offset = 0 + bucket_id = 0 + total_params = 0 + total_chunks = 0 + total_bytes = 0 + + def _flush(is_last: bool): + nonlocal bucket_meta, offset, bucket_id, total_chunks, total_bytes + if not bucket_meta and not is_last: + return + + metadata = { + 'bucket_id': bucket_id, + 'is_last': is_last, + 'bucket_meta': bucket_meta, + 'payload_size': offset, + } + self._serve_bucket_requests(bucket_id, metadata) + self.pyhccl.broadcast(send_buf, src=0) + torch.npu.synchronize() - for name, meta in metadata['bucket_meta'].items(): - dtype, shape = meta['dtype'], meta['shape'] - size = dtype.itemsize * shape.numel() - tensor = send_buf[meta['offset']:meta['offset'] + size].view(dtype=dtype).view(shape) - yield name, tensor + total_chunks += len(bucket_meta) + total_bytes += offset + bucket_id += 1 + bucket_meta = [] + offset = 0 - metadata = await broadcast_op.wait_for_complete() - total_bytes += self.bucket_size - total_params += len(metadata['bucket_meta']) + for name, weight in weights: + total_params += 1 + if weight.device.type != 'npu': + weight = weight.to('npu') + if not weight.is_contiguous(): + weight = weight.contiguous() + + weight_u8 = weight.view(-1).view(torch.uint8) + nbytes = int(weight_u8.numel()) + if nbytes == 0: + if offset >= self.bucket_size: + _flush(is_last=False) + bucket_meta.append({ + 'name': name, + 'shape': weight.shape, + 'dtype': weight.dtype, + 'offset': offset, + 'nbytes': 0, + 'chunk_offset': 0, + 'total_nbytes': 0, + }) + continue + + chunk_offset = 0 + while chunk_offset < nbytes: + if offset >= self.bucket_size: + _flush(is_last=False) + + chunk_nbytes = min(self.bucket_size - offset, nbytes - chunk_offset) + send_buf[offset:offset + chunk_nbytes].copy_(weight_u8[chunk_offset:chunk_offset + chunk_nbytes]) + bucket_meta.append({ + 'name': name, + 'shape': weight.shape, + 'dtype': weight.dtype, + 'offset': offset, + 'nbytes': chunk_nbytes, + 'chunk_offset': chunk_offset, + 'total_nbytes': nbytes, + }) + offset += chunk_nbytes + chunk_offset += chunk_nbytes + + _flush(is_last=True) + elapsed = time.time() - start_time + bandwidth = total_bytes / elapsed / (1024 * 1024 * 1024) if elapsed > 0 else 0.0 + logger.info(f'send_weights done: rank={self.rank}, params={total_params}, chunks={total_chunks}, ' + f'time={elapsed:.2f}s, bandwidth={bandwidth:.2f} GB/s') + + @torch.no_grad() + async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor]]: + assert self.rank is not None and self.rank > 0 + assert self.recv_buf is not None + + recv_buf = self.recv_buf + bucket_id = 0 + total_params = 0 + total_chunks = 0 + total_bytes = 0 + start_time = time.time() + partial_tensors: dict[str, dict[str, Any]] = {} + + while True: + metadata = self._request_bucket(bucket_id) + self.pyhccl.broadcast(recv_buf, src=0) torch.npu.synchronize() - send_buf, recv_buf = recv_buf, send_buf - for name, meta in metadata['bucket_meta'].items(): - dtype, shape = meta['dtype'], meta['shape'] - size = dtype.itemsize * shape.numel() - tensor = send_buf[meta['offset']:meta['offset'] + size].view(dtype=dtype).view(shape) - yield name, tensor + bucket_meta = metadata['bucket_meta'] + if isinstance(bucket_meta, dict): + entries = bucket_meta.values() + else: + entries = bucket_meta + + payload_size = int(metadata.get('payload_size', self.bucket_size)) + total_bytes += payload_size + + for meta in entries: + name = meta['name'] + dtype = meta['dtype'] + shape = meta['shape'] + shape = shape if isinstance(shape, torch.Size) else torch.Size(shape) + offset = int(meta['offset']) + nbytes = int(meta.get('nbytes', int(dtype.itemsize * shape.numel()))) + chunk_offset = int(meta.get('chunk_offset', 0)) + total_nbytes = int(meta.get('total_nbytes', int(dtype.itemsize * shape.numel()))) + total_chunks += 1 + + if nbytes == total_nbytes and chunk_offset == 0: + tensor = recv_buf[offset:offset + nbytes].view(dtype=dtype).view(shape) + yield name, tensor + total_params += 1 + continue + + state = partial_tensors.get(name) + if state is None: + state = { + 'buffer': torch.empty(total_nbytes, dtype=torch.uint8, device=recv_buf.device), + 'dtype': dtype, + 'shape': shape, + 'total': total_nbytes, + 'received': 0, + } + partial_tensors[name] = state + else: + if state['total'] != total_nbytes or state['dtype'] != dtype or state['shape'] != shape: + raise RuntimeError( + f'Inconsistent chunk metadata for weight {name}: ' + f'expected total={state["total"]}, dtype={state["dtype"]}, shape={state["shape"]}; ' + f'got total={total_nbytes}, dtype={dtype}, shape={shape}.') + + if nbytes > 0: + state['buffer'][chunk_offset:chunk_offset + nbytes].copy_(recv_buf[offset:offset + nbytes]) + state['received'] += nbytes + + if state['received'] > state['total']: + raise RuntimeError( + f'Chunk overrun for weight {name}: received={state["received"]}, total={state["total"]}.') + if state['received'] == state['total']: + full_size = int(dtype.itemsize * shape.numel()) + tensor = state['buffer'][:full_size].view(dtype=dtype).view(shape) + yield name, tensor + total_params += 1 + del partial_tensors[name] + + if bool(metadata['is_last']): + if partial_tensors: + pending = ', '.join(sorted(partial_tensors.keys())[:8]) + raise RuntimeError('Incomplete chunked weights at end of stream. ' + f'Pending {len(partial_tensors)} weight(s): {pending}') + break + bucket_id += 1 elapsed = time.time() - start_time - bandwidth = total_bytes / elapsed / (1024 * 1024 * 1024) - logger.info(f'receive_weights done: rank={self.rank}, params={total_params}, ' + bandwidth = total_bytes / elapsed / (1024 * 1024 * 1024) if elapsed > 0 else 0.0 + logger.info(f'receive_weights done: rank={self.rank}, params={total_params}, chunks={total_chunks}, ' f'time={elapsed:.2f}s, bandwidth={bandwidth:.2f} GB/s') diff --git a/src/twinkle/checkpoint_engine/mixin.py b/src/twinkle/checkpoint_engine/mixin.py index 1a4c4466..e2e5d94d 100644 --- a/src/twinkle/checkpoint_engine/mixin.py +++ b/src/twinkle/checkpoint_engine/mixin.py @@ -1,4 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import os + from twinkle import Platform, remote_function from twinkle.checkpoint_engine.base import CheckpointEngine @@ -16,7 +18,13 @@ def _get_or_create_checkpoint_engine(self) -> 'CheckpointEngine': self._checkpoint_engine = NCCLCheckpointEngine(self._bucket_size) elif Platform.get_platform().__name__ == 'NPU': from twinkle.checkpoint_engine import HCCLCheckpointEngine - self._checkpoint_engine = HCCLCheckpointEngine(self._bucket_size) + + # Reusing HCCL communicator across sync steps avoids frequent + # stream/channel allocation and reduces resource exhaustion risk. + self._checkpoint_engine = HCCLCheckpointEngine( + self._bucket_size, + rebuild_group=False, + ) return self._checkpoint_engine @remote_function(collect='first', lazy_collect=False) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index c7b886fe..a12a1e40 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -9,6 +9,8 @@ from twinkle.data_format.sampling import SampledSequence, SampleResponse, SamplingParams, StopReason from twinkle.sampler.base_engine import BaseSamplerEngine from twinkle.utils import Platform +from twinkle.utils.framework import Torch +from twinkle.utils.zmq_utils import configure_zmq_socket, get_timeout_s_from_env logger = get_logger() @@ -520,12 +522,11 @@ async def _sync_iter(): use_gpu_ipc = first_tensor.is_cuda use_shm = not use_gpu_ipc - # fix: On NPU, current_platform.get_device_uuid may be unimplemented and break receive_weights flow. - # fix: Route through platform-level fallback so IPC socket name remains stable. - # Get device UUID for ZMQ handle. - # For NPU, this is resolved from `npu-smi info` Bus-Id when needed. + # Use a per-sync unique IPC endpoint to avoid cross-actor collisions + # when multiple sampler actors share the same device UUID. device_uuid = Platform.get_vllm_device_uuid(0) - zmq_handle = f'ipc:///tmp/twinkle-ipc-{device_uuid}.sock' + sync_id = uuid.uuid4().hex + zmq_handle = f'ipc:///tmp/twinkle-ipc-{device_uuid}-{os.getpid()}-{sync_id}.sock' bucket_size = bucket_size_mb << 20 @@ -546,6 +547,8 @@ async def _sync_iter(): # Setup ZMQ socket FIRST (bind before worker connects) zmq_ctx = zmq.Context() socket = zmq_ctx.socket(zmq.REQ) + zmq_timeout_s = get_timeout_s_from_env('TWINKLE_VLLM_IPC_TIMEOUT_S', 300) + configure_zmq_socket(socket, timeout_ms=zmq_timeout_s * 1000, linger=0) socket.bind(zmq_handle) loop = asyncio.get_running_loop() @@ -555,9 +558,12 @@ async def _sync_iter(): # critical when TP > 1: collective_rpc is an async task on the # same loop, and blocking socket.recv() would prevent it from # being scheduled, causing a deadlock. - def _zmq_send_recv(payload): - socket.send_pyobj(payload) - return socket.recv() + def _zmq_send_recv(payload, where: str): + try: + socket.send_pyobj(payload) + return socket.recv() + except zmq.error.Again as e: + raise RuntimeError(f'IPC timeout ({zmq_timeout_s}s) during {where} on {zmq_handle}') from e # Launch worker side concurrently worker_task = asyncio.ensure_future( @@ -567,12 +573,13 @@ def _zmq_send_recv(payload): 'peft_config': peft_config, 'base_sync_done': base_sync_done, 'use_shm': use_shm, + 'zmq_handle': zmq_handle, }, )) # Send IPC/SHM handle, wait for worker ready (non-blocking) handle_payload = ipc_handle if use_gpu_ipc else {'name': shm_name, 'size': bucket_size} - await loop.run_in_executor(None, _zmq_send_recv, handle_payload) + await loop.run_in_executor(None, _zmq_send_recv, handle_payload, 'handle handshake') # Stream weights into buckets and send to worker async def _chain_first(): @@ -582,53 +589,60 @@ async def _chain_first(): yield item offset = 0 - bucket_meta: dict = {} + bucket_meta: list[dict] = [] n_weights = 0 + async def _flush_bucket(is_last: bool) -> None: + nonlocal offset, bucket_meta + if not bucket_meta and not is_last: + return + if buffer.device.type != 'cpu': + Torch.synchronize() + await loop.run_in_executor( + None, + _zmq_send_recv, + { + 'bucket_meta': bucket_meta, + 'is_last': is_last, + }, + 'final bucket' if is_last else 'bucket flush', + ) + offset = 0 + bucket_meta = [] + async for name, weight in _chain_first(): - if use_shm and weight.is_cuda: + if use_shm and weight.device.type != 'cpu': weight = weight.cpu() - - if weight.nbytes > bucket_size: - raise ValueError(f'Weight {name} ({weight.nbytes / (1 << 20):.1f} MB) exceeds ' - f'bucket size ({bucket_size_mb} MB). Increase bucket_size_mb.') - - # Flush current bucket if it would overflow - if offset + weight.nbytes > bucket_size: - if use_gpu_ipc: - torch.cuda.synchronize() - await loop.run_in_executor( - None, - _zmq_send_recv, - { - 'bucket_meta': bucket_meta, - 'is_last': False - }, + if not weight.is_contiguous(): + weight = weight.contiguous() + + weight_u8 = weight.view(-1).view(torch.uint8) + total_nbytes = int(weight_u8.numel()) + chunk_offset = 0 + while chunk_offset < total_nbytes: + if offset >= bucket_size: + await _flush_bucket(is_last=False) + + chunk_nbytes = min(bucket_size - offset, total_nbytes - chunk_offset) + buffer[offset:offset + chunk_nbytes].copy_( + weight_u8[chunk_offset:chunk_offset + chunk_nbytes], + non_blocking=True, ) - bucket_meta = {} - offset = 0 - - bucket_meta[name] = { - 'name': name, - 'shape': weight.shape, - 'dtype': weight.dtype, - 'offset': offset, - } - buffer[offset:offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) - offset += weight.nbytes + bucket_meta.append({ + 'name': name, + 'shape': weight.shape, + 'dtype': weight.dtype, + 'offset': offset, + 'nbytes': chunk_nbytes, + 'chunk_offset': chunk_offset, + 'total_nbytes': total_nbytes, + }) + offset += chunk_nbytes + chunk_offset += chunk_nbytes n_weights += 1 # Send last bucket - if use_gpu_ipc: - torch.cuda.synchronize() - await loop.run_in_executor( - None, - _zmq_send_recv, - { - 'bucket_meta': bucket_meta, - 'is_last': True - }, - ) + await _flush_bucket(is_last=True) # Wait for worker to finish loading await worker_task @@ -636,6 +650,13 @@ async def _chain_first(): # Clean up socket.close() zmq_ctx.term() + if zmq_handle.startswith('ipc://'): + ipc_path = zmq_handle[len('ipc://'):] + try: + if os.path.exists(ipc_path): + os.remove(ipc_path) + except OSError: + pass del buffer if shm is not None: shm.close() diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py index fa2fb748..42be5095 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py @@ -23,6 +23,7 @@ from twinkle import get_logger from twinkle.utils import Platform from twinkle.utils.framework import Torch +from twinkle.utils.zmq_utils import configure_zmq_socket, get_timeout_s_from_env logger = get_logger() @@ -97,6 +98,7 @@ def update_weights_from_ipc( peft_config: Optional[Dict] = None, base_sync_done: bool = False, use_shm: bool = False, + zmq_handle: Optional[str] = None, ) -> None: """Receive and load weights via ZMQ + CUDA IPC/SHM. @@ -114,6 +116,7 @@ def update_weights_from_ipc( peft_config: If provided with base_sync_done, loads as LoRA. base_sync_done: If True and peft_config, replaces existing LoRA. use_shm: If True, use shared memory instead of CUDA IPC. + zmq_handle: Optional ZMQ IPC endpoint. If None, uses _get_zmq_handle(). """ import torch.distributed as dist import zmq @@ -121,8 +124,10 @@ def update_weights_from_ipc( if self.device is None: # fix: In some worker paths, omitting local_rank can pick the wrong device / trigger get_device arg issues. # fix: Pass local_rank when available so each worker binds to the expected local device. - print(f"VLLM Worker local_rank: {getattr(self, 'local_rank', None)} <<<<<<<<<<<<< {Torch.get_device()}") - self.device = torch.device(Torch.get_device(getattr(self, 'local_rank', None))) + local_rank = getattr(self, 'local_rank', None) + device_str = Torch.get_device(local_rank) + logger.info(f'vLLM worker bind device: local_rank={local_rank}, device={device_str}') + self.device = torch.device(device_str) if peft_config and base_sync_done: self.remove_lora(VLLM_LORA_INT_ID) @@ -155,17 +160,23 @@ def _broadcast_obj(obj): # ── Step 1: Establish ZMQ connection (driver only) ── socket = None + zmq_timeout_s = get_timeout_s_from_env('TWINKLE_VLLM_IPC_TIMEOUT_S', 300) + endpoint = zmq_handle or self._get_zmq_handle() if is_driver: if not hasattr(self, '_zmq_ctx') or self._zmq_ctx is None: self._zmq_ctx = zmq.Context() socket = self._zmq_ctx.socket(zmq.REP) - socket.connect(self._get_zmq_handle()) + configure_zmq_socket(socket, timeout_ms=zmq_timeout_s * 1000, linger=0) + socket.connect(endpoint) # ── Step 2: Receive and broadcast IPC/SHM handle ── buffer, shm = None, None if is_driver: - comm_metadata = socket.recv_pyobj() + try: + comm_metadata = socket.recv_pyobj() + except zmq.error.Again as e: + raise RuntimeError(f'IPC timeout ({zmq_timeout_s}s) waiting handle on {endpoint}') from e else: comm_metadata = None @@ -188,10 +199,14 @@ def _broadcast_obj(obj): socket.send(b'') # Ready # ── Step 3: Receive and process weight buckets ── + partial_tensors: dict = {} while True: # Only the driver receives bucket metadata from VLLMEngine. if is_driver: - metadata = socket.recv_pyobj() + try: + metadata = socket.recv_pyobj() + except zmq.error.Again as e: + raise RuntimeError(f'IPC timeout ({zmq_timeout_s}s) waiting bucket metadata on {endpoint}') from e else: metadata = None @@ -199,15 +214,73 @@ def _broadcast_obj(obj): metadata = _broadcast_obj(metadata) weights = [] - for name, meta in metadata['bucket_meta'].items(): - shape, dtype, offset = meta['shape'], meta['dtype'], meta['offset'] - size = dtype.itemsize * shape.numel() - tensor = buffer[offset:offset + size].view(dtype=dtype).view(shape) - if not use_shm: - tensor = tensor.clone() + bucket_meta = metadata.get('bucket_meta', []) + if isinstance(bucket_meta, dict): + entries = list(bucket_meta.values()) + else: + entries = list(bucket_meta) + + # Drop old slice refs before creating new views into shared memory. + raw_u8 = None + cpu_u8 = None + tensor = None + assembled = None + state = None + for meta in entries: + name = meta['name'] + dtype = meta['dtype'] + shape = meta['shape'] + shape = shape if isinstance(shape, torch.Size) else torch.Size(shape) + offset = int(meta['offset']) + full_size = int(dtype.itemsize * shape.numel()) + nbytes = int(meta.get('nbytes', full_size)) + chunk_offset = int(meta.get('chunk_offset', 0)) + total_nbytes = int(meta.get('total_nbytes', full_size)) + + raw_u8 = buffer[offset:offset + nbytes] + + if nbytes == total_nbytes and chunk_offset == 0: + if use_shm: + cpu_u8 = raw_u8.clone() + tensor = cpu_u8.view(dtype=dtype).view(shape) + else: + tensor = raw_u8.view(dtype=dtype).view(shape).clone() + weights.append((name, tensor)) + continue + + state = partial_tensors.get(name) + if state is None: + state = { + 'buffer': torch.empty(total_nbytes, dtype=torch.uint8, device=buffer.device), + 'dtype': dtype, + 'shape': shape, + 'total': total_nbytes, + 'received': 0, + } + partial_tensors[name] = state else: - tensor = tensor.to(self.device) - weights.append((name, tensor)) + if state['total'] != total_nbytes or state['dtype'] != dtype or state['shape'] != shape: + raise RuntimeError( + f'Inconsistent chunk metadata for {name}: ' + f'expected(total={state["total"]}, dtype={state["dtype"]}, shape={state["shape"]}), ' + f'got(total={total_nbytes}, dtype={dtype}, shape={shape})') + + if nbytes > 0: + state['buffer'][chunk_offset:chunk_offset + nbytes].copy_(raw_u8) + state['received'] += nbytes + + if state['received'] > state['total']: + raise RuntimeError( + f'Chunk overrun for {name}: received={state["received"]}, total={state["total"]}') + + if state['received'] == state['total']: + assembled = state['buffer'].view(dtype=state['dtype']).view(state['shape']) + if use_shm: + tensor = assembled + else: + tensor = assembled.clone() + weights.append((name, tensor)) + del partial_tensors[name] Torch.synchronize() @@ -223,15 +296,34 @@ def _broadcast_obj(obj): del weights if metadata['is_last']: + if partial_tensors: + pending = ', '.join(sorted(partial_tensors.keys())[:8]) + raise RuntimeError( + f'Incomplete chunked weights at stream end: pending {len(partial_tensors)} ({pending})') break + partial_tensors.clear() + metadata = None + raw_u8 = None + cpu_u8 = None + tensor = None + assembled = None + state = None if is_driver and socket is not None: socket.close() del buffer + gc.collect() if shm is not None: - shm.close() + try: + shm.close() + except BufferError: + # Best effort: some temporary views may still be held by runtime internals. + gc.collect() + try: + shm.close() + except BufferError as e: + logger.warning(f'SharedMemory close skipped due to exported pointers: {e}') del shm - gc.collect() Torch.ipc_collect() Torch.empty_cache() diff --git a/src/twinkle/utils/framework.py b/src/twinkle/utils/framework.py index 5a23623e..0cdeb81d 100644 --- a/src/twinkle/utils/framework.py +++ b/src/twinkle/utils/framework.py @@ -126,10 +126,10 @@ def get_device(local_rank) -> str: local_rank = max(0, Platform.get_local_rank()) local_rank = str(local_rank) if Torch.is_gpu_available(): - from .platform import GPU + from .platforms import GPU device = f'{GPU.device_prefix()}:{local_rank}' elif Torch.is_npu_available(): - from .platform import NPU + from .platforms import NPU device = f'{NPU.device_prefix()}:{local_rank}' else: device = 'cpu' diff --git a/src/twinkle/utils/zmq_utils.py b/src/twinkle/utils/zmq_utils.py new file mode 100644 index 00000000..d6754123 --- /dev/null +++ b/src/twinkle/utils/zmq_utils.py @@ -0,0 +1,26 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Utilities for configuring ZeroMQ sockets consistently.""" + +from __future__ import annotations + +import os +import zmq + + +def get_timeout_s_from_env(env_name: str, default: int) -> int: + """Read timeout seconds from env and validate it.""" + raw_value = os.environ.get(env_name, str(default)) + try: + timeout_s = int(raw_value) + except ValueError as e: + raise ValueError(f'Invalid {env_name}={raw_value}, must be an integer > 0') from e + if timeout_s <= 0: + raise ValueError(f'Invalid {env_name}={timeout_s}, must be > 0') + return timeout_s + + +def configure_zmq_socket(socket: zmq.Socket, timeout_ms: int, linger: int = 0) -> None: + """Apply timeout/linger options to a ZMQ socket.""" + socket.setsockopt(zmq.RCVTIMEO, timeout_ms) + socket.setsockopt(zmq.SNDTIMEO, timeout_ms) + socket.setsockopt(zmq.LINGER, linger)