Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,18 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest build pytest_asyncio pytest-mock openyuanrong-datasystem
python -m build --wheel
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install dist/*.whl
pip install -e ".[test,build,yuanrong]"
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Build wheel and test installed distribution
run: |
python -m build --wheel
pip install dist/*.whl --force-reinstall
- name: Test with pytest
run: |
pytest
7 changes: 0 additions & 7 deletions .github/workflows/sanity.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@ jobs:
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install build
python -m build --wheel
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install dist/*.whl
- name: Run license test
run: |
python3 tests/sanity/check_license.py --directories .
Expand Down
13 changes: 9 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@ filterwarnings = [

[[tool.mypy.overrides]]
module = [
"transfer_queue.data_system.*",
"transfer_queue.utils.utils.*",
"transfer_queue.utils.zmq_utils.*",
"transfer_queue.utils.serial_utils.*",
"transfer_queue.*",
]
ignore_errors = false

Expand All @@ -108,9 +105,17 @@ version = {file = "transfer_queue/version/version"}
dependencies = {file = "requirements.txt"}

[project.optional-dependencies]
build = [
"build"
]
test = [
"pytest>=7.0.0",
"pytest-asyncio>=0.20.0",
"flake8",
"pytest-mock",
]
yuanrong = [
"openyuanrong-datasystem"
]

# If you need to mimic `package_dir={'': '.'}`:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_controller_with_single_partition(self, ray_setup):
ProductionStatus.NOT_PRODUCED
)
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
assert partition_index_range == set(range(gbs * num_n_samples))
assert partition_index_range == list(range(gbs * num_n_samples))

print("✓ Initial get metadata correct")

Expand Down Expand Up @@ -194,7 +194,7 @@ def test_controller_with_single_partition(self, ray_setup):
ray.get(tq_controller.clear_partition.remote(partition_id))
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
assert partition_index_range == set()
assert partition_index_range == []
assert partition is None
print("✓ Clear partition correct")

Expand Down Expand Up @@ -307,7 +307,7 @@ def test_controller_with_multi_partitions(self, ray_setup):
[int(sample.fields.get("attention_mask").production_status) for sample in val_metadata.samples]
) == int(ProductionStatus.NOT_PRODUCED)
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
assert partition_index_range == set(range(part1_index_range, part2_index_range + part1_index_range))
assert partition_index_range == list(range(part1_index_range, part2_index_range + part1_index_range))

# Update production status
dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in val_metadata.global_indexes}
Expand Down Expand Up @@ -359,11 +359,11 @@ def test_controller_with_multi_partitions(self, ray_setup):

assert not partition_index_range_1_after_clear
assert partition_1_after_clear is None
assert partition_index_range_1_after_clear == set()
assert partition_index_range_1_after_clear == []

partition_2 = ray.get(tq_controller.get_partition_snapshot.remote(partition_id_2))
partition_index_range_2 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
assert partition_index_range_2 == set([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
assert partition_index_range_2 == [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]
assert torch.all(
partition_2.production_status[list(partition_index_range_2), : len(val_metadata.field_names)] == 1
)
Expand All @@ -387,7 +387,7 @@ def test_controller_with_multi_partitions(self, ray_setup):
[int(sample.fields.get("attention_mask").production_status) for sample in metadata_2.samples]
) == int(ProductionStatus.NOT_PRODUCED)
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_3))
assert partition_index_range == set(list(range(32)) + list(range(48, 80)))
assert partition_index_range == list(range(32)) + list(range(48, 80))
print("✓ Correctly assign partition_3")

def test_controller_clear_meta(self, ray_setup):
Expand Down
14 changes: 10 additions & 4 deletions tests/test_kv_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ def get_meta(data, global_indexes=None):
@pytest.fixture
def test_data():
"""Fixture providing test configuration, data, and metadata."""
cfg = {"client_name": "YuanrongStorageClient", "host": "127.0.0.1", "port": 31501, "device_id": 0}
cfg = {
"controller_info": MagicMock(),
"client_name": "YuanrongStorageClient",
"host": "127.0.0.1",
"port": 31501,
"device_id": 0,
}
global_indexes = [8, 9, 10]

data = TensorDict(
Expand Down Expand Up @@ -288,7 +294,7 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo
mock_storage_client.put.return_value = mock_custom_meta

# Create manager with mocked dependencies
config = {"client_name": "MockClient"}
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
manager = KVStorageManager(config)

Expand Down Expand Up @@ -338,7 +344,7 @@ def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data):
mock_storage_client.put.return_value = None

# Create manager with mocked dependencies
config = {"client_name": "MockClient"}
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
manager = KVStorageManager(config)

Expand All @@ -361,7 +367,7 @@ def test_put_data_custom_meta_length_mismatch_raises_error(test_data_for_put_dat
mock_storage_client.put.return_value = [{"key": "1"}, {"key": "2"}, {"key": "3"}]

# Create manager with mocked dependencies
config = {"client_name": "MockClient"}
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
manager = KVStorageManager(config)

Expand Down
5 changes: 3 additions & 2 deletions transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,7 @@ async def async_get_partition_list(
)

try:
assert socket is not None
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart()
response_msg = ZMQMessage.deserialize(response_serialized)
Expand Down Expand Up @@ -991,10 +992,10 @@ def process_zmq_server_info(
>>> info_dict = process_zmq_server_info(handlers)"""
# Handle single handler object case
if not isinstance(handlers, dict):
return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[attr-defined]
return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined]
else:
# Handle dictionary case
server_info = {}
for name, handler in handlers.items():
server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[attr-defined]
server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined]
return server_info
38 changes: 19 additions & 19 deletions transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,17 +188,17 @@ def release_indexes(self, partition_id: str, indexes_to_release: list[int]):
if not partition_indexes:
self.partition_to_indexes.pop(partition_id, None)

def get_indexes_for_partition(self, partition_id) -> set[int]:
def get_indexes_for_partition(self, partition_id) -> list[int]:
"""
Get all global_indexes for the specified partition.

Args:
partition_id: Partition ID

Returns:
set: Set of global_indexes for this partition
list: List of global_indexes for this partition
"""
return self.partition_to_indexes.get(partition_id, set()).copy()
return list(self.partition_to_indexes.get(partition_id, set()).copy())


@dataclass
Expand All @@ -216,7 +216,7 @@ class DataPartitionStatus:

# Production status tensor - dynamically expandable
# Values: 0 = not produced, 1 = ready for consumption
production_status: Optional[Tensor] = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8)
production_status: Tensor = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8)

# Consumption status per task - task_name -> consumption_tensor
# Each tensor tracks which samples have been consumed by that task
Expand Down Expand Up @@ -251,16 +251,16 @@ def total_fields_num(self) -> int:
@property
def allocated_fields_num(self) -> int:
"""Current number of allocated columns in the tensor."""
return self.production_status.shape[1] if self.production_status is not None else 0
return self.production_status.shape[1]

@property
def allocated_samples_num(self) -> int:
"""Current number of allocated rows in the tensor."""
return self.production_status.shape[0] if self.production_status is not None else 0
return self.production_status.shape[0]

# ==================== Dynamic Expansion Methods ====================

def ensure_samples_capacity(self, required_samples: int) -> bool:
def ensure_samples_capacity(self, required_samples: int) -> None:
"""
Ensure the production status tensor has enough rows for the required samples.
Dynamically expands if needed using unified minimum expansion size.
Expand Down Expand Up @@ -291,17 +291,14 @@ def ensure_samples_capacity(self, required_samples: int) -> bool:
f"to {new_samples} samples (added {min_expansion} samples)"
)

def ensure_fields_capacity(self, required_fields: int):
def ensure_fields_capacity(self, required_fields: int) -> None:
"""
Ensure the production status tensor has enough columns for the required fields.
Dynamically expands if needed using unified minimum expansion size.

Args:
required_fields: Minimum number of fields needed
"""
if self.production_status is None:
# Will be initialized when samples are added
return

current_fields = self.production_status.shape[1]
if required_fields > current_fields:
Expand Down Expand Up @@ -498,7 +495,9 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te
return partition_global_index, consumption_status

# ==================== Production Status Interface ====================
def get_production_status_for_fields(self, field_names: list[str], mask: bool = False) -> tuple[Tensor, Tensor]:
def get_production_status_for_fields(
self, field_names: list[str], mask: bool = False
) -> tuple[Optional[Tensor], Optional[Tensor]]:
"""
Check if all samples for specified fields are fully produced and ready.

Expand All @@ -511,13 +510,13 @@ def get_production_status_for_fields(self, field_names: list[str], mask: bool =
- Partition global index tensor
- Production status tensor for the specified task. 1 for ready, 0 for not ready.
"""
if self.production_status is None or field_names is None or len(field_names) == 0:
return False
if field_names is None or len(field_names) == 0:
return None, None

# Check if all requested fields are registered
for field_name in field_names:
if field_name not in self.field_name_mapping:
return False
return None, None

# Create column mask for requested fields
col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool)
Expand Down Expand Up @@ -548,8 +547,6 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]:
Returns:
List of sample indices that are ready for consumption
"""
if self.production_status is None:
return []

# Check if all requested fields are registered
for field_name in field_names:
Expand Down Expand Up @@ -837,15 +834,15 @@ def list_partitions(self) -> list[str]:

# ==================== Partition Index Management API ====================

def get_partition_index_range(self, partition: DataPartitionStatus) -> set:
def get_partition_index_range(self, partition: DataPartitionStatus) -> list[int]:
"""
Get all indexes for a specific partition.

Args:
partition: Partition identifier

Returns:
Set of indexes allocated to the partition
List of indexes allocated to the partition
"""
return self.index_manager.get_indexes_for_partition(partition)

Expand Down Expand Up @@ -980,6 +977,9 @@ def get_metadata(
if mode == "fetch":
# Find ready samples within current data partition and package into BatchMeta when reading

if batch_size is None:
raise ValueError("must provide batch_size in fetch mode")

start_time = time.time()
while True:
# ready_for_consume_indexes: samples where all required fields are produced
Expand Down
2 changes: 2 additions & 0 deletions transfer_queue/dataloader/streaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
if self._tq_client is None:
self._create_client()

assert self._tq_client is not None, "Failed to create TransferQueue client"

# TODO: need to consider async scenario where the samples in partition is dynamically increasing
while not self._tq_client.check_consumption_status(self.task_name, self.partition_id):
try:
Expand Down
2 changes: 1 addition & 1 deletion transfer_queue/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def get_all_custom_meta(self) -> dict[int, dict[str, Any]]:
"""Get the entire custom meta dictionary"""
return copy.deepcopy(self._custom_meta)

def update_custom_meta(self, new_custom_meta: dict[int, dict[str, Any]] = None):
def update_custom_meta(self, new_custom_meta: Optional[dict[int, dict[str, Any]]]):
"""Update custom meta with a new dictionary"""
if new_custom_meta:
self._custom_meta.update(new_custom_meta)
Expand Down
8 changes: 8 additions & 0 deletions transfer_queue/storage/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ class TransferQueueStorageKVClient(ABC):
Subclasses must implement the core methods: put, get, and clear.
"""

def __init__(self, config: dict[str, Any]):
"""
Initialize the storage client with configuration.
Args:
config (dict[str, Any]): Configuration dictionary for the storage client.
"""
self.config = config

@abstractmethod
def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
"""
Expand Down
5 changes: 3 additions & 2 deletions transfer_queue/storage/clients/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from transfer_queue.storage.clients.base import TransferQueueStorageKVClient


Expand All @@ -23,7 +24,7 @@ class StorageClientFactory:
"""

# Class variable: maps client names to their corresponding classes
_registry: dict[str, TransferQueueStorageKVClient] = {}
_registry: dict[str, type[TransferQueueStorageKVClient]] = {}

@classmethod
def register(cls, client_type: str):
Expand All @@ -35,7 +36,7 @@ def register(cls, client_type: str):
Callable: The decorator function that returns the original class
"""

def decorator(client_class: TransferQueueStorageKVClient) -> TransferQueueStorageKVClient:
def decorator(client_class: type[TransferQueueStorageKVClient]) -> type[TransferQueueStorageKVClient]:
cls._registry[client_type] = client_class
return client_class

Expand Down
Loading