Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.

- Speed up localization model loading and inference ([#15](https://github.com/microsoft/retrochimera/pull/15), [#16](https://github.com/microsoft/retrochimera/pull/16)) ([@kmaziarz])
- Speed up SMILES Transformer model inference ([#17](https://github.com/microsoft/retrochimera/pull/17)) ([@kmaziarz])
- Run submodel inference in parallel ([#18](https://github.com/microsoft/retrochimera/pull/18)) ([@kmaziarz])
- Drop the explicit TensorBoard dependency ([#12](https://github.com/microsoft/retrochimera/pull/12)) ([@kmaziarz])

### Added
Expand Down
40 changes: 37 additions & 3 deletions retrochimera/inference/retrochimera.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import json
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Optional, Sequence, Union

Expand All @@ -11,8 +12,11 @@
from syntheseus.reaction_prediction.inference_base import ExternalBackwardReactionModel

from retrochimera import inference
from retrochimera.utils.logging import get_logger
from retrochimera.utils.misc import lookup_by_name

logger = get_logger(__name__)


class RetroChimeraModel(ExternalBackwardReactionModel):
"""Wrapper for the RetroChimera model."""
Expand All @@ -22,6 +26,7 @@ def __init__(
*args,
model_dir: Optional[Union[str, Path]] = None,
probability_from_score_temperature: float = 8.0,
call_submodels_in_parallel: bool = True,
**kwargs,
) -> None:
"""Initializes the RetroChimera model wrapper.
Expand Down Expand Up @@ -54,6 +59,23 @@ def __init__(

self._init_from_dir(model_dir=model_dir, model_data=model_data, model_kwargs=model_kwargs)
self.probability_from_score_temperature = probability_from_score_temperature
self._call_submodels_in_parallel = call_submodels_in_parallel

if not self._models:
raise RuntimeError(f"No submodels found in {model_dir}")

# Set up CUDA streams to allow executing submodels in parallel.
self._cached_streams: Optional[list[torch.cuda.Stream]] = None
self._cached_executor: Optional[ThreadPoolExecutor] = None

if call_submodels_in_parallel and len(self._models) > 1:
if self.device.startswith("cuda"):
self._cached_streams = [torch.cuda.Stream() for _ in self._models]
self._cached_executor = ThreadPoolExecutor(max_workers=len(self._models))
else:
logger.warning(
f"Submodels will run sequentially as chosen device is not CUDA ({self.device})"
)

def _load_model_data_from_dir(self, model_dir: Path) -> dict[str, tuple[str, list[float]]]:
with open(model_dir / "models.json") as f:
Expand Down Expand Up @@ -88,9 +110,21 @@ def get_parameters(self):
def _get_reactions(
self, inputs: list[Molecule], num_results: int
) -> list[Sequence[SingleProductReaction]]:
model_batch_results: list[list[Sequence[SingleProductReaction]]] = [
model(inputs=inputs, num_results=num_results) for model in self._models
]
if self._cached_streams is not None and self._cached_executor is not None:

def _run(model, stream):
with torch.cuda.stream(stream):
out = model(inputs=inputs, num_results=num_results)
stream.synchronize()
return out

model_batch_results = list(
self._cached_executor.map(_run, self._models, self._cached_streams)
)
else:
model_batch_results = [
model(inputs=inputs, num_results=num_results) for model in self._models
]

return [
combine_results(
Expand Down
Loading