Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import time
import os
from detectron2.data import build_detection_test_loader
from data import build_lmdb_recognizer_train_loader, build_lmdb_recognizer_test_loader

from detectron2.engine.defaults import DefaultTrainer
from detectron2.utils import comm
Expand All @@ -24,7 +25,7 @@
from detectron2.modeling import build_model


from data import DatasetMapper, build_detection_train_loader
from data import DatasetMapper, build_detection_train_loader, lmdb_dataset
from torchtools.optim import RangerLars
from solver import WarmupCosineAnnealingLR
from detectron2.solver import build_lr_scheduler, build_optimizer
Expand Down Expand Up @@ -106,10 +107,14 @@ def build_model(cls, cfg):

@classmethod
def build_test_loader(cls, cfg, dataset_name):
if cfg.DATASETS.TYPE == "CRNN":
return build_lmdb_recognizer_test_loader(cfg)
return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False))

@classmethod
def build_train_loader(cls, cfg):
if cfg.DATASETS.TYPE == "CRNN":
return build_lmdb_recognizer_train_loader(cfg)
return build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True))

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .dataset_mapper import DatasetMapper
from .transforms import *
from .dataset import *
from .build import build_detection_train_loader
from .build import build_detection_train_loader, build_lmdb_recognizer_train_loader, build_lmdb_recognizer_test_loader
from .dataset import lmdb_dataset
28 changes: 28 additions & 0 deletions data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from detectron2.data.detection_utils import check_metadata_consistency
from detectron2.data.samplers import InferenceSampler, RepeatFactorTrainingSampler, TrainingSampler

from .dataset import lmdb_dataset

"""
This file contains the default logic to build a dataloader for training or testing.
"""
Expand All @@ -32,6 +34,8 @@
"get_detection_dataset_dicts",
"load_proposals_into_dataset",
"print_instances_class_histogram",
"build_lmdb_recognizer_train_loader",
"build_lmdb_recognizer_test_loader",
]


Expand Down Expand Up @@ -339,6 +343,30 @@ def build_detection_test_loader(cfg, dataset_name, mapper=None):
)
return data_loader

def build_lmdb_recognizer_train_loader(cfg):
train_dataset = lmdb_dataset.lmdbDataset(root=cfg.DATASETS.TRAIN_ROOT)
sampler = None
batch_size = cfg.SOLVER.IMS_PER_BATCH
if cfg.DATASETS.RANDOM_SAMPLE:
sampler = train_dataset.randomSequentialSampler(train_dataset, batch_size)

train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
shuffle=True, sampler=sampler,
num_workers=int(cfg.SOLVER.WORKERS),
collate_fn=lmdb_dataset.alignCollate(imgH=cfg.INPUT.IMG_H, imgW=cfg.INPUT.IMG_W, keep_ratio=cfg.INPUT.KEEP_RATIO))

return train_loader

def build_lmdb_recognizer_test_loader(cfg):
test_dataset = lmdb_dataset.lmdbDataset(
root=cfg.DATASETS.TEST_ROOT, transform=lmdb_dataset.resizeNormalize((100, 32)))

batch_size = cfg.SOLVER.IMS_PER_BATCH
test_loader = torch.utils.data.DataLoader(
test_dataset, shuffle=True, batch_size=batch_size, num_workers=int(cfg.SOLVER.WORKERS))

return test_loader

