diff --git a/hi-ml-multimodal/src/health_multimodal/image/model/pretrained.py b/hi-ml-multimodal/src/health_multimodal/image/model/pretrained.py index 45d736efb..0ef67f2c3 100644 --- a/hi-ml-multimodal/src/health_multimodal/image/model/pretrained.py +++ b/hi-ml-multimodal/src/health_multimodal/image/model/pretrained.py @@ -83,4 +83,39 @@ def get_biovil_t_image_encoder() -> ImageModel: joint_feature_size=JOINT_FEATURE_SIZE, pretrained_model_path=biovilt_checkpoint_path, ) + + +def get_biovil_t_linear_image_classifier(biovilt_checkpoint_path: str) -> ImageModel: + """ + Download weights from Hugging Face and instantiate the image model. + + The model is initialized with a linear classifier on top of the + BiomedVLP-BioViL-T image encoder. + + Binary classification tasks in order: + ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Edema', 'Consolidation', + 'Pneumonia', 'Pneumothorax', 'Pleural Effusion', 'No Finding'] + + :param biovilt_checkpoint_path: Path to the checkpoint file. + + Example: + >>> checkpoint_path = "..." + >>> image_model = get_biovil_t_linear_image_classifier(checkpoint_path) + >>> image_model(torch.Tensor(batch_size, 3, 448, 448)).class_logits.shape + torch.Size([batch_size, num_classes, num_tasks]) + """ + + num_classes = 2 + num_tasks = 8 + + model_type = ImageEncoderType.RESNET50_MULTI_IMAGE + image_model = ImageModel( + img_encoder_type=model_type, + joint_feature_size=JOINT_FEATURE_SIZE, + pretrained_model_path=biovilt_checkpoint_path, + num_classes=num_classes, + num_tasks=num_tasks, + classifier_hidden_dim=None, + ) + return image_model diff --git a/hi-ml-multimodal/test_multimodal/image/model/test_pretrained.py b/hi-ml-multimodal/test_multimodal/image/model/test_pretrained.py new file mode 100644 index 000000000..515886cdb --- /dev/null +++ b/hi-ml-multimodal/test_multimodal/image/model/test_pretrained.py @@ -0,0 +1,42 @@ +import pytest +import torch + +from health_multimodal.image.model.model import ImageModel +from health_multimodal.image.model.pretrained import get_biovil_t_linear_image_classifier + + +@pytest.fixture +def dummy_input_tensor() -> torch.Tensor: + batch_size = 2 + return torch.randn(batch_size, 3, 448, 448) + + +@pytest.fixture +def biovil_t_linear_image_classifier() -> ImageModel: + # Set the path to None to initialize the model weights with random values + biovil_t_checkpoint_path = None + return get_biovil_t_linear_image_classifier(biovil_t_checkpoint_path) + + +def test_get_biovil_t_linear_image_classifier_shape( + biovil_t_linear_image_classifier: ImageModel, dummy_input_tensor: torch.Tensor +) -> None: + num_classes = 2 + num_tasks = 8 + + output = biovil_t_linear_image_classifier(dummy_input_tensor) + expected_shape = torch.Size([dummy_input_tensor.shape[0], num_classes, num_tasks]) + + assert ( + output.class_logits.shape == expected_shape + ), f"Unexpected output shape. Expected {expected_shape} but got {output.class_logits.shape}" + + +def test_get_biovil_t_linear_image_classifier_inference( + biovil_t_linear_image_classifier: ImageModel, dummy_input_tensor: torch.Tensor +) -> None: + with torch.no_grad(): + try: + _ = biovil_t_linear_image_classifier(dummy_input_tensor) + except Exception as e: + pytest.fail(f"Model inference failed: {e}")