Unofficial training code for SegGPT.

- This implementation is based on my understanding of the paper and by reverse-engineering how the model works from the official implementation.
- I have tested this code to fine-tune on OEM dataset and got promising results. However, there might be some bugs or mistakes in the code. Feel free to raise an issue.
- Fine-tuning from the provided checkpoint requires a lot of GPU memory (at least 24GB) as this trains the whole ViT-16 backbone. Consider using smaller batch size or smaller model overall. I might implement training using LoRA to support smaller VRAM in the future if this repo gains enough tractions.
This code is developed with Python 3.9.
Install the required packages by running:
pip install -r requirements.txtCreate a new conda environment and install the required packages by running:
conda env create -f env.ymlSetup your dataset directory as follows:
<root_dataset_path>
├── images
│ ├── image1.tif
│ ├── image2.tif
│ ...
└── labels
├── image1.tif
├── image2.tif
...
Note:
- Image and labels must have the same name and extension (or you can modify
data.pyto support your needs). - The extension does not have to be
.tifas long as it can be loaded usingPILlibrary. - The label is a single-channel image where each pixel value represents the class of that pixel.
Create a .json config file. You can use the provided configs/base.json as a template. Then, run:
python train.py --config <path_to_json_config>The training uses DDP strategy and utilizes all available GPUs by default. You can specify the GPU to use by setting CUDA_VISIBLE_DEVICES in the environment variable.
You can also launch tensorboard to monitor the training progress:
tensorboard --logdir logsIn the paper, the author mentioned using learnable tensor for in-context tuning. You can find my implementation for this in model.py.