diff --git a/CHANGELOG.md b/CHANGELOG.md index 28abd24..13f7d4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0. - Speed up localization model loading and inference ([#15](https://github.com/microsoft/retrochimera/pull/15), [#16](https://github.com/microsoft/retrochimera/pull/16)) ([@kmaziarz]) - Speed up SMILES Transformer model inference ([#17](https://github.com/microsoft/retrochimera/pull/17)) ([@kmaziarz]) +- Run submodel inference in parallel ([#18](https://github.com/microsoft/retrochimera/pull/18)) ([@kmaziarz]) - Drop the explicit TensorBoard dependency ([#12](https://github.com/microsoft/retrochimera/pull/12)) ([@kmaziarz]) ### Added diff --git a/retrochimera/inference/retrochimera.py b/retrochimera/inference/retrochimera.py index 5b3a545..7af87b7 100644 --- a/retrochimera/inference/retrochimera.py +++ b/retrochimera/inference/retrochimera.py @@ -1,6 +1,7 @@ import itertools import json from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, Optional, Sequence, Union @@ -11,8 +12,11 @@ from syntheseus.reaction_prediction.inference_base import ExternalBackwardReactionModel from retrochimera import inference +from retrochimera.utils.logging import get_logger from retrochimera.utils.misc import lookup_by_name +logger = get_logger(__name__) + class RetroChimeraModel(ExternalBackwardReactionModel): """Wrapper for the RetroChimera model.""" @@ -22,6 +26,7 @@ def __init__( *args, model_dir: Optional[Union[str, Path]] = None, probability_from_score_temperature: float = 8.0, + call_submodels_in_parallel: bool = True, **kwargs, ) -> None: """Initializes the RetroChimera model wrapper. @@ -54,6 +59,23 @@ def __init__( self._init_from_dir(model_dir=model_dir, model_data=model_data, model_kwargs=model_kwargs) self.probability_from_score_temperature = probability_from_score_temperature + self._call_submodels_in_parallel = call_submodels_in_parallel + + if not self._models: + raise RuntimeError(f"No submodels found in {model_dir}") + + # Set up CUDA streams to allow executing submodels in parallel. + self._cached_streams: Optional[list[torch.cuda.Stream]] = None + self._cached_executor: Optional[ThreadPoolExecutor] = None + + if call_submodels_in_parallel and len(self._models) > 1: + if self.device.startswith("cuda"): + self._cached_streams = [torch.cuda.Stream() for _ in self._models] + self._cached_executor = ThreadPoolExecutor(max_workers=len(self._models)) + else: + logger.warning( + f"Submodels will run sequentially as chosen device is not CUDA ({self.device})" + ) def _load_model_data_from_dir(self, model_dir: Path) -> dict[str, tuple[str, list[float]]]: with open(model_dir / "models.json") as f: @@ -88,9 +110,21 @@ def get_parameters(self): def _get_reactions( self, inputs: list[Molecule], num_results: int ) -> list[Sequence[SingleProductReaction]]: - model_batch_results: list[list[Sequence[SingleProductReaction]]] = [ - model(inputs=inputs, num_results=num_results) for model in self._models - ] + if self._cached_streams is not None and self._cached_executor is not None: + + def _run(model, stream): + with torch.cuda.stream(stream): + out = model(inputs=inputs, num_results=num_results) + stream.synchronize() + return out + + model_batch_results = list( + self._cached_executor.map(_run, self._models, self._cached_streams) + ) + else: + model_batch_results = [ + model(inputs=inputs, num_results=num_results) for model in self._models + ] return [ combine_results(