Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 69 additions & 15 deletions chemap/fingerprint_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Union
import numpy as np
import scipy.sparse as sp
from joblib import Parallel, delayed
from rdkit import Chem
from sklearn.base import BaseEstimator, TransformerMixin
from tqdm import tqdm


Expand Down Expand Up @@ -74,13 +76,29 @@ class FingerprintConfig:
class SklearnTransformer(Protocol):
"""Protocol for sklearn-like fingerprint transformers (including scikit-fingerprints)."""

def fit(self, X: Any, y: Any = None) -> "SklearnTransformer":
...

def transform(self, X: Sequence[str]) -> Any:
...

def get_params(self, deep: bool = False) -> Dict[str, Any]:
...


class RobustMolTransformer(BaseEstimator, TransformerMixin):
def __init__(self, n_jobs=-1):
self.n_jobs = n_jobs

def fit(self, X, y=None):
return self

def transform(self, X):
results = Parallel(n_jobs=self.n_jobs)(
delayed(_mol_from_smiles_robust)(s) for s in X
)
return results

# -----------------------------
# Public entry point
# -----------------------------
Expand All @@ -91,6 +109,7 @@ def compute_fingerprints(
config: FingerprintConfig = FingerprintConfig(),
*,
show_progress: bool = False,
n_jobs: int = -1,
) -> FingerprintResult:
"""
Compute fingerprints for a sequence of SMILES.
Expand All @@ -114,10 +133,10 @@ def compute_fingerprints(
_quick_smiles_check(smiles)

if _looks_like_rdkit_fpgen(fpgen):
return _compute_rdkit(smiles, fpgen, config, show_progress=show_progress)
return _compute_rdkit(smiles, fpgen, config, show_progress=show_progress, n_jobs=n_jobs)

if _looks_like_sklearn_transformer(fpgen):
return _compute_sklearn(smiles, fpgen, config, show_progress=show_progress)
return _compute_sklearn(smiles, fpgen, config, show_progress=show_progress, n_jobs=n_jobs)

raise TypeError(
"Unsupported fpgen. Expected an RDKit rdFingerprintGenerator-like object "
Expand Down Expand Up @@ -254,6 +273,28 @@ def _mol_from_smiles_robust(smiles: str) -> Optional["Chem.Mol"]:
return mol


def _compute_mols_parallel(smiles: Sequence[str], n_jobs: int, show_progress: bool) -> List[Optional["Chem.Mol"]]:
"""
Compute RDKit molecules from SMILES in parallel.
"""
if n_jobs == 1:
return [
_mol_from_smiles_robust(s) for s in tqdm(smiles, disable=not show_progress, desc="Generating molecules")
]

results = Parallel(n_jobs=n_jobs, batch_size="auto")(
delayed(_mol_from_smiles_robust)(s)
for s in tqdm(
smiles,
total=len(smiles),
desc="Generating molecules (Parallel)",
disable=not show_progress
)
)

return results


def _infer_fp_size_folded(fpgen: Any, mol: "Chem.Mol", count: bool) -> int:
"""
Infer folded vector length for RDKit generator from a molecule.
Expand All @@ -271,14 +312,15 @@ def _compute_rdkit(
cfg: FingerprintConfig,
*,
show_progress: bool,
n_jobs: int,
) -> FingerprintResult:
if not cfg.folded:
return _rdkit_unfolded(smiles, fpgen, cfg, show_progress=show_progress)
return _rdkit_unfolded(smiles, fpgen, cfg, show_progress=show_progress, n_jobs=n_jobs)

if cfg.return_csr:
return _rdkit_folded_csr(smiles, fpgen, cfg, show_progress=show_progress)
return _rdkit_folded_csr(smiles, fpgen, cfg, show_progress=show_progress, n_jobs=n_jobs)

return _rdkit_folded_dense(smiles, fpgen, cfg, show_progress=show_progress)
return _rdkit_folded_dense(smiles, fpgen, cfg, show_progress=show_progress, n_jobs=n_jobs)


