Skip to content

feat: inference optimization — new API, perf improvements, bug fixes#1

Open
isayev wants to merge 2 commits intomainfrom
inference-opt
Open

feat: inference optimization — new API, perf improvements, bug fixes#1
isayev wants to merge 2 commits intomainfrom
inference-opt

Conversation

@isayev
Copy link

@isayev isayev commented Feb 22, 2026

Summary

  • New megalodon.inference package: public generate_conformers() API accepting a list of SMILES and returning a typed ConformerGenerationResult; includes SMILES validation, identity-preserving featurization, FFD atom-count bin-packing, and variable n_confs per molecule
  • ~8 inference-path perf improvements: pre-built time tensors, cached time embeddings, precomputed attention mask, float additive bias attn_mask (Flash Attention dispatch), skip ETKDG on diffusion inputs, eliminated null-variable softmax no-ops
  • Bug fixes: ModuleDict null-interpolant check, shallow copy().clone() for PyG Data, deprecated DataLoader import, 8s pytorch-lightning transitive import, backward-compat CLI fixes, _Name preservation from SDF inputs

Changes

File Change
src/megalodon/inference/ New package (7 files: __init__, validation, featurization, generation, result, batching)
src/megalodon/models/module.py Float additive bias attn_mask; pre-built time tensors; null-interpolant fix; one-hot pre-encoding; attn_mask precompute
src/megalodon/dynamics/fn_model.py Cache time embeddings; register freqs buffer; remove duplicate layer; Z-branch warning comment
scripts/sample_conformers.py Refactored to use generate_conformers() API
scripts/benchmark_inference.py New timing + accuracy benchmark script

Benchmark (NVIDIA L40S, 20 drug-like molecules)

Mode Throughput
batch_size=4 2.2 conf/s
batch_size=16 5.7 conf/s (2.6× faster)
n_confs=5 × 20 mols 7.6 conf/s
Variable n_confs 8.2 conf/s

Accuracy on 100 generated conformers: 100/100 have valid 3D, correct atom count, and SMILES round-trip match.

Test Plan

  • Run python scripts/benchmark_inference.py --config ... --ckpt ... --dataset_root ... (full suite)
  • Run python scripts/sample_conformers.py --input "c1ccccc1" --config ... --ckpt ... --output out.sdf --n_confs 5
  • Run streamlit run app/app.py and verify app still functions

## New megalodon.inference package

Public API (`from megalodon.inference import ...`):
- `validate_smiles(smiles)` — validates SMILES before featurization; rejects
  salts, unsupported elements (LoQI vocab: 17 atoms), radicals
- `generate_conformers(smiles_list, model, cfg, n_confs, batch_size=48,
  max_atoms_per_batch)` — batched conformer generation returning a structured
  `ConformerGenerationResult` with per-SMILES conformer lists, error records,
  and `.to_sdf()` serialization; default batch_size=48 (sweep-validated optimum
  on L40S: 8.7 conf/s at 83% SM utilization, 1.2 GB peak)
- `ffd_pack_indices / pack_batches` — First-Fit-Decreasing atom-count bin-
  packing to minimise padding waste on heterogeneous molecule sets
- `ConformerGenerationResult / MoleculeProcessingError` — typed result objects

## Performance improvements (src/megalodon/models/ and dynamics/)

- Pre-build time tensors before diffusion loop (eliminates per-step alloc)
- Register `freqs` as buffer in TimestepEmbedder (CPU→GPU transfer gone)
- Cache time embeddings per discrete timestep in MegaFNV3Conf (24/25 MLP
  calls eliminated per sample)
- Precompute attention mask once before diffusion loop (25 recomputations
  eliminated)
- Pre-encode null variable one-hots before sample loop (25× redundant
  F.one_hot calls eliminated)
- Skip softmax for discrete_null pass-through logits in
  `separate_discrete_variables`
- Convert attn_mask to float additive bias (0.0 / -inf) enabling efficient
  Flash Attention dispatch
- Skip ETKDG 3D embedding in app inference path (coordinates are overwritten
  by diffusion prior anyway)

## Bug fixes

- batch_preprocessor argument typo in sample_conformers.py
- Duplicate `lin_edge1` layer definition in fn_model.py
- Stray `torch.max` expressions with discarded results in fn_model.py
- `ModuleDict[key] is None` check (was `.get()` / `not in`, both wrong for
  nn.ModuleDict with None values)
- `DataLoader` import moved from deprecated `torch_geometric.data` to
  `torch_geometric.loader`
- `copy(base_data)` → `base_data.clone()` in featurization (shallow copy
  shared tensor storage, causing in-place mutation bugs)
- Inline `_ATOM_ENCODER` to avoid 8s pytorch-lightning transitive import
- `Chem.SetUseLegacyStereoPerception(True)` in package `__init__.py` to match
  training-time stereo assignment
- Restore `--skip_eval` CLI arg as no-op for backward compat
- Preserve `_Name` from SDF mol inputs in pickle output IDs
- Add warning comment on Z-branch float-mask incompatibility in fn_model.py

## Scripts / tooling

- `scripts/sample_conformers.py` refactored to use `generate_conformers()` API
- `scripts/benchmark_inference.py` — timing + accuracy benchmark (20 curated
  drug-like molecules, batch sweep, FFD vs fixed, SMILES round-trip check)
- `scripts/sustained_perf_test.py` — large-scale sustained-load test using
  real ChEMBL3D test-set SMILES with stratified size sampling
- `scripts/batch_size_sweep.py` — batch-size sweep with live nvidia-smi GPU
  SM% / memory-BW% sampling; identifies throughput knee and efficiency optimum
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant