Utilities for embedding experiments with cross-platform array support
embedding_tools provides a backend-agnostic interface for working with embeddings across NumPy, MLX (Apple Silicon), and PyTorch. It includes memory management, configuration versioning, and similarity search utilities optimized for machine learning research.
- 🔄 Backend Abstraction: Seamlessly switch between NumPy, MLX, and PyTorch
- 💾 Memory Management: Track and limit memory usage with
EmbeddingStore - 🔍 Similarity Search: Built-in cosine similarity and nearest neighbor search
- 📦 Dimension Slicing: Efficient truncation for Matryoshka embeddings
- 🔐 Configuration Versioning: SHA-256 hashing for reproducible experiments
- 🍎 Apple Silicon Optimized: Native MLX support for M-series Macs
# Core (NumPy only)
pip install embedding_tools
# With MLX (Apple Silicon)
pip install embedding_tools[mlx]
# With PyTorch
pip install embedding_tools[torch]
# With JAX (GPU/TPU support)
pip install embedding_tools[jax]
# Everything
pip install embedding_tools[all]To contribute or use the latest development version:
git clone https://github.com/nborwankar/embedding_tools.git
cd embedding_tools
pip install -e ".[dev]"from embedding_tools import get_backend
# Auto-detect best available backend
backend = get_backend() # Uses MLX > JAX > PyTorch > NumPy
# Or specify explicitly
backend = get_backend('numpy') # CPU
backend = get_backend('mlx') # Apple Silicon GPU (fastest on Mac)
backend = get_backend('jax') # GPU/TPU with JIT compilation
backend = get_backend('torch') # PyTorch (CUDA/MPS/CPU)
# Create arrays
embeddings = backend.create_array([[1, 2, 3], [4, 5, 6]])
# Compute similarities
query = backend.create_array([1, 2, 3])
sims = backend.cosine_similarity(query, embeddings)
# Slice to lower dimensions (for Matryoshka embeddings)
truncated = backend.slice_last_dim(embeddings, dim=2)from embedding_tools import EmbeddingStore
import numpy as np
# Create store with memory limit
store = EmbeddingStore(backend='mlx', max_memory_gb=10.0)
# Add embeddings
embeddings_1024d = np.random.randn(10000, 1024).astype(np.float32)
store.add_embeddings(embeddings_1024d, dimension=1024)
# Slice to lower dimensions (Matryoshka)
embeddings_128d = store.slice_to_dimension(source_dim=1024, target_dim=128)
# Similarity search
query = np.random.randn(1024).astype(np.float32)
similarities, indices = store.compute_similarity(
query,
dimension=1024,
top_k=10
)
# Check memory usage
info = store.get_memory_info()
print(f"Total memory: {info['total_gb']:.2f} GB")from embedding_tools import compute_config_hash, compute_param_hash
# Hash a configuration dict
config = {
'model': 'sentence-transformers/all-MiniLM-L6-v2',
'dimension': 384,
'batch_size': 32
}
hash_val = compute_config_hash(config) # Returns 16-char hex string
# Or use keyword arguments
hash_val = compute_param_hash(
model='all-MiniLM-L6-v2',
dimension=384,
batch_size=32
)
# Use for automatic cache invalidation
cache_key = f"embeddings_{hash_val}.npz"| Backend | Hardware | Speed | Memory | Installation |
|---|---|---|---|---|
| NumPy | CPU | 1x | System RAM | pip install embedding_tools |
| MLX | Apple Silicon GPU | 3-5x | Unified memory | pip install embedding_tools[mlx] |
| JAX | GPU/TPU (Metal/CUDA/ROCm) | 5-10x* | GPU VRAM | pip install embedding_tools[jax] |
| PyTorch | CUDA/MPS/CPU | 2-4x | GPU VRAM | pip install embedding_tools[torch] |
*Speed with JIT compilation on repeated operations
Device Options for PyTorch:
device='cuda': NVIDIA GPUs (Linux/Windows)device='mps': Apple Silicon GPU (macOS)device='cpu': CPU fallback (all platforms)
Device Options for JAX:
device='gpu': GPU acceleration (Metal/CUDA/ROCm)device='cpu': CPU fallbackdevice=None: Auto-detection (recommended)
# Explicit device configuration
from embedding_tools import get_backend, EmbeddingStore
# PyTorch: CUDA for NVIDIA GPUs (Linux production)
backend = get_backend('torch', device='cuda')
store = EmbeddingStore(backend='torch', max_memory_gb=40.0, device='cuda')
# PyTorch: MPS for Apple Silicon
backend = get_backend('torch', device='mps')
store = EmbeddingStore(backend='torch', max_memory_gb=20.0, device='mps')
# JAX: GPU acceleration (auto-detects Metal/CUDA/ROCm)
backend = get_backend('jax', device='gpu')
store = EmbeddingStore(backend='jax', max_memory_gb=20.0, device='gpu')
# Auto-detection (recommended)
backend = get_backend('torch') # Automatically picks best device
backend = get_backend('jax') # Automatically picks best deviceRun validation tests after installation:
pytest tests/test_installation.py -vOr run directly:
python tests/test_installation.pyExpected output:
============================================================
embedding_tools Installation Validation Summary
============================================================
Version: 0.1.0
NumPy backend: ✓ Available
MLX backend: ✓ Available
Auto-detected backend: MLXBackend
All core functionality tests passed!
============================================================
# Clone repository
git clone https://github.com/nborwankar/embedding_tools.git
cd embedding_tools
# Install in development mode
pip install -e ".[dev]"
# Run tests
pytest tests/ -v
# Format code
black .
isort .
# Lint
flake8 embedding_tools/Get array backend instance.
Parameters:
backend_name(str, optional): 'numpy', 'mlx', 'jax', or 'torch'. Auto-detects if None.device(str, optional): Device specification for JAX/PyTorch backends. Auto-detects if None.
Returns: ArrayBackend instance
create_array(data, dtype=None)- Create array from datazeros(shape, dtype=None)- Create zero-filled arrayones(shape, dtype=None)- Create one-filled arrayrandom_normal(shape, mean=0.0, std=1.0)- Random normal arraydot(a, b)- Dot productcosine_similarity(a, b)- Cosine similarity matrixnormalize(a, axis=-1)- L2 normalizationconcatenate(arrays, axis=0)- Concatenate arraysstack(arrays, axis=0)- Stack arraysslice_last_dim(array, dim)- Slice to dimensionto_numpy(array)- Convert to NumPyfrom_numpy(array)- Convert from NumPysave(array, filepath)- Save to fileload(filepath)- Load from fileget_memory_usage(array)- Memory in bytesget_shape(array)- Array shapeget_dtype(array)- Array dtype
In-memory embedding storage with memory limits.
Methods:
add_embeddings(embeddings, dimension, text_ids=None, labels=None, metadata=None)get_embeddings(dimension)- Retrieve embeddingsslice_to_dimension(source_dim, target_dim)- Matryoshka slicingcompute_similarity(query_emb, dimension, top_k=None)- Similarity searchget_available_dimensions()- List stored dimensionsget_total_memory_usage()- Total memory in bytesget_memory_info()- Detailed memory statisticssave_to_disk(directory)- Save all embeddingsload_from_disk(directory)- Load all embeddings
Compute SHA-256 hash of configuration dictionary.
Parameters:
config(dict): Configuration dictionary
Returns: 16-character hex string
Convenience function for hashing keyword arguments.
Returns: 16-character hex string
from embedding_tools import EmbeddingStore, get_backend
backend = get_backend('mlx')
store = EmbeddingStore(backend='mlx', max_memory_gb=20)
# Train model to produce 1024D embeddings
full_embeddings = model.encode(documents) # (N, 1024)
store.add_embeddings(full_embeddings, dimension=1024)
# Get truncated versions for different use cases
embeddings_512 = store.slice_to_dimension(1024, 512) # Moderate accuracy
embeddings_128 = store.slice_to_dimension(1024, 128) # Fast search
embeddings_32 = store.slice_to_dimension(1024, 32) # Ultra-fast
# Compare at different dimensions
for dim in [32, 128, 512, 1024]:
sims, indices = store.compute_similarity(query, dim, top_k=10)
print(f"{dim}D recall@10: {compute_recall(indices, ground_truth)}")from embedding_tools import get_backend
# Development on Mac (uses MLX for speed)
if platform.system() == 'Darwin':
backend = get_backend('mlx')
# Production on Linux (uses NumPy or CUDA)
else:
backend = get_backend('numpy')
# Same code works everywhere
embeddings = backend.create_array(data)
similarities = backend.cosine_similarity(query, embeddings)from embedding_tools import compute_param_hash
import os
# Compute hash of experiment parameters
exp_hash = compute_param_hash(
model='all-MiniLM-L6-v2',
chunk_size=512,
overlap=50,
dimension=384
)
# Check if results exist
results_file = f'results_{exp_hash}.json'
if os.path.exists(results_file):
print("Loading cached results...")
results = load_results(results_file)
else:
print("Running new experiment...")
results = run_experiment()
save_results(results, results_file)MIT License - see LICENSE file for details.
Contributions welcome! Please read CONTRIBUTING.md for guidelines.
If you use embedding_tools in your research, please cite:
@software{embedding_tools2025,
title = {embedding_tools: Utilities for embedding experiments},
author = {Nitin Borwankar},
year = {2025},
url = {https://github.com/nborwankar/embedding_tools}
}