diff --git a/chemap/fingerprint_computation.py b/chemap/fingerprint_computation.py index 798233d..6c7229e 100644 --- a/chemap/fingerprint_computation.py +++ b/chemap/fingerprint_computation.py @@ -3,7 +3,9 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Union import numpy as np import scipy.sparse as sp +from joblib import Parallel, delayed from rdkit import Chem +from sklearn.base import BaseEstimator, TransformerMixin from tqdm import tqdm @@ -74,6 +76,9 @@ class FingerprintConfig: class SklearnTransformer(Protocol): """Protocol for sklearn-like fingerprint transformers (including scikit-fingerprints).""" + def fit(self, X: Any, y: Any = None) -> "SklearnTransformer": + ... + def transform(self, X: Sequence[str]) -> Any: ... @@ -81,6 +86,19 @@ def get_params(self, deep: bool = False) -> Dict[str, Any]: ... +class RobustMolTransformer(BaseEstimator, TransformerMixin): + def __init__(self, n_jobs=-1): + self.n_jobs = n_jobs + + def fit(self, X, y=None): + return self + + def transform(self, X): + results = Parallel(n_jobs=self.n_jobs)( + delayed(_mol_from_smiles_robust)(s) for s in X + ) + return results + # ----------------------------- # Public entry point # ----------------------------- @@ -91,6 +109,7 @@ def compute_fingerprints( config: FingerprintConfig = FingerprintConfig(), *, show_progress: bool = False, + n_jobs: int = -1, ) -> FingerprintResult: """ Compute fingerprints for a sequence of SMILES. @@ -114,10 +133,10 @@ def compute_fingerprints( _quick_smiles_check(smiles) if _looks_like_rdkit_fpgen(fpgen): - return _compute_rdkit(smiles, fpgen, config, show_progress=show_progress) + return _compute_rdkit(smiles, fpgen, config, show_progress=show_progress, n_jobs=n_jobs) if _looks_like_sklearn_transformer(fpgen): - return _compute_sklearn(smiles, fpgen, config, show_progress=show_progress) + return _compute_sklearn(smiles, fpgen, config, show_progress=show_progress, n_jobs=n_jobs) raise TypeError( "Unsupported fpgen. Expected an RDKit rdFingerprintGenerator-like object " @@ -254,6 +273,28 @@ def _mol_from_smiles_robust(smiles: str) -> Optional["Chem.Mol"]: return mol +def _compute_mols_parallel(smiles: Sequence[str], n_jobs: int, show_progress: bool) -> List[Optional["Chem.Mol"]]: + """ + Compute RDKit molecules from SMILES in parallel. + """ + if n_jobs == 1: + return [ + _mol_from_smiles_robust(s) for s in tqdm(smiles, disable=not show_progress, desc="Generating molecules") + ] + + results = Parallel(n_jobs=n_jobs, batch_size="auto")( + delayed(_mol_from_smiles_robust)(s) + for s in tqdm( + smiles, + total=len(smiles), + desc="Generating molecules (Parallel)", + disable=not show_progress + ) + ) + + return results + + def _infer_fp_size_folded(fpgen: Any, mol: "Chem.Mol", count: bool) -> int: """ Infer folded vector length for RDKit generator from a molecule. @@ -271,14 +312,15 @@ def _compute_rdkit( cfg: FingerprintConfig, *, show_progress: bool, + n_jobs: int, ) -> FingerprintResult: if not cfg.folded: - return _rdkit_unfolded(smiles, fpgen, cfg, show_progress=show_progress) + return _rdkit_unfolded(smiles, fpgen, cfg, show_progress=show_progress, n_jobs=n_jobs) if cfg.return_csr: - return _rdkit_folded_csr(smiles, fpgen, cfg, show_progress=show_progress) + return _rdkit_folded_csr(smiles, fpgen, cfg, show_progress=show_progress, n_jobs=n_jobs) - return _rdkit_folded_dense(smiles, fpgen, cfg, show_progress=show_progress) + return _rdkit_folded_dense(smiles, fpgen, cfg, show_progress=show_progress, n_jobs=n_jobs) def _rdkit_unfolded( @@ -287,6 +329,7 @@ def _rdkit_unfolded( cfg: FingerprintConfig, *, show_progress: bool, + n_jobs: int, ) -> FingerprintResult: """ Unfolded output for RDKit: use fpgen.GetSparseCountFingerprint(mol) to obtain feature IDs. @@ -294,10 +337,11 @@ def _rdkit_unfolded( - count=False: List[np.ndarray[int64]] feature IDs - count=True : List[(keys:int64, vals:float32)] feature IDs + counts (optionally scaled/weighted) """ + mols = _compute_mols_parallel(smiles, n_jobs, show_progress) + if cfg.count: out: UnfoldedCount = [] - for s in tqdm(smiles, disable=(not show_progress)): - mol = _mol_from_smiles_robust(s) + for s, mol in zip(smiles, mols): if mol is None: _handle_invalid(cfg.invalid_policy, s) if cfg.invalid_policy == "keep": @@ -314,8 +358,7 @@ def _rdkit_unfolded( return out out: UnfoldedBinary = [] - for s in tqdm(smiles, disable=(not show_progress)): - mol = _mol_from_smiles_robust(s) + for s, mol in zip(smiles, mols): if mol is None: _handle_invalid(cfg.invalid_policy, s) if cfg.invalid_policy == "keep": @@ -335,16 +378,17 @@ def _rdkit_folded_dense( cfg: FingerprintConfig, *, show_progress: bool, + n_jobs: int, ) -> np.ndarray: """ Dense folded output (N, D) float32 for RDKit generators. """ + mols = _compute_mols_parallel(smiles, n_jobs, show_progress) rows: List[np.ndarray] = [] n_features: Optional[int] = None pending_invalid: List[int] = [] # indices in `rows` that need backfill after we learn D - for s in tqdm(smiles, disable=(not show_progress)): - mol = _mol_from_smiles_robust(s) + for s, mol in zip(smiles, mols): if mol is None: _handle_invalid(cfg.invalid_policy, s) if cfg.invalid_policy == "keep": @@ -385,6 +429,7 @@ def _rdkit_folded_csr( cfg: FingerprintConfig, *, show_progress: bool, + n_jobs: int, ) -> sp.csr_matrix: """ Folded CSR output for RDKit generators. @@ -396,6 +441,7 @@ def _rdkit_folded_csr( - keep: row is kept as all-zeros (output aligned to input) - raise: raises ValueError """ + mols = _compute_mols_parallel(smiles, n_jobs, show_progress) n_features: Optional[int] = None idx_chunks: List[np.ndarray] = [] @@ -406,8 +452,7 @@ def _rdkit_folded_csr( if cfg.folded_weights is not None: w = np.asarray(cfg.folded_weights, dtype=np.float32).ravel() - for s in tqdm(smiles, disable=(not show_progress)): - mol = _mol_from_smiles_robust(s) + for s, mol in zip(smiles, mols): if mol is None: _handle_invalid(cfg.invalid_policy, s) @@ -482,6 +527,7 @@ def _skfp_configure_output( cfg: FingerprintConfig, *, show_progress: bool, + n_jobs: int, ) -> SklearnTransformer: """ Configure scikit-fingerprints/sklearn transformer to match (folded, return_csr). @@ -496,6 +542,9 @@ def _skfp_configure_output( if "verbose" in params: updates["verbose"] = 1 if show_progress else 0 + if "n_jobs" in params: + updates["n_jobs"] = n_jobs + if not cfg.folded: if "variant" not in params: raise NotImplementedError( @@ -526,9 +575,14 @@ def _compute_sklearn( cfg: FingerprintConfig, *, show_progress: bool = False, + n_jobs: int, ) -> FingerprintResult: - fp = _skfp_configure_output(fpgen, cfg, show_progress=show_progress) - X = fp.transform(smiles) + fp = _skfp_configure_output(fpgen, cfg, show_progress=show_progress, n_jobs=n_jobs) + mol_transformer = RobustMolTransformer(n_jobs=n_jobs) + mols = mol_transformer.transform(smiles) + valid_mols = [m for m in mols if m is not None] + fp.fit(valid_mols) + X = fp.transform(valid_mols) if not cfg.folded: # unfolded output diff --git a/pyproject.toml b/pyproject.toml index 81857e1..af2bc2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "scikit-fingerprints>=1.15.0", "tqdm>=4.67.1", "pooch>=1.8.2", + "joblib>=1.3.2", ] [dependency-groups] diff --git a/tests/test_fingerprint_computation.py b/tests/test_fingerprint_computation.py index dcc9809..93fc6a7 100644 --- a/tests/test_fingerprint_computation.py +++ b/tests/test_fingerprint_computation.py @@ -2,6 +2,7 @@ import numpy as np import pytest import scipy.sparse as sp +from sklearn.base import BaseEstimator, TransformerMixin from chemap import FingerprintConfig, compute_fingerprints @@ -95,7 +96,7 @@ class DummyUnsupported: # Clone-safe sklearn/scikit-fingerprints transformer fakes # ----------------------------------------------------------------------------- -class FakeTransformer: +class FakeTransformer(BaseEstimator, TransformerMixin): """ Clone-safe sklearn/scikit-fingerprints-like transformer fake. @@ -120,6 +121,7 @@ def __init__( variant: str | None = None, mode: str = "onehot", n_features: int = 6, + n_jobs: int = 1, ): self._params = { "sparse": sparse, @@ -127,8 +129,12 @@ def __init__( "variant": variant, "mode": mode, "n_features": int(n_features), + "n_jobs": n_jobs, } + def fit(self, X, y=None): + return self + def get_params(self, deep: bool = False): return dict(self._params) @@ -398,7 +404,7 @@ def test_sklearn_folded_dense_scaling_log_applies_when_count_true(): ) cfg = FingerprintConfig(count=True, folded=True, return_csr=False, scaling="log") - X = compute_fingerprints(["A", "B"], fp, cfg) + X = compute_fingerprints(["C", "CC"], fp, cfg) expected = np.log1p(np.array([[0, 2], [3, 0]], dtype=np.float32)).astype(np.float32) np.testing.assert_allclose(X, expected, rtol=1e-6, atol=1e-6) @@ -412,7 +418,7 @@ def test_sklearn_folded_dense_weights_applies(): w = np.array([1.0, 10.0, 0.5], dtype=np.float32) cfg = FingerprintConfig(count=True, folded=True, return_csr=False, folded_weights=w) - X = compute_fingerprints(["A"], fp, cfg) + X = compute_fingerprints(["C"], fp, cfg) np.testing.assert_allclose(X[0], np.array([1, 20, 1.5], dtype=np.float32), rtol=1e-6, atol=1e-6) @@ -433,7 +439,7 @@ def test_sklearn_unfolded_sets_variant_raw_bits_and_returns_unfolded_binary(): ) cfg = FingerprintConfig(count=False, folded=False) - out = compute_fingerprints(["A"], fp, cfg) + out = compute_fingerprints(["C"], fp, cfg) assert isinstance(out, list) assert out[0].dtype == np.int64 assert list(out[0]) == [1, 4] @@ -448,7 +454,7 @@ def test_sklearn_unfolded_count_scaling_and_unfolded_weights(): ) cfg = FingerprintConfig(count=True, folded=False, scaling="log", unfolded_weights={2: 10.0}) - out = compute_fingerprints(["A"], fp, cfg) + out = compute_fingerprints(["C"], fp, cfg) keys, vals = out[0] assert list(keys) == [2, 5] expected = np.log1p(np.array([4.0, 2.0], dtype=np.float32)) * np.array([10.0, 1.0], dtype=np.float32)