Deep learning models and training configuration for semantic segmentation in marine ecology applications, specifically kelp and mussel detection from aerial/drone imagery. This is the training component of the Kelp-o-Matic ecosystem.
This repository provides PyTorch Lightning-based training infrastructure for computer vision models used in marine ecology research. The models perform semantic segmentation on aerial imagery to detect kelp forests and mussel beds.
Key Features:
- Configurable training pipelines via YAML configs
- Support for multiple architectures (UNet++, SegFormer, UperNet, etc.) via segmentation-models-pytorch
- Integration with Weights & Biases for experiment tracking
- Model export to ONNX and TorchScript for production deployment
- Built-in data augmentation and preprocessing with Albumentations
- Efficient preprocessing pipeline for GeoTIFF imagery
- Python 3.12+
- uv package manager
-
Clone the repository:
git clone <repository-url> cd hakai-ml-train
-
Install dependencies with uv:
# Installs dependencies and creates virtual environment uv sync # For development (includes Jupyter, pre-commit, etc.) uv sync --all-groups
-
Activate the virtual environment:
source .venv/bin/activate # On macOS/Linux # or .venv\Scripts\activate # On Windows
-
Set up pre-commit hooks (optional but recommended):
pre-commit install
The training pipeline expects preprocessed image chips in NPZ format. Follow these steps to prepare your dataset from GeoTIFF imagery.
Your raw data should follow this structure:
raw_data/
├── train/
│ ├── images/
│ │ ├── mosaic_01.tif
│ │ ├── mosaic_02.tif
│ │ └── ...
│ └── labels/
│ ├── mosaic_01.tif
│ ├── mosaic_02.tif
│ └── ...
├── val/
│ ├── images/
│ │ └── ...
│ └── labels/
│ └── ...
└── test/
├── images/
│ └── ...
└── labels/
└── ...
Requirements:
- Images: GeoTIFF files (.tif) with proper georeferencing
- Labels: GeoTIFF files with the same spatial extent and resolution as images
- Naming: Image and label files must have matching names
- Pixel values:
- Labels should use integer values (e.g., 0=background, 1=kelp, 2=nereo, etc.)
- Images can be uint8 (0-255) for RGB or uint16 for multispectral
Use the make_chip_dataset.py script to tile your GeoTIFF mosaics into smaller chips:
python -m src.prepare.make_chip_dataset <raw_data_dir> <output_dir> \
--size 224 \
--stride 224 \
--num_bands 3 \
--remap 0 -100 1 2Parameters:
<raw_data_dir>: Path to directory containing train/val/test folders<output_dir>: Where to save preprocessed NPZ chips--size: Size of square chips (default: 224)--stride: Stride for chip extraction (default: 224, equals --size for no overlap)--num_bands: Number of image bands to keep (3 for RGB, 4 for RGBI, etc.)--remap: Label value remapping as a list where index = old value, value = new value- Format:
new_0 new_1 new_2 new_3 ...(position in list = old label value) - Example:
0 1 0 0 -100means: 0→0 (bg), 1→1 (keep), 2→0 (remap to bg), 3→0 (remap to bg), 4→-100 (ignore) - Example:
0 -100 1 2means: 0→0 (bg), 1→-100 (ignore), 2→1 (class 1), 3→2 (class 2) - Use
-100for pixels to ignore during training
- Format:
--dtype: Data type for image values (default: uint8)
Example (Binary Kelp Detection):
# Assuming labels: 0=background, 1=noise, 2=kelp
# Remap to: 0→0 (bg), 1→-100 (ignore noise), 2→1 (kelp)
python -m src.prepare.make_chip_dataset \
/data/kelp_raw \
/data/kelp_chips_224 \
--size 224 \
--stride 224 \
--num_bands 3 \
--remap 0 -100 1Example (Multi-class Kelp Species):
# Assuming labels: 0=background, 1=noise, 2=macrocystis, 3=nereocystis
# Remap to: 0→0 (bg), 1→-100 (ignore noise), 2→1 (macro), 3→2 (nereo)
python -m src.prepare.make_chip_dataset \
/data/kelp_species_raw \
/data/kelp_species_chips_1024 \
--size 1024 \
--stride 1024 \
--num_bands 3 \
--remap 0 -100 1 2This creates NPZ files containing compressed image and label arrays in:
output_dir/
├── train/
│ ├── mosaic_01_0.npz
│ ├── mosaic_01_1.npz
│ └── ...
├── val/
│ └── ...
└── test/
└── ...
Remove unwanted chips to balance your dataset:
Remove background-only tiles:
python -m src.prepare.remove_bg_only_tiles /data/kelp_chips_224/train
python -m src.prepare.remove_bg_only_tiles /data/kelp_chips_224/val
python -m src.prepare.remove_bg_only_tiles /data/kelp_chips_224/testThis removes chips where all pixels are background (label ≤ 0), which is useful to reduce class imbalance.
Remove tiles with nodata areas:
python -m src.prepare.remove_tiles_with_nodata_areas /data/kelp_chips_224/train --num_channels 3
python -m src.prepare.remove_tiles_with_nodata_areas /data/kelp_chips_224/val --num_channels 3
python -m src.prepare.remove_tiles_with_nodata_areas /data/kelp_chips_224/test --num_channels 3This removes chips containing all-black pixels (assumed to be nodata areas from mosaicking).
For custom normalization, compute channel statistics from your training data:
python -m src.prepare.channel_stats /data/kelp_chips_224/train --max_pixel_val 255.0This outputs mean and std for each channel, saved to /data/channel_stats.npz. Use these values in your config's normalization transform.
Once you have prepared your dataset, training is straightforward using the PyTorch Lightning CLI:
python trainer.py fit --config configs/kelp-rgb/segformer_b3.yamlTraining is controlled via YAML configuration files in the configs/ directory. The configs use PyTorch Lightning CLI format and are organized by dataset type:
kelp-rgb/: RGB kelp detection modelskelp-rgbi/: 4-channel RGBI kelp detection modelskelp-ps8b/: 8-band PlanetScope multispectral kelp modelsmussels-rgb/: RGB mussel detection modelsmussels-goosenecks-rgb/: RGB mussel and gooseneck barnacle multi-class models
Here's an annotated example configuration (configs/kelp-rgb/segformer_b3.yaml):
seed_everything: 42
# Model configuration
model:
class_path: "src.models.smp.SMPMulticlassSegmentationModel"
init_args:
architecture: "Segformer" # Architecture: UnetPlusPlus, DeepLabV3Plus, FPN, MAnet, Segformer, etc.
backbone: "mit_b3" # Encoder backbone (see segmentation-models-pytorch docs)
model_opts:
encoder_weights: imagenet # Pretrained weights
in_channels: 3 # Number of input channels
num_classes: 3 # Number of output classes (including background)
ignore_index: &ignore_index -100 # Label value to ignore during training
lr: 3e-4 # Learning rate
wd: 0.01 # Weight decay
b1: 0.9 # Adam beta1
b2: 0.95 # Adam beta2
loss: "LovaszLoss" # Loss function: DiceLoss, LovaszLoss, FocalLoss, etc.
loss_opts:
mode: "multiclass" # "binary" or "multiclass"
ignore_index: *ignore_index
from_logits: true
# Data configuration
data:
class_path: "src.data.DataModule"
init_args:
# UPDATE THESE PATHS TO YOUR PREPROCESSED DATA
train_chip_dir: "/path/to/your/chips/train"
val_chip_dir: "/path/to/your/chips/val"
test_chip_dir: "/path/to/your/chips/test"
batch_size: 3
num_workers: 8
pin_memory: true
persistent_workers: true
# Training augmentations (serialized Albumentations pipeline)
train_transforms:
__version__: 2.0.9
transform:
__class_fullname__: Compose
transforms:
- __class_fullname__: D4
p: 1.0
- __class_fullname__: OneOf
p: 0.5
transforms:
- __class_fullname__: RandomBrightnessContrast
brightness_limit: [-0.1, 0.1]
contrast_limit: [-0.1, 0.1]
p: 1.0
- __class_fullname__: HueSaturationValue
hue_shift_limit: [-5.0, 5.0]
sat_shift_limit: [-10.0, 10.0]
val_shift_limit: [-15.0, 15.0]
p: 1.0
# ... more augmentations ...
- __class_fullname__: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
p: 1.0
- __class_fullname__: ToTensorV2
p: 1.0
# Validation/test transforms (minimal, just normalization)
test_transforms:
__version__: 2.0.9
transform:
__class_fullname__: Compose
transforms:
- __class_fullname__: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
p: 1.0
- __class_fullname__: ToTensorV2
p: 1.0
# Trainer configuration
trainer:
accelerator: auto # "auto", "gpu", "cpu"
devices: auto # Number of GPUs (auto = all available)
precision: bf16-mixed # "32", "16-mixed", "bf16-mixed"
log_every_n_steps: 50
max_epochs: 500
accumulate_grad_batches: 8 # Effective batch size = batch_size * accumulate_grad_batches
gradient_clip_val: 0.5
default_root_dir: checkpoints
# Weights & Biases logging
logger:
- class_path: lightning.pytorch.loggers.WandbLogger
init_args:
entity: hakai # W&B entity name
project: kom-kelp-rgb # W&B project name
name: segformer_b3 # Run name
group: Jul2025 # Group related runs
log_model: true # Upload checkpoints to W&B
tags:
- kelp
- Jul2025
# Callbacks
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: kelp_rgb_segformer_b3_epoch-{epoch:02d}_val-iou-{val/iou_epoch:.4f}
monitor: val/iou_epoch # Metric to monitor
mode: max # "max" or "min"
save_last: True
save_top_k: 2 # Save top 2 checkpoints
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step-
Choose or create a config file in
configs/directory based on your task -
Edit the config file to update these key parameters:
data: init_args: train_chip_dir: "/path/to/your/preprocessed/chips/train" val_chip_dir: "/path/to/your/preprocessed/chips/val" test_chip_dir: "/path/to/your/preprocessed/chips/test" batch_size: 3 # Adjust based on GPU memory trainer: logger: - class_path: lightning.pytorch.loggers.WandbLogger init_args: project: "your-project-name" name: "your-run-name"
-
Start training:
python trainer.py fit --config configs/kelp-rgb/segformer_b3.yaml
-
Resume from checkpoint (if training was interrupted):
python trainer.py fit --config configs/kelp-rgb/segformer_b3.yaml \ --ckpt_path checkpoints/last.ckpt
-
Test a trained model:
python trainer.py test --config configs/kelp-rgb/segformer_b3.yaml \ --ckpt_path checkpoints/best_checkpoint.ckpt
Model Architectures (via architecture parameter):
Unet,UnetPlusPlus- Classic U-Net variants with skip connectionsDeepLabV3,DeepLabV3Plus- Atrous convolution-basedFPN- Feature Pyramid NetworkPSPNet- Pyramid Scene Parsing NetworkMAnet- Multi-scale Attention NetworkSegformer- Transformer-based (efficient for high resolution)UperNet- Unified Perceptual Parsing
Encoder Backbones (via backbone parameter):
- ResNet:
resnet18,resnet34,resnet50,resnet101 - EfficientNet:
efficientnet-b0throughefficientnet-b7 - MobileNet:
mobilenet_v2 - SegFormer:
mit_b0throughmit_b5 - Swin Transformer:
swin_base_patch4_window7_224 - See segmentation-models-pytorch docs for full list
Loss Functions (via loss parameter):
DiceLoss- Good for imbalanced datasetsLovaszLoss- Optimizes IoU directlyFocalLoss- Focuses on hard examplesJaccardLoss- IoU lossTverskyLoss- Generalization of DiceFocalDiceComboLoss- Combination of Focal and Dice
Binary segmentation (background vs target):
model:
class_path: "src.models.smp.SMPBinarySegmentationModel"
init_args:
num_classes: 1
loss: "LovaszLoss"
loss_opts:
mode: "binary"Multi-class segmentation (background + multiple classes):
model:
class_path: "src.models.smp.SMPMulticlassSegmentationModel"
init_args:
num_classes: 3 # e.g., background, macrocystis, nereocystis
class_names: ["bg", "macro", "nereo"]
loss: "LovaszLoss"
loss_opts:
mode: "multiclass"After training, export PyTorch Lightning checkpoints to production-ready formats (ONNX or TorchScript) for deployment in Kelp-o-Matic.
The recommended format for production deployment:
python -m src.deploy.kom_onnx <config_path> <ckpt_path> <output_path> [--opset 11]Example:
python -m src.deploy.kom_onnx \
configs/kelp-rgb/segformer_b3.yaml \
checkpoints/best_model.ckpt \
models/kelp_segformer_b3.onnx \
--opset 14Parameters:
config_path: Path to the YAML config used for trainingckpt_path: Path to the PyTorch Lightning checkpoint (.ckpt file)output_path: Where to save the ONNX model--opset: ONNX opset version (default: 11, recommend 14 for newer models)
The exported ONNX model:
- Strips the Lightning wrapper and extracts just the segmentation model
- Supports dynamic batch size and spatial dimensions
- Outputs raw logits (no activation function applied)
For backwards compatibility with older Kelp-o-Matic versions:
Legacy RGB Kelp Models:
python -m src.deploy.kom_onnx_legacy_kelp_rgb \
configs/kelp-rgb/kom_baseline.yaml \
checkpoints/legacy_model.ckpt \
models/kelp_legacy_rgb.onnxLegacy RGBI Kelp Models:
python -m src.deploy.kom_onnx_legacy_kelp_rgbi \
configs/kelp-rgbi/kom_baseline.yaml \
checkpoints/legacy_model.ckpt \
models/kelp_legacy_rgbi.onnxAlternative deployment format (less portable but may be faster in pure PyTorch environments):
python -m src.deploy.kom_torchscript \
configs/kelp-rgb/segformer_b3.yaml \
checkpoints/best_model.ckpt \
models/kelp_segformer_b3.ptExported models are typically:
- Uploaded to AWS S3 in the kelp-o-matic bucket for production use
- Integrated into the Kelp-o-Matic inference pipeline
- Used via ONNX Runtime for efficient cross-platform inference
Training checkpoints and experiment logs remain in Weights & Biases under the hakai entity.
Format and lint with Ruff:
# Check for issues
ruff check .
# Auto-fix issues
ruff check --fix .
# Format code
ruff format .Run pre-commit hooks on all files:
pre-commit run --all-filesCore Components:
trainer.py- Lightning CLI entry pointsrc/models/smp.py- Lightning module wrappers (SMPBinarySegmentationModel, SMPMulticlassSegmentationModel)src/data.py- DataModule and dataset classessrc/losses.py- Loss function registrysrc/transforms.py- Augmentation pipeline helpers
Data Preparation:
src/prepare/make_chip_dataset.py- Tile GeoTIFF mosaics into chipssrc/prepare/remove_bg_only_tiles.py- Filter background-only chipssrc/prepare/remove_tiles_with_nodata_areas.py- Filter chips with nodatasrc/prepare/channel_stats.py- Compute normalization statistics
Model Export:
src/deploy/kom_onnx.py- Export to ONNX formatsrc/deploy/kom_torchscript.py- Export to TorchScriptsrc/deploy/kom_onnx_legacy_*.py- Export legacy model formats
Configuration:
configs/- Training configuration files organized by dataset
When creating new configuration files:
-
Use YAML anchors for parameter reuse:
ignore_index: &ignore_index -100 model: init_args: ignore_index: *ignore_index
-
Include descriptive metadata in W&B logging:
- Use clear project names (e.g.,
kom-kelp-rgb) - Include dataset version or date in the
groupfield - Add relevant tags for filtering experiments
- Use clear project names (e.g.,
-
Adjust batch size and gradient accumulation based on GPU memory:
- Effective batch size =
batch_size×accumulate_grad_batches - For large models, use smaller batch_size with larger accumulate_grad_batches
- Effective batch size =
-
Match image normalization to your preprocessing:
- RGB: Use ImageNet stats
[0.485, 0.456, 0.406]/[0.229, 0.224, 0.225] - Multispectral: Compute custom stats with
channel_stats.py
- RGB: Use ImageNet stats
Training metrics, hyperparameters, and model checkpoints are logged to Weights & Biases.
Access: Contact Taylor Denouden for access to the hakai entity.
Organization:
- Project names follow pattern:
kom-{dataset}-{modality}(e.g.,kom-kelp-rgb) - Checkpoints are uploaded as W&B artifacts when
log_model: true - Metrics tracked: IoU, accuracy, precision, recall, F1, loss
Viewing Results:
- Navigate to wandb.ai/hakai
- Select your project to view runs and compare experiments
- Download checkpoints from the Artifacts tab
- PyTorch Lightning - Training orchestration and multi-GPU support
- segmentation-models-pytorch - Model architectures and pretrained backbones
- Albumentations - Data augmentation
- TorchGeo - Geospatial data loading for preprocessing
- Weights & Biases - Experiment tracking
- ONNX - Model export for production
- uv - Fast Python package management
- Habitat-Mapper - Production inference pipeline CLI (formerly kelp-o-matic)
- Hakai Institute - Marine ecology research organization