diff --git a/src/caked/dataloader.py b/src/caked/dataloader.py index a994895..95ab944 100644 --- a/src/caked/dataloader.py +++ b/src/caked/dataloader.py @@ -312,8 +312,12 @@ def read(self, filename): return np.load(filename) if self.datatype == "mrc": - with mrcfile.open(filename) as f: - return np.array(f.data) + try: + with mrcfile.open(filename) as f: + return np.array(f.data) + except ValueError as exc: + msg = f"File {filename} is corrupted." + raise ValueError(msg) from exc else: msg = "Currently we only support mrcfile and numpy arrays." diff --git a/tests/corrupt.mrc b/tests/corrupt.mrc new file mode 100644 index 0000000..f771049 --- /dev/null +++ b/tests/corrupt.mrc @@ -0,0 +1 @@ +This file is corrupt to test the exception throw. diff --git a/tests/test_disk_io.py b/tests/test_disk_io.py index 68f438d..f9a4a7e 100644 --- a/tests/test_disk_io.py +++ b/tests/test_disk_io.py @@ -12,6 +12,7 @@ ORIG_DIR = Path.cwd() TEST_DATA_MRC = Path(testdata_mrc.__file__).parent TEST_DATA_NPY = Path(testdata_npy.__file__).parent +TEST_CORRUPT = Path(__file__).parent / "corrupt.mrc" DISK_PIPELINE = "disk" DATASET_SIZE_ALL = None DATASET_SIZE_SOME = 3 @@ -366,3 +367,13 @@ def test_drop_last(): ) assert loader_train_false.drop_last assert loader_val_false.drop_last + + +def test_corrupt_mrcfile(): + """ + Test that corrupt mrcfiles are not loaded and throw an exception. + """ + test_dataset = DiskDataset(paths=TEST_DATA_MRC, datatype=DATATYPE_MRC) + assert isinstance(test_dataset, DiskDataset) + with pytest.raises(Exception, match=r".* corrupted."): + test_dataset.read(TEST_CORRUPT)