diff --git a/doctr/models/classification/predictor/pytorch.py b/doctr/models/classification/predictor/pytorch.py index af59b4a8c7..280743af4d 100644 --- a/doctr/models/classification/predictor/pytorch.py +++ b/doctr/models/classification/predictor/pytorch.py @@ -37,6 +37,9 @@ def forward( self, inputs: list[np.ndarray], ) -> list[list[int] | list[float]]: + if len(inputs) == 0: + return [[], [], []] + # Dimension check if any(input.ndim != 3 for input in inputs): raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.") diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py index fc0bbad94d..e5d9add27e 100644 --- a/tests/pytorch/test_models_classification_pt.py +++ b/tests/pytorch/test_models_classification_pt.py @@ -137,6 +137,7 @@ def test_crop_orientation_model(mock_text_box): # 270 degrees is equivalent to -90 degrees assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + assert classifier([]) == [[], [], []] # Test with disabled predictor classifier = classification.crop_orientation_predictor( @@ -147,6 +148,7 @@ def test_crop_orientation_model(mock_text_box): [0, 0, 0, 0], [1.0, 1.0, 1.0, 1.0], ] + assert classifier([]) == [[], [], []] # Test custom model loading classifier = classification.crop_orientation_predictor( @@ -182,6 +184,7 @@ def test_page_orientation_model(mock_payslip): # 270 degrees is equivalent to -90 degrees assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + assert classifier([]) == [[], [], []] # Test with disabled predictor classifier = classification.page_orientation_predictor( @@ -192,6 +195,7 @@ def test_page_orientation_model(mock_payslip): [0, 0, 0, 0], [1.0, 1.0, 1.0, 1.0], ] + assert classifier([]) == [[], [], []] # Test custom model loading classifier = classification.page_orientation_predictor(