diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index bd74cff..72c0032 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -31,16 +31,18 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest build pytest_asyncio pytest-mock openyuanrong-datasystem - python -m build --wheel pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu - pip install dist/*.whl + pip install -e ".[test,build,yuanrong]" - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Build wheel and test installed distribution + run: | + python -m build --wheel + pip install dist/*.whl --force-reinstall - name: Test with pytest run: | pytest \ No newline at end of file diff --git a/.github/workflows/sanity.yml b/.github/workflows/sanity.yml index 3c689b2..29e7375 100644 --- a/.github/workflows/sanity.yml +++ b/.github/workflows/sanity.yml @@ -38,13 +38,6 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install build - python -m build --wheel - pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu - pip install dist/*.whl - name: Run license test run: | python3 tests/sanity/check_license.py --directories . diff --git a/pyproject.toml b/pyproject.toml index f6243e0..f853970 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,10 +87,7 @@ filterwarnings = [ [[tool.mypy.overrides]] module = [ - "transfer_queue.data_system.*", - "transfer_queue.utils.utils.*", - "transfer_queue.utils.zmq_utils.*", - "transfer_queue.utils.serial_utils.*", + "transfer_queue.*", ] ignore_errors = false @@ -108,9 +105,17 @@ version = {file = "transfer_queue/version/version"} dependencies = {file = "requirements.txt"} [project.optional-dependencies] +build = [ + "build" +] test = [ "pytest>=7.0.0", "pytest-asyncio>=0.20.0", + "flake8", + "pytest-mock", +] +yuanrong = [ + "openyuanrong-datasystem" ] # If you need to mimic `package_dir={'': '.'}`: diff --git a/tests/test_controller.py b/tests/test_controller.py index 3d1379e..6684526 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -75,7 +75,7 @@ def test_controller_with_single_partition(self, ray_setup): ProductionStatus.NOT_PRODUCED ) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id)) - assert partition_index_range == set(range(gbs * num_n_samples)) + assert partition_index_range == list(range(gbs * num_n_samples)) print("✓ Initial get metadata correct") @@ -194,7 +194,7 @@ def test_controller_with_single_partition(self, ray_setup): ray.get(tq_controller.clear_partition.remote(partition_id)) partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id)) - assert partition_index_range == set() + assert partition_index_range == [] assert partition is None print("✓ Clear partition correct") @@ -307,7 +307,7 @@ def test_controller_with_multi_partitions(self, ray_setup): [int(sample.fields.get("attention_mask").production_status) for sample in val_metadata.samples] ) == int(ProductionStatus.NOT_PRODUCED) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2)) - assert partition_index_range == set(range(part1_index_range, part2_index_range + part1_index_range)) + assert partition_index_range == list(range(part1_index_range, part2_index_range + part1_index_range)) # Update production status dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in val_metadata.global_indexes} @@ -359,11 +359,11 @@ def test_controller_with_multi_partitions(self, ray_setup): assert not partition_index_range_1_after_clear assert partition_1_after_clear is None - assert partition_index_range_1_after_clear == set() + assert partition_index_range_1_after_clear == [] partition_2 = ray.get(tq_controller.get_partition_snapshot.remote(partition_id_2)) partition_index_range_2 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2)) - assert partition_index_range_2 == set([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]) + assert partition_index_range_2 == [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] assert torch.all( partition_2.production_status[list(partition_index_range_2), : len(val_metadata.field_names)] == 1 ) @@ -387,7 +387,7 @@ def test_controller_with_multi_partitions(self, ray_setup): [int(sample.fields.get("attention_mask").production_status) for sample in metadata_2.samples] ) == int(ProductionStatus.NOT_PRODUCED) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_3)) - assert partition_index_range == set(list(range(32)) + list(range(48, 80))) + assert partition_index_range == list(range(32)) + list(range(48, 80)) print("✓ Correctly assign partition_3") def test_controller_clear_meta(self, ray_setup): diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index f65a360..5cfb6fe 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -57,7 +57,13 @@ def get_meta(data, global_indexes=None): @pytest.fixture def test_data(): """Fixture providing test configuration, data, and metadata.""" - cfg = {"client_name": "YuanrongStorageClient", "host": "127.0.0.1", "port": 31501, "device_id": 0} + cfg = { + "controller_info": MagicMock(), + "client_name": "YuanrongStorageClient", + "host": "127.0.0.1", + "port": 31501, + "device_id": 0, + } global_indexes = [8, 9, 10] data = TensorDict( @@ -288,7 +294,7 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo mock_storage_client.put.return_value = mock_custom_meta # Create manager with mocked dependencies - config = {"client_name": "MockClient"} + config = {"controller_info": MagicMock(), "client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): manager = KVStorageManager(config) @@ -338,7 +344,7 @@ def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data): mock_storage_client.put.return_value = None # Create manager with mocked dependencies - config = {"client_name": "MockClient"} + config = {"controller_info": MagicMock(), "client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): manager = KVStorageManager(config) @@ -361,7 +367,7 @@ def test_put_data_custom_meta_length_mismatch_raises_error(test_data_for_put_dat mock_storage_client.put.return_value = [{"key": "1"}, {"key": "2"}, {"key": "3"}] # Create manager with mocked dependencies - config = {"client_name": "MockClient"} + config = {"controller_info": MagicMock(), "client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): manager = KVStorageManager(config) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 3bcfcc5..8d06c15 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -757,6 +757,7 @@ async def async_get_partition_list( ) try: + assert socket is not None await socket.send_multipart(request_msg.serialize()) response_serialized = await socket.recv_multipart() response_msg = ZMQMessage.deserialize(response_serialized) @@ -991,10 +992,10 @@ def process_zmq_server_info( >>> info_dict = process_zmq_server_info(handlers)""" # Handle single handler object case if not isinstance(handlers, dict): - return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[attr-defined] + return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] else: # Handle dictionary case server_info = {} for name, handler in handlers.items(): - server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[attr-defined] + server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] return server_info diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 2b6767f..eda9267 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -188,7 +188,7 @@ def release_indexes(self, partition_id: str, indexes_to_release: list[int]): if not partition_indexes: self.partition_to_indexes.pop(partition_id, None) - def get_indexes_for_partition(self, partition_id) -> set[int]: + def get_indexes_for_partition(self, partition_id) -> list[int]: """ Get all global_indexes for the specified partition. @@ -196,9 +196,9 @@ def get_indexes_for_partition(self, partition_id) -> set[int]: partition_id: Partition ID Returns: - set: Set of global_indexes for this partition + list: List of global_indexes for this partition """ - return self.partition_to_indexes.get(partition_id, set()).copy() + return list(self.partition_to_indexes.get(partition_id, set()).copy()) @dataclass @@ -216,7 +216,7 @@ class DataPartitionStatus: # Production status tensor - dynamically expandable # Values: 0 = not produced, 1 = ready for consumption - production_status: Optional[Tensor] = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8) + production_status: Tensor = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8) # Consumption status per task - task_name -> consumption_tensor # Each tensor tracks which samples have been consumed by that task @@ -251,16 +251,16 @@ def total_fields_num(self) -> int: @property def allocated_fields_num(self) -> int: """Current number of allocated columns in the tensor.""" - return self.production_status.shape[1] if self.production_status is not None else 0 + return self.production_status.shape[1] @property def allocated_samples_num(self) -> int: """Current number of allocated rows in the tensor.""" - return self.production_status.shape[0] if self.production_status is not None else 0 + return self.production_status.shape[0] # ==================== Dynamic Expansion Methods ==================== - def ensure_samples_capacity(self, required_samples: int) -> bool: + def ensure_samples_capacity(self, required_samples: int) -> None: """ Ensure the production status tensor has enough rows for the required samples. Dynamically expands if needed using unified minimum expansion size. @@ -291,7 +291,7 @@ def ensure_samples_capacity(self, required_samples: int) -> bool: f"to {new_samples} samples (added {min_expansion} samples)" ) - def ensure_fields_capacity(self, required_fields: int): + def ensure_fields_capacity(self, required_fields: int) -> None: """ Ensure the production status tensor has enough columns for the required fields. Dynamically expands if needed using unified minimum expansion size. @@ -299,9 +299,6 @@ def ensure_fields_capacity(self, required_fields: int): Args: required_fields: Minimum number of fields needed """ - if self.production_status is None: - # Will be initialized when samples are added - return current_fields = self.production_status.shape[1] if required_fields > current_fields: @@ -498,7 +495,9 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te return partition_global_index, consumption_status # ==================== Production Status Interface ==================== - def get_production_status_for_fields(self, field_names: list[str], mask: bool = False) -> tuple[Tensor, Tensor]: + def get_production_status_for_fields( + self, field_names: list[str], mask: bool = False + ) -> tuple[Optional[Tensor], Optional[Tensor]]: """ Check if all samples for specified fields are fully produced and ready. @@ -511,13 +510,13 @@ def get_production_status_for_fields(self, field_names: list[str], mask: bool = - Partition global index tensor - Production status tensor for the specified task. 1 for ready, 0 for not ready. """ - if self.production_status is None or field_names is None or len(field_names) == 0: - return False + if field_names is None or len(field_names) == 0: + return None, None # Check if all requested fields are registered for field_name in field_names: if field_name not in self.field_name_mapping: - return False + return None, None # Create column mask for requested fields col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool) @@ -548,8 +547,6 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: Returns: List of sample indices that are ready for consumption """ - if self.production_status is None: - return [] # Check if all requested fields are registered for field_name in field_names: @@ -837,7 +834,7 @@ def list_partitions(self) -> list[str]: # ==================== Partition Index Management API ==================== - def get_partition_index_range(self, partition: DataPartitionStatus) -> set: + def get_partition_index_range(self, partition: DataPartitionStatus) -> list[int]: """ Get all indexes for a specific partition. @@ -845,7 +842,7 @@ def get_partition_index_range(self, partition: DataPartitionStatus) -> set: partition: Partition identifier Returns: - Set of indexes allocated to the partition + List of indexes allocated to the partition """ return self.index_manager.get_indexes_for_partition(partition) @@ -980,6 +977,9 @@ def get_metadata( if mode == "fetch": # Find ready samples within current data partition and package into BatchMeta when reading + if batch_size is None: + raise ValueError("must provide batch_size in fetch mode") + start_time = time.time() while True: # ready_for_consume_indexes: samples where all required fields are produced diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index 510a497..77d7186 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -170,6 +170,8 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: if self._tq_client is None: self._create_client() + assert self._tq_client is not None, "Failed to create TransferQueue client" + # TODO: need to consider async scenario where the samples in partition is dynamically increasing while not self._tq_client.check_consumption_status(self.task_name, self.partition_id): try: diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index af3f321..003fd38 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -265,7 +265,7 @@ def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: """Get the entire custom meta dictionary""" return copy.deepcopy(self._custom_meta) - def update_custom_meta(self, new_custom_meta: dict[int, dict[str, Any]] = None): + def update_custom_meta(self, new_custom_meta: Optional[dict[int, dict[str, Any]]]): """Update custom meta with a new dictionary""" if new_custom_meta: self._custom_meta.update(new_custom_meta) diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index d84b381..1457b36 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -23,6 +23,14 @@ class TransferQueueStorageKVClient(ABC): Subclasses must implement the core methods: put, get, and clear. """ + def __init__(self, config: dict[str, Any]): + """ + Initialize the storage client with configuration. + Args: + config (dict[str, Any]): Configuration dictionary for the storage client. + """ + self.config = config + @abstractmethod def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: """ diff --git a/transfer_queue/storage/clients/factory.py b/transfer_queue/storage/clients/factory.py index 43c5d6d..1d9bda5 100644 --- a/transfer_queue/storage/clients/factory.py +++ b/transfer_queue/storage/clients/factory.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from transfer_queue.storage.clients.base import TransferQueueStorageKVClient @@ -23,7 +24,7 @@ class StorageClientFactory: """ # Class variable: maps client names to their corresponding classes - _registry: dict[str, TransferQueueStorageKVClient] = {} + _registry: dict[str, type[TransferQueueStorageKVClient]] = {} @classmethod def register(cls, client_type: str): @@ -35,7 +36,7 @@ def register(cls, client_type: str): Callable: The decorator function that returns the original class """ - def decorator(client_class: TransferQueueStorageKVClient) -> TransferQueueStorageKVClient: + def decorator(client_class: type[TransferQueueStorageKVClient]) -> type[TransferQueueStorageKVClient]: cls._registry[client_type] = client_class return client_class diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index c233472..08d35cc 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -191,6 +191,7 @@ def mset_zcopy(self, keys: list[str], objs: list[Any]): keys (list[str]): List of string keys under which the objects will be stored. objs (list[Any]): List of Python objects to store (e.g., tensors, strings). """ + assert self._cpu_ds_client is not None, "CPU DS client is not available" items_list = [[memoryview(b) for b in _encoder.encode(obj)] for obj in objs] packed_sizes = [calc_packed_size(items) for items in items_list] status, buffers = self._cpu_ds_client.mcreate(keys, packed_sizes) @@ -208,6 +209,7 @@ def mget_zcopy(self, keys: list[str]) -> list[Any]: Returns: list[Any]: List of deserialized objects corresponding to the input keys. """ + assert self._cpu_ds_client is not None, "CPU DS client is not available" status, buffers = self._cpu_ds_client.get_buffers(keys, timeout_ms=500) return [_decoder.decode(unpack_from(buffer)) if buffer is not None else None for buffer in buffers] @@ -241,6 +243,7 @@ def _batch_put(self, keys: list[str], values: list[Any]): cpu_values.append(pickle.dumps(value)) # put NPU data + assert self._npu_ds_client is not None, "NPU DS client is not available" for i in range(0, len(npu_keys), NPU_DS_CLIENT_KEYS_LIMIT): batch_keys = npu_keys[i : i + NPU_DS_CLIENT_KEYS_LIMIT] batch_values = npu_values[i : i + NPU_DS_CLIENT_KEYS_LIMIT] @@ -253,6 +256,7 @@ def _batch_put(self, keys: list[str], values: list[Any]): self._npu_ds_client.dev_mset(batch_keys, batch_values) # put CPU data + assert self._cpu_ds_client is not None, "CPU DS client is not available" for i in range(0, len(cpu_keys), CPU_DS_CLIENT_KEYS_LIMIT): batch_keys = cpu_keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] batch_values = cpu_values[i : i + CPU_DS_CLIENT_KEYS_LIMIT] @@ -320,6 +324,7 @@ def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: results = [None] * len(keys) # Fetch NPU tensors + assert self._npu_ds_client is not None, "NPU DS client is not available" for i in range(0, len(npu_keys), NPU_DS_CLIENT_KEYS_LIMIT): batch_keys = npu_keys[i : i + NPU_DS_CLIENT_KEYS_LIMIT] batch_shapes = npu_shapes[i : i + NPU_DS_CLIENT_KEYS_LIMIT] @@ -347,6 +352,7 @@ def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: cpu_indices.extend([batch_indices[j] for j, k in enumerate(batch_keys) if k in failed_set]) # Fetch CPU/general objects (including NPU fallbacks) + assert self._cpu_ds_client is not None, "CPU DS client is not available" for i in range(0, len(cpu_keys), CPU_DS_CLIENT_KEYS_LIMIT): batch_keys = cpu_keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] batch_indices = cpu_indices[i : i + CPU_DS_CLIENT_KEYS_LIMIT] @@ -398,6 +404,8 @@ def _batch_clear(self, keys: list[str]): keys (List[str]): Keys to delete. """ if self.npu_ds_client_is_available(): + assert self._npu_ds_client is not None, "NPU DS client is not available" + assert self._cpu_ds_client is not None, "CPU DS client is not available" # Try to delete all keys via npu client for i in range(0, len(keys), NPU_DS_CLIENT_KEYS_LIMIT): batch = keys[i : i + NPU_DS_CLIENT_KEYS_LIMIT] @@ -408,6 +416,7 @@ def _batch_clear(self, keys: list[str]): sub_batch = keys[j : j + CPU_DS_CLIENT_KEYS_LIMIT] self._cpu_ds_client.delete(sub_batch) else: + assert self._cpu_ds_client is not None, "CPU DS client is not available" for i in range(0, len(keys), CPU_DS_CLIENT_KEYS_LIMIT): batch = keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] self._cpu_ds_client.delete(batch) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 942ea68..94854ce 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -20,7 +20,7 @@ import weakref from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import Any, Optional from uuid import uuid4 import ray @@ -60,12 +60,14 @@ class TransferQueueStorageManager(ABC): def __init__(self, config: dict[str, Any]): self.storage_manager_id = f"TQ_STORAGE_{uuid4().hex[:8]}" self.config = config - self.controller_info = config.get("controller_info", None) # type: ZMQServerInfo + controller_info = config.get("controller_info") + assert controller_info is not None, "controller_info is required" + self.controller_info: ZMQServerInfo = controller_info - self.data_status_update_socket = None - self.controller_handshake_socket = None + self.data_status_update_socket: Optional[zmq.Socket[bytes]] = None + self.controller_handshake_socket: Optional[zmq.Socket[bytes]] = None - self.zmq_context = None + self.zmq_context: Optional[zmq.Context[Any]] = None self._connect_to_controller() def _connect_to_controller(self) -> None: @@ -88,6 +90,7 @@ def _connect_to_controller(self) -> None: zmq.DEALER, identity=f"{self.storage_manager_id}-data_status_update_socket-{uuid4().hex[:8]}".encode(), ) + assert self.data_status_update_socket is not None, "data_status_update_socket is not properly initialized" self.data_status_update_socket.connect(self.controller_info.to_addr("data_status_update_socket")) # do handshake with controller @@ -106,6 +109,7 @@ def _do_handshake_with_controller(self) -> None: # Create zmq poller for handshake confirmation between controller and storage manager poller = zmq.Poller() + assert self.controller_handshake_socket is not None, "controller_handshake_socket is not properly initialized" self.controller_handshake_socket.connect(self.controller_info.to_addr("handshake_socket")) logger.debug( f"[{self.storage_manager_id}]: Handshake connection from storage manager id #{self.storage_manager_id} " @@ -170,6 +174,7 @@ def _do_handshake_with_controller(self) -> None: def _send_handshake_requests(self) -> None: """Send handshake request to controller.""" + assert self.controller_handshake_socket is not None, "controller_handshake_socket is not properly initialized" request_msg = ZMQMessage.create( request_type=ZMQRequestType.HANDSHAKE, sender_id=self.storage_manager_id, @@ -191,7 +196,7 @@ async def notify_data_update( global_indexes: list[int], dtypes: dict[int, dict[str, Any]], shapes: dict[int, dict[str, Any]], - custom_meta: dict[int, dict[str, Any]] = None, + custom_meta: Optional[dict[int, dict[str, Any]]] = None, ) -> None: """ Notify controller that new data is ready. @@ -213,6 +218,7 @@ async def notify_data_update( # Create zmq poller for notifying data update information poller = zmq.Poller() # Note: data_status_update_socket is already connected during initialization + assert self.data_status_update_socket is not None, "data_status_update_socket is not properly initialized" try: poller.register(self.data_status_update_socket, zmq.POLLIN) @@ -352,7 +358,7 @@ def __init__(self, config: dict[str, Any]): raise ValueError("Missing client_name in config") super().__init__(config) self.storage_client = StorageClientFactory.create(client_name, config) - self._multi_threads_executor = None + self._multi_threads_executor: Optional[ThreadPoolExecutor] = None # Register a cleanup function: automatically invoke shutdown when the instance is garbage collected. self._executor_finalizer = weakref.finalize(self, self._shutdown_executor, self._multi_threads_executor) @@ -390,7 +396,7 @@ def _generate_values(data: TensorDict) -> list[Tensor]: return [row_data for field in sorted(data.keys()) for row_data in data[field]] @staticmethod - def _shutdown_executor(thread_executor: ThreadPoolExecutor): + def _shutdown_executor(thread_executor: Optional[ThreadPoolExecutor]) -> None: """ A static method to ensure no strong reference to 'self' is held within the finalizer's callback, enabling proper garbage collection. @@ -421,6 +427,7 @@ def _get_executor(self) -> ThreadPoolExecutor: max_workers=self._num_threads, thread_name_prefix="KVStorageManager" ) + assert self._multi_threads_executor is not None return self._multi_threads_executor def _merge_tensors_to_tensordict(self, metadata: BatchMeta, values: list[Tensor]) -> TensorDict: @@ -519,6 +526,7 @@ def _get_shape_type_custom_meta_list(metadata: BatchMeta): for field_name in sorted(metadata.field_names): for index in range(len(metadata)): field = metadata.samples[index].get_field_by_name(field_name) + assert field is not None, f"Field {field_name} not found in sample {index}" shapes.append(field.shape) dtypes.append(field.dtype) global_index = metadata.global_indexes[index] @@ -547,8 +555,8 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: loop = asyncio.get_event_loop() custom_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) - per_field_dtypes = {} - per_field_shapes = {} + per_field_dtypes: dict[int, dict[str, Any]] = {} + per_field_shapes: dict[int, dict[str, Any]] = {} # Initialize the data structure for each global index for global_idx in metadata.global_indexes: @@ -567,7 +575,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: ) # Prepare per-field custom_meta if available - per_field_custom_meta = {} + per_field_custom_meta: dict[int, dict[str, Any]] = {} if custom_meta: if len(custom_meta) != len(keys): raise ValueError(f"Length of custom_meta ({len(custom_meta)}) does not match expected ({len(keys)})") diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 92ec897..250438c 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -57,7 +57,7 @@ def __init__(self, config: dict[str, Any]): super().__init__(config) self.config = config - server_infos = config.get("storage_unit_infos", None) # type: ZMQServerInfo | dict[str, ZMQServerInfo] + server_infos: ZMQServerInfo | dict[str, ZMQServerInfo] | None = config.get("storage_unit_infos", None) if server_infos is None: raise ValueError("AsyncSimpleStorageManager requires non-empty 'storage_unit_infos' in config.") @@ -198,8 +198,8 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: # Gather per-field dtype and shape information for each field # global_indexes, local_indexes, and field_data correspond one-to-one - per_field_dtypes = {} - per_field_shapes = {} + per_field_dtypes: dict[int, dict[str, Any]] = {} + per_field_shapes: dict[int, dict[str, Any]] = {} # Initialize the data structure for each global index for global_idx in metadata.global_indexes: @@ -440,7 +440,7 @@ def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict) """ # We use dict here instead of TensorDict to avoid unnecessary TensorDict overhead - results = {} + results: dict[str, Any] = {} batch_indexes = storage_meta_group.get_batch_indexes() if not batch_indexes: