From 5bc1adf9574f1404e030a5e361c504250894a6d1 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Wed, 4 Mar 2026 00:32:53 +0800 Subject: [PATCH 01/11] [WIP]fix npu grpo --- .../hccl_checkpoint_engine.py | 533 ++++++++++-------- src/twinkle/infra/_ray/resource_manager.py | 16 +- .../sampler/vllm_sampler/vllm_engine.py | 123 ++-- .../vllm_sampler/vllm_worker_extension.py | 131 ++++- src/twinkle/utils/framework.py | 4 +- 5 files changed, 508 insertions(+), 299 deletions(-) diff --git a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py index e6b9cdde..fa25a5d0 100644 --- a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py +++ b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py @@ -2,23 +2,23 @@ # Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/hccl_checkpoint_engine.py """HCCL-based checkpoint engine for Ascend NPU. -This engine uses HCCL broadcast for efficient NPU-to-NPU weight transfer -across different processes/nodes. It supports: -- Double buffering for pipelined transfer -- ZMQ for metadata, HCCL for weight data -- Streaming weight transfer to avoid OOM +This implementation keeps HCCL for weight payload transfer and uses a +reliable ZMQ REQ/REP control channel for bucket metadata handshakes. """ -import asyncio +from __future__ import annotations + +import os import time -import torch -import zmq from dataclasses import dataclass from typing import Any, AsyncGenerator, Generator +import torch +import zmq + from twinkle import get_logger from twinkle.utils import find_free_port, find_node_ip, is_valid_ipv6_address, stateless_init_process_group -from .base import CheckpointEngine, TensorMeta +from .base import CheckpointEngine logger = get_logger() @@ -26,79 +26,15 @@ @dataclass class MasterMetadata: """Metadata from the master for process group initialization.""" + zmq_ip: str zmq_port: int dist_ip: str dist_port: int -class BroadcastOperation: - """Async broadcast operation with HCCL in separate thread. - - Args: - rank: The rank of the current process. - process_group: The HCCL process group. - bucket: The tensor buffer to broadcast. - metadata: The metadata of tensors in the bucket. - socket: The ZMQ socket for metadata communication. - topic: The ZMQ topic for pub/sub. - """ - - def __init__( - self, - rank: int, - process_group, - bucket: torch.Tensor, - metadata: dict[str, TensorMeta], - socket: zmq.Socket, - topic: str, - ) -> None: - self.rank = rank - self.pyhccl = process_group - self.bucket = bucket - self.metadata = metadata - self.socket = socket - self.topic = topic - - loop = asyncio.get_running_loop() - self._task = loop.run_in_executor(None, self._run) - - def _run(self): - """Execute the broadcast operation in a thread.""" - # Broadcast tensor metadata via ZMQ PUB/SUB - if self.rank == 0: - self.socket.send_string(self.topic, flags=zmq.SNDMORE) - self.socket.send_pyobj(self.metadata) - else: - self.socket.recv_string() - self.metadata = self.socket.recv_pyobj() - - # Broadcast tensor data via HCCL - self.pyhccl.broadcast(self.bucket, src=0) - - async def wait_for_complete(self) -> dict[str, TensorMeta]: - """Wait for the broadcast operation to complete. - - Returns: - The bucket metadata after broadcast. - """ - await self._task - return self.metadata - - class HCCLCheckpointEngine(CheckpointEngine): - """HCCL checkpoint engine for Ascend NPU. - - Same lifecycle and semantics as NCCLCheckpointEngine but uses HCCL - instead of NCCL and stateless_init_process_group instead of - ray.util.collective. - - Args: - bucket_size: Bucket size in bytes for weight transfer. - group_name: Name of the process group. - rebuild_group: Whether to rebuild the group each sync. - rollout_dtype: Target dtype for weights. - """ + """HCCL checkpoint engine for Ascend NPU.""" def __init__( self, @@ -108,76 +44,92 @@ def __init__( rollout_dtype: torch.dtype = torch.bfloat16, **kwargs, ) -> None: + bucket_mb_env = os.environ.get('TWINKLE_CKPT_HCCL_BUCKET_MB') + if bucket_mb_env: + bucket_size = int(bucket_mb_env) << 20 + self.bucket_size = bucket_size self.group_name = group_name self.rebuild_group = rebuild_group self.rollout_dtype = rollout_dtype self.pyhccl = None - # Get current NPU device + self.meta_timeout_s = int(os.environ.get('TWINKLE_CKPT_HCCL_META_TIMEOUT_S', '300')) + self.meta_timeout_ms = self.meta_timeout_s * 1000 + try: self.device = torch.npu.current_device() except Exception: self.device = 0 - # Set by Manager before prepare() via attribute assignment self.is_master = False - self.topic = 'bucket_metadata' - # Will be set during prepare / init_process_group - self.rank = None - self.world_size = None - self.send_buf = None - self.recv_buf = None - self.socket = None + self.rank: int | None = None + self.world_size: int | None = None + self.send_buf: torch.Tensor | None = None + self.recv_buf: torch.Tensor | None = None + self.socket: zmq.Socket | None = None + self._zmq_ctx: zmq.Context | None = None - # Track whether resources are ready for reuse self._prepared = False self._group_initialized = False - - # ── ZMQ helpers ────────────────────────────────────────────────────── + self.ip: str | None = None + self.zmq_port: int | None = None + self.dist_port: int | None = None + + def _new_socket(self, socket_type: int) -> zmq.Socket: + assert self._zmq_ctx is not None + socket = self._zmq_ctx.socket(socket_type) + socket.setsockopt(zmq.RCVTIMEO, self.meta_timeout_ms) + socket.setsockopt(zmq.SNDTIMEO, self.meta_timeout_ms) + socket.setsockopt(zmq.LINGER, 0) + return socket + + def _recv_pyobj(self, where: str) -> Any: + assert self.socket is not None + try: + return self.socket.recv_pyobj() + except zmq.error.Again as e: + raise RuntimeError( + f'HCCL metadata timeout ({self.meta_timeout_s}s) waiting at {where}.' + ) from e + + def _send_pyobj(self, payload: Any, where: str) -> None: + assert self.socket is not None + try: + self.socket.send_pyobj(payload) + except zmq.error.Again as e: + raise RuntimeError( + f'HCCL metadata timeout ({self.meta_timeout_s}s) sending at {where}.' + ) from e def _start_zmq_server(self): - """Start ZMQ PUB server for metadata broadcast (master only).""" self.ip = find_node_ip() self.zmq_port = find_free_port() self.dist_port = find_free_port() - context = zmq.Context() - self.socket = context.socket(zmq.PUB) + self._zmq_ctx = zmq.Context() + self.socket = self._new_socket(zmq.REP) if is_valid_ipv6_address(self.ip): address = f'tcp://[{self.ip}]:{self.zmq_port}' self.socket.setsockopt(zmq.IPV6, 1) else: address = f'tcp://{self.ip}:{self.zmq_port}' - self.socket.bind(address) - logger.debug(f'ZMQ PUB server started at {address}') + logger.debug(f'ZMQ REP server started at {address}') def _connect_zmq_client(self, metadata: MasterMetadata): - """Connect to the ZMQ PUB server as a subscriber (receiver only).""" - context = zmq.Context() - self.socket = context.socket(zmq.SUB) + self._zmq_ctx = zmq.Context() + self.socket = self._new_socket(zmq.REQ) if is_valid_ipv6_address(metadata.zmq_ip): address = f'tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}' self.socket.setsockopt(zmq.IPV6, 1) else: address = f'tcp://{metadata.zmq_ip}:{metadata.zmq_port}' - self.socket.connect(address) - self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic) - logger.debug(f'ZMQ SUB client connected to {address}') - - # ── Core lifecycle ─────────────────────────────────────────────────── + logger.debug(f'ZMQ REQ client connected to {address}') def prepare(self) -> MasterMetadata | None: - """Allocate double buffers and start ZMQ server (master only). - - Idempotent: skips if already prepared. - - Returns: - MasterMetadata with ZMQ/dist IP/port if master, else None. - """ if self._prepared: if self.is_master: return MasterMetadata( @@ -200,15 +152,11 @@ def prepare(self) -> MasterMetadata | None: dist_ip=self.ip, dist_port=self.dist_port, ) + self._prepared = True return None def finalize(self): - """Clean up resources after a sync. - - When ``rebuild_group=False``: keeps everything alive for reuse. - When ``rebuild_group=True``: full teardown. - """ if self.rebuild_group: if self.socket is not None: try: @@ -217,6 +165,13 @@ def finalize(self): logger.warning(f'Error closing ZMQ socket: {e}') self.socket = None + if self._zmq_ctx is not None: + try: + self._zmq_ctx.term() + except Exception as e: + logger.warning(f'Error terminating ZMQ context: {e}') + self._zmq_ctx = None + if self.rank is not None and self.rank >= 0 and self.pyhccl is not None: try: self.pyhccl.destroyComm(self.pyhccl.comm) @@ -238,10 +193,6 @@ def build_topology( rollout_world_size: int, metadata: list[dict], ) -> tuple[dict[str, list[Any]], dict[str, list[Any]]]: - """Build communication topology for HCCL broadcast. - - Same topology as NCCLCheckpointEngine. - """ master_metadata = None for m in metadata: if m is not None: @@ -261,24 +212,12 @@ def build_topology( return trainer_kwargs, rollout_kwargs def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata): - """Initialize the HCCL process group. - - Idempotent: if already initialized and ``rebuild_group`` is False, - this is a fast no-op. - - Args: - rank: The rank of this worker (-1 for non-participating trainers). - world_size: Total number of workers in the sync group. - master_metadata: Metadata from the master. - """ - # Non-participating trainer ranks if rank < 0: self.rank = rank self.world_size = world_size self._group_initialized = True return - # Fast path: already initialized if self._group_initialized and not self.rebuild_group: return @@ -297,143 +236,273 @@ def init_process_group(self, rank: int, world_size: int, master_metadata: Master assert self.rank == rank assert self.world_size == world_size - # Receivers connect to master's ZMQ PUB server if self.rank > 0 and self.socket is None: self._connect_zmq_client(master_metadata) - # Barrier using all_reduce signal = torch.tensor([1], dtype=torch.int8, device=torch.npu.current_device()) self.pyhccl.all_reduce(signal) self._group_initialized = True logger.info(f'init_process_group: rank={self.rank}, world_size={self.world_size}') - # ── Send / Receive ─────────────────────────────────────────────────── + def _serve_bucket_requests(self, bucket_id: int, metadata: dict[str, Any]) -> None: + assert self.rank == 0 + assert self.world_size is not None + + if self.world_size <= 1: + return + + pending = set(range(1, self.world_size)) + while pending: + req = self._recv_pyobj(f'NEXT request for bucket={bucket_id}') + + if not isinstance(req, dict) or req.get('type') != 'NEXT': + self._send_pyobj({'ok': False, 'error': f'unexpected message: {req}'}, 'NEXT reply') + continue + + req_rank = int(req.get('rank', -1)) + req_bucket_id = int(req.get('bucket_id', -1)) + + if req_rank not in pending: + self._send_pyobj( + {'ok': False, 'error': f'unexpected/duplicate rank={req_rank}'}, + 'NEXT reply', + ) + continue + if req_bucket_id != bucket_id: + self._send_pyobj( + { + 'ok': False, + 'error': f'bucket mismatch rank={req_rank} got={req_bucket_id} expected={bucket_id}', + }, + 'NEXT reply', + ) + continue + + self._send_pyobj({'ok': True, 'metadata': metadata}, 'NEXT reply') + pending.remove(req_rank) + + def _request_bucket(self, bucket_id: int) -> dict[str, Any]: + assert self.rank is not None and self.rank > 0 + + self._send_pyobj( + {'type': 'NEXT', 'rank': self.rank, 'bucket_id': bucket_id}, + f'NEXT send bucket={bucket_id}', + ) + resp = self._recv_pyobj(f'NEXT recv bucket={bucket_id}') + + if not isinstance(resp, dict): + raise RuntimeError(f'Invalid metadata response for bucket {bucket_id}: {resp}') + if not resp.get('ok', False): + raise RuntimeError( + f'Metadata request failed for bucket {bucket_id}: {resp.get("error", "unknown")}' + ) + metadata = resp.get('metadata') + if not isinstance(metadata, dict): + raise RuntimeError(f'Invalid metadata payload for bucket {bucket_id}: {metadata}') + got_bucket_id = int(metadata.get('bucket_id', -1)) + if got_bucket_id != bucket_id: + raise RuntimeError(f'Metadata bucket mismatch: got {got_bucket_id}, expected {bucket_id}') + return metadata + + @staticmethod + def _view_from_u8_buffer( + buffer: torch.Tensor, + offset: int, + size: int, + dtype: torch.dtype, + shape: torch.Size, + ) -> torch.Tensor: + raw = buffer[offset:offset + size] + itemsize = int(dtype.itemsize) + if itemsize > 1 and offset % itemsize != 0: + aligned = torch.empty(size, dtype=torch.uint8, device=buffer.device) + aligned.copy_(raw) + raw = aligned + return raw.view(dtype=dtype).view(shape) @torch.no_grad() async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): - """Send model weights via HCCL broadcast.""" assert self.rank is not None and self.rank <= 0 - if self.rank < 0: - for name, weight in weights: + for _name, _weight in weights: pass return - send_buf, recv_buf = self.send_buf, self.recv_buf - broadcast_op = None + assert self.send_buf is not None + send_buf = self.send_buf start_time = time.time() - bucket_meta: dict[str, TensorMeta] = {} + bucket_meta: list[dict[str, Any]] = [] offset = 0 + bucket_id = 0 + total_params = 0 + total_chunks = 0 + + def _flush(is_last: bool): + nonlocal bucket_meta, offset, bucket_id, total_chunks + if not bucket_meta and not is_last: + return + + metadata = { + 'bucket_id': bucket_id, + 'is_last': is_last, + 'bucket_meta': bucket_meta, + 'payload_size': offset, + } + self._serve_bucket_requests(bucket_id, metadata) + self.pyhccl.broadcast(send_buf, src=0) + torch.npu.synchronize() + + total_chunks += len(bucket_meta) + bucket_id += 1 + bucket_meta = [] + offset = 0 for name, weight in weights: - if offset + weight.nbytes > self.bucket_size: - torch.npu.synchronize() - - if broadcast_op is not None: - await broadcast_op.wait_for_complete() - - broadcast_op = BroadcastOperation( - rank=self.rank, - process_group=self.pyhccl, - bucket=send_buf, - metadata={ - 'bucket_meta': bucket_meta, - 'is_last': False - }, - socket=self.socket, - topic=self.topic, + total_params += 1 + if weight.device.type != 'npu': + weight = weight.to('npu') + if not weight.is_contiguous(): + weight = weight.contiguous() + + weight_u8 = weight.view(-1).view(torch.uint8) + nbytes = int(weight_u8.numel()) + if nbytes == 0: + if offset >= self.bucket_size: + _flush(is_last=False) + bucket_meta.append({ + 'name': name, + 'shape': weight.shape, + 'dtype': weight.dtype, + 'offset': offset, + 'nbytes': 0, + 'chunk_offset': 0, + 'total_nbytes': 0, + }) + continue + + chunk_offset = 0 + while chunk_offset < nbytes: + if offset >= self.bucket_size: + _flush(is_last=False) + + chunk_nbytes = min(self.bucket_size - offset, nbytes - chunk_offset) + send_buf[offset:offset + chunk_nbytes].copy_( + weight_u8[chunk_offset:chunk_offset + chunk_nbytes] ) - - send_buf, recv_buf = recv_buf, send_buf - bucket_meta = {} - offset = 0 - - assert offset + weight.nbytes <= self.bucket_size - - bucket_meta[name] = { - 'name': name, - 'shape': weight.shape, - 'dtype': weight.dtype, - 'offset': offset, - } - send_buf[offset:offset + weight.nbytes] = weight.view(-1).view(torch.uint8) - offset += weight.nbytes - - torch.npu.synchronize() - if broadcast_op is not None: - await broadcast_op.wait_for_complete() - - broadcast_op = BroadcastOperation( - rank=self.rank, - process_group=self.pyhccl, - bucket=send_buf, - metadata={ - 'bucket_meta': bucket_meta, - 'is_last': True - }, - socket=self.socket, - topic=self.topic, - ) - await broadcast_op.wait_for_complete() + bucket_meta.append({ + 'name': name, + 'shape': weight.shape, + 'dtype': weight.dtype, + 'offset': offset, + 'nbytes': chunk_nbytes, + 'chunk_offset': chunk_offset, + 'total_nbytes': nbytes, + }) + offset += chunk_nbytes + chunk_offset += chunk_nbytes + + _flush(is_last=True) elapsed = time.time() - start_time - logger.info(f'send_weights done: rank={self.rank}, time={elapsed:.2f}s') + logger.info( + f'send_weights done: rank={self.rank}, params={total_params}, ' + f'chunks={total_chunks}, time={elapsed:.2f}s' + ) @torch.no_grad() async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: - """Receive model weights via HCCL broadcast.""" assert self.rank is not None and self.rank > 0 + assert self.recv_buf is not None - send_buf, recv_buf = self.send_buf, self.recv_buf - total_bytes, total_params = 0, 0 - + recv_buf = self.recv_buf + bucket_id = 0 + total_params = 0 + total_chunks = 0 + total_bytes = 0 start_time = time.time() - broadcast_op = BroadcastOperation( - rank=self.rank, - process_group=self.pyhccl, - bucket=recv_buf, - metadata=None, - socket=self.socket, - topic=self.topic, - ) - metadata = await broadcast_op.wait_for_complete() - total_bytes += self.bucket_size - total_params += len(metadata['bucket_meta']) - - send_buf, recv_buf = recv_buf, send_buf - - while not metadata['is_last']: - broadcast_op = BroadcastOperation( - rank=self.rank, - process_group=self.pyhccl, - bucket=recv_buf, - metadata=None, - socket=self.socket, - topic=self.topic, - ) - - for name, meta in metadata['bucket_meta'].items(): - dtype, shape = meta['dtype'], meta['shape'] - size = dtype.itemsize * shape.numel() - tensor = send_buf[meta['offset']:meta['offset'] + size].view(dtype=dtype).view(shape) - yield name, tensor - - metadata = await broadcast_op.wait_for_complete() - total_bytes += self.bucket_size - total_params += len(metadata['bucket_meta']) + partial_tensors: dict[str, dict[str, Any]] = {} + while True: + metadata = self._request_bucket(bucket_id) + self.pyhccl.broadcast(recv_buf, src=0) torch.npu.synchronize() - send_buf, recv_buf = recv_buf, send_buf - for name, meta in metadata['bucket_meta'].items(): - dtype, shape = meta['dtype'], meta['shape'] - size = dtype.itemsize * shape.numel() - tensor = send_buf[meta['offset']:meta['offset'] + size].view(dtype=dtype).view(shape) - yield name, tensor + bucket_meta = metadata['bucket_meta'] + if isinstance(bucket_meta, dict): + entries = bucket_meta.values() + else: + entries = bucket_meta + + payload_size = int(metadata.get('payload_size', self.bucket_size)) + total_bytes += payload_size + + for meta in entries: + name = meta['name'] + dtype = meta['dtype'] + shape = meta['shape'] + shape = shape if isinstance(shape, torch.Size) else torch.Size(shape) + offset = int(meta['offset']) + nbytes = int(meta.get('nbytes', int(dtype.itemsize * shape.numel()))) + chunk_offset = int(meta.get('chunk_offset', 0)) + total_nbytes = int(meta.get('total_nbytes', int(dtype.itemsize * shape.numel()))) + total_chunks += 1 + + if nbytes == total_nbytes and chunk_offset == 0: + tensor = self._view_from_u8_buffer(recv_buf, offset, nbytes, dtype, shape) + yield name, tensor + total_params += 1 + continue + + state = partial_tensors.get(name) + if state is None: + state = { + 'buffer': torch.empty(total_nbytes, dtype=torch.uint8, device=recv_buf.device), + 'dtype': dtype, + 'shape': shape, + 'total': total_nbytes, + 'received': 0, + } + partial_tensors[name] = state + else: + if state['total'] != total_nbytes or state['dtype'] != dtype or state['shape'] != shape: + raise RuntimeError( + f'Inconsistent chunk metadata for weight {name}: ' + f'expected total={state["total"]}, dtype={state["dtype"]}, shape={state["shape"]}; ' + f'got total={total_nbytes}, dtype={dtype}, shape={shape}.' + ) + + if nbytes > 0: + state['buffer'][chunk_offset:chunk_offset + nbytes].copy_( + recv_buf[offset:offset + nbytes] + ) + state['received'] += nbytes + + if state['received'] > state['total']: + raise RuntimeError( + f'Chunk overrun for weight {name}: received={state["received"]}, total={state["total"]}.' + ) + if state['received'] == state['total']: + full_size = int(dtype.itemsize * shape.numel()) + tensor = self._view_from_u8_buffer(state['buffer'], 0, full_size, dtype, shape) + yield name, tensor + total_params += 1 + del partial_tensors[name] + + if bool(metadata['is_last']): + if partial_tensors: + pending = ', '.join(sorted(partial_tensors.keys())[:8]) + raise RuntimeError( + 'Incomplete chunked weights at end of stream. ' + f'Pending {len(partial_tensors)} weight(s): {pending}' + ) + break + bucket_id += 1 elapsed = time.time() - start_time - bandwidth = total_bytes / elapsed / (1024 * 1024 * 1024) - logger.info(f'receive_weights done: rank={self.rank}, params={total_params}, ' - f'time={elapsed:.2f}s, bandwidth={bandwidth:.2f} GB/s') + bandwidth = total_bytes / elapsed / (1024 * 1024 * 1024) if elapsed > 0 else 0.0 + logger.info( + f'receive_weights done: rank={self.rank}, params={total_params}, chunks={total_chunks}, ' + f'time={elapsed:.2f}s, bandwidth={bandwidth:.2f} GB/s' + ) diff --git a/src/twinkle/infra/_ray/resource_manager.py b/src/twinkle/infra/_ray/resource_manager.py index 09149503..67be367c 100644 --- a/src/twinkle/infra/_ray/resource_manager.py +++ b/src/twinkle/infra/_ray/resource_manager.py @@ -68,17 +68,29 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, groups: List[De self.nnodes = math.ceil(cpu_proc_count / ncpu_proc_per_node) self.nodes = [] + cluster_resource_totals = {} for node in ray.nodes(): # get available nodes resource = node['Resources'] + for name, amount in resource.items(): + if isinstance(amount, (int, float)): + cluster_resource_totals[name] = cluster_resource_totals.get(name, 0.0) + float(amount) node_device_num = int(resource.get(device_type, 0)) if device_type != 'CPU' and node_device_num >= nproc_per_node: self.nodes.append(node) if device_type == 'CPU' and int(node['Resources']['CPU']) // 4 >= ncpu_proc_per_node: self.nodes.append(node) - assert self.nnodes <= len( - self.nodes), f'Not enough resources, required nodes: {self.nnodes}, available: {len(self.nodes)}' + if self.nnodes > len(self.nodes): + hint = '' + if device_type == 'GPU' and cluster_resource_totals.get('NPU', 0) > 0 and cluster_resource_totals.get( + 'GPU', 0) == 0: + hint = " Hint: Ray cluster exposes 'NPU' resources but no 'GPU'. Set DeviceGroup.device_type='NPU'." + raise AssertionError( + f'Not enough resources, required nodes: {self.nnodes}, available: {len(self.nodes)}. ' + f"requested device: '{device_type}', cluster total for requested device: " + f"{int(cluster_resource_totals.get(device_type, 0))}. " + f'cluster resource keys: {sorted(cluster_resource_totals.keys())}.{hint}') bundles = [] cpu_bundles = [] diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index c7b886fe..03114993 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -492,6 +492,10 @@ async def update_weights( start_time = time.time() + bucket_size_mb = int(os.environ.get('TWINKLE_VLLM_IPC_BUCKET_MB', str(bucket_size_mb))) + if bucket_size_mb <= 0: + raise ValueError(f'Invalid TWINKLE_VLLM_IPC_BUCKET_MB={bucket_size_mb}, must be > 0') + # Normalise *weights* into an async iterator regardless of input type. if isinstance(weights, dict): @@ -520,12 +524,11 @@ async def _sync_iter(): use_gpu_ipc = first_tensor.is_cuda use_shm = not use_gpu_ipc - # fix: On NPU, current_platform.get_device_uuid may be unimplemented and break receive_weights flow. - # fix: Route through platform-level fallback so IPC socket name remains stable. - # Get device UUID for ZMQ handle. - # For NPU, this is resolved from `npu-smi info` Bus-Id when needed. + # Use a per-sync unique IPC endpoint to avoid cross-actor collisions + # when multiple sampler actors share the same device UUID. device_uuid = Platform.get_vllm_device_uuid(0) - zmq_handle = f'ipc:///tmp/twinkle-ipc-{device_uuid}.sock' + sync_id = uuid.uuid4().hex + zmq_handle = f'ipc:///tmp/twinkle-ipc-{device_uuid}-{os.getpid()}-{sync_id}.sock' bucket_size = bucket_size_mb << 20 @@ -546,6 +549,10 @@ async def _sync_iter(): # Setup ZMQ socket FIRST (bind before worker connects) zmq_ctx = zmq.Context() socket = zmq_ctx.socket(zmq.REQ) + zmq_timeout_s = int(os.environ.get('TWINKLE_VLLM_IPC_TIMEOUT_S', '300')) + socket.setsockopt(zmq.RCVTIMEO, zmq_timeout_s * 1000) + socket.setsockopt(zmq.SNDTIMEO, zmq_timeout_s * 1000) + socket.setsockopt(zmq.LINGER, 0) socket.bind(zmq_handle) loop = asyncio.get_running_loop() @@ -555,9 +562,14 @@ async def _sync_iter(): # critical when TP > 1: collective_rpc is an async task on the # same loop, and blocking socket.recv() would prevent it from # being scheduled, causing a deadlock. - def _zmq_send_recv(payload): - socket.send_pyobj(payload) - return socket.recv() + def _zmq_send_recv(payload, where: str): + try: + socket.send_pyobj(payload) + return socket.recv() + except zmq.error.Again as e: + raise RuntimeError( + f'IPC timeout ({zmq_timeout_s}s) during {where} on {zmq_handle}' + ) from e # Launch worker side concurrently worker_task = asyncio.ensure_future( @@ -567,12 +579,13 @@ def _zmq_send_recv(payload): 'peft_config': peft_config, 'base_sync_done': base_sync_done, 'use_shm': use_shm, + 'zmq_handle': zmq_handle, }, )) # Send IPC/SHM handle, wait for worker ready (non-blocking) handle_payload = ipc_handle if use_gpu_ipc else {'name': shm_name, 'size': bucket_size} - await loop.run_in_executor(None, _zmq_send_recv, handle_payload) + await loop.run_in_executor(None, _zmq_send_recv, handle_payload, 'handle handshake') # Stream weights into buckets and send to worker async def _chain_first(): @@ -582,53 +595,60 @@ async def _chain_first(): yield item offset = 0 - bucket_meta: dict = {} + bucket_meta: list[dict] = [] n_weights = 0 + async def _flush_bucket(is_last: bool) -> None: + nonlocal offset, bucket_meta + if not bucket_meta and not is_last: + return + if use_gpu_ipc: + torch.cuda.synchronize() + await loop.run_in_executor( + None, + _zmq_send_recv, + { + 'bucket_meta': bucket_meta, + 'is_last': is_last, + }, + 'final bucket' if is_last else 'bucket flush', + ) + offset = 0 + bucket_meta = [] + async for name, weight in _chain_first(): - if use_shm and weight.is_cuda: + if use_shm and weight.device.type != 'cpu': weight = weight.cpu() - - if weight.nbytes > bucket_size: - raise ValueError(f'Weight {name} ({weight.nbytes / (1 << 20):.1f} MB) exceeds ' - f'bucket size ({bucket_size_mb} MB). Increase bucket_size_mb.') - - # Flush current bucket if it would overflow - if offset + weight.nbytes > bucket_size: - if use_gpu_ipc: - torch.cuda.synchronize() - await loop.run_in_executor( - None, - _zmq_send_recv, - { - 'bucket_meta': bucket_meta, - 'is_last': False - }, + if not weight.is_contiguous(): + weight = weight.contiguous() + + weight_u8 = weight.view(-1).view(torch.uint8) + total_nbytes = int(weight_u8.numel()) + chunk_offset = 0 + while chunk_offset < total_nbytes: + if offset >= bucket_size: + await _flush_bucket(is_last=False) + + chunk_nbytes = min(bucket_size - offset, total_nbytes - chunk_offset) + buffer[offset:offset + chunk_nbytes].copy_( + weight_u8[chunk_offset:chunk_offset + chunk_nbytes], + non_blocking=True, ) - bucket_meta = {} - offset = 0 - - bucket_meta[name] = { - 'name': name, - 'shape': weight.shape, - 'dtype': weight.dtype, - 'offset': offset, - } - buffer[offset:offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) - offset += weight.nbytes + bucket_meta.append({ + 'name': name, + 'shape': weight.shape, + 'dtype': weight.dtype, + 'offset': offset, + 'nbytes': chunk_nbytes, + 'chunk_offset': chunk_offset, + 'total_nbytes': total_nbytes, + }) + offset += chunk_nbytes + chunk_offset += chunk_nbytes n_weights += 1 # Send last bucket - if use_gpu_ipc: - torch.cuda.synchronize() - await loop.run_in_executor( - None, - _zmq_send_recv, - { - 'bucket_meta': bucket_meta, - 'is_last': True - }, - ) + await _flush_bucket(is_last=True) # Wait for worker to finish loading await worker_task @@ -636,6 +656,13 @@ async def _chain_first(): # Clean up socket.close() zmq_ctx.term() + if zmq_handle.startswith('ipc://'): + ipc_path = zmq_handle[len('ipc://'):] + try: + if os.path.exists(ipc_path): + os.remove(ipc_path) + except OSError: + pass del buffer if shm is not None: shm.close() diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py index fa2fb748..1ed6cdd8 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py @@ -97,6 +97,7 @@ def update_weights_from_ipc( peft_config: Optional[Dict] = None, base_sync_done: bool = False, use_shm: bool = False, + zmq_handle: Optional[str] = None, ) -> None: """Receive and load weights via ZMQ + CUDA IPC/SHM. @@ -121,8 +122,10 @@ def update_weights_from_ipc( if self.device is None: # fix: In some worker paths, omitting local_rank can pick the wrong device / trigger get_device arg issues. # fix: Pass local_rank when available so each worker binds to the expected local device. - print(f"VLLM Worker local_rank: {getattr(self, 'local_rank', None)} <<<<<<<<<<<<< {Torch.get_device()}") - self.device = torch.device(Torch.get_device(getattr(self, 'local_rank', None))) + local_rank = getattr(self, 'local_rank', None) + device_str = Torch.get_device(local_rank) + logger.info(f'vLLM worker bind device: local_rank={local_rank}, device={device_str}') + self.device = torch.device(device_str) if peft_config and base_sync_done: self.remove_lora(VLLM_LORA_INT_ID) @@ -155,17 +158,27 @@ def _broadcast_obj(obj): # ── Step 1: Establish ZMQ connection (driver only) ── socket = None + zmq_timeout_s = int(os.environ.get('TWINKLE_VLLM_IPC_TIMEOUT_S', '300')) + endpoint = zmq_handle or self._get_zmq_handle() if is_driver: if not hasattr(self, '_zmq_ctx') or self._zmq_ctx is None: self._zmq_ctx = zmq.Context() socket = self._zmq_ctx.socket(zmq.REP) - socket.connect(self._get_zmq_handle()) + socket.setsockopt(zmq.RCVTIMEO, zmq_timeout_s * 1000) + socket.setsockopt(zmq.SNDTIMEO, zmq_timeout_s * 1000) + socket.setsockopt(zmq.LINGER, 0) + socket.connect(endpoint) # ── Step 2: Receive and broadcast IPC/SHM handle ── buffer, shm = None, None if is_driver: - comm_metadata = socket.recv_pyobj() + try: + comm_metadata = socket.recv_pyobj() + except zmq.error.Again as e: + raise RuntimeError( + f'IPC timeout ({zmq_timeout_s}s) waiting handle on {endpoint}' + ) from e else: comm_metadata = None @@ -188,10 +201,16 @@ def _broadcast_obj(obj): socket.send(b'') # Ready # ── Step 3: Receive and process weight buckets ── + partial_tensors: dict = {} while True: # Only the driver receives bucket metadata from VLLMEngine. if is_driver: - metadata = socket.recv_pyobj() + try: + metadata = socket.recv_pyobj() + except zmq.error.Again as e: + raise RuntimeError( + f'IPC timeout ({zmq_timeout_s}s) waiting bucket metadata on {endpoint}' + ) from e else: metadata = None @@ -199,15 +218,77 @@ def _broadcast_obj(obj): metadata = _broadcast_obj(metadata) weights = [] - for name, meta in metadata['bucket_meta'].items(): - shape, dtype, offset = meta['shape'], meta['dtype'], meta['offset'] - size = dtype.itemsize * shape.numel() - tensor = buffer[offset:offset + size].view(dtype=dtype).view(shape) - if not use_shm: - tensor = tensor.clone() + bucket_meta = metadata.get('bucket_meta', []) + if isinstance(bucket_meta, dict): + entries = list(bucket_meta.values()) + else: + entries = list(bucket_meta) + + # Drop old slice refs before creating new views into shared memory. + raw_u8 = None + cpu_u8 = None + tensor = None + assembled = None + state = None + for meta in entries: + name = meta['name'] + dtype = meta['dtype'] + shape = meta['shape'] + shape = shape if isinstance(shape, torch.Size) else torch.Size(shape) + offset = int(meta['offset']) + full_size = int(dtype.itemsize * shape.numel()) + nbytes = int(meta.get('nbytes', full_size)) + chunk_offset = int(meta.get('chunk_offset', 0)) + total_nbytes = int(meta.get('total_nbytes', full_size)) + + raw_u8 = buffer[offset:offset + nbytes] + + if nbytes == total_nbytes and chunk_offset == 0: + if use_shm: + cpu_u8 = raw_u8.clone() + # Keep SHM tensors on CPU; loading will copy into model params + # without allocating an extra full-size temporary tensor on NPU. + tensor = cpu_u8.view(dtype=dtype).view(shape) + else: + tensor = raw_u8.view(dtype=dtype).view(shape).clone() + weights.append((name, tensor)) + continue + + state = partial_tensors.get(name) + if state is None: + state = { + 'buffer': torch.empty(total_nbytes, dtype=torch.uint8, device=buffer.device), + 'dtype': dtype, + 'shape': shape, + 'total': total_nbytes, + 'received': 0, + } + partial_tensors[name] = state else: - tensor = tensor.to(self.device) - weights.append((name, tensor)) + if state['total'] != total_nbytes or state['dtype'] != dtype or state['shape'] != shape: + raise RuntimeError( + f'Inconsistent chunk metadata for {name}: ' + f'expected(total={state["total"]}, dtype={state["dtype"]}, shape={state["shape"]}), ' + f'got(total={total_nbytes}, dtype={dtype}, shape={shape})' + ) + + if nbytes > 0: + state['buffer'][chunk_offset:chunk_offset + nbytes].copy_(raw_u8) + state['received'] += nbytes + + if state['received'] > state['total']: + raise RuntimeError( + f'Chunk overrun for {name}: received={state["received"]}, total={state["total"]}' + ) + + if state['received'] == state['total']: + assembled = state['buffer'].view(dtype=state['dtype']).view(state['shape']) + if use_shm: + tensor = assembled + else: + tensor = assembled.clone() + weights.append((name, tensor)) + del partial_tensors[name] Torch.synchronize() @@ -223,15 +304,35 @@ def _broadcast_obj(obj): del weights if metadata['is_last']: + if partial_tensors: + pending = ', '.join(sorted(partial_tensors.keys())[:8]) + raise RuntimeError( + f'Incomplete chunked weights at stream end: pending {len(partial_tensors)} ({pending})' + ) break + partial_tensors.clear() + metadata = None + raw_u8 = None + cpu_u8 = None + tensor = None + assembled = None + state = None if is_driver and socket is not None: socket.close() del buffer + gc.collect() if shm is not None: - shm.close() + try: + shm.close() + except BufferError: + # Best effort: some temporary views may still be held by runtime internals. + gc.collect() + try: + shm.close() + except BufferError as e: + logger.warning(f'SharedMemory close skipped due to exported pointers: {e}') del shm - gc.collect() Torch.ipc_collect() Torch.empty_cache() diff --git a/src/twinkle/utils/framework.py b/src/twinkle/utils/framework.py index 5a23623e..0cdeb81d 100644 --- a/src/twinkle/utils/framework.py +++ b/src/twinkle/utils/framework.py @@ -126,10 +126,10 @@ def get_device(local_rank) -> str: local_rank = max(0, Platform.get_local_rank()) local_rank = str(local_rank) if Torch.is_gpu_available(): - from .platform import GPU + from .platforms import GPU device = f'{GPU.device_prefix()}:{local_rank}' elif Torch.is_npu_available(): - from .platform import NPU + from .platforms import NPU device = f'{NPU.device_prefix()}:{local_rank}' else: device = 'cpu' From 34cbceb4ba02d90a0cc8261293988b793a67e0a6 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Wed, 4 Mar 2026 00:40:56 +0800 Subject: [PATCH 02/11] drop useless code --- .../hccl_checkpoint_engine.py | 20 ++----------------- .../vllm_sampler/vllm_worker_extension.py | 2 -- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py index fa25a5d0..4a8053b5 100644 --- a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py +++ b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py @@ -305,22 +305,6 @@ def _request_bucket(self, bucket_id: int) -> dict[str, Any]: raise RuntimeError(f'Metadata bucket mismatch: got {got_bucket_id}, expected {bucket_id}') return metadata - @staticmethod - def _view_from_u8_buffer( - buffer: torch.Tensor, - offset: int, - size: int, - dtype: torch.dtype, - shape: torch.Size, - ) -> torch.Tensor: - raw = buffer[offset:offset + size] - itemsize = int(dtype.itemsize) - if itemsize > 1 and offset % itemsize != 0: - aligned = torch.empty(size, dtype=torch.uint8, device=buffer.device) - aligned.copy_(raw) - raw = aligned - return raw.view(dtype=dtype).view(shape) - @torch.no_grad() async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): assert self.rank is not None and self.rank <= 0 @@ -450,7 +434,7 @@ async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None total_chunks += 1 if nbytes == total_nbytes and chunk_offset == 0: - tensor = self._view_from_u8_buffer(recv_buf, offset, nbytes, dtype, shape) + tensor = recv_buf[offset:offset + nbytes].view(dtype=dtype).view(shape) yield name, tensor total_params += 1 continue @@ -485,7 +469,7 @@ async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None ) if state['received'] == state['total']: full_size = int(dtype.itemsize * shape.numel()) - tensor = self._view_from_u8_buffer(state['buffer'], 0, full_size, dtype, shape) + tensor = state['buffer'][:full_size].view(dtype=dtype).view(shape) yield name, tensor total_params += 1 del partial_tensors[name] diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py index 1ed6cdd8..b6fa6823 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py @@ -246,8 +246,6 @@ def _broadcast_obj(obj): if nbytes == total_nbytes and chunk_offset == 0: if use_shm: cpu_u8 = raw_u8.clone() - # Keep SHM tensors on CPU; loading will copy into model params - # without allocating an extra full-size temporary tensor on NPU. tensor = cpu_u8.view(dtype=dtype).view(shape) else: tensor = raw_u8.view(dtype=dtype).view(shape).clone() From ebf6177b167c3304bb8716746c08ffb5338f0adb Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Wed, 4 Mar 2026 01:16:49 +0800 Subject: [PATCH 03/11] [WIP] remove useless code --- .../hccl_checkpoint_engine.py | 61 ++++++------------- 1 file changed, 17 insertions(+), 44 deletions(-) diff --git a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py index 4a8053b5..a697e9df 100644 --- a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py +++ b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py @@ -44,10 +44,6 @@ def __init__( rollout_dtype: torch.dtype = torch.bfloat16, **kwargs, ) -> None: - bucket_mb_env = os.environ.get('TWINKLE_CKPT_HCCL_BUCKET_MB') - if bucket_mb_env: - bucket_size = int(bucket_mb_env) << 20 - self.bucket_size = bucket_size self.group_name = group_name self.rebuild_group = rebuild_group @@ -321,10 +317,9 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, offset = 0 bucket_id = 0 total_params = 0 - total_chunks = 0 def _flush(is_last: bool): - nonlocal bucket_meta, offset, bucket_id, total_chunks + nonlocal bucket_meta, offset, bucket_id if not bucket_meta and not is_last: return @@ -338,7 +333,6 @@ def _flush(is_last: bool): self.pyhccl.broadcast(send_buf, src=0) torch.npu.synchronize() - total_chunks += len(bucket_meta) bucket_id += 1 bucket_meta = [] offset = 0 @@ -352,48 +346,27 @@ def _flush(is_last: bool): weight_u8 = weight.view(-1).view(torch.uint8) nbytes = int(weight_u8.numel()) - if nbytes == 0: - if offset >= self.bucket_size: - _flush(is_last=False) - bucket_meta.append({ - 'name': name, - 'shape': weight.shape, - 'dtype': weight.dtype, - 'offset': offset, - 'nbytes': 0, - 'chunk_offset': 0, - 'total_nbytes': 0, - }) - continue - - chunk_offset = 0 - while chunk_offset < nbytes: - if offset >= self.bucket_size: - _flush(is_last=False) - - chunk_nbytes = min(self.bucket_size - offset, nbytes - chunk_offset) - send_buf[offset:offset + chunk_nbytes].copy_( - weight_u8[chunk_offset:chunk_offset + chunk_nbytes] + if nbytes > self.bucket_size: + raise ValueError( + f'Weight {name}({tuple(weight.shape)}, {weight.dtype}) is too large ' + f'for bucket ({self.bucket_size / (1 << 20):.1f} MB). Increase bucket size.' ) - bucket_meta.append({ - 'name': name, - 'shape': weight.shape, - 'dtype': weight.dtype, - 'offset': offset, - 'nbytes': chunk_nbytes, - 'chunk_offset': chunk_offset, - 'total_nbytes': nbytes, - }) - offset += chunk_nbytes - chunk_offset += chunk_nbytes + if offset + nbytes > self.bucket_size: + _flush(is_last=False) + + send_buf[offset:offset + nbytes].copy_(weight_u8) + bucket_meta.append({ + 'name': name, + 'shape': weight.shape, + 'dtype': weight.dtype, + 'offset': offset, + }) + offset += nbytes _flush(is_last=True) elapsed = time.time() - start_time - logger.info( - f'send_weights done: rank={self.rank}, params={total_params}, ' - f'chunks={total_chunks}, time={elapsed:.2f}s' - ) + logger.info(f'send_weights done: rank={self.rank}, params={total_params}, time={elapsed:.2f}s') @torch.no_grad() async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: From d928e403cba1c453ef00d61c305e2ae967675a43 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Wed, 4 Mar 2026 01:28:29 +0800 Subject: [PATCH 04/11] fix hccl resource problem --- src/twinkle/checkpoint_engine/mixin.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/twinkle/checkpoint_engine/mixin.py b/src/twinkle/checkpoint_engine/mixin.py index 1a4c4466..a38b8e7e 100644 --- a/src/twinkle/checkpoint_engine/mixin.py +++ b/src/twinkle/checkpoint_engine/mixin.py @@ -1,4 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import os + from twinkle import Platform, remote_function from twinkle.checkpoint_engine.base import CheckpointEngine @@ -16,7 +18,13 @@ def _get_or_create_checkpoint_engine(self) -> 'CheckpointEngine': self._checkpoint_engine = NCCLCheckpointEngine(self._bucket_size) elif Platform.get_platform().__name__ == 'NPU': from twinkle.checkpoint_engine import HCCLCheckpointEngine - self._checkpoint_engine = HCCLCheckpointEngine(self._bucket_size) + # Reusing HCCL communicator across sync steps avoids frequent + # stream/channel allocation and reduces resource exhaustion risk. + rebuild_group = bool(int(os.environ.get('TWINKLE_CKPT_HCCL_REBUILD_GROUP', '0'))) + self._checkpoint_engine = HCCLCheckpointEngine( + self._bucket_size, + rebuild_group=rebuild_group, + ) return self._checkpoint_engine @remote_function(collect='first', lazy_collect=False) From f636a586daff5c44e9d518a0a037d445fe896dec Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Thu, 5 Mar 2026 09:22:08 +0800 Subject: [PATCH 05/11] fix lint --- .../hccl_checkpoint_engine.py | 60 ++++++++----------- src/twinkle/checkpoint_engine/mixin.py | 1 + src/twinkle/infra/_ray/resource_manager.py | 9 ++- .../vllm_sampler/vllm_worker_extension.py | 17 ++---- 4 files changed, 35 insertions(+), 52 deletions(-) diff --git a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py index a697e9df..9b72302e 100644 --- a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py +++ b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py @@ -10,11 +10,10 @@ import os import time -from dataclasses import dataclass -from typing import Any, AsyncGenerator, Generator - import torch import zmq +from dataclasses import dataclass +from typing import Any, AsyncGenerator, Generator from twinkle import get_logger from twinkle.utils import find_free_port, find_node_ip, is_valid_ipv6_address, stateless_init_process_group @@ -86,18 +85,14 @@ def _recv_pyobj(self, where: str) -> Any: try: return self.socket.recv_pyobj() except zmq.error.Again as e: - raise RuntimeError( - f'HCCL metadata timeout ({self.meta_timeout_s}s) waiting at {where}.' - ) from e + raise RuntimeError(f'HCCL metadata timeout ({self.meta_timeout_s}s) waiting at {where}.') from e def _send_pyobj(self, payload: Any, where: str) -> None: assert self.socket is not None try: self.socket.send_pyobj(payload) except zmq.error.Again as e: - raise RuntimeError( - f'HCCL metadata timeout ({self.meta_timeout_s}s) sending at {where}.' - ) from e + raise RuntimeError(f'HCCL metadata timeout ({self.meta_timeout_s}s) sending at {where}.') from e def _start_zmq_server(self): self.ip = find_node_ip() @@ -261,7 +256,10 @@ def _serve_bucket_requests(self, bucket_id: int, metadata: dict[str, Any]) -> No if req_rank not in pending: self._send_pyobj( - {'ok': False, 'error': f'unexpected/duplicate rank={req_rank}'}, + { + 'ok': False, + 'error': f'unexpected/duplicate rank={req_rank}' + }, 'NEXT reply', ) continue @@ -282,7 +280,11 @@ def _request_bucket(self, bucket_id: int) -> dict[str, Any]: assert self.rank is not None and self.rank > 0 self._send_pyobj( - {'type': 'NEXT', 'rank': self.rank, 'bucket_id': bucket_id}, + { + 'type': 'NEXT', + 'rank': self.rank, + 'bucket_id': bucket_id + }, f'NEXT send bucket={bucket_id}', ) resp = self._recv_pyobj(f'NEXT recv bucket={bucket_id}') @@ -290,9 +292,7 @@ def _request_bucket(self, bucket_id: int) -> dict[str, Any]: if not isinstance(resp, dict): raise RuntimeError(f'Invalid metadata response for bucket {bucket_id}: {resp}') if not resp.get('ok', False): - raise RuntimeError( - f'Metadata request failed for bucket {bucket_id}: {resp.get("error", "unknown")}' - ) + raise RuntimeError(f'Metadata request failed for bucket {bucket_id}: {resp.get("error", "unknown")}') metadata = resp.get('metadata') if not isinstance(metadata, dict): raise RuntimeError(f'Invalid metadata payload for bucket {bucket_id}: {metadata}') @@ -302,7 +302,7 @@ def _request_bucket(self, bucket_id: int) -> dict[str, Any]: return metadata @torch.no_grad() - async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + async def send_weights(self, weights: Generator[tuple[str, torch.Tensor]]): assert self.rank is not None and self.rank <= 0 if self.rank < 0: for _name, _weight in weights: @@ -347,10 +347,8 @@ def _flush(is_last: bool): weight_u8 = weight.view(-1).view(torch.uint8) nbytes = int(weight_u8.numel()) if nbytes > self.bucket_size: - raise ValueError( - f'Weight {name}({tuple(weight.shape)}, {weight.dtype}) is too large ' - f'for bucket ({self.bucket_size / (1 << 20):.1f} MB). Increase bucket size.' - ) + raise ValueError(f'Weight {name}({tuple(weight.shape)}, {weight.dtype}) is too large ' + f'for bucket ({self.bucket_size / (1 << 20):.1f} MB). Increase bucket size.') if offset + nbytes > self.bucket_size: _flush(is_last=False) @@ -369,7 +367,7 @@ def _flush(is_last: bool): logger.info(f'send_weights done: rank={self.rank}, params={total_params}, time={elapsed:.2f}s') @torch.no_grad() - async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: + async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor]]: assert self.rank is not None and self.rank > 0 assert self.recv_buf is not None @@ -427,19 +425,15 @@ async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None raise RuntimeError( f'Inconsistent chunk metadata for weight {name}: ' f'expected total={state["total"]}, dtype={state["dtype"]}, shape={state["shape"]}; ' - f'got total={total_nbytes}, dtype={dtype}, shape={shape}.' - ) + f'got total={total_nbytes}, dtype={dtype}, shape={shape}.') if nbytes > 0: - state['buffer'][chunk_offset:chunk_offset + nbytes].copy_( - recv_buf[offset:offset + nbytes] - ) + state['buffer'][chunk_offset:chunk_offset + nbytes].copy_(recv_buf[offset:offset + nbytes]) state['received'] += nbytes if state['received'] > state['total']: raise RuntimeError( - f'Chunk overrun for weight {name}: received={state["received"]}, total={state["total"]}.' - ) + f'Chunk overrun for weight {name}: received={state["received"]}, total={state["total"]}.') if state['received'] == state['total']: full_size = int(dtype.itemsize * shape.numel()) tensor = state['buffer'][:full_size].view(dtype=dtype).view(shape) @@ -450,16 +444,12 @@ async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None if bool(metadata['is_last']): if partial_tensors: pending = ', '.join(sorted(partial_tensors.keys())[:8]) - raise RuntimeError( - 'Incomplete chunked weights at end of stream. ' - f'Pending {len(partial_tensors)} weight(s): {pending}' - ) + raise RuntimeError('Incomplete chunked weights at end of stream. ' + f'Pending {len(partial_tensors)} weight(s): {pending}') break bucket_id += 1 elapsed = time.time() - start_time bandwidth = total_bytes / elapsed / (1024 * 1024 * 1024) if elapsed > 0 else 0.0 - logger.info( - f'receive_weights done: rank={self.rank}, params={total_params}, chunks={total_chunks}, ' - f'time={elapsed:.2f}s, bandwidth={bandwidth:.2f} GB/s' - ) + logger.info(f'receive_weights done: rank={self.rank}, params={total_params}, chunks={total_chunks}, ' + f'time={elapsed:.2f}s, bandwidth={bandwidth:.2f} GB/s') diff --git a/src/twinkle/checkpoint_engine/mixin.py b/src/twinkle/checkpoint_engine/mixin.py index a38b8e7e..8adcc9f5 100644 --- a/src/twinkle/checkpoint_engine/mixin.py +++ b/src/twinkle/checkpoint_engine/mixin.py @@ -18,6 +18,7 @@ def _get_or_create_checkpoint_engine(self) -> 'CheckpointEngine': self._checkpoint_engine = NCCLCheckpointEngine(self._bucket_size) elif Platform.get_platform().__name__ == 'NPU': from twinkle.checkpoint_engine import HCCLCheckpointEngine + # Reusing HCCL communicator across sync steps avoids frequent # stream/channel allocation and reduces resource exhaustion risk. rebuild_group = bool(int(os.environ.get('TWINKLE_CKPT_HCCL_REBUILD_GROUP', '0'))) diff --git a/src/twinkle/infra/_ray/resource_manager.py b/src/twinkle/infra/_ray/resource_manager.py index 67be367c..455cff67 100644 --- a/src/twinkle/infra/_ray/resource_manager.py +++ b/src/twinkle/infra/_ray/resource_manager.py @@ -86,11 +86,10 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, groups: List[De if device_type == 'GPU' and cluster_resource_totals.get('NPU', 0) > 0 and cluster_resource_totals.get( 'GPU', 0) == 0: hint = " Hint: Ray cluster exposes 'NPU' resources but no 'GPU'. Set DeviceGroup.device_type='NPU'." - raise AssertionError( - f'Not enough resources, required nodes: {self.nnodes}, available: {len(self.nodes)}. ' - f"requested device: '{device_type}', cluster total for requested device: " - f"{int(cluster_resource_totals.get(device_type, 0))}. " - f'cluster resource keys: {sorted(cluster_resource_totals.keys())}.{hint}') + raise AssertionError(f'Not enough resources, required nodes: {self.nnodes}, available: {len(self.nodes)}. ' + f"requested device: '{device_type}', cluster total for requested device: " + f'{int(cluster_resource_totals.get(device_type, 0))}. ' + f'cluster resource keys: {sorted(cluster_resource_totals.keys())}.{hint}') bundles = [] cpu_bundles = [] diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py index b6fa6823..f4f258ca 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py @@ -176,9 +176,7 @@ def _broadcast_obj(obj): try: comm_metadata = socket.recv_pyobj() except zmq.error.Again as e: - raise RuntimeError( - f'IPC timeout ({zmq_timeout_s}s) waiting handle on {endpoint}' - ) from e + raise RuntimeError(f'IPC timeout ({zmq_timeout_s}s) waiting handle on {endpoint}') from e else: comm_metadata = None @@ -208,9 +206,7 @@ def _broadcast_obj(obj): try: metadata = socket.recv_pyobj() except zmq.error.Again as e: - raise RuntimeError( - f'IPC timeout ({zmq_timeout_s}s) waiting bucket metadata on {endpoint}' - ) from e + raise RuntimeError(f'IPC timeout ({zmq_timeout_s}s) waiting bucket metadata on {endpoint}') from e else: metadata = None @@ -267,8 +263,7 @@ def _broadcast_obj(obj): raise RuntimeError( f'Inconsistent chunk metadata for {name}: ' f'expected(total={state["total"]}, dtype={state["dtype"]}, shape={state["shape"]}), ' - f'got(total={total_nbytes}, dtype={dtype}, shape={shape})' - ) + f'got(total={total_nbytes}, dtype={dtype}, shape={shape})') if nbytes > 0: state['buffer'][chunk_offset:chunk_offset + nbytes].copy_(raw_u8) @@ -276,8 +271,7 @@ def _broadcast_obj(obj): if state['received'] > state['total']: raise RuntimeError( - f'Chunk overrun for {name}: received={state["received"]}, total={state["total"]}' - ) + f'Chunk overrun for {name}: received={state["received"]}, total={state["total"]}') if state['received'] == state['total']: assembled = state['buffer'].view(dtype=state['dtype']).view(state['shape']) @@ -305,8 +299,7 @@ def _broadcast_obj(obj): if partial_tensors: pending = ', '.join(sorted(partial_tensors.keys())[:8]) raise RuntimeError( - f'Incomplete chunked weights at stream end: pending {len(partial_tensors)} ({pending})' - ) + f'Incomplete chunked weights at stream end: pending {len(partial_tensors)} ({pending})') break partial_tensors.clear() From d83dd50e9d91dcf384d5ed1364749cbe499f4630 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Thu, 5 Mar 2026 09:25:49 +0800 Subject: [PATCH 06/11] fix lint --- src/twinkle/sampler/vllm_sampler/vllm_engine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 03114993..d8bafc97 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -567,9 +567,7 @@ def _zmq_send_recv(payload, where: str): socket.send_pyobj(payload) return socket.recv() except zmq.error.Again as e: - raise RuntimeError( - f'IPC timeout ({zmq_timeout_s}s) during {where} on {zmq_handle}' - ) from e + raise RuntimeError(f'IPC timeout ({zmq_timeout_s}s) during {where} on {zmq_handle}') from e # Launch worker side concurrently worker_task = asyncio.ensure_future( From 8e8f0373dcdbd2d798ae24741f9db68fb7dec69f Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Thu, 5 Mar 2026 09:31:51 +0800 Subject: [PATCH 07/11] fix send_weights --- .../hccl_checkpoint_engine.py | 56 +++++++++++++------ 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py index 9b72302e..e3ffdd0b 100644 --- a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py +++ b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py @@ -317,9 +317,11 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor]]): offset = 0 bucket_id = 0 total_params = 0 + total_chunks = 0 + total_bytes = 0 def _flush(is_last: bool): - nonlocal bucket_meta, offset, bucket_id + nonlocal bucket_meta, offset, bucket_id, total_chunks, total_bytes if not bucket_meta and not is_last: return @@ -333,6 +335,8 @@ def _flush(is_last: bool): self.pyhccl.broadcast(send_buf, src=0) torch.npu.synchronize() + total_chunks += len(bucket_meta) + total_bytes += offset bucket_id += 1 bucket_meta = [] offset = 0 @@ -346,25 +350,45 @@ def _flush(is_last: bool): weight_u8 = weight.view(-1).view(torch.uint8) nbytes = int(weight_u8.numel()) - if nbytes > self.bucket_size: - raise ValueError(f'Weight {name}({tuple(weight.shape)}, {weight.dtype}) is too large ' - f'for bucket ({self.bucket_size / (1 << 20):.1f} MB). Increase bucket size.') - if offset + nbytes > self.bucket_size: - _flush(is_last=False) - - send_buf[offset:offset + nbytes].copy_(weight_u8) - bucket_meta.append({ - 'name': name, - 'shape': weight.shape, - 'dtype': weight.dtype, - 'offset': offset, - }) - offset += nbytes + if nbytes == 0: + if offset >= self.bucket_size: + _flush(is_last=False) + bucket_meta.append({ + 'name': name, + 'shape': weight.shape, + 'dtype': weight.dtype, + 'offset': offset, + 'nbytes': 0, + 'chunk_offset': 0, + 'total_nbytes': 0, + }) + continue + + chunk_offset = 0 + while chunk_offset < nbytes: + if offset >= self.bucket_size: + _flush(is_last=False) + + chunk_nbytes = min(self.bucket_size - offset, nbytes - chunk_offset) + send_buf[offset:offset + chunk_nbytes].copy_(weight_u8[chunk_offset:chunk_offset + chunk_nbytes]) + bucket_meta.append({ + 'name': name, + 'shape': weight.shape, + 'dtype': weight.dtype, + 'offset': offset, + 'nbytes': chunk_nbytes, + 'chunk_offset': chunk_offset, + 'total_nbytes': nbytes, + }) + offset += chunk_nbytes + chunk_offset += chunk_nbytes _flush(is_last=True) elapsed = time.time() - start_time - logger.info(f'send_weights done: rank={self.rank}, params={total_params}, time={elapsed:.2f}s') + bandwidth = total_bytes / elapsed / (1024 * 1024 * 1024) if elapsed > 0 else 0.0 + logger.info(f'send_weights done: rank={self.rank}, params={total_params}, chunks={total_chunks}, ' + f'time={elapsed:.2f}s, bandwidth={bandwidth:.2f} GB/s') @torch.no_grad() async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor]]: From 54d6ca7452f7be14e45247cd76820ee9766dcabe Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Thu, 5 Mar 2026 09:43:51 +0800 Subject: [PATCH 08/11] drop useless code --- src/twinkle/infra/_ray/resource_manager.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/twinkle/infra/_ray/resource_manager.py b/src/twinkle/infra/_ray/resource_manager.py index 455cff67..09149503 100644 --- a/src/twinkle/infra/_ray/resource_manager.py +++ b/src/twinkle/infra/_ray/resource_manager.py @@ -68,28 +68,17 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, groups: List[De self.nnodes = math.ceil(cpu_proc_count / ncpu_proc_per_node) self.nodes = [] - cluster_resource_totals = {} for node in ray.nodes(): # get available nodes resource = node['Resources'] - for name, amount in resource.items(): - if isinstance(amount, (int, float)): - cluster_resource_totals[name] = cluster_resource_totals.get(name, 0.0) + float(amount) node_device_num = int(resource.get(device_type, 0)) if device_type != 'CPU' and node_device_num >= nproc_per_node: self.nodes.append(node) if device_type == 'CPU' and int(node['Resources']['CPU']) // 4 >= ncpu_proc_per_node: self.nodes.append(node) - if self.nnodes > len(self.nodes): - hint = '' - if device_type == 'GPU' and cluster_resource_totals.get('NPU', 0) > 0 and cluster_resource_totals.get( - 'GPU', 0) == 0: - hint = " Hint: Ray cluster exposes 'NPU' resources but no 'GPU'. Set DeviceGroup.device_type='NPU'." - raise AssertionError(f'Not enough resources, required nodes: {self.nnodes}, available: {len(self.nodes)}. ' - f"requested device: '{device_type}', cluster total for requested device: " - f'{int(cluster_resource_totals.get(device_type, 0))}. ' - f'cluster resource keys: {sorted(cluster_resource_totals.keys())}.{hint}') + assert self.nnodes <= len( + self.nodes), f'Not enough resources, required nodes: {self.nnodes}, available: {len(self.nodes)}' bundles = [] cpu_bundles = [] From a77258ee3367e810017122198328be8e1c71f3c7 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Thu, 5 Mar 2026 14:30:27 +0800 Subject: [PATCH 09/11] fix review --- src/twinkle/checkpoint_engine/mixin.py | 3 +-- src/twinkle/sampler/vllm_sampler/vllm_engine.py | 9 +++------ .../sampler/vllm_sampler/vllm_worker_extension.py | 1 + 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/twinkle/checkpoint_engine/mixin.py b/src/twinkle/checkpoint_engine/mixin.py index 8adcc9f5..e2e5d94d 100644 --- a/src/twinkle/checkpoint_engine/mixin.py +++ b/src/twinkle/checkpoint_engine/mixin.py @@ -21,10 +21,9 @@ def _get_or_create_checkpoint_engine(self) -> 'CheckpointEngine': # Reusing HCCL communicator across sync steps avoids frequent # stream/channel allocation and reduces resource exhaustion risk. - rebuild_group = bool(int(os.environ.get('TWINKLE_CKPT_HCCL_REBUILD_GROUP', '0'))) self._checkpoint_engine = HCCLCheckpointEngine( self._bucket_size, - rebuild_group=rebuild_group, + rebuild_group=False, ) return self._checkpoint_engine diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index d8bafc97..5e04ebad 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -9,6 +9,7 @@ from twinkle.data_format.sampling import SampledSequence, SampleResponse, SamplingParams, StopReason from twinkle.sampler.base_engine import BaseSamplerEngine from twinkle.utils import Platform +from twinkle.utils.framework import Torch logger = get_logger() @@ -492,10 +493,6 @@ async def update_weights( start_time = time.time() - bucket_size_mb = int(os.environ.get('TWINKLE_VLLM_IPC_BUCKET_MB', str(bucket_size_mb))) - if bucket_size_mb <= 0: - raise ValueError(f'Invalid TWINKLE_VLLM_IPC_BUCKET_MB={bucket_size_mb}, must be > 0') - # Normalise *weights* into an async iterator regardless of input type. if isinstance(weights, dict): @@ -600,8 +597,8 @@ async def _flush_bucket(is_last: bool) -> None: nonlocal offset, bucket_meta if not bucket_meta and not is_last: return - if use_gpu_ipc: - torch.cuda.synchronize() + if buffer.device.type != 'cpu': + Torch.synchronize() await loop.run_in_executor( None, _zmq_send_recv, diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py index f4f258ca..899b1783 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py @@ -115,6 +115,7 @@ def update_weights_from_ipc( peft_config: If provided with base_sync_done, loads as LoRA. base_sync_done: If True and peft_config, replaces existing LoRA. use_shm: If True, use shared memory instead of CUDA IPC. + zmq_handle: Optional ZMQ IPC endpoint for per-bucket handshake/control messages. If None, falls back to _get_zmq_handle(). """ import torch.distributed as dist import zmq From abf77ea9a88b59006220910814ac4b046d8f087e Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Thu, 5 Mar 2026 14:52:19 +0800 Subject: [PATCH 10/11] fix review --- .../hccl_checkpoint_engine.py | 5 ++-- .../sampler/vllm_sampler/vllm_engine.py | 7 +++-- .../vllm_sampler/vllm_worker_extension.py | 9 +++---- src/twinkle/utils/zmq_utils.py | 26 +++++++++++++++++++ 4 files changed, 35 insertions(+), 12 deletions(-) create mode 100644 src/twinkle/utils/zmq_utils.py diff --git a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py index e3ffdd0b..33db5764 100644 --- a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py +++ b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py @@ -17,6 +17,7 @@ from twinkle import get_logger from twinkle.utils import find_free_port, find_node_ip, is_valid_ipv6_address, stateless_init_process_group +from twinkle.utils.zmq_utils import configure_zmq_socket from .base import CheckpointEngine logger = get_logger() @@ -75,9 +76,7 @@ def __init__( def _new_socket(self, socket_type: int) -> zmq.Socket: assert self._zmq_ctx is not None socket = self._zmq_ctx.socket(socket_type) - socket.setsockopt(zmq.RCVTIMEO, self.meta_timeout_ms) - socket.setsockopt(zmq.SNDTIMEO, self.meta_timeout_ms) - socket.setsockopt(zmq.LINGER, 0) + configure_zmq_socket(socket, timeout_ms=self.meta_timeout_ms, linger=0) return socket def _recv_pyobj(self, where: str) -> Any: diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 5e04ebad..a12a1e40 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -10,6 +10,7 @@ from twinkle.sampler.base_engine import BaseSamplerEngine from twinkle.utils import Platform from twinkle.utils.framework import Torch +from twinkle.utils.zmq_utils import configure_zmq_socket, get_timeout_s_from_env logger = get_logger() @@ -546,10 +547,8 @@ async def _sync_iter(): # Setup ZMQ socket FIRST (bind before worker connects) zmq_ctx = zmq.Context() socket = zmq_ctx.socket(zmq.REQ) - zmq_timeout_s = int(os.environ.get('TWINKLE_VLLM_IPC_TIMEOUT_S', '300')) - socket.setsockopt(zmq.RCVTIMEO, zmq_timeout_s * 1000) - socket.setsockopt(zmq.SNDTIMEO, zmq_timeout_s * 1000) - socket.setsockopt(zmq.LINGER, 0) + zmq_timeout_s = get_timeout_s_from_env('TWINKLE_VLLM_IPC_TIMEOUT_S', 300) + configure_zmq_socket(socket, timeout_ms=zmq_timeout_s * 1000, linger=0) socket.bind(zmq_handle) loop = asyncio.get_running_loop() diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py index 899b1783..42be5095 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py @@ -23,6 +23,7 @@ from twinkle import get_logger from twinkle.utils import Platform from twinkle.utils.framework import Torch +from twinkle.utils.zmq_utils import configure_zmq_socket, get_timeout_s_from_env logger = get_logger() @@ -115,7 +116,7 @@ def update_weights_from_ipc( peft_config: If provided with base_sync_done, loads as LoRA. base_sync_done: If True and peft_config, replaces existing LoRA. use_shm: If True, use shared memory instead of CUDA IPC. - zmq_handle: Optional ZMQ IPC endpoint for per-bucket handshake/control messages. If None, falls back to _get_zmq_handle(). + zmq_handle: Optional ZMQ IPC endpoint. If None, uses _get_zmq_handle(). """ import torch.distributed as dist import zmq @@ -159,15 +160,13 @@ def _broadcast_obj(obj): # ── Step 1: Establish ZMQ connection (driver only) ── socket = None - zmq_timeout_s = int(os.environ.get('TWINKLE_VLLM_IPC_TIMEOUT_S', '300')) + zmq_timeout_s = get_timeout_s_from_env('TWINKLE_VLLM_IPC_TIMEOUT_S', 300) endpoint = zmq_handle or self._get_zmq_handle() if is_driver: if not hasattr(self, '_zmq_ctx') or self._zmq_ctx is None: self._zmq_ctx = zmq.Context() socket = self._zmq_ctx.socket(zmq.REP) - socket.setsockopt(zmq.RCVTIMEO, zmq_timeout_s * 1000) - socket.setsockopt(zmq.SNDTIMEO, zmq_timeout_s * 1000) - socket.setsockopt(zmq.LINGER, 0) + configure_zmq_socket(socket, timeout_ms=zmq_timeout_s * 1000, linger=0) socket.connect(endpoint) # ── Step 2: Receive and broadcast IPC/SHM handle ── diff --git a/src/twinkle/utils/zmq_utils.py b/src/twinkle/utils/zmq_utils.py new file mode 100644 index 00000000..d6754123 --- /dev/null +++ b/src/twinkle/utils/zmq_utils.py @@ -0,0 +1,26 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Utilities for configuring ZeroMQ sockets consistently.""" + +from __future__ import annotations + +import os +import zmq + + +def get_timeout_s_from_env(env_name: str, default: int) -> int: + """Read timeout seconds from env and validate it.""" + raw_value = os.environ.get(env_name, str(default)) + try: + timeout_s = int(raw_value) + except ValueError as e: + raise ValueError(f'Invalid {env_name}={raw_value}, must be an integer > 0') from e + if timeout_s <= 0: + raise ValueError(f'Invalid {env_name}={timeout_s}, must be > 0') + return timeout_s + + +def configure_zmq_socket(socket: zmq.Socket, timeout_ms: int, linger: int = 0) -> None: + """Apply timeout/linger options to a ZMQ socket.""" + socket.setsockopt(zmq.RCVTIMEO, timeout_ms) + socket.setsockopt(zmq.SNDTIMEO, timeout_ms) + socket.setsockopt(zmq.LINGER, linger) From e55cd2b323e50f24a31a483fb9265e6d9611dc10 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Thu, 5 Mar 2026 15:02:24 +0800 Subject: [PATCH 11/11] add env_var doc --- .../Components/Checkpoint Engine/HCCLCheckpointEngine.md | 7 +++++++ docs/source_en/Components/Sampler/vLLMSampler.md | 7 +++++++ .../HCCLCheckpointEngine.md" | 6 ++++++ .../\351\207\207\346\240\267\345\231\250/vLLMSampler.md" | 6 ++++++ 4 files changed, 26 insertions(+) diff --git a/docs/source_en/Components/Checkpoint Engine/HCCLCheckpointEngine.md b/docs/source_en/Components/Checkpoint Engine/HCCLCheckpointEngine.md index 585031ca..e2824265 100644 --- a/docs/source_en/Components/Checkpoint Engine/HCCLCheckpointEngine.md +++ b/docs/source_en/Components/Checkpoint Engine/HCCLCheckpointEngine.md @@ -25,4 +25,11 @@ HCCLCheckpointEngine is specifically designed for Ascend NPU environments: - Synchronizing model weights between NPUs - Large-scale NPU cluster deployment +## Environment Variables + +- `TWINKLE_CKPT_HCCL_META_TIMEOUT_S`: + Controls the timeout (in seconds) for the HCCL CheckpointEngine + metadata handshake channel (ZMQ REQ/REP). + Default is `300`. This value should be an integer greater than `0`. + > In Ascend NPU environments, HCCLCheckpointEngine provides performance comparable to NCCL. diff --git a/docs/source_en/Components/Sampler/vLLMSampler.md b/docs/source_en/Components/Sampler/vLLMSampler.md index f6eb1fa6..95e7d283 100644 --- a/docs/source_en/Components/Sampler/vLLMSampler.md +++ b/docs/source_en/Components/Sampler/vLLMSampler.md @@ -69,4 +69,11 @@ sampler = vLLMSampler( response = sampler.sample(trajectories, sampling_params=params) ``` +## Environment Variables + +- `TWINKLE_VLLM_IPC_TIMEOUT_S`: + Controls the timeout (in seconds) for the IPC channel (ZMQ REQ/REP) + between `vLLMSampler` and the vLLM worker extension. + Default is `300`. This value must be greater than `0`. + > In RLHF training, vLLMSampler is typically separated from the Actor model, using different hardware resources to avoid interference between inference and training. diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/HCCLCheckpointEngine.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/HCCLCheckpointEngine.md" index 0aaf6d9b..b58000c4 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/HCCLCheckpointEngine.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/HCCLCheckpointEngine.md" @@ -25,4 +25,10 @@ HCCLCheckpointEngine 专门用于昇腾 NPU 环境: - 需要在 NPU 间同步模型权重 - 大规模 NPU 集群部署 +## 环境变量 + +- `TWINKLE_CKPT_HCCL_META_TIMEOUT_S`: + 控制 HCCL CheckpointEngine 元数据握手通道(ZMQ REQ/REP)的超时时间(秒)。 + 默认值为 `300`。该值应设置为大于 `0` 的整数。 + > 在昇腾 NPU 环境中,HCCLCheckpointEngine 提供了与 NCCL 相当的性能。 diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md" "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md" index eced51f7..38b4e5be 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md" @@ -69,4 +69,10 @@ sampler = vLLMSampler( response = sampler.sample(trajectories, sampling_params=params) ``` +## 环境变量 + +- `TWINKLE_VLLM_IPC_TIMEOUT_S`: + 控制 `vLLMSampler` 与 vLLM worker extension 之间 IPC 通道(ZMQ REQ/REP)的超时时间(秒)。 + 默认值为 `300`。该值必须大于 `0`。 + > vLLMSampler 在 RLHF 训练中通常与 Actor 模型分离,使用不同的硬件资源,避免推理和训练相互干扰。