def _rdkit_unfolded(
Expand All @@ -287,17 +329,19 @@ def _rdkit_unfolded(
cfg: FingerprintConfig,
*,
show_progress: bool,
n_jobs: int,
) -> FingerprintResult:
"""
Unfolded output for RDKit: use fpgen.GetSparseCountFingerprint(mol) to obtain feature IDs.

- count=False: List[np.ndarray[int64]] feature IDs
- count=True : List[(keys:int64, vals:float32)] feature IDs + counts (optionally scaled/weighted)
"""
mols = _compute_mols_parallel(smiles, n_jobs, show_progress)

if cfg.count:
out: UnfoldedCount = []
for s in tqdm(smiles, disable=(not show_progress)):
mol = _mol_from_smiles_robust(s)
for s, mol in zip(smiles, mols):
if mol is None:
_handle_invalid(cfg.invalid_policy, s)
if cfg.invalid_policy == "keep":
Expand All @@ -314,8 +358,7 @@ def _rdkit_unfolded(
return out

out: UnfoldedBinary = []
for s in tqdm(smiles, disable=(not show_progress)):
mol = _mol_from_smiles_robust(s)
for s, mol in zip(smiles, mols):
if mol is None:
_handle_invalid(cfg.invalid_policy, s)
if cfg.invalid_policy == "keep":
Expand All @@ -335,16 +378,17 @@ def _rdkit_folded_dense(
cfg: FingerprintConfig,
*,
show_progress: bool,
n_jobs: int,
) -> np.ndarray:
"""
Dense folded output (N, D) float32 for RDKit generators.
"""
mols = _compute_mols_parallel(smiles, n_jobs, show_progress)
rows: List[np.ndarray] = []
n_features: Optional[int] = None
pending_invalid: List[int] = [] # indices in `rows` that need backfill after we learn D

for s in tqdm(smiles, disable=(not show_progress)):
mol = _mol_from_smiles_robust(s)
for s, mol in zip(smiles, mols):
if mol is None:
_handle_invalid(cfg.invalid_policy, s)
if cfg.invalid_policy == "keep":
Expand Down Expand Up @@ -385,6 +429,7 @@ def _rdkit_folded_csr(
cfg: FingerprintConfig,
*,
show_progress: bool,
n_jobs: int,
) -> sp.csr_matrix:
"""
Folded CSR output for RDKit generators.
Expand All @@ -396,6 +441,7 @@ def _rdkit_folded_csr(
- keep: row is kept as all-zeros (output aligned to input)
- raise: raises ValueError
"""
mols = _compute_mols_parallel(smiles, n_jobs, show_progress)
n_features: Optional[int] = None

idx_chunks: List[np.ndarray] = []
Expand All @@ -406,8 +452,7 @@ def _rdkit_folded_csr(
if cfg.folded_weights is not None:
w = np.asarray(cfg.folded_weights, dtype=np.float32).ravel()

for s in tqdm(smiles, disable=(not show_progress)):
mol = _mol_from_smiles_robust(s)
for s, mol in zip(smiles, mols):
if mol is None:
_handle_invalid(cfg.invalid_policy, s)

Expand Down Expand Up @@ -482,6 +527,7 @@ def _skfp_configure_output(
cfg: FingerprintConfig,
*,
show_progress: bool,
n_jobs: int,
) -> SklearnTransformer:
"""
Configure scikit-fingerprints/sklearn transformer to match (folded, return_csr).
Expand All @@ -496,6 +542,9 @@ def _skfp_configure_output(
if "verbose" in params:
updates["verbose"] = 1 if show_progress else 0

if "n_jobs" in params:
updates["n_jobs"] = n_jobs

if not cfg.folded:
if "variant" not in params:
raise NotImplementedError(
Expand Down Expand Up @@ -526,9 +575,14 @@ def _compute_sklearn(
cfg: FingerprintConfig,
*,
show_progress: bool = False,
n_jobs: int,
) -> FingerprintResult:
fp = _skfp_configure_output(fpgen, cfg, show_progress=show_progress)
X = fp.transform(smiles)
fp = _skfp_configure_output(fpgen, cfg, show_progress=show_progress, n_jobs=n_jobs)
mol_transformer = RobustMolTransformer(n_jobs=n_jobs)
mols = mol_transformer.transform(smiles)
valid_mols = [m for m in mols if m is not None]
fp.fit(valid_mols)
X = fp.transform(valid_mols)

if not cfg.folded:
# unfolded output
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies = [
"scikit-fingerprints>=1.15.0",
"tqdm>=4.67.1",
"pooch>=1.8.2",
"joblib>=1.3.2",
]

[dependency-groups]
Expand Down
16 changes: 11 additions & 5 deletions tests/test_fingerprint_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pytest
import scipy.sparse as sp
from sklearn.base import BaseEstimator, TransformerMixin
from chemap import FingerprintConfig, compute_fingerprints


Expand Down Expand Up @@ -95,7 +96,7 @@ class DummyUnsupported:
# Clone-safe sklearn/scikit-fingerprints transformer fakes
# -----------------------------------------------------------------------------

class FakeTransformer:
class FakeTransformer(BaseEstimator, TransformerMixin):
"""
Clone-safe sklearn/scikit-fingerprints-like transformer fake.

Expand All @@ -120,15 +121,20 @@ def __init__(
variant: str | None = None,
mode: str = "onehot",
n_features: int = 6,
n_jobs: int = 1,
):
self._params = {
"sparse": sparse,
"verbose": verbose,
"variant": variant,
"mode": mode,
"n_features": int(n_features),
"n_jobs": n_jobs,
}

def fit(self, X, y=None):
return self

def get_params(self, deep: bool = False):
return dict(self._params)

Expand Down Expand Up @@ -398,7 +404,7 @@ def test_sklearn_folded_dense_scaling_log_applies_when_count_true():
)
cfg = FingerprintConfig(count=True, folded=True, return_csr=False, scaling="log")

X = compute_fingerprints(["A", "B"], fp, cfg)
X = compute_fingerprints(["C", "CC"], fp, cfg)
expected = np.log1p(np.array([[0, 2], [3, 0]], dtype=np.float32)).astype(np.float32)
np.testing.assert_allclose(X, expected, rtol=1e-6, atol=1e-6)

Expand All @@ -412,7 +418,7 @@ def test_sklearn_folded_dense_weights_applies():
w = np.array([1.0, 10.0, 0.5], dtype=np.float32)
cfg = FingerprintConfig(count=True, folded=True, return_csr=False, folded_weights=w)

X = compute_fingerprints(["A"], fp, cfg)
X = compute_fingerprints(["C"], fp, cfg)
np.testing.assert_allclose(X[0], np.array([1, 20, 1.5], dtype=np.float32), rtol=1e-6, atol=1e-6)


Expand All @@ -433,7 +439,7 @@ def test_sklearn_unfolded_sets_variant_raw_bits_and_returns_unfolded_binary():
)
cfg = FingerprintConfig(count=False, folded=False)

out = compute_fingerprints(["A"], fp, cfg)
out = compute_fingerprints(["C"], fp, cfg)
assert isinstance(out, list)
assert out[0].dtype == np.int64
assert list(out[0]) == [1, 4]
Expand All @@ -448,7 +454,7 @@ def test_sklearn_unfolded_count_scaling_and_unfolded_weights():
)
cfg = FingerprintConfig(count=True, folded=False, scaling="log", unfolded_weights={2: 10.0})

out = compute_fingerprints(["A"], fp, cfg)
out = compute_fingerprints(["C"], fp, cfg)
keys, vals = out[0]
assert list(keys) == [2, 5]
expected = np.log1p(np.array([4.0, 2.0], dtype=np.float32)) * np.array([10.0, 1.0], dtype=np.float32)
Expand Down