dpgmm is a library implementing high-performance MCMC sampler for Dirichlet Process Gaussian Mixture Models (DPGMM). Built on PyTorch and accelerated with Triton kernels, it is designed to handle high-dimensional data efficiently.
High Performance: Optimized Gibbs sampling leveraging GPU acceleration via PyTorch and Triton kernels.
Data Generation: Built-in utilities to generate high-dimensional synthetic datasets for validation.
Observability: Native integration with Weights & Biases for experiment tracking.
Metrics: Comprehensive tools for calculating assignment log likelihood, data complexity, and data dimensions entanglement.
Modern Stack: Developed using modern Python tooling (uv, ruff, pytest).
The package is available on PyPI:
pip install dpgmmFor triton accelerated experience, you need to install it triton manually. For now, for development on HPC clusters, we recommend building this package from source using uv or pip.
Initialize the generator and the Gibbs sampler.
import torch
from dpgmm.datasets import GaussianDataGenerator
from dpgmm.samplers import FullCovarianceCollapsedGibbsSampler, DiagCovarianceCollapsedGibbsSampler
# 1. Generate synthetic data
data_generator = GaussianDataGenerator(cov_type="full")
data_payload = data_generator.generate(n_points=256, data_dim=2, num_components=4)
data_tensor = torch.as_tensor(data_payload["data"])
# 2. Initialize the Sampler
sampler = FullCovarianceCollapsedGibbsSampler(
init_strategy="init_data_stats",
max_clusters_num=10,
batch_size=1
)
# 3. Fit the model
result = sampler.fit(iterations_num=100, data=data_tensor)
# Access results
cluster_params = result["cluster_params"]
cluster_assignment = result["cluster_assignment"]
alpha = result["alpha"]Visualize the clusters, covariance matrices, and assignments.
from dpgmm.visualisation import ClusterParamsVisualizer
data_visualizer = ClusterParamsVisualizer()
data_visualizer.plot_params_full_covariance(
data_payload["data"],
centers=cluster_params["mean"],
cov_chol=cluster_params["cov_chol"],
assignment=cluster_assignment,
trace_alpha=alpha,
)You can save checkpoints during training and resume from them later.
# To save during training, specify an out_dir
sampler.fit(iterations_num=25, data=data_tensor, out_dir="out/save_and_load")
# To resume, pass the path to the snapshot directory in kwargs
additional_kwargs = {"restore_snapshot_pkl_path": "out/save_and_load/"}
sampler_restored = FullCovarianceCollapsedGibbsSampler(
init_strategy="init_data_stats",
max_clusters_num=10,
batch_size=1,
**additional_kwargs,
)Estimate entropy from sampling versus data to gauge model fit.
from dpgmm.metrics import ComplexityFromTraceEstimator
estimator = ComplexityFromTraceEstimator(
trace_path="/path/to/results/cgs_19.pkl",
data_trace_path="/path/to/results/cgs_0.pkl",
samples_num=100_000,
)
entropy_sampled = estimator.estimate_entropy_with_sampling()
entropy_data = estimator.estimate_entropy_on_data(data_tensor)
print(f"Entropy from sampling: {entropy_sampled}")
print(f"Entropy on data: {entropy_data}")Calculate the KL divergence between joint and product marginals to measure feature entanglement.
from dpgmm.metrics import EntanglementFromTraceEstimator
estimator = EntanglementFromTraceEstimator(
trace_path="/path/to/results/cgs_99.pkl",
samples_num=100_000
)
dkl_joint_prod = estimator.calculate_joint_and_prod_dkl()
dkl_symmetric = estimator.calculate_symmetric_dkl()
print(f"KL(Joint || Marginals Prod): {dkl_joint_prod:.4f}")The sampler supports W&B out of the box for tracking loss curves, cluster evolution, and system metrics. To enable experiment tracking, just make sure to export WAND_API_KEY environment variable.
export WANDB_API_KEY=your_key_hereThanks to Triton kernels, dpgmm achieves significant speedups compared to standard implementations, especially in high-dimensional experiments. The following table showcases the average iteration time (in seconds) for
| Data dim | PyTorch CPU [s] | Optimized GPU [s] | Speedup |
|---|---|---|---|
| 128 | |||
| 256 | |||
| 512 | |||
| 1024 | |||
| 2048 | |||
| 4096 | |||
| 8192 |
This project uses uv for dependency management and Task (go-task) for orchestrating development workflows.
# Install dependencies
uv sync
# Install pre-commit hooks
uv run pre-commit installA Taskfile.yml is provided to simplify common development command - use the following commands:
# Run linter and formatter (Ruff)
uv run task lint
# Run security audits (Bandit & Safety)
uv run task audit
# Check code complexity (Xenon)
uv run task complexity
# Run all quality and security checks
uv run task check-all
# Build documentation
uv run task build-docs
# Run all tests
uv run task run-tests