def trivial_batch_collator(batch):
"""
Expand Down
211 changes: 211 additions & 0 deletions data/dataset/dataset_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
import functools
import collections
from typing import Any, Dict, Optional
from data.dataset.datasets_tools.misc import is_seq_of, deprecated_api_warning
import platform
import numpy as np
import torch
import random
import warnings
from functools import partial
from torch.utils.data import DataLoader
from data.dataset.datasets_tools.parallel import collate

from data.dataset.datasets_tools.registry import Registry, build_from_cfg

from data.dataset.runner.dist_utils import get_dist_info
from data.dataset.runner.version_utils import TORCH_VERSION, digit_version

from data.dataset.samples import (ClassAwareSampler, DistributedGroupSampler,
DistributedSampler, GroupSampler, InfiniteBatchSampler,
InfiniteGroupBatchSampler)

DATASETS = Registry('dataset')

def build_dataset(cfg, default_args=None):
from dataset_wrapper import (ClassBalancedDataset, ConcatDataset,
MultiImageMixDataset, RepeatDataset)
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'ConcatDataset':
dataset = ConcatDataset(
[build_dataset(c, default_args) for c in cfg['datasets']],
cfg.get('separate_eval', True))
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif cfg['type'] == 'ClassBalancedDataset':
dataset = ClassBalancedDataset(
build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
elif cfg['type'] == 'MultiImageMixDataset':
cp_cfg = copy.deepcopy(cfg)
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'])
cp_cfg.pop('type')
dataset = MultiImageMixDataset(**cp_cfg)
elif isinstance(cfg.get('ann_file'), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)

return dataset

def _concat_dataset(cfg, default_args=None):
from dataset_wrapper import ConcatDataset
ann_files = cfg['ann_file']
img_prefixes = cfg.get('img_prefix', None)
seg_prefixes = cfg.get('seg_prefix', None)
proposal_files = cfg.get('proposal_file', None)
separate_eval = cfg.get('separate_eval', True)

datasets = []
num_dset = len(ann_files)
for i in range(num_dset):
data_cfg = copy.deepcopy(cfg)
# pop 'separate_eval' since it is not a valid key for common datasets.
if 'separate_eval' in data_cfg:
data_cfg.pop('separate_eval')
data_cfg['ann_file'] = ann_files[i]
if isinstance(img_prefixes, (list, tuple)):
data_cfg['img_prefix'] = img_prefixes[i]
if isinstance(seg_prefixes, (list, tuple)):
data_cfg['seg_prefix'] = seg_prefixes[i]
if isinstance(proposal_files, (list, tuple)):
data_cfg['proposal_file'] = proposal_files[i]
datasets.append(build_dataset(data_cfg, default_args))

return ConcatDataset(datasets, separate_eval)




def build_dataloader(dataset,
samples_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
shuffle=True,
seed=None,
runner_type='EpochBasedRunner',
persistent_workers=False,
class_aware_sampler=None,
**kwargs):
"""Build PyTorch DataLoader.

In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.

Args:
dataset (Dataset): A PyTorch dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
seed (int, Optional): Seed to be used. Default: None.
runner_type (str): Type of runner. Default: `EpochBasedRunner`
persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers `Dataset` instances alive.
This argument is only valid when PyTorch>=1.7.0. Default: False.
class_aware_sampler (dict): Whether to use `ClassAwareSampler`
during training. Default: None.
kwargs: any keyword argument to be used to initialize DataLoader

Returns:
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()

if dist:
# When model is :obj:`DistributedDataParallel`,
# `batch_size` of :obj:`dataloader` is the
# number of training samples on each GPU.
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
# When model is obj:`DataParallel`
# the batch size is samples on all the GPUS
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu

if runner_type == 'IterBasedRunner':
# this is a batch sampler, which can yield
# a mini-batch indices each time.
# it can be used in both `DataParallel` and
# `DistributedDataParallel`
if shuffle:
batch_sampler = InfiniteGroupBatchSampler(
dataset, batch_size, world_size, rank, seed=seed)
else:
batch_sampler = InfiniteBatchSampler(
dataset,
batch_size,
world_size,
rank,
seed=seed,
shuffle=False)
batch_size = 1
sampler = None
else:
if class_aware_sampler is not None:
# ClassAwareSampler can be used in both distributed and
# non-distributed training.
num_sample_class = class_aware_sampler.get('num_sample_class', 1)
sampler = ClassAwareSampler(
dataset,
samples_per_gpu,
world_size,
rank,
seed=seed,
num_sample_class=num_sample_class)
elif dist:
# DistributedGroupSampler will definitely shuffle the data to
# satisfy that images on each GPU are in the same group
if shuffle:
sampler = DistributedGroupSampler(
dataset, samples_per_gpu, world_size, rank, seed=seed)
else:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=False, seed=seed)
else:
sampler = GroupSampler(dataset,
samples_per_gpu) if shuffle else None
batch_sampler = None

init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None

if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.7.0')):
kwargs['persistent_workers'] = persistent_workers
elif persistent_workers is True:
warnings.warn('persistent_workers is invalid because your pytorch '
'version is lower than 1.7.0')

data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=kwargs.pop('pin_memory', False),
worker_init_fn=init_fn,
**kwargs)

return data_loader

def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
Loading