diff --git a/apis/train.py b/apis/train.py index 6c69158..441ee59 100755 --- a/apis/train.py +++ b/apis/train.py @@ -12,6 +12,7 @@ import time import os from detectron2.data import build_detection_test_loader +from data import build_lmdb_recognizer_train_loader, build_lmdb_recognizer_test_loader from detectron2.engine.defaults import DefaultTrainer from detectron2.utils import comm @@ -24,7 +25,7 @@ from detectron2.modeling import build_model -from data import DatasetMapper, build_detection_train_loader +from data import DatasetMapper, build_detection_train_loader, lmdb_dataset from torchtools.optim import RangerLars from solver import WarmupCosineAnnealingLR from detectron2.solver import build_lr_scheduler, build_optimizer @@ -106,10 +107,14 @@ def build_model(cls, cfg): @classmethod def build_test_loader(cls, cfg, dataset_name): + if cfg.DATASETS.TYPE == "CRNN": + return build_lmdb_recognizer_test_loader(cfg) return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False)) @classmethod def build_train_loader(cls, cfg): + if cfg.DATASETS.TYPE == "CRNN": + return build_lmdb_recognizer_train_loader(cfg) return build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True)) @classmethod diff --git a/data/__init__.py b/data/__init__.py index 07398bc..45cfd49 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -3,4 +3,5 @@ from .dataset_mapper import DatasetMapper from .transforms import * from .dataset import * -from .build import build_detection_train_loader \ No newline at end of file +from .build import build_detection_train_loader, build_lmdb_recognizer_train_loader, build_lmdb_recognizer_test_loader +from .dataset import lmdb_dataset \ No newline at end of file diff --git a/data/build.py b/data/build.py index 0a30da5..4483228 100644 --- a/data/build.py +++ b/data/build.py @@ -21,6 +21,8 @@ from detectron2.data.detection_utils import check_metadata_consistency from detectron2.data.samplers import InferenceSampler, RepeatFactorTrainingSampler, TrainingSampler +from .dataset import lmdb_dataset + """ This file contains the default logic to build a dataloader for training or testing. """ @@ -32,6 +34,8 @@ "get_detection_dataset_dicts", "load_proposals_into_dataset", "print_instances_class_histogram", + "build_lmdb_recognizer_train_loader", + "build_lmdb_recognizer_test_loader", ] @@ -339,6 +343,30 @@ def build_detection_test_loader(cfg, dataset_name, mapper=None): ) return data_loader +def build_lmdb_recognizer_train_loader(cfg): + train_dataset = lmdb_dataset.lmdbDataset(root=cfg.DATASETS.TRAIN_ROOT) + sampler = None + batch_size = cfg.SOLVER.IMS_PER_BATCH + if cfg.DATASETS.RANDOM_SAMPLE: + sampler = train_dataset.randomSequentialSampler(train_dataset, batch_size) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, + shuffle=True, sampler=sampler, + num_workers=int(cfg.SOLVER.WORKERS), + collate_fn=lmdb_dataset.alignCollate(imgH=cfg.INPUT.IMG_H, imgW=cfg.INPUT.IMG_W, keep_ratio=cfg.INPUT.KEEP_RATIO)) + + return train_loader + +def build_lmdb_recognizer_test_loader(cfg): + test_dataset = lmdb_dataset.lmdbDataset( + root=cfg.DATASETS.TEST_ROOT, transform=lmdb_dataset.resizeNormalize((100, 32))) + + batch_size = cfg.SOLVER.IMS_PER_BATCH + test_loader = torch.utils.data.DataLoader( + test_dataset, shuffle=True, batch_size=batch_size, num_workers=int(cfg.SOLVER.WORKERS)) + + return test_loader def trivial_batch_collator(batch): """ diff --git a/data/dataset/dataset_builder.py b/data/dataset/dataset_builder.py new file mode 100644 index 0000000..26ac526 --- /dev/null +++ b/data/dataset/dataset_builder.py @@ -0,0 +1,211 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import inspect +import functools +import collections +from typing import Any, Dict, Optional +from data.dataset.datasets_tools.misc import is_seq_of, deprecated_api_warning +import platform +import numpy as np +import torch +import random +import warnings +from functools import partial +from torch.utils.data import DataLoader +from data.dataset.datasets_tools.parallel import collate + +from data.dataset.datasets_tools.registry import Registry, build_from_cfg + +from data.dataset.runner.dist_utils import get_dist_info +from data.dataset.runner.version_utils import TORCH_VERSION, digit_version + +from data.dataset.samples import (ClassAwareSampler, DistributedGroupSampler, + DistributedSampler, GroupSampler, InfiniteBatchSampler, + InfiniteGroupBatchSampler) + +DATASETS = Registry('dataset') + +def build_dataset(cfg, default_args=None): + from dataset_wrapper import (ClassBalancedDataset, ConcatDataset, + MultiImageMixDataset, RepeatDataset) + if isinstance(cfg, (list, tuple)): + dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) + elif cfg['type'] == 'ConcatDataset': + dataset = ConcatDataset( + [build_dataset(c, default_args) for c in cfg['datasets']], + cfg.get('separate_eval', True)) + elif cfg['type'] == 'RepeatDataset': + dataset = RepeatDataset( + build_dataset(cfg['dataset'], default_args), cfg['times']) + elif cfg['type'] == 'ClassBalancedDataset': + dataset = ClassBalancedDataset( + build_dataset(cfg['dataset'], default_args), cfg['oversample_thr']) + elif cfg['type'] == 'MultiImageMixDataset': + cp_cfg = copy.deepcopy(cfg) + cp_cfg['dataset'] = build_dataset(cp_cfg['dataset']) + cp_cfg.pop('type') + dataset = MultiImageMixDataset(**cp_cfg) + elif isinstance(cfg.get('ann_file'), (list, tuple)): + dataset = _concat_dataset(cfg, default_args) + else: + dataset = build_from_cfg(cfg, DATASETS, default_args) + + return dataset + +def _concat_dataset(cfg, default_args=None): + from dataset_wrapper import ConcatDataset + ann_files = cfg['ann_file'] + img_prefixes = cfg.get('img_prefix', None) + seg_prefixes = cfg.get('seg_prefix', None) + proposal_files = cfg.get('proposal_file', None) + separate_eval = cfg.get('separate_eval', True) + + datasets = [] + num_dset = len(ann_files) + for i in range(num_dset): + data_cfg = copy.deepcopy(cfg) + # pop 'separate_eval' since it is not a valid key for common datasets. + if 'separate_eval' in data_cfg: + data_cfg.pop('separate_eval') + data_cfg['ann_file'] = ann_files[i] + if isinstance(img_prefixes, (list, tuple)): + data_cfg['img_prefix'] = img_prefixes[i] + if isinstance(seg_prefixes, (list, tuple)): + data_cfg['seg_prefix'] = seg_prefixes[i] + if isinstance(proposal_files, (list, tuple)): + data_cfg['proposal_file'] = proposal_files[i] + datasets.append(build_dataset(data_cfg, default_args)) + + return ConcatDataset(datasets, separate_eval) + + + + +def build_dataloader(dataset, + samples_per_gpu, + workers_per_gpu, + num_gpus=1, + dist=True, + shuffle=True, + seed=None, + runner_type='EpochBasedRunner', + persistent_workers=False, + class_aware_sampler=None, + **kwargs): + """Build PyTorch DataLoader. + + In distributed training, each GPU/process has a dataloader. + In non-distributed training, there is only one dataloader for all GPUs. + + Args: + dataset (Dataset): A PyTorch dataset. + samples_per_gpu (int): Number of training samples on each GPU, i.e., + batch size of each GPU. + workers_per_gpu (int): How many subprocesses to use for data loading + for each GPU. + num_gpus (int): Number of GPUs. Only used in non-distributed training. + dist (bool): Distributed training/test or not. Default: True. + shuffle (bool): Whether to shuffle the data at every epoch. + Default: True. + seed (int, Optional): Seed to be used. Default: None. + runner_type (str): Type of runner. Default: `EpochBasedRunner` + persistent_workers (bool): If True, the data loader will not shutdown + the worker processes after a dataset has been consumed once. + This allows to maintain the workers `Dataset` instances alive. + This argument is only valid when PyTorch>=1.7.0. Default: False. + class_aware_sampler (dict): Whether to use `ClassAwareSampler` + during training. Default: None. + kwargs: any keyword argument to be used to initialize DataLoader + + Returns: + DataLoader: A PyTorch dataloader. + """ + rank, world_size = get_dist_info() + + if dist: + # When model is :obj:`DistributedDataParallel`, + # `batch_size` of :obj:`dataloader` is the + # number of training samples on each GPU. + batch_size = samples_per_gpu + num_workers = workers_per_gpu + else: + # When model is obj:`DataParallel` + # the batch size is samples on all the GPUS + batch_size = num_gpus * samples_per_gpu + num_workers = num_gpus * workers_per_gpu + + if runner_type == 'IterBasedRunner': + # this is a batch sampler, which can yield + # a mini-batch indices each time. + # it can be used in both `DataParallel` and + # `DistributedDataParallel` + if shuffle: + batch_sampler = InfiniteGroupBatchSampler( + dataset, batch_size, world_size, rank, seed=seed) + else: + batch_sampler = InfiniteBatchSampler( + dataset, + batch_size, + world_size, + rank, + seed=seed, + shuffle=False) + batch_size = 1 + sampler = None + else: + if class_aware_sampler is not None: + # ClassAwareSampler can be used in both distributed and + # non-distributed training. + num_sample_class = class_aware_sampler.get('num_sample_class', 1) + sampler = ClassAwareSampler( + dataset, + samples_per_gpu, + world_size, + rank, + seed=seed, + num_sample_class=num_sample_class) + elif dist: + # DistributedGroupSampler will definitely shuffle the data to + # satisfy that images on each GPU are in the same group + if shuffle: + sampler = DistributedGroupSampler( + dataset, samples_per_gpu, world_size, rank, seed=seed) + else: + sampler = DistributedSampler( + dataset, world_size, rank, shuffle=False, seed=seed) + else: + sampler = GroupSampler(dataset, + samples_per_gpu) if shuffle else None + batch_sampler = None + + init_fn = partial( + worker_init_fn, num_workers=num_workers, rank=rank, + seed=seed) if seed is not None else None + + if (TORCH_VERSION != 'parrots' + and digit_version(TORCH_VERSION) >= digit_version('1.7.0')): + kwargs['persistent_workers'] = persistent_workers + elif persistent_workers is True: + warnings.warn('persistent_workers is invalid because your pytorch ' + 'version is lower than 1.7.0') + + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + batch_sampler=batch_sampler, + collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), + pin_memory=kwargs.pop('pin_memory', False), + worker_init_fn=init_fn, + **kwargs) + + return data_loader + +def worker_init_fn(worker_id, num_workers, rank, seed): + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + torch.manual_seed(worker_seed) \ No newline at end of file diff --git a/data/dataset/dataset_wrapper.py b/data/dataset/dataset_wrapper.py new file mode 100644 index 0000000..570d1b8 --- /dev/null +++ b/data/dataset/dataset_wrapper.py @@ -0,0 +1,445 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import bisect +import collections +import copy +import math +from collections import defaultdict +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset +from data.dataset.utils import print_log +import numpy as np + +class ConcatDataset(_ConcatDataset): + """A wrapper of concatenated dataset. + + Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but + concat the group flag for image aspect ratio. + + Args: + datasets (list[:obj:`Dataset`]): A list of datasets. + separate_eval (bool): Whether to evaluate the results + separately if it is used as validation dataset. + Defaults to True. + """ + + def __init__(self, datasets, separate_eval=True): + super(ConcatDataset, self).__init__(datasets) + self.CLASSES = datasets[0].CLASSES + self.PALETTE = getattr(datasets[0], 'PALETTE', None) + self.separate_eval = separate_eval + # if not separate_eval: + # if any([isinstance(ds, CocoDataset) for ds in datasets]): + # raise NotImplementedError( + # 'Evaluating concatenated CocoDataset as a whole is not' + # ' supported! Please set "separate_eval=True"') + # elif len(set([type(ds) for ds in datasets])) != 1: + # raise NotImplementedError( + # 'All the datasets should have same types') + + if hasattr(datasets[0], 'flag'): + flags = [] + for i in range(0, len(datasets)): + flags.append(datasets[i].flag) + self.flag = np.concatenate(flags) + + def get_cat_ids(self, idx): + """Get category ids of concatenated dataset by index. + + Args: + idx (int): Index of data. + + Returns: + list[int]: All categories in the image of specified index. + """ + + if idx < 0: + if -idx > len(self): + raise ValueError( + 'absolute value of index should not exceed dataset length') + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx].get_cat_ids(sample_idx) + + def get_ann_info(self, idx): + """Get annotation of concatenated dataset by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + + if idx < 0: + if -idx > len(self): + raise ValueError( + 'absolute value of index should not exceed dataset length') + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx].get_ann_info(sample_idx) + + def evaluate(self, results, logger=None, **kwargs): + """Evaluate the results. + + Args: + results (list[list | tuple]): Testing results of the dataset. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + + Returns: + dict[str: float]: AP results of the total dataset or each separate + dataset if `self.separate_eval=True`. + """ + assert len(results) == self.cumulative_sizes[-1], \ + ('Dataset and results have different sizes: ' + f'{self.cumulative_sizes[-1]} v.s. {len(results)}') + + # Check whether all the datasets support evaluation + for dataset in self.datasets: + assert hasattr(dataset, 'evaluate'), \ + f'{type(dataset)} does not implement evaluate function' + + if self.separate_eval: + dataset_idx = -1 + total_eval_results = dict() + for size, dataset in zip(self.cumulative_sizes, self.datasets): + start_idx = 0 if dataset_idx == -1 else \ + self.cumulative_sizes[dataset_idx] + end_idx = self.cumulative_sizes[dataset_idx + 1] + + results_per_dataset = results[start_idx:end_idx] + print_log( + f'\nEvaluateing {dataset.ann_file} with ' + f'{len(results_per_dataset)} images now', + logger=logger) + + eval_results_per_dataset = dataset.evaluate( + results_per_dataset, logger=logger, **kwargs) + dataset_idx += 1 + for k, v in eval_results_per_dataset.items(): + total_eval_results.update({f'{dataset_idx}_{k}': v}) + + return total_eval_results + # elif any([isinstance(ds, CocoDataset) for ds in self.datasets]): + # raise NotImplementedError( + # 'Evaluating concatenated CocoDataset as a whole is not' + # ' supported! Please set "separate_eval=True"') + elif len(set([type(ds) for ds in self.datasets])) != 1: + raise NotImplementedError( + 'All the datasets should have same types') + else: + original_data_infos = self.datasets[0].data_infos + self.datasets[0].data_infos = sum( + [dataset.data_infos for dataset in self.datasets], []) + eval_results = self.datasets[0].evaluate( + results, logger=logger, **kwargs) + self.datasets[0].data_infos = original_data_infos + return eval_results + +class RepeatDataset: + """A wrapper of repeated dataset. + + The length of repeated dataset will be `times` larger than the original + dataset. This is useful when the data loading time is long but the dataset + is small. Using RepeatDataset can reduce the data loading time between + epochs. + + Args: + dataset (:obj:`Dataset`): The dataset to be repeated. + times (int): Repeat times. + """ + + def __init__(self, dataset, times): + self.dataset = dataset + self.times = times + self.CLASSES = dataset.CLASSES + self.PALETTE = getattr(dataset, 'PALETTE', None) + if hasattr(self.dataset, 'flag'): + self.flag = np.tile(self.dataset.flag, times) + + self._ori_len = len(self.dataset) + + def __getitem__(self, idx): + return self.dataset[idx % self._ori_len] + + def get_cat_ids(self, idx): + """Get category ids of repeat dataset by index. + + Args: + idx (int): Index of data. + + Returns: + list[int]: All categories in the image of specified index. + """ + + return self.dataset.get_cat_ids(idx % self._ori_len) + + def get_ann_info(self, idx): + """Get annotation of repeat dataset by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + + return self.dataset.get_ann_info(idx % self._ori_len) + + def __len__(self): + """Length after repetition.""" + return self.times * self._ori_len + + +class ClassBalancedDataset: + """A wrapper of repeated dataset with repeat factor. + + Suitable for training on class imbalanced datasets like LVIS. Following + the sampling strategy in the `paper `_, + in each epoch, an image may appear multiple times based on its + "repeat factor". + The repeat factor for an image is a function of the frequency the rarest + category labeled in that image. The "frequency of category c" in [0, 1] + is defined by the fraction of images in the training set (without repeats) + in which category c appears. + The dataset needs to instantiate :func:`self.get_cat_ids` to support + ClassBalancedDataset. + + The repeat factor is computed as followed. + + 1. For each category c, compute the fraction # of images + that contain it: :math:`f(c)` + 2. For each category c, compute the category-level repeat factor: + :math:`r(c) = max(1, sqrt(t/f(c)))` + 3. For each image I, compute the image-level repeat factor: + :math:`r(I) = max_{c in I} r(c)` + + Args: + dataset (:obj:`CustomDataset`): The dataset to be repeated. + oversample_thr (float): frequency threshold below which data is + repeated. For categories with ``f_c >= oversample_thr``, there is + no oversampling. For categories with ``f_c < oversample_thr``, the + degree of oversampling following the square-root inverse frequency + heuristic above. + filter_empty_gt (bool, optional): If set true, images without bounding + boxes will not be oversampled. Otherwise, they will be categorized + as the pure background class and involved into the oversampling. + Default: True. + """ + + def __init__(self, dataset, oversample_thr, filter_empty_gt=True): + self.dataset = dataset + self.oversample_thr = oversample_thr + self.filter_empty_gt = filter_empty_gt + self.CLASSES = dataset.CLASSES + self.PALETTE = getattr(dataset, 'PALETTE', None) + + repeat_factors = self._get_repeat_factors(dataset, oversample_thr) + repeat_indices = [] + for dataset_idx, repeat_factor in enumerate(repeat_factors): + repeat_indices.extend([dataset_idx] * math.ceil(repeat_factor)) + self.repeat_indices = repeat_indices + + flags = [] + if hasattr(self.dataset, 'flag'): + for flag, repeat_factor in zip(self.dataset.flag, repeat_factors): + flags.extend([flag] * int(math.ceil(repeat_factor))) + assert len(flags) == len(repeat_indices) + self.flag = np.asarray(flags, dtype=np.uint8) + + def _get_repeat_factors(self, dataset, repeat_thr): + """Get repeat factor for each images in the dataset. + + Args: + dataset (:obj:`CustomDataset`): The dataset + repeat_thr (float): The threshold of frequency. If an image + contains the categories whose frequency below the threshold, + it would be repeated. + + Returns: + list[float]: The repeat factors for each images in the dataset. + """ + + # 1. For each category c, compute the fraction # of images + # that contain it: f(c) + category_freq = defaultdict(int) + num_images = len(dataset) + for idx in range(num_images): + cat_ids = set(self.dataset.get_cat_ids(idx)) + if len(cat_ids) == 0 and not self.filter_empty_gt: + cat_ids = set([len(self.CLASSES)]) + for cat_id in cat_ids: + category_freq[cat_id] += 1 + for k, v in category_freq.items(): + category_freq[k] = v / num_images + + # 2. For each category c, compute the category-level repeat factor: + # r(c) = max(1, sqrt(t/f(c))) + category_repeat = { + cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq)) + for cat_id, cat_freq in category_freq.items() + } + + # 3. For each image I, compute the image-level repeat factor: + # r(I) = max_{c in I} r(c) + repeat_factors = [] + for idx in range(num_images): + cat_ids = set(self.dataset.get_cat_ids(idx)) + if len(cat_ids) == 0 and not self.filter_empty_gt: + cat_ids = set([len(self.CLASSES)]) + repeat_factor = 1 + if len(cat_ids) > 0: + repeat_factor = max( + {category_repeat[cat_id] + for cat_id in cat_ids}) + repeat_factors.append(repeat_factor) + + return repeat_factors + + def __getitem__(self, idx): + ori_index = self.repeat_indices[idx] + return self.dataset[ori_index] + + def get_ann_info(self, idx): + """Get annotation of dataset by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + ori_index = self.repeat_indices[idx] + return self.dataset.get_ann_info(ori_index) + + def __len__(self): + """Length after repetition.""" + return len(self.repeat_indices) + + +class MultiImageMixDataset: + """A wrapper of multiple images mixed dataset. + + Suitable for training on multiple images mixed data augmentation like + mosaic and mixup. For the augmentation pipeline of mixed image data, + the `get_indexes` method needs to be provided to obtain the image + indexes, and you can set `skip_flags` to change the pipeline running + process. At the same time, we provide the `dynamic_scale` parameter + to dynamically change the output image size. + + Args: + dataset (:obj:`CustomDataset`): The dataset to be mixed. + pipeline (Sequence[dict]): Sequence of transform object or + config dict to be composed. + dynamic_scale (tuple[int], optional): The image scale can be changed + dynamically. Default to None. It is deprecated. + skip_type_keys (list[str], optional): Sequence of type string to + be skip pipeline. Default to None. + max_refetch (int): The maximum number of retry iterations for getting + valid results from the pipeline. If the number of iterations is + greater than `max_refetch`, but results is still None, then the + iteration is terminated and raise the error. Default: 15. + """ + + def __init__(self, + dataset, + pipeline, + dynamic_scale=None, + skip_type_keys=None, + max_refetch=15): + if dynamic_scale is not None: + raise RuntimeError( + 'dynamic_scale is deprecated. Please use Resize pipeline ' + 'to achieve similar functions') + assert isinstance(pipeline, collections.abc.Sequence) + if skip_type_keys is not None: + assert all([ + isinstance(skip_type_key, str) + for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys + + self.pipeline = [] + self.pipeline_types = [] + for transform in pipeline: + if isinstance(transform, dict): + self.pipeline_types.append(transform['type']) + # transform = build_from_cfg(transform, PIPELINES) TODO: need refactor + self.pipeline.append(transform) + else: + raise TypeError('pipeline must be a dict') + + self.dataset = dataset + self.CLASSES = dataset.CLASSES + self.PALETTE = getattr(dataset, 'PALETTE', None) + if hasattr(self.dataset, 'flag'): + self.flag = dataset.flag + self.num_samples = len(dataset) + self.max_refetch = max_refetch + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + results = copy.deepcopy(self.dataset[idx]) + for (transform, transform_type) in zip(self.pipeline, + self.pipeline_types): + if self._skip_type_keys is not None and \ + transform_type in self._skip_type_keys: + continue + + if hasattr(transform, 'get_indexes'): + for i in range(self.max_refetch): + # Make sure the results passed the loading pipeline + # of the original dataset is not None. + indexes = transform.get_indexes(self.dataset) + if not isinstance(indexes, collections.abc.Sequence): + indexes = [indexes] + mix_results = [ + copy.deepcopy(self.dataset[index]) for index in indexes + ] + if None not in mix_results: + results['mix_results'] = mix_results + break + else: + raise RuntimeError( + 'The loading pipeline of the original dataset' + ' always return None. Please check the correctness ' + 'of the dataset and its pipeline.') + + for i in range(self.max_refetch): + # To confirm the results passed the training pipeline + # of the wrapper is not None. + updated_results = transform(copy.deepcopy(results)) + if updated_results is not None: + results = updated_results + break + else: + raise RuntimeError( + 'The training pipeline of the dataset wrapper' + ' always return None.Please check the correctness ' + 'of the dataset and its pipeline.') + + if 'mix_results' in results: + results.pop('mix_results') + + return results + + def update_skip_type_keys(self, skip_type_keys): + """Update skip_type_keys. It is called by an external hook. + + Args: + skip_type_keys (list[str], optional): Sequence of type + string to be skip pipeline. + """ + assert all([ + isinstance(skip_type_key, str) for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys \ No newline at end of file diff --git a/data/dataset/datasets_tools/data_container.py b/data/dataset/datasets_tools/data_container.py new file mode 100644 index 0000000..cedb0d3 --- /dev/null +++ b/data/dataset/datasets_tools/data_container.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import torch + + +def assert_tensor_type(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not isinstance(args[0].data, torch.Tensor): + raise AttributeError( + f'{args[0].__class__.__name__} has no attribute ' + f'{func.__name__} for type {args[0].datatype}') + return func(*args, **kwargs) + + return wrapper + + +class DataContainer: + """A container for any type of objects. + + Typically tensors will be stacked in the collate function and sliced along + some dimension in the scatter function. This behavior has some limitations. + 1. All tensors have to be the same size. + 2. Types are limited (numpy array or Tensor). + + We design `DataContainer` and `MMDataParallel` to overcome these + limitations. The behavior can be either of the following. + + - copy to GPU, pad all tensors to the same size and stack them + - copy to GPU without stacking + - leave the objects as is and pass it to the model + - pad_dims specifies the number of last few dimensions to do padding + """ + + def __init__(self, + data, + stack=False, + padding_value=0, + cpu_only=False, + pad_dims=2): + self._data = data + self._cpu_only = cpu_only + self._stack = stack + self._padding_value = padding_value + assert pad_dims in [None, 1, 2, 3] + self._pad_dims = pad_dims + + def __repr__(self): + return f'{self.__class__.__name__}({repr(self.data)})' + + def __len__(self): + return len(self._data) + + @property + def data(self): + return self._data + + @property + def datatype(self): + if isinstance(self.data, torch.Tensor): + return self.data.type() + else: + return type(self.data) + + @property + def cpu_only(self): + return self._cpu_only + + @property + def stack(self): + return self._stack + + @property + def padding_value(self): + return self._padding_value + + @property + def pad_dims(self): + return self._pad_dims + + @assert_tensor_type + def size(self, *args, **kwargs): + return self.data.size(*args, **kwargs) + + @assert_tensor_type + def dim(self): + return self.data.dim() diff --git a/data/dataset/datasets_tools/misc.py b/data/dataset/datasets_tools/misc.py new file mode 100644 index 0000000..2452cf8 --- /dev/null +++ b/data/dataset/datasets_tools/misc.py @@ -0,0 +1,86 @@ +import collections +import inspect +import functools +import warnings + +def is_seq_of(seq, expected_type, seq_type=None): + """Check whether it is a sequence of some type. + + Args: + seq (Sequence): The sequence to be checked. + expected_type (type): Expected type of sequence items. + seq_type (type, optional): Expected sequence type. + + Returns: + bool: Whether the sequence is valid. + """ + if seq_type is None: + exp_seq_type = collections.abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + +def deprecated_api_warning(name_dict, cls_name=None): + """A decorator to check if some arguments are deprecate and try to replace + deprecate src_arg_name to dst_arg_name. + + Args: + name_dict(dict): + key (str): Deprecate argument names. + val (str): Expected argument names. + + Returns: + func: New function. + """ + + def api_warning_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # get the arg spec of the decorated method + args_info = inspect.getfullargspec(old_func) + # get name of the function + func_name = old_func.__name__ + if cls_name is not None: + func_name = f'{cls_name}.{func_name}' + if args: + arg_names = args_info.args[:len(args)] + for src_arg_name, dst_arg_name in name_dict.items(): + if src_arg_name in arg_names: + warnings.warn( + f'"{src_arg_name}" is deprecated in ' + f'`{func_name}`, please use "{dst_arg_name}" ' + 'instead', DeprecationWarning) + arg_names[arg_names.index(src_arg_name)] = dst_arg_name + if kwargs: + for src_arg_name, dst_arg_name in name_dict.items(): + if src_arg_name in kwargs: + + assert dst_arg_name not in kwargs, ( + f'The expected behavior is to replace ' + f'the deprecated key `{src_arg_name}` to ' + f'new key `{dst_arg_name}`, but got them ' + f'in the arguments at the same time, which ' + f'is confusing. `{src_arg_name} will be ' + f'deprecated in the future, please ' + f'use `{dst_arg_name}` instead.') + + warnings.warn( + f'"{src_arg_name}" is deprecated in ' + f'`{func_name}`, please use "{dst_arg_name}" ' + 'instead', DeprecationWarning) + kwargs[dst_arg_name] = kwargs.pop(src_arg_name) + + # apply converted arguments to the decorated method + output = old_func(*args, **kwargs) + return output + + return new_func + + return api_warning_wrapper \ No newline at end of file diff --git a/data/dataset/datasets_tools/parallel.py b/data/dataset/datasets_tools/parallel.py new file mode 100644 index 0000000..ad74919 --- /dev/null +++ b/data/dataset/datasets_tools/parallel.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections.abc import Mapping, Sequence + +import torch +import torch.nn.functional as F +from torch.utils.data.dataloader import default_collate + +from .data_container import DataContainer + + +def collate(batch, samples_per_gpu=1): + """Puts each data field into a tensor/DataContainer with outer dimension + batch size. + + Extend default_collate to add support for + :type:`~mmcv.parallel.DataContainer`. There are 3 cases. + + 1. cpu_only = True, e.g., meta data + 2. cpu_only = False, stack = True, e.g., images tensors + 3. cpu_only = False, stack = False, e.g., gt bboxes + """ + + if not isinstance(batch, Sequence): + raise TypeError(f'{batch.dtype} is not supported.') + + if isinstance(batch[0], DataContainer): + stacked = [] + if batch[0].cpu_only: + for i in range(0, len(batch), samples_per_gpu): + stacked.append( + [sample.data for sample in batch[i:i + samples_per_gpu]]) + return DataContainer( + stacked, batch[0].stack, batch[0].padding_value, cpu_only=True) + elif batch[0].stack: + for i in range(0, len(batch), samples_per_gpu): + assert isinstance(batch[i].data, torch.Tensor) + + if batch[i].pad_dims is not None: + ndim = batch[i].dim() + assert ndim > batch[i].pad_dims + max_shape = [0 for _ in range(batch[i].pad_dims)] + for dim in range(1, batch[i].pad_dims + 1): + max_shape[dim - 1] = batch[i].size(-dim) + for sample in batch[i:i + samples_per_gpu]: + for dim in range(0, ndim - batch[i].pad_dims): + assert batch[i].size(dim) == sample.size(dim) + for dim in range(1, batch[i].pad_dims + 1): + max_shape[dim - 1] = max(max_shape[dim - 1], + sample.size(-dim)) + padded_samples = [] + for sample in batch[i:i + samples_per_gpu]: + pad = [0 for _ in range(batch[i].pad_dims * 2)] + for dim in range(1, batch[i].pad_dims + 1): + pad[2 * dim - + 1] = max_shape[dim - 1] - sample.size(-dim) + padded_samples.append( + F.pad( + sample.data, pad, value=sample.padding_value)) + stacked.append(default_collate(padded_samples)) + elif batch[i].pad_dims is None: + stacked.append( + default_collate([ + sample.data + for sample in batch[i:i + samples_per_gpu] + ])) + else: + raise ValueError( + 'pad_dims should be either None or integers (1-3)') + + else: + for i in range(0, len(batch), samples_per_gpu): + stacked.append( + [sample.data for sample in batch[i:i + samples_per_gpu]]) + return DataContainer(stacked, batch[0].stack, batch[0].padding_value) + elif isinstance(batch[0], Sequence): + transposed = zip(*batch) + return [collate(samples, samples_per_gpu) for samples in transposed] + elif isinstance(batch[0], Mapping): + return { + key: collate([d[key] for d in batch], samples_per_gpu) + for key in batch[0] + } + else: + return default_collate(batch) diff --git a/data/dataset/datasets_tools/registry.py b/data/dataset/datasets_tools/registry.py new file mode 100644 index 0000000..f69a4e3 --- /dev/null +++ b/data/dataset/datasets_tools/registry.py @@ -0,0 +1,339 @@ +import warnings +import inspect +import functools +import collections +from typing import Any, Dict, Optional +from data.dataset.datasets_tools.misc import is_seq_of, deprecated_api_warning + + +def build_from_cfg(cfg: Dict, + registry: 'Registry', + default_args: Optional[Dict] = None) -> Any: + """Build a module from config dict when it is a class configuration, or + call a function from config dict when it is a function configuration. + + Example: + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet = build_from_cfg(dict(type='Resnet'), MODELS) + >>> # Returns an instantiated object + >>> @MODELS.register_module() + >>> def resnet50(): + >>> pass + >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS) + >>> # Return a result of the calling function + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + default_args (dict, optional): Default initialization arguments. + + Returns: + object: The constructed object. + """ + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if 'type' not in cfg: + if default_args is None or 'type' not in default_args: + raise KeyError( + '`cfg` or `default_args` must contain the key "type", ' + f'but got {cfg}\n{default_args}') + if not isinstance(registry, Registry): + raise TypeError('registry must be an mmcv.Registry object, ' + f'but got {type(registry)}') + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError('default_args must be a dict or None, ' + f'but got {type(default_args)}') + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + obj_type = args.pop('type') + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError( + f'{obj_type} is not in the {registry.name} registry') + elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + try: + return obj_cls(**args) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f'{obj_cls.__name__}: {e}') + +class Registry: + """A registry to map strings to classes or functions. + + Registered object could be built from registry. Meanwhile, registered + functions could be called from registry. + + Example: + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet = MODELS.build(dict(type='ResNet')) + >>> @MODELS.register_module() + >>> def resnet50(): + >>> pass + >>> resnet = MODELS.build(dict(type='resnet50')) + + Please refer to + https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for + advanced usage. + + Args: + name (str): Registry name. + build_func(func, optional): Build function to construct instance from + Registry, func:`build_from_cfg` is used if neither ``parent`` or + ``build_func`` is specified. If ``parent`` is specified and + ``build_func`` is not given, ``build_func`` will be inherited + from ``parent``. Default: None. + parent (Registry, optional): Parent registry. The class registered in + children registry could be built from parent. Default: None. + scope (str, optional): The scope of registry. It is the key to search + for children registry. If not specified, scope will be the name of + the package where class is defined, e.g. mmdet, mmcls, mmseg. + Default: None. + """ + + def __init__(self, name, build_func=None, parent=None, scope=None): + self._name = name + self._module_dict = dict() + self._children = dict() + self._scope = self.infer_scope() if scope is None else scope + + # self.build_func will be set with the following priority: + # 1. build_func + # 2. parent.build_func + # 3. build_from_cfg + if build_func is None: + if parent is not None: + self.build_func = parent.build_func + else: + self.build_func = build_from_cfg + else: + self.build_func = build_func + if parent is not None: + assert isinstance(parent, Registry) + parent._add_children(self) + self.parent = parent + else: + self.parent = None + + def __len__(self): + return len(self._module_dict) + + def __contains__(self, key): + return self.get(key) is not None + + def __repr__(self): + format_str = self.__class__.__name__ + \ + f'(name={self._name}, ' \ + f'items={self._module_dict})' + return format_str + + @staticmethod + def infer_scope(): + """Infer the scope of registry. + + The name of the package where registry is defined will be returned. + + Example: + >>> # in mmdet/models/backbone/resnet.py + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + The scope of ``ResNet`` will be ``mmdet``. + + Returns: + str: The inferred scope name. + """ + # We access the caller using inspect.currentframe() instead of + # inspect.stack() for performance reasons. See details in PR #1844 + frame = inspect.currentframe() + # get the frame where `infer_scope()` is called + infer_scope_caller = frame.f_back.f_back + filename = inspect.getmodule(infer_scope_caller).__name__ + split_filename = filename.split('.') + return split_filename[0] + + @staticmethod + def split_scope_key(key): + """Split scope and key. + + The first scope will be split from key. + + Examples: + >>> Registry.split_scope_key('mmdet.ResNet') + 'mmdet', 'ResNet' + >>> Registry.split_scope_key('ResNet') + None, 'ResNet' + + Return: + tuple[str | None, str]: The former element is the first scope of + the key, which can be ``None``. The latter is the remaining key. + """ + split_index = key.find('.') + if split_index != -1: + return key[:split_index], key[split_index + 1:] + else: + return None, key + + @property + def name(self): + return self._name + + @property + def scope(self): + return self._scope + + @property + def module_dict(self): + return self._module_dict + + @property + def children(self): + return self._children + + def get(self, key): + """Get the registry record. + + Args: + key (str): The class name in string format. + + Returns: + class: The corresponding class. + """ + scope, real_key = self.split_scope_key(key) + if scope is None or scope == self._scope: + # get from self + if real_key in self._module_dict: + return self._module_dict[real_key] + else: + # get from self._children + if scope in self._children: + return self._children[scope].get(real_key) + else: + # goto root + parent = self.parent + while parent.parent is not None: + parent = parent.parent + return parent.get(key) + + def build(self, *args, **kwargs): + return self.build_func(*args, **kwargs, registry=self) + + def _add_children(self, registry): + """Add children for a registry. + + The ``registry`` will be added as children based on its scope. + The parent registry could build objects from children registry. + + Example: + >>> models = Registry('models') + >>> mmdet_models = Registry('models', parent=models) + >>> @mmdet_models.register_module() + >>> class ResNet: + >>> pass + >>> resnet = models.build(dict(type='mmdet.ResNet')) + """ + + assert isinstance(registry, Registry) + assert registry.scope is not None + assert registry.scope not in self.children, \ + f'scope {registry.scope} exists in {self.name} registry' + self.children[registry.scope] = registry + + @deprecated_api_warning(name_dict=dict(module_class='module')) + def _register_module(self, module, module_name=None, force=False): + if not inspect.isclass(module) and not inspect.isfunction(module): + raise TypeError('module must be a class or a function, ' + f'but got {type(module)}') + + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f'{name} is already registered ' + f'in {self.name}') + self._module_dict[name] = module + + def deprecated_register_module(self, cls=None, force=False): + warnings.warn( + 'The old API of register_module(module, force=False) ' + 'is deprecated and will be removed, please use the new API ' + 'register_module(name=None, force=False, module=None) instead.', + DeprecationWarning) + if cls is None: + return partial(self.deprecated_register_module, force=force) + self._register_module(cls, force=force) + return cls + + def register_module(self, name=None, force=False, module=None): + """Register a module. + + A record will be added to `self._module_dict`, whose key is the class + name or the specified name, and value is the class itself. + It can be used as a decorator or a normal function. + + Example: + >>> backbones = Registry('backbone') + >>> @backbones.register_module() + >>> class ResNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> @backbones.register_module(name='mnet') + >>> class MobileNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> class ResNet: + >>> pass + >>> backbones.register_module(ResNet) + + Args: + name (str | None): The module name to be registered. If not + specified, the class name will be used. + force (bool, optional): Whether to override an existing class with + the same name. Default: False. + module (type): Module class or function to be registered. + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + # NOTE: This is a walkaround to be compatible with the old api, + # while it may introduce unexpected bugs. + if isinstance(name, type): + return self.deprecated_register_module(name, force=force) + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or is_seq_of(name, str)): + raise TypeError( + 'name must be either of None, an instance of str or a sequence' + f' of str, but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + self._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + self._register_module(module=module, module_name=name, force=force) + return module + + return _register + diff --git a/data/dataset/lmdb_dataset.py b/data/dataset/lmdb_dataset.py new file mode 100644 index 0000000..0590840 --- /dev/null +++ b/data/dataset/lmdb_dataset.py @@ -0,0 +1,137 @@ +#!/usr/bin/python +# encoding: utf-8 + +import random +import torch +from torch.utils.data import Dataset +from torch.utils.data import sampler +import torchvision.transforms as transforms +import lmdb +import six +import sys +from PIL import Image +import numpy as np + +__all__ = ["lmdbDataset"] + +class lmdbDataset(Dataset): + + def __init__(self, root=None, transform=None, target_transform=None): + self.env = lmdb.open( + root, + max_readers=1, + readonly=True, + lock=False, + readahead=False, + meminit=False) + + if not self.env: + print('cannot creat lmdb from %s' % (root)) + sys.exit(0) + + with self.env.begin(write=False) as txn: + nSamples = int(txn.get('num-samples')) + self.nSamples = nSamples + + self.transform = transform + self.target_transform = target_transform + + def __len__(self): + return self.nSamples + + def __getitem__(self, index): + assert index <= len(self), 'index range error' + index += 1 + with self.env.begin(write=False) as txn: + img_key = 'image-%09d' % index + imgbuf = txn.get(img_key) + + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + try: + img = Image.open(buf).convert('L') + except IOError: + print('Corrupted image for %d' % index) + return self[index + 1] + + if self.transform is not None: + img = self.transform(img) + + label_key = 'label-%09d' % index + label = str(txn.get(label_key)) + + if self.target_transform is not None: + label = self.target_transform(label) + + return (img, label) + + +class resizeNormalize(object): + + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + self.toTensor = transforms.ToTensor() + + def __call__(self, img): + img = img.resize(self.size, self.interpolation) + img = self.toTensor(img) + img.sub_(0.5).div_(0.5) + return img + + +class randomSequentialSampler(sampler.Sampler): + + def __init__(self, data_source, batch_size): + self.num_samples = len(data_source) + self.batch_size = batch_size + + def __iter__(self): + n_batch = len(self) // self.batch_size + tail = len(self) % self.batch_size + index = torch.LongTensor(len(self)).fill_(0) + for i in range(n_batch): + random_start = random.randint(0, len(self) - self.batch_size) + batch_index = random_start + torch.range(0, self.batch_size - 1) + index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index + # deal with tail + if tail: + random_start = random.randint(0, len(self) - self.batch_size) + tail_index = random_start + torch.range(0, tail - 1) + index[(i + 1) * self.batch_size:] = tail_index + + return iter(index) + + def __len__(self): + return self.num_samples + + +class alignCollate(object): + + def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1): + self.imgH = imgH + self.imgW = imgW + self.keep_ratio = keep_ratio + self.min_ratio = min_ratio + + def __call__(self, batch): + images, labels = zip(*batch) + + imgH = self.imgH + imgW = self.imgW + if self.keep_ratio: + ratios = [] + for image in images: + w, h = image.size + ratios.append(w / float(h)) + ratios.sort() + max_ratio = ratios[-1] + imgW = int(np.floor(max_ratio * imgH)) + imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW + + transform = resizeNormalize((imgW, imgH)) + images = [transform(image) for image in images] + images = torch.cat([t.unsqueeze(0) for t in images], 0) + + return images, labels \ No newline at end of file diff --git a/data/dataset/runner/core_utils.py b/data/dataset/runner/core_utils.py new file mode 100644 index 0000000..053b376 --- /dev/null +++ b/data/dataset/runner/core_utils.py @@ -0,0 +1,55 @@ +import torch +import numpy as np +import torch.distributed as dist +from .dist_utils import get_dist_info + +def sync_random_seed(seed=None, device='cuda'): + """Make sure different ranks share the same seed. + + All workers must call this function, otherwise it will deadlock. + This method is generally used in `DistributedSampler`, + because the seed should be identical across all processes + in the distributed group. + + In distributed sampling, different ranks should sample non-overlapped + data in the dataset. Therefore, this function is used to make sure that + each rank shuffles the data indices in the same order based + on the same seed. Then different ranks could use different indices + to select non-overlapped data from the same data list. + + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + + Returns: + int: Seed to be used. + """ + if seed is None: + seed = np.random.randint(2**31) + assert isinstance(seed, int) + + rank, world_size = get_dist_info() + + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() + +def is_mlu_available(): + """Returns a bool indicating if MLU is currently available.""" + return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() + +def get_device(): + """Returns an available device, cpu, cuda or mlu.""" + is_device_available = { + 'cuda': torch.cuda.is_available(), + 'mlu': is_mlu_available() + } + device_list = [k for k, v in is_device_available.items() if v] + return device_list[0] if len(device_list) == 1 else 'cpu' \ No newline at end of file diff --git a/data/dataset/runner/dist_utils.py b/data/dataset/runner/dist_utils.py new file mode 100644 index 0000000..5308c88 --- /dev/null +++ b/data/dataset/runner/dist_utils.py @@ -0,0 +1,12 @@ +from torch import distributed as dist +from typing import Callable, List, Optional, Tuple + +def get_dist_info() -> Tuple[int, int]: + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + diff --git a/data/dataset/runner/version_utils.py b/data/dataset/runner/version_utils.py new file mode 100644 index 0000000..fec8426 --- /dev/null +++ b/data/dataset/runner/version_utils.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import torch +import subprocess +import warnings + +from packaging.version import parse + +TORCH_VERSION = torch.__version__ + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int]: The version info in digits (integers). + """ + assert 'parrots' not in version_str + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) # type: ignore + else: + release.extend([0, 0]) + return tuple(release) \ No newline at end of file diff --git a/data/dataset/samples/__init__.py b/data/dataset/samples/__init__.py new file mode 100644 index 0000000..a4c7ea1 --- /dev/null +++ b/data/dataset/samples/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .class_aware_sampler import ClassAwareSampler +from .distributed_sampler import DistributedSampler +from .group_sampler import DistributedGroupSampler, GroupSampler +from .infinite_sampler import InfiniteBatchSampler, InfiniteGroupBatchSampler + +__all__ = [ + 'DistributedSampler', 'DistributedGroupSampler', 'GroupSampler', + 'InfiniteGroupBatchSampler', 'InfiniteBatchSampler', 'ClassAwareSampler' +] diff --git a/data/dataset/samples/class_aware_sampler.py b/data/dataset/samples/class_aware_sampler.py new file mode 100644 index 0000000..2f97875 --- /dev/null +++ b/data/dataset/samples/class_aware_sampler.py @@ -0,0 +1,175 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from data.dataset.runner.dist_utils import get_dist_info +from torch.utils.data import Sampler +from data.dataset.runner.core_utils import sync_random_seed + + +class ClassAwareSampler(Sampler): + r"""Sampler that restricts data loading to the label of the dataset. + + A class-aware sampling strategy to effectively tackle the + non-uniform class distribution. The length of the training data is + consistent with source data. Simple improvements based on `Relay + Backpropagation for Effective Learning of Deep Convolutional + Neural Networks `_ + + The implementation logic is referred to + https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py + + Args: + dataset: Dataset used for sampling. + samples_per_gpu (int): When model is :obj:`DistributedDataParallel`, + it is the number of training samples on each GPU. + When model is :obj:`DataParallel`, it is + `num_gpus * samples_per_gpu`. + Default : 1. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + seed (int, optional): random seed used to shuffle the sampler if + ``shuffle=True``. This number should be identical across all + processes in the distributed group. Default: 0. + num_sample_class (int): The number of samples taken from each + per-label list. Default: 1 + """ + + def __init__(self, + dataset, + samples_per_gpu=1, + num_replicas=None, + rank=None, + seed=0, + num_sample_class=1): + _rank, _num_replicas = get_dist_info() + if num_replicas is None: + num_replicas = _num_replicas + if rank is None: + rank = _rank + + self.dataset = dataset + self.num_replicas = num_replicas + self.samples_per_gpu = samples_per_gpu + self.rank = rank + self.epoch = 0 + # Must be the same across all workers. If None, will use a + # random seed shared among workers + # (require synchronization among all workers) + self.seed = sync_random_seed(seed) + + # The number of samples taken from each per-label list + assert num_sample_class > 0 and isinstance(num_sample_class, int) + self.num_sample_class = num_sample_class + # Get per-label image list from dataset + assert hasattr(dataset, 'get_cat2imgs'), \ + 'dataset must have `get_cat2imgs` function' + self.cat_dict = dataset.get_cat2imgs() + + self.num_samples = int( + math.ceil( + len(self.dataset) * 1.0 / self.num_replicas / + self.samples_per_gpu)) * self.samples_per_gpu + self.total_size = self.num_samples * self.num_replicas + + # get number of images containing each category + self.num_cat_imgs = [len(x) for x in self.cat_dict.values()] + # filter labels without images + self.valid_cat_inds = [ + i for i, length in enumerate(self.num_cat_imgs) if length != 0 + ] + self.num_classes = len(self.valid_cat_inds) + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch + self.seed) + + # initialize label list + label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g) + # initialize each per-label image list + data_iter_dict = dict() + for i in self.valid_cat_inds: + data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g) + + def gen_cat_img_inds(cls_list, data_dict, num_sample_cls): + """Traverse the categories and extract `num_sample_cls` image + indexes of the corresponding categories one by one.""" + id_indices = [] + for _ in range(len(cls_list)): + cls_idx = next(cls_list) + for _ in range(num_sample_cls): + id = next(data_dict[cls_idx]) + id_indices.append(id) + return id_indices + + # deterministically shuffle based on epoch + num_bins = int( + math.ceil(self.total_size * 1.0 / self.num_classes / + self.num_sample_class)) + indices = [] + for i in range(num_bins): + indices += gen_cat_img_inds(label_iter_list, data_iter_dict, + self.num_sample_class) + + # fix extra samples to make it evenly divisible + if len(indices) >= self.total_size: + indices = indices[:self.total_size] + else: + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset:offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + + +class RandomCycleIter: + """Shuffle the list and do it again after the list have traversed. + + The implementation logic is referred to + https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py + + Example: + >>> label_list = [0, 1, 2, 4, 5] + >>> g = torch.Generator() + >>> g.manual_seed(0) + >>> label_iter_list = RandomCycleIter(label_list, generator=g) + >>> index = next(label_iter_list) + Args: + data (list or ndarray): The data that needs to be shuffled. + generator: An torch.Generator object, which is used in setting the seed + for generating random numbers. + """ # noqa: W605 + + def __init__(self, data, generator=None): + self.data = data + self.length = len(data) + self.index = torch.randperm(self.length, generator=generator).numpy() + self.i = 0 + self.generator = generator + + def __iter__(self): + return self + + def __len__(self): + return len(self.data) + + def __next__(self): + if self.i == self.length: + self.index = torch.randperm( + self.length, generator=self.generator).numpy() + self.i = 0 + idx = self.data[self.index[self.i]] + self.i += 1 + return idx diff --git a/data/dataset/samples/distributed_sampler.py b/data/dataset/samples/distributed_sampler.py new file mode 100644 index 0000000..fb2e39d --- /dev/null +++ b/data/dataset/samples/distributed_sampler.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from torch.utils.data import DistributedSampler as _DistributedSampler + +from data.dataset.runner.core_utils import sync_random_seed, get_device + + +class DistributedSampler(_DistributedSampler): + + def __init__(self, + dataset, + num_replicas=None, + rank=None, + shuffle=True, + seed=0): + super().__init__( + dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. + device = get_device() + self.seed = sync_random_seed(seed, device) + + def __iter__(self): + # deterministically shuffle based on epoch + if self.shuffle: + g = torch.Generator() + # When :attr:`shuffle=True`, this ensures all replicas + # use a different random ordering for each epoch. + # Otherwise, the next iteration of this sampler will + # yield the same ordering. + g.manual_seed(self.epoch + self.seed) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + # in case that indices is shorter than half of total_size + indices = (indices * + math.ceil(self.total_size / len(indices)))[:self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) diff --git a/data/dataset/samples/group_sampler.py b/data/dataset/samples/group_sampler.py new file mode 100644 index 0000000..6e4b1f5 --- /dev/null +++ b/data/dataset/samples/group_sampler.py @@ -0,0 +1,147 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import numpy as np +import torch +from data.dataset.runner.dist_utils import get_dist_info +from torch.utils.data import Sampler + +class GroupSampler(Sampler): + + def __init__(self, dataset, samples_per_gpu=1): + assert hasattr(dataset, 'flag') + self.dataset = dataset + self.samples_per_gpu = samples_per_gpu + self.flag = dataset.flag.astype(np.int64) + self.group_sizes = np.bincount(self.flag) + self.num_samples = 0 + for i, size in enumerate(self.group_sizes): + self.num_samples += int(np.ceil( + size / self.samples_per_gpu)) * self.samples_per_gpu + + def __iter__(self): + indices = [] + for i, size in enumerate(self.group_sizes): + if size == 0: + continue + indice = np.where(self.flag == i)[0] + assert len(indice) == size + np.random.shuffle(indice) + num_extra = int(np.ceil(size / self.samples_per_gpu) + ) * self.samples_per_gpu - len(indice) + indice = np.concatenate( + [indice, np.random.choice(indice, num_extra)]) + indices.append(indice) + indices = np.concatenate(indices) + indices = [ + indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu] + for i in np.random.permutation( + range(len(indices) // self.samples_per_gpu)) + ] + indices = np.concatenate(indices) + indices = indices.astype(np.int64).tolist() + assert len(indices) == self.num_samples + return iter(indices) + + def __len__(self): + return self.num_samples + + +class DistributedGroupSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + seed (int, optional): random seed used to shuffle the sampler if + ``shuffle=True``. This number should be identical across all + processes in the distributed group. Default: 0. + """ + + def __init__(self, + dataset, + samples_per_gpu=1, + num_replicas=None, + rank=None, + seed=0): + _rank, _num_replicas = get_dist_info() + if num_replicas is None: + num_replicas = _num_replicas + if rank is None: + rank = _rank + self.dataset = dataset + self.samples_per_gpu = samples_per_gpu + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.seed = seed if seed is not None else 0 + + assert hasattr(self.dataset, 'flag') + self.flag = self.dataset.flag + self.group_sizes = np.bincount(self.flag) + + self.num_samples = 0 + for i, j in enumerate(self.group_sizes): + self.num_samples += int( + math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu / + self.num_replicas)) * self.samples_per_gpu + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch + self.seed) + + indices = [] + for i, size in enumerate(self.group_sizes): + if size > 0: + indice = np.where(self.flag == i)[0] + assert len(indice) == size + # add .numpy() to avoid bug when selecting indice in parrots. + # TODO: check whether torch.randperm() can be replaced by + # numpy.random.permutation(). + indice = indice[list( + torch.randperm(int(size), generator=g).numpy())].tolist() + extra = int( + math.ceil( + size * 1.0 / self.samples_per_gpu / self.num_replicas) + ) * self.samples_per_gpu * self.num_replicas - len(indice) + # pad indice + tmp = indice.copy() + for _ in range(extra // size): + indice.extend(tmp) + indice.extend(tmp[:extra % size]) + indices.extend(indice) + + assert len(indices) == self.total_size + + indices = [ + indices[j] for i in list( + torch.randperm( + len(indices) // self.samples_per_gpu, generator=g)) + for j in range(i * self.samples_per_gpu, (i + 1) * + self.samples_per_gpu) + ] + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset:offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch \ No newline at end of file diff --git a/data/dataset/samples/infinite_sampler.py b/data/dataset/samples/infinite_sampler.py new file mode 100644 index 0000000..9eeb7ce --- /dev/null +++ b/data/dataset/samples/infinite_sampler.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools + +import numpy as np +import torch +from data.dataset.runner.dist_utils import get_dist_info +from data.dataset.runner.core_utils import sync_random_seed +from torch.utils.data.sampler import Sampler + + + + +class InfiniteGroupBatchSampler(Sampler): + """Similar to `BatchSampler` warping a `GroupSampler. It is designed for + iteration-based runners like `IterBasedRunner` and yields a mini-batch + indices each time, all indices in a batch should be in the same group. + + The implementation logic is referred to + https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py + + Args: + dataset (object): The dataset. + batch_size (int): When model is :obj:`DistributedDataParallel`, + it is the number of training samples on each GPU. + When model is :obj:`DataParallel`, it is + `num_gpus * samples_per_gpu`. + Default : 1. + world_size (int, optional): Number of processes participating in + distributed training. Default: None. + rank (int, optional): Rank of current process. Default: None. + seed (int): Random seed. Default: 0. + shuffle (bool): Whether shuffle the indices of a dummy `epoch`, it + should be noted that `shuffle` can not guarantee that you can + generate sequential indices because it need to ensure + that all indices in a batch is in a group. Default: True. + """ # noqa: W605 + + def __init__(self, + dataset, + batch_size=1, + world_size=None, + rank=None, + seed=0, + shuffle=True): + _rank, _world_size = get_dist_info() + if world_size is None: + world_size = _world_size + if rank is None: + rank = _rank + self.rank = rank + self.world_size = world_size + self.dataset = dataset + self.batch_size = batch_size + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. + self.seed = sync_random_seed(seed) + self.shuffle = shuffle + + assert hasattr(self.dataset, 'flag') + self.flag = self.dataset.flag + self.group_sizes = np.bincount(self.flag) + # buffer used to save indices of each group + self.buffer_per_group = {k: [] for k in range(len(self.group_sizes))} + + self.size = len(dataset) + self.indices = self._indices_of_rank() + + def _infinite_indices(self): + """Infinitely yield a sequence of indices.""" + g = torch.Generator() + g.manual_seed(self.seed) + while True: + if self.shuffle: + yield from torch.randperm(self.size, generator=g).tolist() + + else: + yield from torch.arange(self.size).tolist() + + def _indices_of_rank(self): + """Slice the infinite indices by rank.""" + yield from itertools.islice(self._infinite_indices(), self.rank, None, + self.world_size) + + def __iter__(self): + # once batch size is reached, yield the indices + for idx in self.indices: + flag = self.flag[idx] + group_buffer = self.buffer_per_group[flag] + group_buffer.append(idx) + if len(group_buffer) == self.batch_size: + yield group_buffer[:] + del group_buffer[:] + + def __len__(self): + """Length of base dataset.""" + return self.size + + def set_epoch(self, epoch): + """Not supported in `IterationBased` runner.""" + raise NotImplementedError + + +class InfiniteBatchSampler(Sampler): + """Similar to `BatchSampler` warping a `DistributedSampler. It is designed + iteration-based runners like `IterBasedRunner` and yields a mini-batch + indices each time. + + The implementation logic is referred to + https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py + + Args: + dataset (object): The dataset. + batch_size (int): When model is :obj:`DistributedDataParallel`, + it is the number of training samples on each GPU, + When model is :obj:`DataParallel`, it is + `num_gpus * samples_per_gpu`. + Default : 1. + world_size (int, optional): Number of processes participating in + distributed training. Default: None. + rank (int, optional): Rank of current process. Default: None. + seed (int): Random seed. Default: 0. + shuffle (bool): Whether shuffle the dataset or not. Default: True. + """ # noqa: W605 + + def __init__(self, + dataset, + batch_size=1, + world_size=None, + rank=None, + seed=0, + shuffle=True): + _rank, _world_size = get_dist_info() + if world_size is None: + world_size = _world_size + if rank is None: + rank = _rank + self.rank = rank + self.world_size = world_size + self.dataset = dataset + self.batch_size = batch_size + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. + self.seed = sync_random_seed(seed) + self.shuffle = shuffle + self.size = len(dataset) + self.indices = self._indices_of_rank() + + def _infinite_indices(self): + """Infinitely yield a sequence of indices.""" + g = torch.Generator() + g.manual_seed(self.seed) + while True: + if self.shuffle: + yield from torch.randperm(self.size, generator=g).tolist() + + else: + yield from torch.arange(self.size).tolist() + + def _indices_of_rank(self): + """Slice the infinite indices by rank.""" + yield from itertools.islice(self._infinite_indices(), self.rank, None, + self.world_size) + + def __iter__(self): + # once batch size is reached, yield the indices + batch_buffer = [] + for idx in self.indices: + batch_buffer.append(idx) + if len(batch_buffer) == self.batch_size: + yield batch_buffer + batch_buffer = [] + + def __len__(self): + """Length of base dataset.""" + return self.size + + def set_epoch(self, epoch): + """Not supported in `IterationBased` runner.""" + raise NotImplementedError diff --git a/data/dataset/uniform_concat_dataset.py b/data/dataset/uniform_concat_dataset.py new file mode 100644 index 0000000..0e41e7b --- /dev/null +++ b/data/dataset/uniform_concat_dataset.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from collections import defaultdict + +import numpy as np +from data.dataset.utils import print_log +from data.dataset.dataset_wrapper import ConcatDataset + +from mmocr.utils import is_2dlist, is_type_list + + +class UniformConcatDataset(ConcatDataset): + """A wrapper of ConcatDataset which support dataset pipeline assignment and + replacement. + + Args: + datasets (list[dict] | list[list[dict]]): A list of datasets cfgs. + separate_eval (bool): Whether to evaluate the results + separately if it is used as validation dataset. + Defaults to True. + show_mean_scores (str | bool): Whether to compute the mean evaluation + results, only applicable when ``separate_eval=True``. Options are + [True, False, ``auto``]. If ``True``, mean results will be added to + the result dictionary with keys in the form of + ``mean_{metric_name}``. If 'auto', mean results will be shown only + when more than 1 dataset is wrapped. + pipeline (None | list[dict] | list[list[dict]]): If ``None``, + each dataset in datasets use its own pipeline; + If ``list[dict]``, it will be assigned to the dataset whose + pipeline is None in datasets; + If ``list[list[dict]]``, pipeline of dataset which is None + in datasets will be replaced by the corresponding pipeline + in the list. + force_apply (bool): If True, apply pipeline above to each dataset + even if it have its own pipeline. Default: False. + """ + + def __init__(self, + datasets, + separate_eval=True, + show_mean_scores='auto', + pipeline=None, + force_apply=False, + **kwargs): + new_datasets = [] + if pipeline is not None: + assert isinstance( + pipeline, + list), 'pipeline must be list[dict] or list[list[dict]].' + if is_type_list(pipeline, dict): + self._apply_pipeline(datasets, pipeline, force_apply) + new_datasets = datasets + elif is_2dlist(pipeline): + assert is_2dlist(datasets) + assert len(datasets) == len(pipeline) + for sub_datasets, tmp_pipeline in zip(datasets, pipeline): + self._apply_pipeline(sub_datasets, tmp_pipeline, + force_apply) + new_datasets.extend(sub_datasets) + else: + if is_2dlist(datasets): + for sub_datasets in datasets: + new_datasets.extend(sub_datasets) + else: + new_datasets = datasets + datasets = [build_dataset(c, kwargs) for c in new_datasets] + super().__init__(datasets, separate_eval) + + if not separate_eval: + raise NotImplementedError( + 'Evaluating datasets as a whole is not' + ' supported yet. Please use "separate_eval=True"') + + assert isinstance(show_mean_scores, bool) or show_mean_scores == 'auto' + if show_mean_scores == 'auto': + show_mean_scores = len(self.datasets) > 1 + self.show_mean_scores = show_mean_scores + if show_mean_scores is True or show_mean_scores == 'auto' and len( + self.datasets) > 1: + if len(set([type(ds) for ds in self.datasets])) != 1: + raise NotImplementedError( + 'To compute mean evaluation scores, all datasets' + 'must have the same type') + + @staticmethod + def _apply_pipeline(datasets, pipeline, force_apply=False): + from_cfg = all(isinstance(x, dict) for x in datasets) + assert from_cfg, 'datasets should be config dicts' + assert all(isinstance(x, dict) for x in pipeline) + for dataset in datasets: + if dataset['pipeline'] is None or force_apply: + dataset['pipeline'] = copy.deepcopy(pipeline) + + def evaluate(self, results, logger=None, **kwargs): + """Evaluate the results. + + Args: + results (list[list | tuple]): Testing results of the dataset. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + + Returns: + dict[str: float]: Results of each separate + dataset if `self.separate_eval=True`. + """ + assert len(results) == self.cumulative_sizes[-1], \ + ('Dataset and results have different sizes: ' + f'{self.cumulative_sizes[-1]} v.s. {len(results)}') + + # Check whether all the datasets support evaluation + for dataset in self.datasets: + assert hasattr(dataset, 'evaluate'), \ + f'{type(dataset)} does not implement evaluate function' + + if self.separate_eval: + dataset_idx = -1 + + total_eval_results = dict() + + if self.show_mean_scores: + mean_eval_results = defaultdict(list) + + for dataset in self.datasets: + start_idx = 0 if dataset_idx == -1 else \ + self.cumulative_sizes[dataset_idx] + end_idx = self.cumulative_sizes[dataset_idx + 1] + + results_per_dataset = results[start_idx:end_idx] + print_log( + f'\nEvaluating {dataset.ann_file} with ' + f'{len(results_per_dataset)} images now', + logger=logger) + + eval_results_per_dataset = dataset.evaluate( + results_per_dataset, logger=logger, **kwargs) + dataset_idx += 1 + for k, v in eval_results_per_dataset.items(): + total_eval_results.update({f'{dataset_idx}_{k}': v}) + if self.show_mean_scores: + mean_eval_results[k].append(v) + + if self.show_mean_scores: + for k, v in mean_eval_results.items(): + total_eval_results[f'mean_{k}'] = np.mean(v) + + return total_eval_results + else: + raise NotImplementedError( + 'Evaluating datasets as a whole is not' + ' supported yet. Please use "separate_eval=True"') diff --git a/data/dataset/utils.py b/data/dataset/utils.py new file mode 100644 index 0000000..56221d0 --- /dev/null +++ b/data/dataset/utils.py @@ -0,0 +1,110 @@ +import logging + +import torch.distributed as dist + +logger_initialized: dict = {} + + +def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): + """Initialize and get a logger by name. + + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified and the process rank is 0, a FileHandler + will also be added. + + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. Note that only the process of + rank 0 is affected, and other processes will set the level to + "Error" thus be silent most of the time. + file_mode (str): The file mode used in opening log file. + Defaults to 'w'. + + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + # handle hierarchical names + # e.g., logger "a" is initialized, then logger "a.b" will skip the + # initialization since it is a child of "a". + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + # handle duplicate logs to the console + # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) + # to the root logger. As logger.propagate is True by default, this root + # level handler causes logging messages from rank>0 processes to + # unexpectedly show up on the console, creating much unwanted clutter. + # To fix this issue, we set the root logger's StreamHandler, if any, to log + # at the ERROR level. + for handler in logger.root.handlers: + if type(handler) is logging.StreamHandler: + handler.setLevel(logging.ERROR) + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + # only rank 0 will add a FileHandler + if rank == 0 and log_file is not None: + # Here, the default behaviour of the official logger is 'a'. Thus, we + # provide an interface to change the file mode to the default + # behaviour. + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + if rank == 0: + logger.setLevel(log_level) + else: + logger.setLevel(logging.ERROR) + + logger_initialized[name] = True + + return logger + + +def print_log(msg, logger=None, level=logging.INFO): + """Print a log message. + + Args: + msg (str): The message to be logged. + logger (logging.Logger | str | None): The logger to be used. + Some special loggers are: + + - "silent": no message will be printed. + - other str: the logger obtained with `get_root_logger(logger)`. + - None: The `print()` method will be used to print log messages. + level (int): Logging level. Only available when `logger` is a Logger + object or "root". + """ + if logger is None: + print(msg) + elif isinstance(logger, logging.Logger): + logger.log(level, msg) + elif logger == 'silent': + pass + elif isinstance(logger, str): + _logger = get_logger(logger) + _logger.log(level, msg) + else: + raise TypeError( + 'logger should be either a logging.Logger object, str, ' + f'"silent" or None, but got {type(logger)}') diff --git a/modeling/backbone/very_deep_vgg.py b/modeling/backbone/very_deep_vgg.py new file mode 100644 index 0000000..0cf773b --- /dev/null +++ b/modeling/backbone/very_deep_vgg.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch.nn as nn +from mmcv.runner import BaseModule, Sequential +from .build import BACKBONE_REGISTRY, get_norm + + + +class VeryDeepVgg(BaseModule): + """Implement VGG-VeryDeep backbone for text recognition, modified from + `VGG-VeryDeep `_ + + Args: + leaky_relu (bool): Use leakyRelu or not. + input_channels (int): Number of channels of input image tensor. + """ + + def __init__(self, + leaky_relu=True, + input_channels=3, + init_cfg=[ + dict(type='Xavier', layer='Conv2d'), + dict(type='Uniform', layer='BatchNorm2d') + ]): + super().__init__(init_cfg=init_cfg) + + ks = [3, 3, 3, 3, 3, 3, 2] + ps = [1, 1, 1, 1, 1, 1, 0] + ss = [1, 1, 1, 1, 1, 1, 1] + nm = [64, 128, 256, 256, 512, 512, 512] + + self.channels = nm + + # cnn = nn.Sequential() + cnn = Sequential() + + def conv_relu(i, batch_normalization=False): + n_in = input_channels if i == 0 else nm[i - 1] + n_out = nm[i] + cnn.add_module('conv{0}'.format(i), + nn.Conv2d(n_in, n_out, ks[i], ss[i], ps[i])) + if batch_normalization: + cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(n_out)) + if leaky_relu: + cnn.add_module('relu{0}'.format(i), + nn.LeakyReLU(0.2, inplace=True)) + else: + cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) + + conv_relu(0) + cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 + conv_relu(1) + cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 + conv_relu(2, True) + conv_relu(3) + cnn.add_module('pooling{0}'.format(2), + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 + conv_relu(4, True) + conv_relu(5) + cnn.add_module('pooling{0}'.format(3), + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 + conv_relu(6, True) # 512x1x16 + + self.cnn = cnn + + def out_channels(self): + return self.channels[-1] + + def forward(self, x): + """ + Args: + x (Tensor): Images of shape :math:`(N, C, H, W)`. + + Returns: + Tensor: The feature Tensor of shape :math:`(N, 512, H/32, (W/4+1)`. + """ + output = self.cnn(x) + + return output + + +@BACKBONE_REGISTRY.register() +def build_very_deep_vgg(cfg): + leaky_relu = cfg.MODEL.BACKBONE_PARAMS.LEAKY_RELU + input_channels = cfg.MODEL.BACKBONE_PARAMS.INPUT_CHANNELS + model = VeryDeepVgg(leaky_relu, input_channels) + + pretrain = cfg.MODEL.BACKBONE.PRETRAIN + + model.init_weights(num_layers, pretrain) + return model \ No newline at end of file diff --git a/modeling/decoders/crnn_decode.py b/modeling/decoders/crnn_decode.py new file mode 100644 index 0000000..e085686 --- /dev/null +++ b/modeling/decoders/crnn_decode.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.runner import Sequential +from mmcv.runner import BaseModule + +from mmocr.models.builder import DECODERS + +class BidirectionalLSTM(nn.Module): + + def __init__(self, nIn, nHidden, nOut): + super().__init__() + + self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) + self.embedding = nn.Linear(nHidden * 2, nOut) + + def forward(self, input): + recurrent, _ = self.rnn(input) + T, b, h = recurrent.size() + t_rec = recurrent.view(T * b, h) + + output = self.embedding(t_rec) # [T * b, nOut] + output = output.view(T, b, -1) + + return output + + +class CRNNDecoder(BaseModule): + """Decoder for CRNN. + + Args: + in_channels (int): Number of input channels. + num_classes (int): Number of output classes. + rnn_flag (bool): Use RNN or CNN as the decoder. + init_cfg (dict or list[dict], optional): Initialization configs. + """ + + def __init__(self, + in_channels=None, + num_classes=None, + rnn_flag=False, + init_cfg=dict(type='Xavier', layer='Conv2d'), + **kwargs): + super().__init__(init_cfg=init_cfg) + + self.num_classes = num_classes + self.rnn_flag = rnn_flag + + if rnn_flag: + self.decoder = Sequential( + BidirectionalLSTM(in_channels, 256, 256), + BidirectionalLSTM(256, 256, num_classes)) + else: + self.decoder = nn.Conv2d( + in_channels, num_classes, kernel_size=1, stride=1) + + + def forward(self, + feat, + out_enc, + targets_dict=None, + img_metas=None, + train_mode=True): + self.train_mode = train_mode + if train_mode: + return self.forward_train(feat, out_enc, targets_dict, img_metas) + + return self.forward_test(feat, out_enc, img_metas) + + def forward_train(self, feat, out_enc, targets_dict, img_metas): + """ + Args: + feat (Tensor): A Tensor of shape :math:`(N, H, 1, W)`. + + Returns: + Tensor: The raw logit tensor. Shape :math:`(N, W, C)` where + :math:`C` is ``num_classes``. + """ + assert feat.size(2) == 1, 'feature height must be 1' + if self.rnn_flag: + x = feat.squeeze(2) # [N, C, W] + x = x.permute(2, 0, 1) # [W, N, C] + x = self.decoder(x) # [W, N, C] + outputs = x.permute(1, 0, 2).contiguous() + else: + x = self.decoder(feat) + x = x.permute(0, 3, 1, 2).contiguous() + n, w, c, h = x.size() + outputs = x.view(n, w, c * h) + return outputs + + def forward_test(self, feat, out_enc, img_metas): + """ + Args: + feat (Tensor): A Tensor of shape :math:`(N, H, 1, W)`. + + Returns: + Tensor: The raw logit tensor. Shape :math:`(N, W, C)` where + :math:`C` is ``num_classes``. + """ + return self.forward_train(feat, out_enc, None, img_metas) + + + diff --git a/modeling/losses/ctc_loss.py b/modeling/losses/ctc_loss.py new file mode 100644 index 0000000..24c6390 --- /dev/null +++ b/modeling/losses/ctc_loss.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn + +from mmocr.models.builder import LOSSES + + +@LOSSES.register_module() +class CTCLoss(nn.Module): + """Implementation of loss module for CTC-loss based text recognition. + + Args: + flatten (bool): If True, use flattened targets, else padded targets. + blank (int): Blank label. Default 0. + reduction (str): Specifies the reduction to apply to the output, + should be one of the following: ('none', 'mean', 'sum'). + zero_infinity (bool): Whether to zero infinite losses and + the associated gradients. Default: False. + Infinite losses mainly occur when the inputs + are too short to be aligned to the targets. + """ + + def __init__(self, + flatten=True, + blank=0, + reduction='mean', + zero_infinity=False, + **kwargs): + super().__init__() + assert isinstance(flatten, bool) + assert isinstance(blank, int) + assert isinstance(reduction, str) + assert isinstance(zero_infinity, bool) + + self.flatten = flatten + self.blank = blank + self.ctc_loss = nn.CTCLoss( + blank=blank, reduction=reduction, zero_infinity=zero_infinity) + + def forward(self, outputs, targets_dict, img_metas=None): + """ + Args: + outputs (Tensor): A raw logit tensor of shape :math:`(N, T, C)`. + targets_dict (dict): A dict with 3 keys ``target_lengths``, + ``flatten_targets`` and ``targets``. + + - | ``target_lengths`` (Tensor): A tensor of shape :math:`(N)`. + Each item is the length of a word. + + - | ``flatten_targets`` (Tensor): Used if ``self.flatten=True`` + (default). A tensor of shape + (sum(targets_dict['target_lengths'])). Each item is the + index of a character. + + - | ``targets`` (Tensor): Used if ``self.flatten=False``. A + tensor of :math:`(N, T)`. Empty slots are padded with + ``self.blank``. + + img_metas (dict): A dict that contains meta information of input + images. Preferably with the key ``valid_ratio``. + + Returns: + dict: The loss dict with key ``loss_ctc``. + """ + valid_ratios = None + if img_metas is not None: + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] + + outputs = torch.log_softmax(outputs, dim=2) + bsz, seq_len = outputs.size(0), outputs.size(1) + outputs_for_loss = outputs.permute(1, 0, 2).contiguous() # T * N * C + + if self.flatten: + targets = targets_dict['flatten_targets'] + else: + targets = torch.full( + size=(bsz, seq_len), fill_value=self.blank, dtype=torch.long) + for idx, tensor in enumerate(targets_dict['targets']): + valid_len = min(tensor.size(0), seq_len) + targets[idx, :valid_len] = tensor[:valid_len] + + target_lengths = targets_dict['target_lengths'] + target_lengths = torch.clamp(target_lengths, min=1, max=seq_len).long() + + input_lengths = torch.full( + size=(bsz, ), fill_value=seq_len, dtype=torch.long) + if not self.flatten and valid_ratios is not None: + input_lengths = [ + math.ceil(valid_ratio * seq_len) + for valid_ratio in valid_ratios + ] + input_lengths = torch.Tensor(input_lengths).long() + + loss_ctc = self.ctc_loss(outputs_for_loss, targets, input_lengths, + target_lengths) + + losses = dict(loss_ctc=loss_ctc) + + return losses diff --git a/modeling/recognizers/__init__.py b/modeling/recognizers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modeling/recognizers/crnn.py b/modeling/recognizers/crnn.py new file mode 100644 index 0000000..a8d20a6 --- /dev/null +++ b/modeling/recognizers/crnn.py @@ -0,0 +1,100 @@ +import math +import numpy as np +import torch +import torch.nn as nn +from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY +from detectron2.structures import Boxes, ImageList, Instances + + +from ..backbone import build_backbone +from ..decoders import crnn_decode +# from ..losses import ctc_loss +from str_label_converter import strLabelConverter +from warpctc_pytorch import CTCLoss + +__all__ = ["CRNNet"] + + +@META_ARCH_REGISTRY.register() +class CRNNet(nn.Module): + """ + Implement CRNNet + """ + + def __init__(self, cfg): + super().__init__() + + self.device = torch.device(cfg.MODEL.DEVICE) + self.cfg = cfg + + self.backbone = build_backbone(cfg) + + crnn_in_channels = cfg.MODEL.CRNN.IN_CHANNELS + self.alphabet = cfg.MODEL.ALPHABET + self.num_classes = len(self.alphabet) + 1 + self.crnn_decode = crnn_decode.CRNNDecoder(crnn_in_channels, self.num_classes) + self.loss_func = CTCLoss() + + self.converter = strLabelConverter(self.alphabet) + self.to(self.device) + + def forward(self, batched_inputs): + image, text, length = self.preprocess_image(batched_inputs) + + if not self.training: + # return self.inference(images) + return self.inference(image) + + # image_shape = images.tensor.shape[-2:] + + features = self.backbone(image.tensor) + + # features = features[self.cfg.MODEL.RESNETS.OUT_FEATURES[0]] + preds = self.crnn_decode(features) + batch_size = self.cfg.SOLVER.IMS_PER_BATCH + preds_size = torch.IntTensor([preds.size(0)] * batch_size) + + loss = {} + + loss_ctc = self.loss_func(preds, text, preds_size, length) / batch_size + + gt_loss = {"loss_ctc": loss_ctc} + loss = {**loss, **gt_loss} + return loss + + @torch.no_grad() + def inference(self, image): + features = self.backbone(image.tensor) + preds = self.crnn_decode(features) + batch_size = self.cfg.SOLVER.IMS_PER_BATCH + preds_size = torch.IntTensor([preds.size(0)] * batch_size) + _, preds = preds.max(2) + preds = preds.squeeze(2) + preds = preds.transpose(1, 0).contiguous().view(-1) + sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False) + raw_preds = self.converter.decode(preds.data, preds_size.data, raw=True)[:self.cfg.TEST.N_TEST_DISP] + + return sim_preds, raw_preds + + def preprocess_image(self, batched_inputs): + """ + Normalize, pad and batch the input images. + """ + batch_size = self.cfg.SOLVER.IMS_PER_BATCH + + image = torch.FloatTensor(batch_size, 3, self.cfg.INPUT.IMG_W, self.cfg.INPUT.IMG_H) + image = image.to(self.device) + text = torch.IntTensor(batch_size * 5) + length = torch.IntTensor(batch_size) + cpu_images, cpu_texts = batched_inputs + batch_size = cpu_images.size(0) + self.loadData(image, cpu_images) + t, l = self.converter.encode(cpu_texts) + + self.loadData(text, t) + self.loadData(length, l) + + return image, text, length + + def loadData(self, v, data): + v.data.resize_(data.size()).copy_(data) \ No newline at end of file diff --git a/modeling/recognizers/str_label_converter.py b/modeling/recognizers/str_label_converter.py new file mode 100644 index 0000000..ef36bf5 --- /dev/null +++ b/modeling/recognizers/str_label_converter.py @@ -0,0 +1,42 @@ +import torch +import collections + +class strLabelConverter(object): + """Convert between str and label. + NOTE: + Insert `blank` to the alphabet for CTC. + Args: + alphabet (str): set of the possible characters. + ignore_case (bool, default=True): whether or not to ignore all of the case. + """ + + def __init__(self, alphabet, ignore_case=True): + self._ignore_case = ignore_case + if self._ignore_case: + alphabet = alphabet.lower() + self.alphabet = alphabet + '-' # for `-1` index + + self.dict = {} + for i, char in enumerate(alphabet): + # NOTE: 0 is reserved for 'blank' required by wrap_ctc + self.dict[char] = i + 1 + + def encode(self, text): + """Support batch or single str. + Args: + text (str or list of str): texts to convert. + Returns: + torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. + torch.IntTensor [n]: length of each text. + """ + if isinstance(text, str): + text = [ + self.dict[char.lower() if self._ignore_case else char] + for char in text + ] + length = [len(text)] + elif isinstance(text, collections.Iterable): + length = [len(s) for s in text] + text = ''.join(text) + text, _ = self.encode(text) + return (torch.IntTensor(text), torch.IntTensor(length)) \ No newline at end of file diff --git a/yamls/text_recognizer/crnn_text_recognizer_toy.yaml b/yamls/text_recognizer/crnn_text_recognizer_toy.yaml new file mode 100644 index 0000000..b8954e1 --- /dev/null +++ b/yamls/text_recognizer/crnn_text_recognizer_toy.yaml @@ -0,0 +1,20 @@ +MODEL: + BACKBONE: "build_very_deep_vgg" + BACKBONE_PARAMS: + INPUT_CHANNELS: 3 + LEAKY_RELU: False + DEVICE: "cuda" + CRNN: + IN_CHANNELS: 512 + ALPHABET: "0123456789abcdefghijklmnopqrstuvwxyz" + + +TEST: + N_TEST_DISP: 10 + +SOLVER: + IMS_PER_BATCH: 64 + +INPUT: + IMG_W: 100 + IMG_H: 32 \ No newline at end of file