Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/caked/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,12 @@ def read(self, filename):
return np.load(filename)

if self.datatype == "mrc":
with mrcfile.open(filename) as f:
return np.array(f.data)
try:
with mrcfile.open(filename) as f:
return np.array(f.data)
except ValueError as exc:
msg = f"File {filename} is corrupted."
raise ValueError(msg) from exc

else:
msg = "Currently we only support mrcfile and numpy arrays."
Expand Down
1 change: 1 addition & 0 deletions tests/corrupt.mrc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This file is corrupt to test the exception throw.
11 changes: 11 additions & 0 deletions tests/test_disk_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ORIG_DIR = Path.cwd()
TEST_DATA_MRC = Path(testdata_mrc.__file__).parent
TEST_DATA_NPY = Path(testdata_npy.__file__).parent
TEST_CORRUPT = Path(__file__).parent / "corrupt.mrc"
DISK_PIPELINE = "disk"
DATASET_SIZE_ALL = None
DATASET_SIZE_SOME = 3
Expand Down Expand Up @@ -366,3 +367,13 @@ def test_drop_last():
)
assert loader_train_false.drop_last
assert loader_val_false.drop_last


def test_corrupt_mrcfile():
"""
Test that corrupt mrcfiles are not loaded and throw an exception.
"""
test_dataset = DiskDataset(paths=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_dataset, DiskDataset)
with pytest.raises(Exception, match=r".* corrupted."):
test_dataset.read(TEST_CORRUPT)