From f1a510dafefb8a16568c065bde306af14b1dbfb7 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 26 Jan 2026 17:00:20 +0800 Subject: [PATCH 1/8] optimize pyproject Signed-off-by: 0oshowero0 --- .github/workflows/python-package.yml | 7 ++++--- .github/workflows/sanity.yml | 7 ------- pyproject.toml | 25 +++++++++++++------------ 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 2917c78..0b42eae 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -31,16 +31,17 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest build pytest_asyncio - python -m build --wheel pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu - pip install dist/*.whl + pip install -e ".[test,build,yuanrong]" - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test Build + run: | + python -m build --wheel - name: Test with pytest run: | pytest \ No newline at end of file diff --git a/.github/workflows/sanity.yml b/.github/workflows/sanity.yml index 3c689b2..29e7375 100644 --- a/.github/workflows/sanity.yml +++ b/.github/workflows/sanity.yml @@ -38,13 +38,6 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install build - python -m build --wheel - pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu - pip install dist/*.whl - name: Run license test run: | python3 tests/sanity/check_license.py --directories . diff --git a/pyproject.toml b/pyproject.toml index f6243e0..cee0331 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,9 +73,7 @@ pretty = true ignore_missing_imports = true explicit_package_bases = true follow_imports = "skip" - -# Blanket silence -ignore_errors = true +ignore_errors = false # ------------------------------- # tool.pytest - pytest config @@ -85,15 +83,6 @@ filterwarnings = [ "ignore:.*PyTorch API of nested tensors.*prototype.*:UserWarning", ] -[[tool.mypy.overrides]] -module = [ - "transfer_queue.data_system.*", - "transfer_queue.utils.utils.*", - "transfer_queue.utils.zmq_utils.*", - "transfer_queue.utils.serial_utils.*", -] -ignore_errors = false - # ------------------------------- # tool.setuptools - Additional config # ------------------------------- @@ -108,11 +97,23 @@ version = {file = "transfer_queue/version/version"} dependencies = {file = "requirements.txt"} [project.optional-dependencies] + +build = [ + "build" +] + test = [ "pytest>=7.0.0", "pytest-asyncio>=0.20.0", + "flake8", + "pytest-mock", ] +yuanrong = [ + "openyuanrong-datasystem" +] + + # If you need to mimic `package_dir={'': '.'}`: [tool.setuptools.package-dir] "" = "." From 61aadaf35d96a112e55a917b9991d7c2fa3d9c8c Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 26 Jan 2026 17:30:22 +0800 Subject: [PATCH 2/8] fix partial pre-commit Signed-off-by: 0oshowero0 --- pyproject.toml | 10 +++++++++- transfer_queue/controller.py | 25 +++++++++++++++---------- transfer_queue/metadata.py | 2 +- transfer_queue/storage/managers/base.py | 2 +- 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cee0331..05918ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,9 @@ pretty = true ignore_missing_imports = true explicit_package_bases = true follow_imports = "skip" -ignore_errors = false + +# Blanket silence +ignore_errors = true # ------------------------------- # tool.pytest - pytest config @@ -83,6 +85,12 @@ filterwarnings = [ "ignore:.*PyTorch API of nested tensors.*prototype.*:UserWarning", ] +[[tool.mypy.overrides]] +module = [ + "transfer_queue.*", +] +ignore_errors = false + # ------------------------------- # tool.setuptools - Additional config # ------------------------------- diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 2b6767f..9dafcfe 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -188,7 +188,7 @@ def release_indexes(self, partition_id: str, indexes_to_release: list[int]): if not partition_indexes: self.partition_to_indexes.pop(partition_id, None) - def get_indexes_for_partition(self, partition_id) -> set[int]: + def get_indexes_for_partition(self, partition_id) -> list[int]: """ Get all global_indexes for the specified partition. @@ -196,9 +196,9 @@ def get_indexes_for_partition(self, partition_id) -> set[int]: partition_id: Partition ID Returns: - set: Set of global_indexes for this partition + list: List of global_indexes for this partition """ - return self.partition_to_indexes.get(partition_id, set()).copy() + return list(self.partition_to_indexes.get(partition_id, set()).copy()) @dataclass @@ -216,7 +216,7 @@ class DataPartitionStatus: # Production status tensor - dynamically expandable # Values: 0 = not produced, 1 = ready for consumption - production_status: Optional[Tensor] = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8) + production_status: Tensor = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8) # Consumption status per task - task_name -> consumption_tensor # Each tensor tracks which samples have been consumed by that task @@ -260,7 +260,7 @@ def allocated_samples_num(self) -> int: # ==================== Dynamic Expansion Methods ==================== - def ensure_samples_capacity(self, required_samples: int) -> bool: + def ensure_samples_capacity(self, required_samples: int): """ Ensure the production status tensor has enough rows for the required samples. Dynamically expands if needed using unified minimum expansion size. @@ -498,7 +498,9 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te return partition_global_index, consumption_status # ==================== Production Status Interface ==================== - def get_production_status_for_fields(self, field_names: list[str], mask: bool = False) -> tuple[Tensor, Tensor]: + def get_production_status_for_fields( + self, field_names: list[str], mask: bool = False + ) -> tuple[Optional[Tensor], Optional[Tensor]]: """ Check if all samples for specified fields are fully produced and ready. @@ -512,12 +514,12 @@ def get_production_status_for_fields(self, field_names: list[str], mask: bool = - Production status tensor for the specified task. 1 for ready, 0 for not ready. """ if self.production_status is None or field_names is None or len(field_names) == 0: - return False + return None, None # Check if all requested fields are registered for field_name in field_names: if field_name not in self.field_name_mapping: - return False + return None, None # Create column mask for requested fields col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool) @@ -837,7 +839,7 @@ def list_partitions(self) -> list[str]: # ==================== Partition Index Management API ==================== - def get_partition_index_range(self, partition: DataPartitionStatus) -> set: + def get_partition_index_range(self, partition: DataPartitionStatus) -> list[int]: """ Get all indexes for a specific partition. @@ -845,7 +847,7 @@ def get_partition_index_range(self, partition: DataPartitionStatus) -> set: partition: Partition identifier Returns: - Set of indexes allocated to the partition + List of indexes allocated to the partition """ return self.index_manager.get_indexes_for_partition(partition) @@ -980,6 +982,9 @@ def get_metadata( if mode == "fetch": # Find ready samples within current data partition and package into BatchMeta when reading + if batch_size is None: + raise ValueError("must provide batch_size in fetch mode") + start_time = time.time() while True: # ready_for_consume_indexes: samples where all required fields are produced diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index af3f321..003fd38 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -265,7 +265,7 @@ def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: """Get the entire custom meta dictionary""" return copy.deepcopy(self._custom_meta) - def update_custom_meta(self, new_custom_meta: dict[int, dict[str, Any]] = None): + def update_custom_meta(self, new_custom_meta: Optional[dict[int, dict[str, Any]]]): """Update custom meta with a new dictionary""" if new_custom_meta: self._custom_meta.update(new_custom_meta) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 942ea68..e9c79e7 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -60,7 +60,7 @@ class TransferQueueStorageManager(ABC): def __init__(self, config: dict[str, Any]): self.storage_manager_id = f"TQ_STORAGE_{uuid4().hex[:8]}" self.config = config - self.controller_info = config.get("controller_info", None) # type: ZMQServerInfo + self.controller_info = config.get("controller_info") # type: ZMQServerInfo self.data_status_update_socket = None self.controller_handshake_socket = None From ea6b902afe0cef5a5360f320a78761129047bc55 Mon Sep 17 00:00:00 2001 From: tianyi-ge Date: Mon, 26 Jan 2026 19:22:20 +0800 Subject: [PATCH 3/8] fix type checks Signed-off-by: tianyi-ge --- transfer_queue/client.py | 5 ++-- .../dataloader/streaming_dataset.py | 1 + transfer_queue/storage/clients/base.py | 9 ++++++ transfer_queue/storage/clients/factory.py | 5 ++-- .../storage/clients/yuanrong_client.py | 9 ++++++ transfer_queue/storage/managers/base.py | 30 ++++++++++++------- .../managers/simple_backend_manager.py | 8 ++--- 7 files changed, 48 insertions(+), 19 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 3bcfcc5..8d06c15 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -757,6 +757,7 @@ async def async_get_partition_list( ) try: + assert socket is not None await socket.send_multipart(request_msg.serialize()) response_serialized = await socket.recv_multipart() response_msg = ZMQMessage.deserialize(response_serialized) @@ -991,10 +992,10 @@ def process_zmq_server_info( >>> info_dict = process_zmq_server_info(handlers)""" # Handle single handler object case if not isinstance(handlers, dict): - return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[attr-defined] + return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] else: # Handle dictionary case server_info = {} for name, handler in handlers.items(): - server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[attr-defined] + server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] return server_info diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index 510a497..97cff38 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -170,6 +170,7 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: if self._tq_client is None: self._create_client() + assert self._tq_client is not None, "Failed to create TransferQueue client" # TODO: need to consider async scenario where the samples in partition is dynamically increasing while not self._tq_client.check_consumption_status(self.task_name, self.partition_id): try: diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index d84b381..f414a72 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -23,6 +23,15 @@ class TransferQueueStorageKVClient(ABC): Subclasses must implement the core methods: put, get, and clear. """ + @abstractmethod + def __init__(self, config: dict[str, Any]): + """ + Initialize the storage client with configuration. + Args: + config (dict[str, Any]): Configuration dictionary for the storage client. + """ + ... + @abstractmethod def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: """ diff --git a/transfer_queue/storage/clients/factory.py b/transfer_queue/storage/clients/factory.py index 43c5d6d..1d9bda5 100644 --- a/transfer_queue/storage/clients/factory.py +++ b/transfer_queue/storage/clients/factory.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from transfer_queue.storage.clients.base import TransferQueueStorageKVClient @@ -23,7 +24,7 @@ class StorageClientFactory: """ # Class variable: maps client names to their corresponding classes - _registry: dict[str, TransferQueueStorageKVClient] = {} + _registry: dict[str, type[TransferQueueStorageKVClient]] = {} @classmethod def register(cls, client_type: str): @@ -35,7 +36,7 @@ def register(cls, client_type: str): Callable: The decorator function that returns the original class """ - def decorator(client_class: TransferQueueStorageKVClient) -> TransferQueueStorageKVClient: + def decorator(client_class: type[TransferQueueStorageKVClient]) -> type[TransferQueueStorageKVClient]: cls._registry[client_type] = client_class return client_class diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index c233472..7d46193 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -191,6 +191,7 @@ def mset_zcopy(self, keys: list[str], objs: list[Any]): keys (list[str]): List of string keys under which the objects will be stored. objs (list[Any]): List of Python objects to store (e.g., tensors, strings). """ + assert self._cpu_ds_client is not None, "CPU DS client is not available" items_list = [[memoryview(b) for b in _encoder.encode(obj)] for obj in objs] packed_sizes = [calc_packed_size(items) for items in items_list] status, buffers = self._cpu_ds_client.mcreate(keys, packed_sizes) @@ -208,6 +209,7 @@ def mget_zcopy(self, keys: list[str]) -> list[Any]: Returns: list[Any]: List of deserialized objects corresponding to the input keys. """ + assert self._cpu_ds_client is not None, "CPU DS client is not available" status, buffers = self._cpu_ds_client.get_buffers(keys, timeout_ms=500) return [_decoder.decode(unpack_from(buffer)) if buffer is not None else None for buffer in buffers] @@ -241,6 +243,7 @@ def _batch_put(self, keys: list[str], values: list[Any]): cpu_values.append(pickle.dumps(value)) # put NPU data + assert self._npu_ds_client is not None for i in range(0, len(npu_keys), NPU_DS_CLIENT_KEYS_LIMIT): batch_keys = npu_keys[i : i + NPU_DS_CLIENT_KEYS_LIMIT] batch_values = npu_values[i : i + NPU_DS_CLIENT_KEYS_LIMIT] @@ -253,6 +256,7 @@ def _batch_put(self, keys: list[str], values: list[Any]): self._npu_ds_client.dev_mset(batch_keys, batch_values) # put CPU data + assert self._cpu_ds_client is not None for i in range(0, len(cpu_keys), CPU_DS_CLIENT_KEYS_LIMIT): batch_keys = cpu_keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] batch_values = cpu_values[i : i + CPU_DS_CLIENT_KEYS_LIMIT] @@ -320,6 +324,7 @@ def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: results = [None] * len(keys) # Fetch NPU tensors + assert self._npu_ds_client is not None for i in range(0, len(npu_keys), NPU_DS_CLIENT_KEYS_LIMIT): batch_keys = npu_keys[i : i + NPU_DS_CLIENT_KEYS_LIMIT] batch_shapes = npu_shapes[i : i + NPU_DS_CLIENT_KEYS_LIMIT] @@ -347,6 +352,7 @@ def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: cpu_indices.extend([batch_indices[j] for j, k in enumerate(batch_keys) if k in failed_set]) # Fetch CPU/general objects (including NPU fallbacks) + assert self._cpu_ds_client is not None for i in range(0, len(cpu_keys), CPU_DS_CLIENT_KEYS_LIMIT): batch_keys = cpu_keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] batch_indices = cpu_indices[i : i + CPU_DS_CLIENT_KEYS_LIMIT] @@ -398,6 +404,8 @@ def _batch_clear(self, keys: list[str]): keys (List[str]): Keys to delete. """ if self.npu_ds_client_is_available(): + assert self._npu_ds_client is not None + assert self._cpu_ds_client is not None # Try to delete all keys via npu client for i in range(0, len(keys), NPU_DS_CLIENT_KEYS_LIMIT): batch = keys[i : i + NPU_DS_CLIENT_KEYS_LIMIT] @@ -408,6 +416,7 @@ def _batch_clear(self, keys: list[str]): sub_batch = keys[j : j + CPU_DS_CLIENT_KEYS_LIMIT] self._cpu_ds_client.delete(sub_batch) else: + assert self._cpu_ds_client is not None for i in range(0, len(keys), CPU_DS_CLIENT_KEYS_LIMIT): batch = keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] self._cpu_ds_client.delete(batch) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index e9c79e7..dbcc480 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -20,7 +20,7 @@ import weakref from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import Any, Optional from uuid import uuid4 import ray @@ -60,12 +60,14 @@ class TransferQueueStorageManager(ABC): def __init__(self, config: dict[str, Any]): self.storage_manager_id = f"TQ_STORAGE_{uuid4().hex[:8]}" self.config = config - self.controller_info = config.get("controller_info") # type: ZMQServerInfo + controller_info = config.get("controller_info") + assert controller_info is not None, "controller_info is required" + self.controller_info: ZMQServerInfo = controller_info - self.data_status_update_socket = None - self.controller_handshake_socket = None + self.data_status_update_socket: Optional[zmq.Socket[bytes]] = None + self.controller_handshake_socket: Optional[zmq.Socket[bytes]] = None - self.zmq_context = None + self.zmq_context: Optional[zmq.Context[Any]] = None self._connect_to_controller() def _connect_to_controller(self) -> None: @@ -88,6 +90,7 @@ def _connect_to_controller(self) -> None: zmq.DEALER, identity=f"{self.storage_manager_id}-data_status_update_socket-{uuid4().hex[:8]}".encode(), ) + assert self.data_status_update_socket is not None self.data_status_update_socket.connect(self.controller_info.to_addr("data_status_update_socket")) # do handshake with controller @@ -106,6 +109,7 @@ def _do_handshake_with_controller(self) -> None: # Create zmq poller for handshake confirmation between controller and storage manager poller = zmq.Poller() + assert self.controller_handshake_socket is not None self.controller_handshake_socket.connect(self.controller_info.to_addr("handshake_socket")) logger.debug( f"[{self.storage_manager_id}]: Handshake connection from storage manager id #{self.storage_manager_id} " @@ -170,6 +174,7 @@ def _do_handshake_with_controller(self) -> None: def _send_handshake_requests(self) -> None: """Send handshake request to controller.""" + assert self.controller_handshake_socket is not None request_msg = ZMQMessage.create( request_type=ZMQRequestType.HANDSHAKE, sender_id=self.storage_manager_id, @@ -191,7 +196,7 @@ async def notify_data_update( global_indexes: list[int], dtypes: dict[int, dict[str, Any]], shapes: dict[int, dict[str, Any]], - custom_meta: dict[int, dict[str, Any]] = None, + custom_meta: Optional[dict[int, dict[str, Any]]] = None, ) -> None: """ Notify controller that new data is ready. @@ -213,6 +218,7 @@ async def notify_data_update( # Create zmq poller for notifying data update information poller = zmq.Poller() # Note: data_status_update_socket is already connected during initialization + assert self.data_status_update_socket is not None try: poller.register(self.data_status_update_socket, zmq.POLLIN) @@ -352,7 +358,7 @@ def __init__(self, config: dict[str, Any]): raise ValueError("Missing client_name in config") super().__init__(config) self.storage_client = StorageClientFactory.create(client_name, config) - self._multi_threads_executor = None + self._multi_threads_executor: Optional[ThreadPoolExecutor] = None # Register a cleanup function: automatically invoke shutdown when the instance is garbage collected. self._executor_finalizer = weakref.finalize(self, self._shutdown_executor, self._multi_threads_executor) @@ -390,7 +396,7 @@ def _generate_values(data: TensorDict) -> list[Tensor]: return [row_data for field in sorted(data.keys()) for row_data in data[field]] @staticmethod - def _shutdown_executor(thread_executor: ThreadPoolExecutor): + def _shutdown_executor(thread_executor: Optional[ThreadPoolExecutor]) -> None: """ A static method to ensure no strong reference to 'self' is held within the finalizer's callback, enabling proper garbage collection. @@ -421,6 +427,7 @@ def _get_executor(self) -> ThreadPoolExecutor: max_workers=self._num_threads, thread_name_prefix="KVStorageManager" ) + assert self._multi_threads_executor is not None return self._multi_threads_executor def _merge_tensors_to_tensordict(self, metadata: BatchMeta, values: list[Tensor]) -> TensorDict: @@ -519,6 +526,7 @@ def _get_shape_type_custom_meta_list(metadata: BatchMeta): for field_name in sorted(metadata.field_names): for index in range(len(metadata)): field = metadata.samples[index].get_field_by_name(field_name) + assert field is not None, f"Field {field_name} not found in sample {index}" shapes.append(field.shape) dtypes.append(field.dtype) global_index = metadata.global_indexes[index] @@ -547,8 +555,8 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: loop = asyncio.get_event_loop() custom_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) - per_field_dtypes = {} - per_field_shapes = {} + per_field_dtypes: dict[int, dict[str, Any]] = {} + per_field_shapes: dict[int, dict[str, Any]] = {} # Initialize the data structure for each global index for global_idx in metadata.global_indexes: @@ -567,7 +575,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: ) # Prepare per-field custom_meta if available - per_field_custom_meta = {} + per_field_custom_meta: dict[int, dict[str, Any]] = {} if custom_meta: if len(custom_meta) != len(keys): raise ValueError(f"Length of custom_meta ({len(custom_meta)}) does not match expected ({len(keys)})") diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 92ec897..250438c 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -57,7 +57,7 @@ def __init__(self, config: dict[str, Any]): super().__init__(config) self.config = config - server_infos = config.get("storage_unit_infos", None) # type: ZMQServerInfo | dict[str, ZMQServerInfo] + server_infos: ZMQServerInfo | dict[str, ZMQServerInfo] | None = config.get("storage_unit_infos", None) if server_infos is None: raise ValueError("AsyncSimpleStorageManager requires non-empty 'storage_unit_infos' in config.") @@ -198,8 +198,8 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: # Gather per-field dtype and shape information for each field # global_indexes, local_indexes, and field_data correspond one-to-one - per_field_dtypes = {} - per_field_shapes = {} + per_field_dtypes: dict[int, dict[str, Any]] = {} + per_field_shapes: dict[int, dict[str, Any]] = {} # Initialize the data structure for each global index for global_idx in metadata.global_indexes: @@ -440,7 +440,7 @@ def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict) """ # We use dict here instead of TensorDict to avoid unnecessary TensorDict overhead - results = {} + results: dict[str, Any] = {} batch_indexes = storage_meta_group.get_batch_indexes() if not batch_indexes: From b92295ef21ec8f09febc5e5dd080e3aafc9b5c9f Mon Sep 17 00:00:00 2001 From: tianyi-ge Date: Mon, 26 Jan 2026 20:45:04 +0800 Subject: [PATCH 4/8] fix unittests: 1. set -> list 2. controller info is required Signed-off-by: tianyi-ge --- tests/test_controller.py | 12 ++++++------ tests/test_kv_storage_manager.py | 14 ++++++++++---- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/test_controller.py b/tests/test_controller.py index 3d1379e..6684526 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -75,7 +75,7 @@ def test_controller_with_single_partition(self, ray_setup): ProductionStatus.NOT_PRODUCED ) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id)) - assert partition_index_range == set(range(gbs * num_n_samples)) + assert partition_index_range == list(range(gbs * num_n_samples)) print("✓ Initial get metadata correct") @@ -194,7 +194,7 @@ def test_controller_with_single_partition(self, ray_setup): ray.get(tq_controller.clear_partition.remote(partition_id)) partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id)) - assert partition_index_range == set() + assert partition_index_range == [] assert partition is None print("✓ Clear partition correct") @@ -307,7 +307,7 @@ def test_controller_with_multi_partitions(self, ray_setup): [int(sample.fields.get("attention_mask").production_status) for sample in val_metadata.samples] ) == int(ProductionStatus.NOT_PRODUCED) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2)) - assert partition_index_range == set(range(part1_index_range, part2_index_range + part1_index_range)) + assert partition_index_range == list(range(part1_index_range, part2_index_range + part1_index_range)) # Update production status dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in val_metadata.global_indexes} @@ -359,11 +359,11 @@ def test_controller_with_multi_partitions(self, ray_setup): assert not partition_index_range_1_after_clear assert partition_1_after_clear is None - assert partition_index_range_1_after_clear == set() + assert partition_index_range_1_after_clear == [] partition_2 = ray.get(tq_controller.get_partition_snapshot.remote(partition_id_2)) partition_index_range_2 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2)) - assert partition_index_range_2 == set([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]) + assert partition_index_range_2 == [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] assert torch.all( partition_2.production_status[list(partition_index_range_2), : len(val_metadata.field_names)] == 1 ) @@ -387,7 +387,7 @@ def test_controller_with_multi_partitions(self, ray_setup): [int(sample.fields.get("attention_mask").production_status) for sample in metadata_2.samples] ) == int(ProductionStatus.NOT_PRODUCED) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_3)) - assert partition_index_range == set(list(range(32)) + list(range(48, 80))) + assert partition_index_range == list(range(32)) + list(range(48, 80)) print("✓ Correctly assign partition_3") def test_controller_clear_meta(self, ray_setup): diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index f65a360..5cfb6fe 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -57,7 +57,13 @@ def get_meta(data, global_indexes=None): @pytest.fixture def test_data(): """Fixture providing test configuration, data, and metadata.""" - cfg = {"client_name": "YuanrongStorageClient", "host": "127.0.0.1", "port": 31501, "device_id": 0} + cfg = { + "controller_info": MagicMock(), + "client_name": "YuanrongStorageClient", + "host": "127.0.0.1", + "port": 31501, + "device_id": 0, + } global_indexes = [8, 9, 10] data = TensorDict( @@ -288,7 +294,7 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo mock_storage_client.put.return_value = mock_custom_meta # Create manager with mocked dependencies - config = {"client_name": "MockClient"} + config = {"controller_info": MagicMock(), "client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): manager = KVStorageManager(config) @@ -338,7 +344,7 @@ def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data): mock_storage_client.put.return_value = None # Create manager with mocked dependencies - config = {"client_name": "MockClient"} + config = {"controller_info": MagicMock(), "client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): manager = KVStorageManager(config) @@ -361,7 +367,7 @@ def test_put_data_custom_meta_length_mismatch_raises_error(test_data_for_put_dat mock_storage_client.put.return_value = [{"key": "1"}, {"key": "2"}, {"key": "3"}] # Create manager with mocked dependencies - config = {"client_name": "MockClient"} + config = {"controller_info": MagicMock(), "client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): manager = KVStorageManager(config) From 08c545a6c01be4ca71b045b98e66f05bb777a0db Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 27 Jan 2026 09:55:52 +0800 Subject: [PATCH 5/8] fix comments Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 11 +++-------- transfer_queue/storage/clients/base.py | 3 +-- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 9dafcfe..2db0e7d 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -260,7 +260,7 @@ def allocated_samples_num(self) -> int: # ==================== Dynamic Expansion Methods ==================== - def ensure_samples_capacity(self, required_samples: int): + def ensure_samples_capacity(self, required_samples: int) -> None: """ Ensure the production status tensor has enough rows for the required samples. Dynamically expands if needed using unified minimum expansion size. @@ -291,7 +291,7 @@ def ensure_samples_capacity(self, required_samples: int): f"to {new_samples} samples (added {min_expansion} samples)" ) - def ensure_fields_capacity(self, required_fields: int): + def ensure_fields_capacity(self, required_fields: int) -> None: """ Ensure the production status tensor has enough columns for the required fields. Dynamically expands if needed using unified minimum expansion size. @@ -299,9 +299,6 @@ def ensure_fields_capacity(self, required_fields: int): Args: required_fields: Minimum number of fields needed """ - if self.production_status is None: - # Will be initialized when samples are added - return current_fields = self.production_status.shape[1] if required_fields > current_fields: @@ -513,7 +510,7 @@ def get_production_status_for_fields( - Partition global index tensor - Production status tensor for the specified task. 1 for ready, 0 for not ready. """ - if self.production_status is None or field_names is None or len(field_names) == 0: + if field_names is None or len(field_names) == 0: return None, None # Check if all requested fields are registered @@ -550,8 +547,6 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: Returns: List of sample indices that are ready for consumption """ - if self.production_status is None: - return [] # Check if all requested fields are registered for field_name in field_names: diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index f414a72..1457b36 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -23,14 +23,13 @@ class TransferQueueStorageKVClient(ABC): Subclasses must implement the core methods: put, get, and clear. """ - @abstractmethod def __init__(self, config: dict[str, Any]): """ Initialize the storage client with configuration. Args: config (dict[str, Any]): Configuration dictionary for the storage client. """ - ... + self.config = config @abstractmethod def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: From f52694a500632c75ba968e25ddcd9004d9e03b0d Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 27 Jan 2026 09:59:16 +0800 Subject: [PATCH 6/8] add package install test Signed-off-by: 0oshowero0 --- .github/workflows/python-package.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 0b42eae..72c0032 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -39,9 +39,10 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test Build + - name: Build wheel and test installed distribution run: | python -m build --wheel + pip install dist/*.whl --force-reinstall - name: Test with pytest run: | pytest \ No newline at end of file From 3952bf385667fb319e38c9a2f9c564ab8496c430 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 27 Jan 2026 10:07:01 +0800 Subject: [PATCH 7/8] minor fix Signed-off-by: 0oshowero0 --- pyproject.toml | 4 ---- transfer_queue/dataloader/streaming_dataset.py | 1 + transfer_queue/storage/clients/yuanrong_client.py | 14 +++++++------- transfer_queue/storage/managers/base.py | 8 ++++---- 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 05918ee..f853970 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,23 +105,19 @@ version = {file = "transfer_queue/version/version"} dependencies = {file = "requirements.txt"} [project.optional-dependencies] - build = [ "build" ] - test = [ "pytest>=7.0.0", "pytest-asyncio>=0.20.0", "flake8", "pytest-mock", ] - yuanrong = [ "openyuanrong-datasystem" ] - # If you need to mimic `package_dir={'': '.'}`: [tool.setuptools.package-dir] "" = "." diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index 97cff38..77d7186 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -171,6 +171,7 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: self._create_client() assert self._tq_client is not None, "Failed to create TransferQueue client" + # TODO: need to consider async scenario where the samples in partition is dynamically increasing while not self._tq_client.check_consumption_status(self.task_name, self.partition_id): try: diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index 7d46193..08d35cc 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -243,7 +243,7 @@ def _batch_put(self, keys: list[str], values: list[Any]): cpu_values.append(pickle.dumps(value)) # put NPU data - assert self._npu_ds_client is not None + assert self._npu_ds_client is not None, "NPU DS client is not available" for i in range(0, len(npu_keys), NPU_DS_CLIENT_KEYS_LIMIT): batch_keys = npu_keys[i : i + NPU_DS_CLIENT_KEYS_LIMIT] batch_values = npu_values[i : i + NPU_DS_CLIENT_KEYS_LIMIT] @@ -256,7 +256,7 @@ def _batch_put(self, keys: list[str], values: list[Any]): self._npu_ds_client.dev_mset(batch_keys, batch_values) # put CPU data - assert self._cpu_ds_client is not None + assert self._cpu_ds_client is not None, "CPU DS client is not available" for i in range(0, len(cpu_keys), CPU_DS_CLIENT_KEYS_LIMIT): batch_keys = cpu_keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] batch_values = cpu_values[i : i + CPU_DS_CLIENT_KEYS_LIMIT] @@ -324,7 +324,7 @@ def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: results = [None] * len(keys) # Fetch NPU tensors - assert self._npu_ds_client is not None + assert self._npu_ds_client is not None, "NPU DS client is not available" for i in range(0, len(npu_keys), NPU_DS_CLIENT_KEYS_LIMIT): batch_keys = npu_keys[i : i + NPU_DS_CLIENT_KEYS_LIMIT] batch_shapes = npu_shapes[i : i + NPU_DS_CLIENT_KEYS_LIMIT] @@ -352,7 +352,7 @@ def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: cpu_indices.extend([batch_indices[j] for j, k in enumerate(batch_keys) if k in failed_set]) # Fetch CPU/general objects (including NPU fallbacks) - assert self._cpu_ds_client is not None + assert self._cpu_ds_client is not None, "CPU DS client is not available" for i in range(0, len(cpu_keys), CPU_DS_CLIENT_KEYS_LIMIT): batch_keys = cpu_keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] batch_indices = cpu_indices[i : i + CPU_DS_CLIENT_KEYS_LIMIT] @@ -404,8 +404,8 @@ def _batch_clear(self, keys: list[str]): keys (List[str]): Keys to delete. """ if self.npu_ds_client_is_available(): - assert self._npu_ds_client is not None - assert self._cpu_ds_client is not None + assert self._npu_ds_client is not None, "NPU DS client is not available" + assert self._cpu_ds_client is not None, "CPU DS client is not available" # Try to delete all keys via npu client for i in range(0, len(keys), NPU_DS_CLIENT_KEYS_LIMIT): batch = keys[i : i + NPU_DS_CLIENT_KEYS_LIMIT] @@ -416,7 +416,7 @@ def _batch_clear(self, keys: list[str]): sub_batch = keys[j : j + CPU_DS_CLIENT_KEYS_LIMIT] self._cpu_ds_client.delete(sub_batch) else: - assert self._cpu_ds_client is not None + assert self._cpu_ds_client is not None, "CPU DS client is not available" for i in range(0, len(keys), CPU_DS_CLIENT_KEYS_LIMIT): batch = keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] self._cpu_ds_client.delete(batch) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index dbcc480..94854ce 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -90,7 +90,7 @@ def _connect_to_controller(self) -> None: zmq.DEALER, identity=f"{self.storage_manager_id}-data_status_update_socket-{uuid4().hex[:8]}".encode(), ) - assert self.data_status_update_socket is not None + assert self.data_status_update_socket is not None, "data_status_update_socket is not properly initialized" self.data_status_update_socket.connect(self.controller_info.to_addr("data_status_update_socket")) # do handshake with controller @@ -109,7 +109,7 @@ def _do_handshake_with_controller(self) -> None: # Create zmq poller for handshake confirmation between controller and storage manager poller = zmq.Poller() - assert self.controller_handshake_socket is not None + assert self.controller_handshake_socket is not None, "controller_handshake_socket is not properly initialized" self.controller_handshake_socket.connect(self.controller_info.to_addr("handshake_socket")) logger.debug( f"[{self.storage_manager_id}]: Handshake connection from storage manager id #{self.storage_manager_id} " @@ -174,7 +174,7 @@ def _do_handshake_with_controller(self) -> None: def _send_handshake_requests(self) -> None: """Send handshake request to controller.""" - assert self.controller_handshake_socket is not None + assert self.controller_handshake_socket is not None, "controller_handshake_socket is not properly initialized" request_msg = ZMQMessage.create( request_type=ZMQRequestType.HANDSHAKE, sender_id=self.storage_manager_id, @@ -218,7 +218,7 @@ async def notify_data_update( # Create zmq poller for notifying data update information poller = zmq.Poller() # Note: data_status_update_socket is already connected during initialization - assert self.data_status_update_socket is not None + assert self.data_status_update_socket is not None, "data_status_update_socket is not properly initialized" try: poller.register(self.data_status_update_socket, zmq.POLLIN) From ea540a85fac2700c8859ce06cf64a273454fe982 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 27 Jan 2026 10:26:19 +0800 Subject: [PATCH 8/8] del redundant check Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 2db0e7d..eda9267 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -251,12 +251,12 @@ def total_fields_num(self) -> int: @property def allocated_fields_num(self) -> int: """Current number of allocated columns in the tensor.""" - return self.production_status.shape[1] if self.production_status is not None else 0 + return self.production_status.shape[1] @property def allocated_samples_num(self) -> int: """Current number of allocated rows in the tensor.""" - return self.production_status.shape[0] if self.production_status is not None else 0 + return self.production_status.shape[0] # ==================== Dynamic Expansion Methods ====================