Skip to content

Official implementation of MuViT (CVPR 2026), a ViT-based architecture for multi-scale modelling of gigapixel microscopy images.

License

Notifications You must be signed in to change notification settings

weigertlab/muvit

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

54 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

logo

paper PyPI Python Build Downloads License

MuViT: Multi-Resolution Vision Transformers for Learning Across Scales in Microscopy

Official implementation of MuViT (CVPR 2026), a vision transformer architecture that fuses true multi-resolution observations from the same image within a single encoder, for large-scale microscopy analysis. For technical details check the preprint.

This repository contains the implementation of the MuViT architecture, along with the multi-resolution Masked Autoencoder (MAE) pre-training framework.

Overview

Fig overview

Modern microscopy routinely produces gigapixel images containing structures across multiple spatial scales, from fine cellular morphology to broader tissue organization. A central challenge in analyzing these images is that models must trade off effective context against spatial resolution. Standard CNNs or ViTs typically operate on single-resolution crops, building hierarchical feature pyramids from a single view.

To tackle this, MuViT jointly processes crops of the same image at different physical resolutions within a unified encoder. All patches are embedded into a shared world-coordinate system via rotary positional embeddings (RoPE), ensuring that the same physical location receives the same positional encoding across scales. This enables cross-resolution attention, allowing integration of wide-field context (e.g. anatomical) with high-resolution detail (e.g. cellular) for dense prediction tasks like segmentation.

Furthermore, MuViT extends the Masked Autoencoder (MAE) pre-training framework to a multi-resolution setting to learn powerful representations from unlabeled large-scale data. This produces highly informative, scale-consistent features that substantially accelerate convergence and improve sample efficiency on downstream tasks.

Installation

The software is compatible with Python versions 3.11, 3.12 and 3.13. It can be easily installed using pip:

pip install muvit

Note: MuViT has been tested on Linux and MacOS systems. For MacOS, GPU-accelerated MAE pre-training is disabled due to torch's lack of support for Dirichlet sampling, but the encoder can still be used with MPS for feature extraction and downstream tasks. We highly recommend pre-training on a Linux machine with a CUDA-capable GPU.

Usage

Creating a MuViT dataset

All PyTorch datasets to be used for MuViT should inherit from muvit.data.MuViTDataset, which will run sanity checks on e.g. the output format to ensure consistency. It requires implementing the following methods and properties (check the implementation of the MuViTDataset class for more details):

from muvit.data import MuViTDataset

class MyMuViTDataset(MuViTDataset):
    def __init__(self):
        pass

    def __len__(self) -> int:
        # number of samples in the dataset
        return 42 # change accordingly

    @property
    def n_channels(self) -> int:
        # number of channels in the input images
        return 1 # change accordingly

    @property
    def levels(self) -> tuple[int, ...]:
        # return resolution levels (in ascending order)
        return (1,8,32) # change accordingly

    @property
    def ndim(self) -> int:
        # returns number of spatial dimensions
        return 2 # change accordingly

    def __getitem__(self, idx) -> dict:
        # should return a dictionary like
        return {
            "img": img, # torch tensor of shape (L,C,Y,X)
            "bbox": bbox, # torch tensor of shape (L,2,Nd) where Nd is the number of spatial dimensions (e.g. 2)
        } 

Bounding box format

The bbox (bounding box) tensor defines the exact physical extent (field of view) of each image crop within a shared world-coordinate system, which we define as the highest resolution pixel space. For a single dataset sample, it must have the shape $(L, 2, N_d)$, where $L$ is the number of resolution levels and $N_d$ is the number of spatial dimensions (e.g., 2). The second dimension, always of size 2 represents the boundaries of the crop: index 0 contains the minimum coordinates (top-left, i.e., [y_min, x_min]) and index 1 contains the maximum coordinates (bottom-right, i.e., [y_max, x_max]). Providing them as accurately as possible is crucial, as MuViT relies on them to geometrically align the different resolutions.

Multiscale MAE pre-training

In order to pre-train an MAE model on your created dataset, you can simply instantiate the MuViTMAE2d class and pass the dataloaders to its .fit method. Most of the parameters are customizable (e.g. number of layers, patch size, etc.). For more information please check the implementation of the MuViTMAE2d class. We use PyTorch Lightning to handle the training logic. For example:

import torch

from muvit.data import MuViTDataset
from muvit.mae import MuViTMAE2d

class MyMuViTDataset(MuViTDataset):
    # implement the dataset as shown above
    pass

train_ds = MyMuViTDataset(args1)
val_ds = MyMuViTDataset(args2)

model = MuViTMAE2d(
    in_channels=train_ds.n_channels,
    levels=train_ds.levels,
    patch_size=8,
    num_layers=12,
    num_layers_decoder=4,
    ... # other parameters
)

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=16, shuffle=False)
model.fit(train_dl, val_dl, output="/path/to/pretrained", num_epochs=100, ...)

Using a pre-trained encoder

After pre-training the MAE model, you can use the encoder for downstream tasks or feature extraction. To get the encoder from the MAE pre-trained model, you can simply load it using our helper function and access it via the encoder attribute:

from muvit.mae import MuViTMAE2d

encoder = MuViTMAE2d.from_folder("/path/to/pretrained").encoder

which returns a MuViTEncoder PyTorch module that is pluggable into any downstream pipeline. The encoder expects an input tensor of shape $(B,L,C,Y,X)$ (where $L$ denotes the number of resolution levels) along with the world coordinates, which are given as a "bounding-box" tensor of shape $(B,L,2,2)$ (for 2D). Note that not giving an explicit bounding box might cause undefined behaviour. The output of an encoder is a tensor of shape $(B,N,D)$ where $N$ is the number of tokens and $D$ is the embedding dimension.

The method compute_features() of an encoder will run a forward pass on a given multi-scale tensor and corresponding bounding boxes and return the features in a spatially structured format $(B,L,D,H',W')$ where $H'=\frac{H}{P}$ and $W'=\frac{W}{P}$, with $P$ being the patch size.

Citation

If you use this code for your research, please cite the following article:

@misc{dominguezmantes2026muvit,
      title={MuViT: Multi-Resolution Vision Transformers for Learning Across Scales in Microscopy}, 
      author={Albert Dominguez Mantes and Gioele La Manno and Martin Weigert},
      year={2026},
      eprint={2602.24222},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2602.24222}, 
}

About

Official implementation of MuViT (CVPR 2026), a ViT-based architecture for multi-scale modelling of gigapixel microscopy images.

Resources

License

Stars

Watchers

Forks

Contributors

Languages