Official implementation of MuViT (CVPR 2026), a vision transformer architecture that fuses true multi-resolution observations from the same image within a single encoder, for large-scale microscopy analysis. For technical details check the preprint.
This repository contains the implementation of the MuViT architecture, along with the multi-resolution Masked Autoencoder (MAE) pre-training framework.
Modern microscopy routinely produces gigapixel images containing structures across multiple spatial scales, from fine cellular morphology to broader tissue organization. A central challenge in analyzing these images is that models must trade off effective context against spatial resolution. Standard CNNs or ViTs typically operate on single-resolution crops, building hierarchical feature pyramids from a single view.
To tackle this, MuViT jointly processes crops of the same image at different physical resolutions within a unified encoder. All patches are embedded into a shared world-coordinate system via rotary positional embeddings (RoPE), ensuring that the same physical location receives the same positional encoding across scales. This enables cross-resolution attention, allowing integration of wide-field context (e.g. anatomical) with high-resolution detail (e.g. cellular) for dense prediction tasks like segmentation.
Furthermore, MuViT extends the Masked Autoencoder (MAE) pre-training framework to a multi-resolution setting to learn powerful representations from unlabeled large-scale data. This produces highly informative, scale-consistent features that substantially accelerate convergence and improve sample efficiency on downstream tasks.
The software is compatible with Python versions 3.11, 3.12 and 3.13. It can be easily installed using pip:
pip install muvitNote: MuViT has been tested on Linux and MacOS systems. For MacOS, GPU-accelerated MAE pre-training is disabled due to torch's lack of support for Dirichlet sampling, but the encoder can still be used with MPS for feature extraction and downstream tasks. We highly recommend pre-training on a Linux machine with a CUDA-capable GPU.
All PyTorch datasets to be used for MuViT should inherit from muvit.data.MuViTDataset, which will run sanity checks on e.g. the output format to ensure consistency. It requires implementing the following methods and properties (check the implementation of the MuViTDataset class for more details):
from muvit.data import MuViTDataset
class MyMuViTDataset(MuViTDataset):
def __init__(self):
pass
def __len__(self) -> int:
# number of samples in the dataset
return 42 # change accordingly
@property
def n_channels(self) -> int:
# number of channels in the input images
return 1 # change accordingly
@property
def levels(self) -> tuple[int, ...]:
# return resolution levels (in ascending order)
return (1,8,32) # change accordingly
@property
def ndim(self) -> int:
# returns number of spatial dimensions
return 2 # change accordingly
def __getitem__(self, idx) -> dict:
# should return a dictionary like
return {
"img": img, # torch tensor of shape (L,C,Y,X)
"bbox": bbox, # torch tensor of shape (L,2,Nd) where Nd is the number of spatial dimensions (e.g. 2)
} The bbox (bounding box) tensor defines the exact physical extent (field of view) of each image crop within a shared world-coordinate system, which we define as the highest resolution pixel space. For a single dataset sample, it must have the shape [y_min, x_min]) and index 1 contains the maximum coordinates (bottom-right, i.e., [y_max, x_max]). Providing them as accurately as possible is crucial, as MuViT relies on them to geometrically align the different resolutions.
In order to pre-train an MAE model on your created dataset, you can simply instantiate the MuViTMAE2d class and pass the dataloaders to its .fit method. Most of the parameters are customizable (e.g. number of layers, patch size, etc.). For more information please check the implementation of the MuViTMAE2d class. We use PyTorch Lightning to handle the training logic.
For example:
import torch
from muvit.data import MuViTDataset
from muvit.mae import MuViTMAE2d
class MyMuViTDataset(MuViTDataset):
# implement the dataset as shown above
pass
train_ds = MyMuViTDataset(args1)
val_ds = MyMuViTDataset(args2)
model = MuViTMAE2d(
in_channels=train_ds.n_channels,
levels=train_ds.levels,
patch_size=8,
num_layers=12,
num_layers_decoder=4,
... # other parameters
)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=16, shuffle=False)
model.fit(train_dl, val_dl, output="/path/to/pretrained", num_epochs=100, ...)After pre-training the MAE model, you can use the encoder for downstream tasks or feature extraction. To get the encoder from the MAE pre-trained model, you can simply load it using our helper function and access it via the encoder attribute:
from muvit.mae import MuViTMAE2d
encoder = MuViTMAE2d.from_folder("/path/to/pretrained").encoderwhich returns a MuViTEncoder PyTorch module that is pluggable into any downstream pipeline. The encoder expects an input tensor of shape
The method compute_features() of an encoder will run a forward pass on a given multi-scale tensor and corresponding bounding boxes and return the features in a spatially structured format
If you use this code for your research, please cite the following article:
@misc{dominguezmantes2026muvit,
title={MuViT: Multi-Resolution Vision Transformers for Learning Across Scales in Microscopy},
author={Albert Dominguez Mantes and Gioele La Manno and Martin Weigert},
year={2026},
eprint={2602.24222},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2602.24222},
}

