SPROUT is a multi-crop, multi-task agricultural foundation model trained via diffusion denoising.
This repository currently supports three types of downstream tasks:
- Semantic Segmentation
- Monocular Depth Estimation
- Counting
Refer to https://github.com/illrayy/DODA to set up the environment.
Download the pretrained checkpoint and place it under pretrained_models/.
| Model | Params | Iterations | Download |
|---|---|---|---|
| SPROUT-S | 51M | 300k | Google Drive |
| SPROUT-B | 112M | 450k | Google Drive |
| SPROUT-L | 361M | 700k | Google Drive |
your_dataset/
├── train/
│ ├── images/ # RGB images (e.g., 001.png, 002.png, ...)
│ └── masks/ # Segmentation masks (same filenames as images)
├── val/
│ ├── images/
│ └── masks/
└── test/
├── images/
└── masks/
- Images and masks must share identical filenames (e.g.,
001.pngin bothimages/andmasks/). - Training images must be square (height == width).
- Masks should contain integer class IDs starting from 0. Use
255as the ignore label.
your_dataset/
├── train/
│ ├── images/ # RGB images
│ └── depth/ # Depth maps (same filenames as images)
└── val/
├── images/
└── depth/
- Images and depth maps must share identical filenames.
- Depth maps should be 16-bit PNG files with values in millimeters (the code divides by 1000 to convert to meters).
- Depth values outside the valid range (
min_depthtomax_depth, in meters) are treated as invalid.
your_dataset/
├── train/
│ ├── images/ # RGB images
│ └── train.txt # Annotation file
├── val/
│ ├── images/
│ └── val.txt
└── test/
├── images/
└── test.txt
Annotation file format (space-separated, one sample per line):
image001.png 42
image002.png 15
image003.png 87
Each line contains the image filename and its corresponding numeric label.
All training scripts use hardcoded hyperparameters in their __main__ block. Before running, open the script and modify the following variables:
| Variable | Description |
|---|---|
dataset_path |
Path to your dataset root directory |
num_classes |
Number of segmentation classes (segmentation only) |
max_number |
Maximum regression target value (regression only) |
min_depth / max_depth |
Valid depth range in meters (depth only) |
weight |
Path to pretrained checkpoint |
configs |
Path to model config YAML |
input_shape |
Input resolution (default: 256 for seg/depth, 384 for regression) |
time_step |
Diffusion timestep (default: 50 for seg/depth, 10 for regression) |
lr |
Learning rate |
batch_size |
Batch size |
total_iters |
Total training iterations |
Edit train_segmentation.py and set your parameters, then run:
python train_segmentation.py- Loss: Focal Loss + Lovász Loss
- LR schedule: Cosine with warmup
- Validation: Sliding-window mIoU evaluation
- Checkpoints: Saved to
logs/segmentation/
Edit train_depth_estimation.py and set your parameters, then run:
python train_depth_estimation.py- Loss: SiLog Loss
- Validation: RMSE and other depth metrics with horizontal flip TTA
- Checkpoints: Saved to
logs/depth/
Edit train_regression.py and set your parameters, then run:
python train_regression.py- Loss: MSE Loss
- Output mapping:
sigmoid(output) * max_number - Validation: R², MAE, MSE, RMSE, MAPE
- Checkpoints: Saved to
logs/regression/
Edit segmentation_ms_inference.py to set dataset_path, weight, num_classes, etc., then run:
python segmentation_ms_inference.py- Uses multi-scale inference at scales
[0.75, 1.0, 1.25]with sliding window (window_size=256,step_size=128). - Applies horizontal flip TTA (test-time augmentation).
- Reports mIoU on the test set.
Edit regression_inference.py to set dataset_path, weight, max_number, etc., then run:
python regression_inference.py- Applies horizontal flip TTA.
- Reports R², MAE, MSE, and RMSE on the test set.
Edit visiualize_depth.py to set weight, min_depth, max_depth, etc., then run:
python visiualize_depth.py- Place input images in
visualization/depth/input/. - Colorized depth maps (rainbow colormap) are saved to
visualization/depth/output/{model_name}/.